Add HunyuanImage3Wrapper

pull/4325/head
Disty0 2025-10-29 22:07:25 +03:00
parent 027a793cb5
commit 9533b7e847
1 changed files with 52 additions and 2 deletions

View File

@ -1,3 +1,4 @@
import torch
import transformers
import diffusers
from modules import shared, sd_models, devices, model_quant, sd_hijack_te, sd_hijack_vae
@ -47,6 +48,7 @@ def load_hyimage3(checkpoint_info, diffusers_load_config={}): # pylint: disable=
from modules import sdnq # pylint: disable=unused-import # register to diffusers and transformers
sd_models.allow_post_quant = False # we already handled it
allow_quant = False
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True, allow_quant=allow_quant)
pipe = transformers.AutoModelForCausalLM.from_pretrained(
repo_id,
@ -57,8 +59,56 @@ def load_hyimage3(checkpoint_info, diffusers_load_config={}): # pylint: disable=
**quant_args,
)
pipe.load_tokenizer(repo_id)
pipe.__call__ = pipe.generate_image
pipe.task_args = {'diff_infer_steps': 20}
pipe.pipeline # call it to set up pipeline
pipe = HunyuanImage3Wrapper(pipe)
devices.torch_gc(force=True, reason='load')
return pipe
class HunyuanImage3Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def __call__(
self,
prompt: str,
height: int = None,
width: int = None,
num_inference_steps: int = 50,
num_images_per_prompt: int = 1,
guidance_scale: float = 7.5,
guidance_rescale: float = 0.0,
callback_on_step_end = None,
callback_on_step_end_tensor_inputs = ["latents"],
**kwargs,
):
if hasattr(self.model._pipeline.model, "_hf_hook"):
self.model._pipeline.model._hf_hook.execution_device = torch.device(devices.device)
if num_inference_steps > 1:
if isinstance(prompt, str):
prompt = [prompt]
prompt = prompt * num_images_per_prompt
if height is None and width is None:
image_size = "auto"
if height is None:
image_size = (width, width)
if width is None:
image_size = (height, height)
else:
image_size = (height, width)
return self.model.generate_image(
prompt,
image_size=(height, width),
diff_infer_steps=num_inference_steps,
guidance_scale=guidance_scale,
guidance_rescale=guidance_rescale,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
**kwargs,
)