diff --git a/pipelines/model_hyimage.py b/pipelines/model_hyimage.py index 9a52cd2c0..6ea7f174a 100644 --- a/pipelines/model_hyimage.py +++ b/pipelines/model_hyimage.py @@ -1,3 +1,4 @@ +import torch import transformers import diffusers from modules import shared, sd_models, devices, model_quant, sd_hijack_te, sd_hijack_vae @@ -47,6 +48,7 @@ def load_hyimage3(checkpoint_info, diffusers_load_config={}): # pylint: disable= from modules import sdnq # pylint: disable=unused-import # register to diffusers and transformers sd_models.allow_post_quant = False # we already handled it allow_quant = False + load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True, allow_quant=allow_quant) pipe = transformers.AutoModelForCausalLM.from_pretrained( repo_id, @@ -57,8 +59,56 @@ def load_hyimage3(checkpoint_info, diffusers_load_config={}): # pylint: disable= **quant_args, ) pipe.load_tokenizer(repo_id) - pipe.__call__ = pipe.generate_image - pipe.task_args = {'diff_infer_steps': 20} + + pipe.pipeline # call it to set up pipeline + pipe = HunyuanImage3Wrapper(pipe) devices.torch_gc(force=True, reason='load') return pipe + + +class HunyuanImage3Wrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def __call__( + self, + prompt: str, + height: int = None, + width: int = None, + num_inference_steps: int = 50, + num_images_per_prompt: int = 1, + guidance_scale: float = 7.5, + guidance_rescale: float = 0.0, + callback_on_step_end = None, + callback_on_step_end_tensor_inputs = ["latents"], + **kwargs, + ): + if hasattr(self.model._pipeline.model, "_hf_hook"): + self.model._pipeline.model._hf_hook.execution_device = torch.device(devices.device) + + if num_inference_steps > 1: + if isinstance(prompt, str): + prompt = [prompt] + prompt = prompt * num_images_per_prompt + + if height is None and width is None: + image_size = "auto" + if height is None: + image_size = (width, width) + if width is None: + image_size = (height, height) + else: + image_size = (height, width) + + return self.model.generate_image( + prompt, + image_size=(height, width), + diff_infer_steps=num_inference_steps, + guidance_scale=guidance_scale, + guidance_rescale=guidance_rescale, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + **kwargs, + )