Do not ignore offload with FramePack

pull/4058/head
Disty0 2025-07-13 20:41:55 +03:00
parent 4e449ca9b2
commit 3087b8d8cd
3 changed files with 15 additions and 13 deletions

View File

@ -71,7 +71,7 @@ def vae_decode_remote(latents):
def vae_decode_full(latents):
with devices.inference_context():
vae = shared.sd_model.vae
latents = (latents / vae.config.scaling_factor).to(device=vae.device, dtype=vae.dtype)
latents = (latents / vae.config.scaling_factor).to(device=devices.device, dtype=devices.dtype)
images = vae.decode(latents).sample
return images
@ -91,6 +91,6 @@ def vae_decode(latents, vae_type):
def vae_encode(image):
with devices.inference_context():
vae = shared.sd_model.vae
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
latents = vae.encode(image.to(device=devices.device, dtype=devices.dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
return latents

View File

@ -91,8 +91,6 @@ def worker(
shared.state.textinfo = 'Text encode'
stream.output_queue.push(('progress', (None, 'Text encoding...')))
sd_models.apply_balanced_offload(shared.sd_model)
sd_models.move_model(text_encoder, devices.device, force=True) # required as hunyuan.encode_prompt_conds checks device before calling model
sd_models.move_model(text_encoder_2, devices.device, force=True)
framepack_hijack.set_prompt_template(prompt, system_prompt, optimized_prompt, unmodified_prompt)
llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
metadata['comment'] = prompt
@ -102,6 +100,7 @@ def worker(
llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
llama_vec, llama_attention_mask = utils.crop_or_pad_yield_mask(llama_vec, length=512)
llama_vec_n, llama_attention_mask_n = utils.crop_or_pad_yield_mask(llama_vec_n, length=512)
sd_models.apply_balanced_offload(shared.sd_model)
timer.process.add('prompt', time.time()-t0)
return llama_vec, llama_vec_n, llama_attention_mask, llama_attention_mask_n, clip_l_pooler, clip_l_pooler_n
@ -112,7 +111,6 @@ def worker(
torch.manual_seed(seed)
stream.output_queue.push(('progress', (None, 'VAE encoding...')))
sd_models.apply_balanced_offload(shared.sd_model)
sd_models.move_model(vae, devices.device, force=True)
if input_image is not None:
input_image_pt = torch.from_numpy(input_image).float() / 127.5 - 1
input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
@ -126,6 +124,7 @@ def worker(
end_latent = framepack_vae.vae_encode(end_image_pt)
else:
end_latent = None
sd_models.apply_balanced_offload(shared.sd_model)
timer.process.add('encode', time.time()-t0)
return start_latent, end_latent
@ -136,6 +135,7 @@ def worker(
shared.state.textinfo = 'Vision encode'
stream.output_queue.push(('progress', (None, 'Vision encoding...')))
sd_models.apply_balanced_offload(shared.sd_model)
# siglip doesn't work with offload
sd_models.move_model(feature_extractor, devices.device, force=True)
sd_models.move_model(image_encoder, devices.device, force=True)
preprocessed = feature_extractor.preprocess(images=input_image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
@ -146,8 +146,9 @@ def worker(
end_image_encoder_output = image_encoder(**preprocessed)
end_image_encoder_last_hidden_state = end_image_encoder_output.last_hidden_state
image_encoder_last_hidden_state = (image_encoder_last_hidden_state * start_weight) + (end_image_encoder_last_hidden_state * end_weight) / (start_weight + end_weight) # use weighted approach
timer.process.add('vision', time.time()-t0)
image_encoder_last_hidden_state = image_encoder_last_hidden_state * vision_weight
sd_models.apply_balanced_offload(shared.sd_model)
timer.process.add('vision', time.time()-t0)
return image_encoder_last_hidden_state
def step_callback(d):
@ -284,7 +285,6 @@ def worker(
t_vae = time.time()
sd_models.apply_balanced_offload(shared.sd_model)
sd_models.move_model(vae, devices.device, force=True)
if history_pixels is None:
history_pixels = framepack_vae.vae_decode(real_history_latents, vae_type=vae_type).cpu()
else:
@ -297,6 +297,7 @@ def worker(
section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
current_pixels = framepack_vae.vae_decode(real_history_latents[:, :, :section_latent_frames], vae_type=vae_type).cpu()
history_pixels = utils.soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
sd_models.apply_balanced_offload(shared.sd_model)
timer.process.add('vae', time.time()-t_vae)
if is_last_section:

View File

@ -1,5 +1,6 @@
import torch
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
from modules import devices
@torch.no_grad()
@ -24,8 +25,8 @@ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokeniz
return_attention_mask=True,
)
llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
llama_input_ids = llama_inputs.input_ids.to(devices.device)
llama_attention_mask = llama_inputs.attention_mask.to(devices.device)
llama_attention_length = int(llama_attention_mask.sum())
llama_outputs = text_encoder(
@ -51,7 +52,7 @@ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokeniz
return_length=False,
return_tensors="pt",
).input_ids
clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
clip_l_pooler = text_encoder_2(clip_l_input_ids.to(devices.device), output_hidden_states=False).pooler_output
return llama_vec, clip_l_pooler
@ -93,9 +94,9 @@ def vae_decode(latents, vae, image_mode=False):
latents = latents / vae.config.scaling_factor
if not image_mode:
image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
image = vae.decode(latents.to(device=devices.device, dtype=devices.dtype)).sample
else:
latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
latents = latents.to(device=devices.device, dtype=devices.dtype).unbind(2)
image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
image = torch.cat(image, dim=2)
@ -104,6 +105,6 @@ def vae_decode(latents, vae, image_mode=False):
@torch.no_grad()
def vae_encode(image, vae):
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
latents = vae.encode(image.to(device=devices.device, dtype=devices.dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
return latents