mirror of https://github.com/vladmandic/automatic
Add HunyuanImage3Wrapper
parent
027a793cb5
commit
9533b7e847
|
|
@ -1,3 +1,4 @@
|
|||
import torch
|
||||
import transformers
|
||||
import diffusers
|
||||
from modules import shared, sd_models, devices, model_quant, sd_hijack_te, sd_hijack_vae
|
||||
|
|
@ -47,6 +48,7 @@ def load_hyimage3(checkpoint_info, diffusers_load_config={}): # pylint: disable=
|
|||
from modules import sdnq # pylint: disable=unused-import # register to diffusers and transformers
|
||||
sd_models.allow_post_quant = False # we already handled it
|
||||
allow_quant = False
|
||||
|
||||
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True, allow_quant=allow_quant)
|
||||
pipe = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
repo_id,
|
||||
|
|
@ -57,8 +59,56 @@ def load_hyimage3(checkpoint_info, diffusers_load_config={}): # pylint: disable=
|
|||
**quant_args,
|
||||
)
|
||||
pipe.load_tokenizer(repo_id)
|
||||
pipe.__call__ = pipe.generate_image
|
||||
pipe.task_args = {'diff_infer_steps': 20}
|
||||
|
||||
pipe.pipeline # call it to set up pipeline
|
||||
pipe = HunyuanImage3Wrapper(pipe)
|
||||
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
return pipe
|
||||
|
||||
|
||||
class HunyuanImage3Wrapper(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int = None,
|
||||
width: int = None,
|
||||
num_inference_steps: int = 50,
|
||||
num_images_per_prompt: int = 1,
|
||||
guidance_scale: float = 7.5,
|
||||
guidance_rescale: float = 0.0,
|
||||
callback_on_step_end = None,
|
||||
callback_on_step_end_tensor_inputs = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
if hasattr(self.model._pipeline.model, "_hf_hook"):
|
||||
self.model._pipeline.model._hf_hook.execution_device = torch.device(devices.device)
|
||||
|
||||
if num_inference_steps > 1:
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
prompt = prompt * num_images_per_prompt
|
||||
|
||||
if height is None and width is None:
|
||||
image_size = "auto"
|
||||
if height is None:
|
||||
image_size = (width, width)
|
||||
if width is None:
|
||||
image_size = (height, height)
|
||||
else:
|
||||
image_size = (height, width)
|
||||
|
||||
return self.model.generate_image(
|
||||
prompt,
|
||||
image_size=(height, width),
|
||||
diff_infer_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
guidance_rescale=guidance_rescale,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue