mirror of https://github.com/vladmandic/automatic
parent
ca2e5ae610
commit
4c684143a4
|
|
@ -113,6 +113,8 @@ def guess_by_name(fn, current_guess):
|
|||
return 'Kandinsky 2.2'
|
||||
elif 'kandinsky-3' in fn.lower():
|
||||
return 'Kandinsky 3.0'
|
||||
elif 'hunyuanimage' in fn.lower():
|
||||
return 'HunyuanImage'
|
||||
return current_guess
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -394,6 +394,10 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='
|
|||
from pipelines.model_hdm import load_hdm
|
||||
sd_model = load_hdm(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['HunyuanImage']:
|
||||
from pipelines.model_hyimage import load_hyimage
|
||||
sd_model = load_hyimage(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
except Exception as e:
|
||||
shared.log.error(f'Load {op}: path="{checkpoint_info.path}" {e}')
|
||||
if debug_load:
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ pipelines = {
|
|||
'FLite': getattr(diffusers, 'DiffusionPipeline', None),
|
||||
'Bria': getattr(diffusers, 'DiffusionPipeline', None),
|
||||
'hdm': getattr(diffusers, 'DiffusionPipeline', None),
|
||||
'HunyuanImage': getattr(diffusers, 'DiffusionPipeline', None),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
from modules import shared, sd_models, devices, model_quant, errors # pylint: disable=unused-import
|
||||
|
||||
|
||||
def load_hyimage(checkpoint_info, diffusers_load_config={}):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
sd_models.hf_auth_check(checkpoint_info)
|
||||
|
||||
shared.log.error(f'Load model: type=NextStep model="{checkpoint_info.name}" repo="{repo_id}" not supported')
|
||||
|
||||
"""
|
||||
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config)
|
||||
shared.log.debug(f'Load model: type=HunyuanImage repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}')
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
from pipelines.hyimage.diffusion.pipelines.hunyuanimage_pipeline import HunyuanImagePipeline, HunyuanImagePipelineConfig
|
||||
from pipelines.hyimage.common.config import instantiate
|
||||
|
||||
use_distilled = 'distilled' in repo_id.lower()
|
||||
config = HunyuanImagePipelineConfig.create_default(version='v2.1', use_distilled=use_distilled)
|
||||
config.torch_dtype = devices.dtype
|
||||
config.device = devices.device
|
||||
config.enable_dit_offloading = False
|
||||
config.enable_reprompt_model_offloading = False
|
||||
config.enable_refiner_offloading = False
|
||||
|
||||
pipe = HunyuanImagePipeline(config=config)
|
||||
|
||||
snapshot_folder = snapshot_download(repo_id, cache_dir=shared.opts.hfcache_dir, allow_patterns='vae/vae_2_1/*')
|
||||
pipe.config.vae_config.load_from = os.path.join(snapshot_folder, 'vae/vae_2_1')
|
||||
pipe.vae = instantiate(pipe.config.vae_config.model, vae_path=pipe.config.vae_config.load_from)
|
||||
|
||||
pipe._load_dit() # pylint: disable=protected-access
|
||||
pipe._load_byt5() # pylint: disable=protected-access
|
||||
pipe._load_text_encoder() # pylint: disable=protected-access
|
||||
pipe._load_reprompt_model() # pylint: disable=protected-access
|
||||
|
||||
pipe.task_args = {
|
||||
'use_reprompt': False,
|
||||
'use_refiner': False,
|
||||
}
|
||||
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
"""
|
||||
return None
|
||||
Loading…
Reference in New Issue