diff --git a/modules/framepack/framepack_vae.py b/modules/framepack/framepack_vae.py index d9bc5c3b9..05b2b7154 100644 --- a/modules/framepack/framepack_vae.py +++ b/modules/framepack/framepack_vae.py @@ -71,7 +71,7 @@ def vae_decode_remote(latents): def vae_decode_full(latents): with devices.inference_context(): vae = shared.sd_model.vae - latents = (latents / vae.config.scaling_factor).to(device=vae.device, dtype=vae.dtype) + latents = (latents / vae.config.scaling_factor).to(device=devices.device, dtype=devices.dtype) images = vae.decode(latents).sample return images @@ -91,6 +91,6 @@ def vae_decode(latents, vae_type): def vae_encode(image): with devices.inference_context(): vae = shared.sd_model.vae - latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample() + latents = vae.encode(image.to(device=devices.device, dtype=devices.dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor return latents diff --git a/modules/framepack/framepack_worker.py b/modules/framepack/framepack_worker.py index 8ce562b8e..ab29f219e 100644 --- a/modules/framepack/framepack_worker.py +++ b/modules/framepack/framepack_worker.py @@ -91,8 +91,6 @@ def worker( shared.state.textinfo = 'Text encode' stream.output_queue.push(('progress', (None, 'Text encoding...'))) sd_models.apply_balanced_offload(shared.sd_model) - sd_models.move_model(text_encoder, devices.device, force=True) # required as hunyuan.encode_prompt_conds checks device before calling model - sd_models.move_model(text_encoder_2, devices.device, force=True) framepack_hijack.set_prompt_template(prompt, system_prompt, optimized_prompt, unmodified_prompt) llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) metadata['comment'] = prompt @@ -102,6 +100,7 @@ def worker( llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler) llama_vec, llama_attention_mask = utils.crop_or_pad_yield_mask(llama_vec, length=512) llama_vec_n, llama_attention_mask_n = utils.crop_or_pad_yield_mask(llama_vec_n, length=512) + sd_models.apply_balanced_offload(shared.sd_model) timer.process.add('prompt', time.time()-t0) return llama_vec, llama_vec_n, llama_attention_mask, llama_attention_mask_n, clip_l_pooler, clip_l_pooler_n @@ -112,7 +111,6 @@ def worker( torch.manual_seed(seed) stream.output_queue.push(('progress', (None, 'VAE encoding...'))) sd_models.apply_balanced_offload(shared.sd_model) - sd_models.move_model(vae, devices.device, force=True) if input_image is not None: input_image_pt = torch.from_numpy(input_image).float() / 127.5 - 1 input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None] @@ -126,6 +124,7 @@ def worker( end_latent = framepack_vae.vae_encode(end_image_pt) else: end_latent = None + sd_models.apply_balanced_offload(shared.sd_model) timer.process.add('encode', time.time()-t0) return start_latent, end_latent @@ -136,6 +135,7 @@ def worker( shared.state.textinfo = 'Vision encode' stream.output_queue.push(('progress', (None, 'Vision encoding...'))) sd_models.apply_balanced_offload(shared.sd_model) + # siglip doesn't work with offload sd_models.move_model(feature_extractor, devices.device, force=True) sd_models.move_model(image_encoder, devices.device, force=True) preprocessed = feature_extractor.preprocess(images=input_image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype) @@ -146,8 +146,9 @@ def worker( end_image_encoder_output = image_encoder(**preprocessed) end_image_encoder_last_hidden_state = end_image_encoder_output.last_hidden_state image_encoder_last_hidden_state = (image_encoder_last_hidden_state * start_weight) + (end_image_encoder_last_hidden_state * end_weight) / (start_weight + end_weight) # use weighted approach - timer.process.add('vision', time.time()-t0) image_encoder_last_hidden_state = image_encoder_last_hidden_state * vision_weight + sd_models.apply_balanced_offload(shared.sd_model) + timer.process.add('vision', time.time()-t0) return image_encoder_last_hidden_state def step_callback(d): @@ -284,7 +285,6 @@ def worker( t_vae = time.time() sd_models.apply_balanced_offload(shared.sd_model) - sd_models.move_model(vae, devices.device, force=True) if history_pixels is None: history_pixels = framepack_vae.vae_decode(real_history_latents, vae_type=vae_type).cpu() else: @@ -297,6 +297,7 @@ def worker( section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) current_pixels = framepack_vae.vae_decode(real_history_latents[:, :, :section_latent_frames], vae_type=vae_type).cpu() history_pixels = utils.soft_append_bcthw(current_pixels, history_pixels, overlapped_frames) + sd_models.apply_balanced_offload(shared.sd_model) timer.process.add('vae', time.time()-t_vae) if is_last_section: diff --git a/modules/framepack/pipeline/hunyuan.py b/modules/framepack/pipeline/hunyuan.py index d9fc1e77c..bd36c0e62 100644 --- a/modules/framepack/pipeline/hunyuan.py +++ b/modules/framepack/pipeline/hunyuan.py @@ -1,5 +1,6 @@ import torch from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE +from modules import devices @torch.no_grad() @@ -24,8 +25,8 @@ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokeniz return_attention_mask=True, ) - llama_input_ids = llama_inputs.input_ids.to(text_encoder.device) - llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device) + llama_input_ids = llama_inputs.input_ids.to(devices.device) + llama_attention_mask = llama_inputs.attention_mask.to(devices.device) llama_attention_length = int(llama_attention_mask.sum()) llama_outputs = text_encoder( @@ -51,7 +52,7 @@ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokeniz return_length=False, return_tensors="pt", ).input_ids - clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output + clip_l_pooler = text_encoder_2(clip_l_input_ids.to(devices.device), output_hidden_states=False).pooler_output return llama_vec, clip_l_pooler @@ -93,9 +94,9 @@ def vae_decode(latents, vae, image_mode=False): latents = latents / vae.config.scaling_factor if not image_mode: - image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample + image = vae.decode(latents.to(device=devices.device, dtype=devices.dtype)).sample else: - latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2) + latents = latents.to(device=devices.device, dtype=devices.dtype).unbind(2) image = [vae.decode(l.unsqueeze(2)).sample for l in latents] image = torch.cat(image, dim=2) @@ -104,6 +105,6 @@ def vae_decode(latents, vae, image_mode=False): @torch.no_grad() def vae_encode(image, vae): - latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample() + latents = vae.encode(image.to(device=devices.device, dtype=devices.dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor return latents