Better doing of the thingz
parent
5201d457e3
commit
f54e4b53dc
|
|
@ -16,7 +16,7 @@ from modules.paths import models_path
|
|||
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||
from modules.sd_models import select_checkpoint, load_model_weights
|
||||
|
||||
print("Fixing all the things that could have just been a pull request!")
|
||||
print("Fixing all the things!")
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
|
||||
|
|
@ -73,37 +73,31 @@ def get_config(checkpoint_info):
|
|||
path = checkpoint_info[0]
|
||||
model_config = checkpoint_info.config
|
||||
checkpoint_dir = os.path.join(shared.script_path, "extensions", "sd_auto_fix", "configs")
|
||||
is_v21 = False
|
||||
if model_config == shared.cmd_opts.config:
|
||||
try:
|
||||
checkpoint = torch.load(path)
|
||||
c_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
||||
v2_key = "cond_stage_model.model.ln_final.weight"
|
||||
if v2_key in c_dict:
|
||||
model_config = os.path.join(checkpoint_dir, "v2-inference-v.yaml")
|
||||
if "global_step" in checkpoint and checkpoint_info.config == shared.cmd_opts.config:
|
||||
if checkpoint["global_step"] == 875000 or checkpoint["global_step"] == 220000:
|
||||
model_config = os.path.join(checkpoint_dir, "v2-inference.yaml")
|
||||
else:
|
||||
model_config = os.path.join(checkpoint_dir, "v2-inference-v.yaml")
|
||||
print(f"V2 Model detected, selecting model config: {model_config}")
|
||||
v21_keys = ["callbacks", "lr_schedulers", "native_amp_scaling_state"]
|
||||
for v22key in v21_keys:
|
||||
if v22key in checkpoint:
|
||||
is_v21 = True
|
||||
break
|
||||
|
||||
print(f"V2 Model detected, selecting model config: {model_config.replace(checkpoint_dir, '')}")
|
||||
|
||||
del checkpoint
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}")
|
||||
traceback.print_exc()
|
||||
pass
|
||||
return model_config, is_v21
|
||||
return model_config
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
model_config, is_v21 = get_config(checkpoint_info)
|
||||
model_config = get_config(checkpoint_info)
|
||||
|
||||
if model_config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {model_config}")
|
||||
|
|
@ -137,6 +131,16 @@ def load_model(checkpoint_info=None):
|
|||
sd_model.to(shared.device)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
is_v21 = False
|
||||
try:
|
||||
is_v21 = sd_config.model.params.parameterization == "v"
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
is_v21 = sd_config.model.params.unet_config.params.num_head_channels == 64
|
||||
except:
|
||||
pass
|
||||
|
||||
if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half:
|
||||
print("Fixing attention for v21 model.")
|
||||
|
|
@ -158,10 +162,11 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
if sd_model is not None and checkpoint_info is not None:
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
model_config, is_v21 = get_config(checkpoint_info)
|
||||
model_config = get_config(checkpoint_info)
|
||||
checkpoint_info = checkpoint_info._replace(config=model_config)
|
||||
if sd_model.sd_checkpoint_info.config != model_config or should_hijack_inpainting(
|
||||
checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||
|
|
@ -180,6 +185,17 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
load_model_weights(sd_model, checkpoint_info)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
sd_config = OmegaConf.load(model_config)
|
||||
is_v21 = False
|
||||
try:
|
||||
is_v21 = sd_config.model.params.parameterization == "v"
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
is_v21 = sd_config.model.params.unet_config.params.num_head_channels == 64
|
||||
except:
|
||||
pass
|
||||
|
||||
if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half:
|
||||
print("Fixing attention for v21 model.")
|
||||
|
|
@ -194,7 +210,5 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
return sd_model
|
||||
|
||||
|
||||
# This is so effing ridiculous that we have to do this $hit.
|
||||
|
||||
sd_models.load_model = load_model
|
||||
sd_models.reload_model_weights = reload_model_weights
|
||||
|
|
|
|||
Loading…
Reference in New Issue