Better doing of the thingz

main
d8ahazard 2023-01-02 21:00:30 -06:00
parent 5201d457e3
commit f54e4b53dc
1 changed files with 31 additions and 17 deletions

View File

@ -16,7 +16,7 @@ from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
from modules.sd_models import select_checkpoint, load_model_weights from modules.sd_models import select_checkpoint, load_model_weights
print("Fixing all the things that could have just been a pull request!") print("Fixing all the things!")
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_path = os.path.abspath(os.path.join(models_path, model_dir))
@ -73,37 +73,31 @@ def get_config(checkpoint_info):
path = checkpoint_info[0] path = checkpoint_info[0]
model_config = checkpoint_info.config model_config = checkpoint_info.config
checkpoint_dir = os.path.join(shared.script_path, "extensions", "sd_auto_fix", "configs") checkpoint_dir = os.path.join(shared.script_path, "extensions", "sd_auto_fix", "configs")
is_v21 = False
if model_config == shared.cmd_opts.config: if model_config == shared.cmd_opts.config:
try: try:
checkpoint = torch.load(path) checkpoint = torch.load(path)
c_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint c_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
v2_key = "cond_stage_model.model.ln_final.weight" v2_key = "cond_stage_model.model.ln_final.weight"
if v2_key in c_dict: if v2_key in c_dict:
model_config = os.path.join(checkpoint_dir, "v2-inference-v.yaml")
if "global_step" in checkpoint and checkpoint_info.config == shared.cmd_opts.config: if "global_step" in checkpoint and checkpoint_info.config == shared.cmd_opts.config:
if checkpoint["global_step"] == 875000 or checkpoint["global_step"] == 220000: if checkpoint["global_step"] == 875000 or checkpoint["global_step"] == 220000:
model_config = os.path.join(checkpoint_dir, "v2-inference.yaml") model_config = os.path.join(checkpoint_dir, "v2-inference.yaml")
else:
model_config = os.path.join(checkpoint_dir, "v2-inference-v.yaml") print(f"V2 Model detected, selecting model config: {model_config.replace(checkpoint_dir, '')}")
print(f"V2 Model detected, selecting model config: {model_config}")
v21_keys = ["callbacks", "lr_schedulers", "native_amp_scaling_state"]
for v22key in v21_keys:
if v22key in checkpoint:
is_v21 = True
break
del checkpoint del checkpoint
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"Exception: {e}")
traceback.print_exc() traceback.print_exc()
pass pass
return model_config, is_v21 return model_config
def load_model(checkpoint_info=None): def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
model_config, is_v21 = get_config(checkpoint_info) model_config = get_config(checkpoint_info)
if model_config != shared.cmd_opts.config: if model_config != shared.cmd_opts.config:
print(f"Loading config from: {model_config}") print(f"Loading config from: {model_config}")
@ -137,6 +131,16 @@ def load_model(checkpoint_info=None):
sd_model.to(shared.device) sd_model.to(shared.device)
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
is_v21 = False
try:
is_v21 = sd_config.model.params.parameterization == "v"
except:
pass
try:
is_v21 = sd_config.model.params.unet_config.params.num_head_channels == 64
except:
pass
if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half: if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half:
print("Fixing attention for v21 model.") print("Fixing attention for v21 model.")
@ -158,10 +162,11 @@ def reload_model_weights(sd_model=None, info=None):
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = shared.sd_model
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model is not None and checkpoint_info is not None:
return if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
model_config, is_v21 = get_config(checkpoint_info) model_config = get_config(checkpoint_info)
checkpoint_info = checkpoint_info._replace(config=model_config) checkpoint_info = checkpoint_info._replace(config=model_config)
if sd_model.sd_checkpoint_info.config != model_config or should_hijack_inpainting( if sd_model.sd_checkpoint_info.config != model_config or should_hijack_inpainting(
checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
@ -180,6 +185,17 @@ def reload_model_weights(sd_model=None, info=None):
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
sd_config = OmegaConf.load(model_config)
is_v21 = False
try:
is_v21 = sd_config.model.params.parameterization == "v"
except:
pass
try:
is_v21 = sd_config.model.params.unet_config.params.num_head_channels == 64
except:
pass
if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half: if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half:
print("Fixing attention for v21 model.") print("Fixing attention for v21 model.")
@ -194,7 +210,5 @@ def reload_model_weights(sd_model=None, info=None):
return sd_model return sd_model
# This is so effing ridiculous that we have to do this $hit.
sd_models.load_model = load_model sd_models.load_model = load_model
sd_models.reload_model_weights = reload_model_weights sd_models.reload_model_weights = reload_model_weights