mirror of https://github.com/vladmandic/automatic
Do not ignore offload with FramePack
parent
4e449ca9b2
commit
3087b8d8cd
|
|
@ -71,7 +71,7 @@ def vae_decode_remote(latents):
|
|||
def vae_decode_full(latents):
|
||||
with devices.inference_context():
|
||||
vae = shared.sd_model.vae
|
||||
latents = (latents / vae.config.scaling_factor).to(device=vae.device, dtype=vae.dtype)
|
||||
latents = (latents / vae.config.scaling_factor).to(device=devices.device, dtype=devices.dtype)
|
||||
images = vae.decode(latents).sample
|
||||
return images
|
||||
|
||||
|
|
@ -91,6 +91,6 @@ def vae_decode(latents, vae_type):
|
|||
def vae_encode(image):
|
||||
with devices.inference_context():
|
||||
vae = shared.sd_model.vae
|
||||
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
|
||||
latents = vae.encode(image.to(device=devices.device, dtype=devices.dtype)).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
return latents
|
||||
|
|
|
|||
|
|
@ -91,8 +91,6 @@ def worker(
|
|||
shared.state.textinfo = 'Text encode'
|
||||
stream.output_queue.push(('progress', (None, 'Text encoding...')))
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
sd_models.move_model(text_encoder, devices.device, force=True) # required as hunyuan.encode_prompt_conds checks device before calling model
|
||||
sd_models.move_model(text_encoder_2, devices.device, force=True)
|
||||
framepack_hijack.set_prompt_template(prompt, system_prompt, optimized_prompt, unmodified_prompt)
|
||||
llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
|
||||
metadata['comment'] = prompt
|
||||
|
|
@ -102,6 +100,7 @@ def worker(
|
|||
llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
|
||||
llama_vec, llama_attention_mask = utils.crop_or_pad_yield_mask(llama_vec, length=512)
|
||||
llama_vec_n, llama_attention_mask_n = utils.crop_or_pad_yield_mask(llama_vec_n, length=512)
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
timer.process.add('prompt', time.time()-t0)
|
||||
return llama_vec, llama_vec_n, llama_attention_mask, llama_attention_mask_n, clip_l_pooler, clip_l_pooler_n
|
||||
|
||||
|
|
@ -112,7 +111,6 @@ def worker(
|
|||
torch.manual_seed(seed)
|
||||
stream.output_queue.push(('progress', (None, 'VAE encoding...')))
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
sd_models.move_model(vae, devices.device, force=True)
|
||||
if input_image is not None:
|
||||
input_image_pt = torch.from_numpy(input_image).float() / 127.5 - 1
|
||||
input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
|
||||
|
|
@ -126,6 +124,7 @@ def worker(
|
|||
end_latent = framepack_vae.vae_encode(end_image_pt)
|
||||
else:
|
||||
end_latent = None
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
timer.process.add('encode', time.time()-t0)
|
||||
return start_latent, end_latent
|
||||
|
||||
|
|
@ -136,6 +135,7 @@ def worker(
|
|||
shared.state.textinfo = 'Vision encode'
|
||||
stream.output_queue.push(('progress', (None, 'Vision encoding...')))
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
# siglip doesn't work with offload
|
||||
sd_models.move_model(feature_extractor, devices.device, force=True)
|
||||
sd_models.move_model(image_encoder, devices.device, force=True)
|
||||
preprocessed = feature_extractor.preprocess(images=input_image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
|
||||
|
|
@ -146,8 +146,9 @@ def worker(
|
|||
end_image_encoder_output = image_encoder(**preprocessed)
|
||||
end_image_encoder_last_hidden_state = end_image_encoder_output.last_hidden_state
|
||||
image_encoder_last_hidden_state = (image_encoder_last_hidden_state * start_weight) + (end_image_encoder_last_hidden_state * end_weight) / (start_weight + end_weight) # use weighted approach
|
||||
timer.process.add('vision', time.time()-t0)
|
||||
image_encoder_last_hidden_state = image_encoder_last_hidden_state * vision_weight
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
timer.process.add('vision', time.time()-t0)
|
||||
return image_encoder_last_hidden_state
|
||||
|
||||
def step_callback(d):
|
||||
|
|
@ -284,7 +285,6 @@ def worker(
|
|||
|
||||
t_vae = time.time()
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
sd_models.move_model(vae, devices.device, force=True)
|
||||
if history_pixels is None:
|
||||
history_pixels = framepack_vae.vae_decode(real_history_latents, vae_type=vae_type).cpu()
|
||||
else:
|
||||
|
|
@ -297,6 +297,7 @@ def worker(
|
|||
section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
|
||||
current_pixels = framepack_vae.vae_decode(real_history_latents[:, :, :section_latent_frames], vae_type=vae_type).cpu()
|
||||
history_pixels = utils.soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
timer.process.add('vae', time.time()-t_vae)
|
||||
|
||||
if is_last_section:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
|
||||
from modules import devices
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
@ -24,8 +25,8 @@ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokeniz
|
|||
return_attention_mask=True,
|
||||
)
|
||||
|
||||
llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
|
||||
llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
|
||||
llama_input_ids = llama_inputs.input_ids.to(devices.device)
|
||||
llama_attention_mask = llama_inputs.attention_mask.to(devices.device)
|
||||
llama_attention_length = int(llama_attention_mask.sum())
|
||||
|
||||
llama_outputs = text_encoder(
|
||||
|
|
@ -51,7 +52,7 @@ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokeniz
|
|||
return_length=False,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
|
||||
clip_l_pooler = text_encoder_2(clip_l_input_ids.to(devices.device), output_hidden_states=False).pooler_output
|
||||
|
||||
return llama_vec, clip_l_pooler
|
||||
|
||||
|
|
@ -93,9 +94,9 @@ def vae_decode(latents, vae, image_mode=False):
|
|||
latents = latents / vae.config.scaling_factor
|
||||
|
||||
if not image_mode:
|
||||
image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
|
||||
image = vae.decode(latents.to(device=devices.device, dtype=devices.dtype)).sample
|
||||
else:
|
||||
latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
|
||||
latents = latents.to(device=devices.device, dtype=devices.dtype).unbind(2)
|
||||
image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
|
||||
image = torch.cat(image, dim=2)
|
||||
|
||||
|
|
@ -104,6 +105,6 @@ def vae_decode(latents, vae, image_mode=False):
|
|||
|
||||
@torch.no_grad()
|
||||
def vae_encode(image, vae):
|
||||
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
|
||||
latents = vae.encode(image.to(device=devices.device, dtype=devices.dtype)).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
return latents
|
||||
|
|
|
|||
Loading…
Reference in New Issue