diff --git a/scripts/t2v_model.py b/scripts/t2v_model.py index 200d9b8..b0d6330 100644 --- a/scripts/t2v_model.py +++ b/scripts/t2v_model.py @@ -1474,7 +1474,8 @@ class GaussianDiffusion(object): condition_fn=None, guide_scale=None, ddim_timesteps=20, - eta=0.0): + eta=0.0, + unet_lowvram=False): # prepare input b = noise.size(0) xt = noise diff --git a/scripts/t2v_pipeline.py b/scripts/t2v_pipeline.py index f058c78..bccaafc 100644 --- a/scripts/t2v_pipeline.py +++ b/scripts/t2v_pipeline.py @@ -102,7 +102,7 @@ class TextToVideoSynthesis(): temporal_attention=cfg['temporal_attention']) self.sd_model.load_state_dict( torch.load( - osp.join(self.model_dir, self.config.model["model_args"]["ckpt_unet"])), + osp.join(self.model_dir, self.config.model["model_args"]["ckpt_unet"]), map_location='cpu'), strict=True) self.sd_model.eval() self.sd_model.half() @@ -150,7 +150,7 @@ class TextToVideoSynthesis(): self.clip_encoder.to("cpu") #@torch.compile() - def infer(self, prompt, n_prompt, steps, frames, scale, width=256, height=256, eta=0.0, cpu_vae='GPU (half precision)', device = torch.device('cpu'), latents=None): + def infer(self, prompt, n_prompt, steps, frames, scale, width=256, height=256, eta=0.0, cpu_vae='GPU (half precision)', device = torch.device('cpu'), latents=None, unet_lowvram=True): r""" The entry function of text to image synthesis task. 1. Using diffusion model to generate the video's latent representation. @@ -165,7 +165,9 @@ class TextToVideoSynthesis(): #print(self.sd_model.use_fps_condition) self.sd_model.use_fps_condition = False self.device = device + print('CLIP-encoding') self.clip_encoder.to(self.device) + self.clip_encoder.half() y, zero_y = self.preprocess(prompt, n_prompt) self.clip_encoder.to("cpu") #self.clip_encoder = None @@ -179,15 +181,19 @@ class TextToVideoSynthesis(): num_sample = 1 max_frames = frames latent_h, latent_w = height // 8, width // 8 - self.sd_model.to(self.device) + print('PREPARING LATENTS') + #self.sd_model.to(self.device) if latents == None: latents = torch.randn(num_sample, 4, max_frames, latent_h, latent_w).to( self.device) else: latents.to(self.device) - with amp.autocast(enabled=True): + latents.half() + print('STARTING THE DIFFUSION MODEL') + with amp.autocast(enabled=True, dtype=torch.float16): self.sd_model.to(self.device) + print('SD MODEL ON DEVICE') x0 = self.diffusion.ddim_sample_loop( noise=latents, # shape: b c f h w model=self.sd_model, @@ -200,7 +206,8 @@ class TextToVideoSynthesis(): }], guide_scale=scale, ddim_timesteps=steps, - eta=eta) + eta=eta, + unet_lowvram=unet_lowvram) self.last_tensor = x0 self.last_tensor.cpu() self.sd_model.to("cpu") @@ -291,7 +298,7 @@ class TextToVideoSynthesis(): def cleanup(self): pass def preprocess(self, prompt, n_prompt, offload=True): - self.clip_encoder.to(self.device) + #self.clip_encoder.to(self.device) text_emb = self.clip_encoder(prompt) text_emb_zero = self.clip_encoder(n_prompt) if offload: