various (not so successful) attempts to lower vram

lower-vram
kabachuha 2023-03-22 13:30:37 +03:00
parent 557a017ec5
commit 765a76a3aa
2 changed files with 15 additions and 7 deletions

View File

@ -1474,7 +1474,8 @@ class GaussianDiffusion(object):
condition_fn=None,
guide_scale=None,
ddim_timesteps=20,
eta=0.0):
eta=0.0,
unet_lowvram=False):
# prepare input
b = noise.size(0)
xt = noise

View File

@ -102,7 +102,7 @@ class TextToVideoSynthesis():
temporal_attention=cfg['temporal_attention'])
self.sd_model.load_state_dict(
torch.load(
osp.join(self.model_dir, self.config.model["model_args"]["ckpt_unet"])),
osp.join(self.model_dir, self.config.model["model_args"]["ckpt_unet"]), map_location='cpu'),
strict=True)
self.sd_model.eval()
self.sd_model.half()
@ -150,7 +150,7 @@ class TextToVideoSynthesis():
self.clip_encoder.to("cpu")
#@torch.compile()
def infer(self, prompt, n_prompt, steps, frames, scale, width=256, height=256, eta=0.0, cpu_vae='GPU (half precision)', device = torch.device('cpu'), latents=None):
def infer(self, prompt, n_prompt, steps, frames, scale, width=256, height=256, eta=0.0, cpu_vae='GPU (half precision)', device = torch.device('cpu'), latents=None, unet_lowvram=True):
r"""
The entry function of text to image synthesis task.
1. Using diffusion model to generate the video's latent representation.
@ -165,7 +165,9 @@ class TextToVideoSynthesis():
#print(self.sd_model.use_fps_condition)
self.sd_model.use_fps_condition = False
self.device = device
print('CLIP-encoding')
self.clip_encoder.to(self.device)
self.clip_encoder.half()
y, zero_y = self.preprocess(prompt, n_prompt)
self.clip_encoder.to("cpu")
#self.clip_encoder = None
@ -179,15 +181,19 @@ class TextToVideoSynthesis():
num_sample = 1
max_frames = frames
latent_h, latent_w = height // 8, width // 8
self.sd_model.to(self.device)
print('PREPARING LATENTS')
#self.sd_model.to(self.device)
if latents == None:
latents = torch.randn(num_sample, 4, max_frames, latent_h,
latent_w).to(
self.device)
else:
latents.to(self.device)
with amp.autocast(enabled=True):
latents.half()
print('STARTING THE DIFFUSION MODEL')
with amp.autocast(enabled=True, dtype=torch.float16):
self.sd_model.to(self.device)
print('SD MODEL ON DEVICE')
x0 = self.diffusion.ddim_sample_loop(
noise=latents, # shape: b c f h w
model=self.sd_model,
@ -200,7 +206,8 @@ class TextToVideoSynthesis():
}],
guide_scale=scale,
ddim_timesteps=steps,
eta=eta)
eta=eta,
unet_lowvram=unet_lowvram)
self.last_tensor = x0
self.last_tensor.cpu()
self.sd_model.to("cpu")
@ -291,7 +298,7 @@ class TextToVideoSynthesis():
def cleanup(self):
pass
def preprocess(self, prompt, n_prompt, offload=True):
self.clip_encoder.to(self.device)
#self.clip_encoder.to(self.device)
text_emb = self.clip_encoder(prompt)
text_emb_zero = self.clip_encoder(n_prompt)
if offload: