From f54e4b53dc48748072cdc36512b7100933dfc998 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Mon, 2 Jan 2023 21:00:30 -0600 Subject: [PATCH] Better doing of the thingz --- scripts/patch_fixer.py | 48 +++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/scripts/patch_fixer.py b/scripts/patch_fixer.py index c35d728..514b6bb 100644 --- a/scripts/patch_fixer.py +++ b/scripts/patch_fixer.py @@ -16,7 +16,7 @@ from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting from modules.sd_models import select_checkpoint, load_model_weights -print("Fixing all the things that could have just been a pull request!") +print("Fixing all the things!") model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -73,37 +73,31 @@ def get_config(checkpoint_info): path = checkpoint_info[0] model_config = checkpoint_info.config checkpoint_dir = os.path.join(shared.script_path, "extensions", "sd_auto_fix", "configs") - is_v21 = False if model_config == shared.cmd_opts.config: try: checkpoint = torch.load(path) c_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint v2_key = "cond_stage_model.model.ln_final.weight" if v2_key in c_dict: + model_config = os.path.join(checkpoint_dir, "v2-inference-v.yaml") if "global_step" in checkpoint and checkpoint_info.config == shared.cmd_opts.config: if checkpoint["global_step"] == 875000 or checkpoint["global_step"] == 220000: model_config = os.path.join(checkpoint_dir, "v2-inference.yaml") - else: - model_config = os.path.join(checkpoint_dir, "v2-inference-v.yaml") - print(f"V2 Model detected, selecting model config: {model_config}") - v21_keys = ["callbacks", "lr_schedulers", "native_amp_scaling_state"] - for v22key in v21_keys: - if v22key in checkpoint: - is_v21 = True - break + + print(f"V2 Model detected, selecting model config: {model_config.replace(checkpoint_dir, '')}") del checkpoint except Exception as e: print(f"Exception: {e}") traceback.print_exc() pass - return model_config, is_v21 + return model_config def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() - model_config, is_v21 = get_config(checkpoint_info) + model_config = get_config(checkpoint_info) if model_config != shared.cmd_opts.config: print(f"Loading config from: {model_config}") @@ -137,6 +131,16 @@ def load_model(checkpoint_info=None): sd_model.to(shared.device) sd_hijack.model_hijack.hijack(sd_model) + is_v21 = False + try: + is_v21 = sd_config.model.params.parameterization == "v" + except: + pass + + try: + is_v21 = sd_config.model.params.unet_config.params.num_head_channels == 64 + except: + pass if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half: print("Fixing attention for v21 model.") @@ -158,10 +162,11 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model - if sd_model.sd_model_checkpoint == checkpoint_info.filename: - return + if sd_model is not None and checkpoint_info is not None: + if sd_model.sd_model_checkpoint == checkpoint_info.filename: + return - model_config, is_v21 = get_config(checkpoint_info) + model_config = get_config(checkpoint_info) checkpoint_info = checkpoint_info._replace(config=model_config) if sd_model.sd_checkpoint_info.config != model_config or should_hijack_inpainting( checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): @@ -180,6 +185,17 @@ def reload_model_weights(sd_model=None, info=None): load_model_weights(sd_model, checkpoint_info) sd_hijack.model_hijack.hijack(sd_model) + sd_config = OmegaConf.load(model_config) + is_v21 = False + try: + is_v21 = sd_config.model.params.parameterization == "v" + except: + pass + + try: + is_v21 = sd_config.model.params.unet_config.params.num_head_channels == 64 + except: + pass if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half: print("Fixing attention for v21 model.") @@ -194,7 +210,5 @@ def reload_model_weights(sd_model=None, info=None): return sd_model -# This is so effing ridiculous that we have to do this $hit. - sd_models.load_model = load_model sd_models.reload_model_weights = reload_model_weights