Better doing of the thingz
parent
5201d457e3
commit
f54e4b53dc
|
|
@ -16,7 +16,7 @@ from modules.paths import models_path
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||||
from modules.sd_models import select_checkpoint, load_model_weights
|
from modules.sd_models import select_checkpoint, load_model_weights
|
||||||
|
|
||||||
print("Fixing all the things that could have just been a pull request!")
|
print("Fixing all the things!")
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
|
|
||||||
|
|
@ -73,37 +73,31 @@ def get_config(checkpoint_info):
|
||||||
path = checkpoint_info[0]
|
path = checkpoint_info[0]
|
||||||
model_config = checkpoint_info.config
|
model_config = checkpoint_info.config
|
||||||
checkpoint_dir = os.path.join(shared.script_path, "extensions", "sd_auto_fix", "configs")
|
checkpoint_dir = os.path.join(shared.script_path, "extensions", "sd_auto_fix", "configs")
|
||||||
is_v21 = False
|
|
||||||
if model_config == shared.cmd_opts.config:
|
if model_config == shared.cmd_opts.config:
|
||||||
try:
|
try:
|
||||||
checkpoint = torch.load(path)
|
checkpoint = torch.load(path)
|
||||||
c_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
c_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
||||||
v2_key = "cond_stage_model.model.ln_final.weight"
|
v2_key = "cond_stage_model.model.ln_final.weight"
|
||||||
if v2_key in c_dict:
|
if v2_key in c_dict:
|
||||||
|
model_config = os.path.join(checkpoint_dir, "v2-inference-v.yaml")
|
||||||
if "global_step" in checkpoint and checkpoint_info.config == shared.cmd_opts.config:
|
if "global_step" in checkpoint and checkpoint_info.config == shared.cmd_opts.config:
|
||||||
if checkpoint["global_step"] == 875000 or checkpoint["global_step"] == 220000:
|
if checkpoint["global_step"] == 875000 or checkpoint["global_step"] == 220000:
|
||||||
model_config = os.path.join(checkpoint_dir, "v2-inference.yaml")
|
model_config = os.path.join(checkpoint_dir, "v2-inference.yaml")
|
||||||
else:
|
|
||||||
model_config = os.path.join(checkpoint_dir, "v2-inference-v.yaml")
|
print(f"V2 Model detected, selecting model config: {model_config.replace(checkpoint_dir, '')}")
|
||||||
print(f"V2 Model detected, selecting model config: {model_config}")
|
|
||||||
v21_keys = ["callbacks", "lr_schedulers", "native_amp_scaling_state"]
|
|
||||||
for v22key in v21_keys:
|
|
||||||
if v22key in checkpoint:
|
|
||||||
is_v21 = True
|
|
||||||
break
|
|
||||||
|
|
||||||
del checkpoint
|
del checkpoint
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Exception: {e}")
|
print(f"Exception: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pass
|
pass
|
||||||
return model_config, is_v21
|
return model_config
|
||||||
|
|
||||||
|
|
||||||
def load_model(checkpoint_info=None):
|
def load_model(checkpoint_info=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
model_config, is_v21 = get_config(checkpoint_info)
|
model_config = get_config(checkpoint_info)
|
||||||
|
|
||||||
if model_config != shared.cmd_opts.config:
|
if model_config != shared.cmd_opts.config:
|
||||||
print(f"Loading config from: {model_config}")
|
print(f"Loading config from: {model_config}")
|
||||||
|
|
@ -137,6 +131,16 @@ def load_model(checkpoint_info=None):
|
||||||
sd_model.to(shared.device)
|
sd_model.to(shared.device)
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
is_v21 = False
|
||||||
|
try:
|
||||||
|
is_v21 = sd_config.model.params.parameterization == "v"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
is_v21 = sd_config.model.params.unet_config.params.num_head_channels == 64
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half:
|
if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half:
|
||||||
print("Fixing attention for v21 model.")
|
print("Fixing attention for v21 model.")
|
||||||
|
|
@ -158,10 +162,11 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = shared.sd_model
|
sd_model = shared.sd_model
|
||||||
|
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model is not None and checkpoint_info is not None:
|
||||||
return
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
|
return
|
||||||
|
|
||||||
model_config, is_v21 = get_config(checkpoint_info)
|
model_config = get_config(checkpoint_info)
|
||||||
checkpoint_info = checkpoint_info._replace(config=model_config)
|
checkpoint_info = checkpoint_info._replace(config=model_config)
|
||||||
if sd_model.sd_checkpoint_info.config != model_config or should_hijack_inpainting(
|
if sd_model.sd_checkpoint_info.config != model_config or should_hijack_inpainting(
|
||||||
checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||||
|
|
@ -180,6 +185,17 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
sd_config = OmegaConf.load(model_config)
|
||||||
|
is_v21 = False
|
||||||
|
try:
|
||||||
|
is_v21 = sd_config.model.params.parameterization == "v"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
is_v21 = sd_config.model.params.unet_config.params.num_head_channels == 64
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half:
|
if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half:
|
||||||
print("Fixing attention for v21 model.")
|
print("Fixing attention for v21 model.")
|
||||||
|
|
@ -194,7 +210,5 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
# This is so effing ridiculous that we have to do this $hit.
|
|
||||||
|
|
||||||
sd_models.load_model = load_model
|
sd_models.load_model = load_model
|
||||||
sd_models.reload_model_weights = reload_model_weights
|
sd_models.reload_model_weights = reload_model_weights
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue