diff --git a/modules/sd_detect.py b/modules/sd_detect.py index dc488bb1d..b945a6457 100644 --- a/modules/sd_detect.py +++ b/modules/sd_detect.py @@ -113,6 +113,8 @@ def guess_by_name(fn, current_guess): return 'Kandinsky 2.2' elif 'kandinsky-3' in fn.lower(): return 'Kandinsky 3.0' + elif 'hunyuanimage' in fn.lower(): + return 'HunyuanImage' return current_guess diff --git a/modules/sd_models.py b/modules/sd_models.py index 67f9ef118..789b292d2 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -394,6 +394,10 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op=' from pipelines.model_hdm import load_hdm sd_model = load_hdm(checkpoint_info, diffusers_load_config) allow_post_quant = False + elif model_type in ['HunyuanImage']: + from pipelines.model_hyimage import load_hyimage + sd_model = load_hyimage(checkpoint_info, diffusers_load_config) + allow_post_quant = False except Exception as e: shared.log.error(f'Load {op}: path="{checkpoint_info.path}" {e}') if debug_load: diff --git a/modules/shared_items.py b/modules/shared_items.py index 938d9064a..8eac5e070 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -55,6 +55,7 @@ pipelines = { 'FLite': getattr(diffusers, 'DiffusionPipeline', None), 'Bria': getattr(diffusers, 'DiffusionPipeline', None), 'hdm': getattr(diffusers, 'DiffusionPipeline', None), + 'HunyuanImage': getattr(diffusers, 'DiffusionPipeline', None), } diff --git a/pipelines/model_hyimage.py b/pipelines/model_hyimage.py new file mode 100644 index 000000000..5f33c12fd --- /dev/null +++ b/pipelines/model_hyimage.py @@ -0,0 +1,44 @@ +from modules import shared, sd_models, devices, model_quant, errors # pylint: disable=unused-import + + +def load_hyimage(checkpoint_info, diffusers_load_config={}): + repo_id = sd_models.path_to_repo(checkpoint_info) + sd_models.hf_auth_check(checkpoint_info) + + shared.log.error(f'Load model: type=NextStep model="{checkpoint_info.name}" repo="{repo_id}" not supported') + + """ + load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config) + shared.log.debug(f'Load model: type=HunyuanImage repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}') + + sys.path.append(os.path.dirname(__file__)) + from pipelines.hyimage.diffusion.pipelines.hunyuanimage_pipeline import HunyuanImagePipeline, HunyuanImagePipelineConfig + from pipelines.hyimage.common.config import instantiate + + use_distilled = 'distilled' in repo_id.lower() + config = HunyuanImagePipelineConfig.create_default(version='v2.1', use_distilled=use_distilled) + config.torch_dtype = devices.dtype + config.device = devices.device + config.enable_dit_offloading = False + config.enable_reprompt_model_offloading = False + config.enable_refiner_offloading = False + + pipe = HunyuanImagePipeline(config=config) + + snapshot_folder = snapshot_download(repo_id, cache_dir=shared.opts.hfcache_dir, allow_patterns='vae/vae_2_1/*') + pipe.config.vae_config.load_from = os.path.join(snapshot_folder, 'vae/vae_2_1') + pipe.vae = instantiate(pipe.config.vae_config.model, vae_path=pipe.config.vae_config.load_from) + + pipe._load_dit() # pylint: disable=protected-access + pipe._load_byt5() # pylint: disable=protected-access + pipe._load_text_encoder() # pylint: disable=protected-access + pipe._load_reprompt_model() # pylint: disable=protected-access + + pipe.task_args = { + 'use_reprompt': False, + 'use_refiner': False, + } + + devices.torch_gc(force=True, reason='load') + """ + return None