various (not so successful) attempts to lower vram
parent
557a017ec5
commit
765a76a3aa
|
|
@ -1474,7 +1474,8 @@ class GaussianDiffusion(object):
|
|||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20,
|
||||
eta=0.0):
|
||||
eta=0.0,
|
||||
unet_lowvram=False):
|
||||
# prepare input
|
||||
b = noise.size(0)
|
||||
xt = noise
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ class TextToVideoSynthesis():
|
|||
temporal_attention=cfg['temporal_attention'])
|
||||
self.sd_model.load_state_dict(
|
||||
torch.load(
|
||||
osp.join(self.model_dir, self.config.model["model_args"]["ckpt_unet"])),
|
||||
osp.join(self.model_dir, self.config.model["model_args"]["ckpt_unet"]), map_location='cpu'),
|
||||
strict=True)
|
||||
self.sd_model.eval()
|
||||
self.sd_model.half()
|
||||
|
|
@ -150,7 +150,7 @@ class TextToVideoSynthesis():
|
|||
self.clip_encoder.to("cpu")
|
||||
|
||||
#@torch.compile()
|
||||
def infer(self, prompt, n_prompt, steps, frames, scale, width=256, height=256, eta=0.0, cpu_vae='GPU (half precision)', device = torch.device('cpu'), latents=None):
|
||||
def infer(self, prompt, n_prompt, steps, frames, scale, width=256, height=256, eta=0.0, cpu_vae='GPU (half precision)', device = torch.device('cpu'), latents=None, unet_lowvram=True):
|
||||
r"""
|
||||
The entry function of text to image synthesis task.
|
||||
1. Using diffusion model to generate the video's latent representation.
|
||||
|
|
@ -165,7 +165,9 @@ class TextToVideoSynthesis():
|
|||
#print(self.sd_model.use_fps_condition)
|
||||
self.sd_model.use_fps_condition = False
|
||||
self.device = device
|
||||
print('CLIP-encoding')
|
||||
self.clip_encoder.to(self.device)
|
||||
self.clip_encoder.half()
|
||||
y, zero_y = self.preprocess(prompt, n_prompt)
|
||||
self.clip_encoder.to("cpu")
|
||||
#self.clip_encoder = None
|
||||
|
|
@ -179,15 +181,19 @@ class TextToVideoSynthesis():
|
|||
num_sample = 1
|
||||
max_frames = frames
|
||||
latent_h, latent_w = height // 8, width // 8
|
||||
self.sd_model.to(self.device)
|
||||
print('PREPARING LATENTS')
|
||||
#self.sd_model.to(self.device)
|
||||
if latents == None:
|
||||
latents = torch.randn(num_sample, 4, max_frames, latent_h,
|
||||
latent_w).to(
|
||||
self.device)
|
||||
else:
|
||||
latents.to(self.device)
|
||||
with amp.autocast(enabled=True):
|
||||
latents.half()
|
||||
print('STARTING THE DIFFUSION MODEL')
|
||||
with amp.autocast(enabled=True, dtype=torch.float16):
|
||||
self.sd_model.to(self.device)
|
||||
print('SD MODEL ON DEVICE')
|
||||
x0 = self.diffusion.ddim_sample_loop(
|
||||
noise=latents, # shape: b c f h w
|
||||
model=self.sd_model,
|
||||
|
|
@ -200,7 +206,8 @@ class TextToVideoSynthesis():
|
|||
}],
|
||||
guide_scale=scale,
|
||||
ddim_timesteps=steps,
|
||||
eta=eta)
|
||||
eta=eta,
|
||||
unet_lowvram=unet_lowvram)
|
||||
self.last_tensor = x0
|
||||
self.last_tensor.cpu()
|
||||
self.sd_model.to("cpu")
|
||||
|
|
@ -291,7 +298,7 @@ class TextToVideoSynthesis():
|
|||
def cleanup(self):
|
||||
pass
|
||||
def preprocess(self, prompt, n_prompt, offload=True):
|
||||
self.clip_encoder.to(self.device)
|
||||
#self.clip_encoder.to(self.device)
|
||||
text_emb = self.clip_encoder(prompt)
|
||||
text_emb_zero = self.clip_encoder(n_prompt)
|
||||
if offload:
|
||||
|
|
|
|||
Loading…
Reference in New Issue