diff --git a/pipelines/model_wanai.py b/pipelines/model_wanai.py index 1be7b0cc7..1ae481226 100644 --- a/pipelines/model_wanai.py +++ b/pipelines/model_wanai.py @@ -60,13 +60,13 @@ def load_wan(checkpoint_info, diffusers_load_config={}): sd_models.hf_auth_check(checkpoint_info) if 'a14b' in repo_id.lower(): - if shared.opts.model_wan_stage == 'high noise': + if shared.opts.model_wan_stage == 'high noise' or shared.opts.model_wan_stage == 'first': transformer = load_transformer(repo_id, diffusers_load_config, 'transformer') transformer_2 = None - elif shared.opts.model_wan_stage == 'low noise': + elif shared.opts.model_wan_stage == 'low noise' or shared.opts.model_wan_stage == 'second': transformer = load_transformer(repo_id, diffusers_load_config, 'transformer_2') transformer_2 = None - elif shared.opts.model_wan_stage == 'combined': + elif shared.opts.model_wan_stage == 'combined' or shared.opts.model_wan_stage == 'both': transformer = load_transformer(repo_id, diffusers_load_config, 'transformer') transformer_2 = load_transformer(repo_id, diffusers_load_config, 'transformer_2') else: