import collections import os.path import sys import gc import re import io from os import mkdir from urllib import request from rich import print, progress # pylint: disable=redefined-builtin import torch import safetensors.torch from omegaconf import OmegaConf import ldm.modules.midas as midas from ldm.util import instantiate_from_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) checkpoints_list = {} checkpoint_aliases = {} checkpoints_loaded = collections.OrderedDict() class CheckpointInfo: def __init__(self, filename): self.filename = filename abspath = os.path.abspath(filename) if shared.opts.ckpt_dir is not None and abspath.startswith(shared.opts.ckpt_dir): name = abspath.replace(shared.opts.ckpt_dir, '') elif abspath.startswith(model_path): name = abspath.replace(model_path, '') else: name = os.path.basename(filename) if name.startswith("\\") or name.startswith("/"): name = name[1:] self.name = name self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.hash = model_hash(filename) self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name) self.shorthash = self.sha256[0:10] if self.sha256 else None self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self for i in self.ids: checkpoint_aliases[i] = self def calculate_shorthash(self): self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name) if self.sha256 is None: return self.shorthash = self.sha256[0:10] if self.shorthash not in self.ids: self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] checkpoints_list.pop(self.title) self.title = f'{self.name} [{self.shorthash}]' self.register() return self.shorthash try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. from transformers import logging logging.set_verbosity_error() except Exception: pass def setup_model(): if not os.path.exists(model_path): os.makedirs(model_path) list_models() enable_midas_autodownload() def checkpoint_tiles(): def convert(name): return int(name) if name.isdigit() else name.lower() def alphanumeric_key(key): return [convert(c) for c in re.split('([0-9]+)', key)] return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) def list_models(): global model_path # pylint: disable=global-statement model_path = shared.opts.ckpt_dir checkpoints_list.clear() checkpoint_aliases.clear() model_list = modelloader.load_models(model_path=model_path, model_url=None, command_path=shared.opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name=None, ext_blacklist=[".vae.ckpt", ".vae.safetensors"]) if shared.cmd_opts.ckpt is not None and os.path.exists(shared.cmd_opts.ckpt): checkpoint_info = CheckpointInfo(shared.cmd_opts.ckpt) checkpoint_info.register() shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title elif shared.cmd_opts.ckpt != shared.default_sd_model_file: print(f"Checkpoint not found: {shared.cmd_opts.ckpt}", file=sys.stderr) for filename in sorted(model_list, key=str.lower): checkpoint_info = CheckpointInfo(filename) checkpoint_info.register() print(f'Available models: {shared.opts.ckpt_dir} {len(checkpoints_list)}') if len(checkpoints_list) == 0: if not shared.cmd_opts.no_download_sd_model: key = input('Download the default model? (y/N) ') if key.lower().startswith('y'): model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors" model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"]) for filename in sorted(model_list, key=str.lower): checkpoint_info = CheckpointInfo(filename) checkpoint_info.register() def get_closet_checkpoint_match(search_string): checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: return checkpoint_info found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) if found: return found[0] return None def model_hash(filename): """old hash that only looks at a small part of the file and is prone to collisions""" try: with open(filename, "rb") as file: import hashlib m = hashlib.sha256() file.seek(0x100000) m.update(file.read(0x10000)) return m.hexdigest()[0:8] except FileNotFoundError: return 'NOFILE' def select_checkpoint(): model_checkpoint = shared.opts.sd_model_checkpoint checkpoint_info = checkpoint_aliases.get(model_checkpoint, None) if checkpoint_info is not None: return checkpoint_info if len(checkpoints_list) == 0: print("Cannot run without a checkpoint", file=sys.stderr) print("Use --ckpt to force using existing checkpoint", file=sys.stderr) exit(1) checkpoint_info = next(iter(checkpoints_list.values())) if model_checkpoint is not None: print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr) return checkpoint_info checkpoint_dict_replacements = { 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', } def transform_checkpoint_dict_key(k): for text, replacement in checkpoint_dict_replacements.items(): if k.startswith(text): k = replacement + k[len(text):] return k def get_state_dict_from_checkpoint(pl_sd): pl_sd = pl_sd.pop("state_dict", pl_sd) pl_sd.pop("state_dict", None) sd = {} for k, v in pl_sd.items(): new_key = transform_checkpoint_dict_key(k) if new_key is not None: sd[new_key] = v pl_sd.clear() pl_sd.update(sd) return pl_sd def read_metadata_from_safetensors(filename): import json with open(filename, mode="rb") as file: metadata_len = file.read(8) metadata_len = int.from_bytes(metadata_len, "little") json_start = file.read(2) assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file" json_data = json_start + file.read(metadata_len-2) json_obj = json.loads(json_data) res = {} for k, v in json_obj.get("__metadata__", {}).items(): res[k] = v if isinstance(v, str) and v[0:1] == '{': try: res[k] = json.loads(v) except Exception: pass return res def read_state_dict(checkpoint_file): try: with progress.open(checkpoint_file, 'rb', description=f'Loading weights: [cyan]{checkpoint_file}', auto_refresh=True) as f: _, extension = os.path.splitext(checkpoint_file) if 'v1-5-pruned-emaonly.safetensors' or 'vae-ft-mse-840000-ema-pruned.ckpt' in checkpoint_file: if extension.lower() == ".safetensors": pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu') else: pl_sd = torch.load(checkpoint_file, map_location='cpu') else: if extension.lower() == ".safetensors": buffer = f.read() pl_sd = safetensors.torch.load(buffer) else: buffer = io.BytesIO(f.read()) pl_sd = torch.load(buffer, map_location='cpu') sd = get_state_dict_from_checkpoint(pl_sd) except Exception as e: errors.display(e, f'loading model: {checkpoint_file}') sd = None return sd def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): if checkpoint_info in checkpoints_loaded: # use checkpoint cache print("Loading weights from cache") return checkpoints_loaded[checkpoint_info] res = read_state_dict(checkpoint_info.filename) timer.record("load") return res def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("hash") shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) model.load_state_dict(state_dict, strict=False) del state_dict timer.record("apply") if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model checkpoints_loaded[checkpoint_info] = model.state_dict().copy() if shared.opts.opt_channelslast: model.to(memory_format=torch.channels_last) timer.record("channels") if not shared.cmd_opts.no_half: vae = model.first_stage_model depth_model = getattr(model, 'depth_model', None) # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 if shared.cmd_opts.no_half_vae: model.first_stage_model = None # with --upcast-sampling, don't convert the depth model weights to float16 if shared.opts.upcast_sampling and depth_model: model.depth_model = None model.half() model.first_stage_model = vae if depth_model: model.depth_model = depth_model devices.set_cuda_params() devices.dtype_unet = model.model.diffusion_model.dtype model.first_stage_model.to(devices.dtype_vae) # clean up cache if limit is reached while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_info.filename model.sd_checkpoint_info = checkpoint_info shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) sd_vae.load_vae(model, vae_file, vae_source) timer.record("vae") def enable_midas_autodownload(): """ Gives the ldm.modules.midas.api.load_model function automatic downloading. When the 512-depth-ema model, and other future models like it, is loaded, it calls midas.api.load_model to load the associated midas depth model. This function applies a wrapper to download the model to the correct location automatically. """ midas_path = os.path.join(paths.models_path, 'midas') # stable-diffusion-stability-ai hard-codes the midas model path to # a location that differs from where other scripts using this model look. # HACK: Overriding the path here. for k, v in midas.api.ISL_PATHS.items(): file_name = os.path.basename(v) midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name) midas_urls = { "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt", "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt", "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt", "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt", } midas.api.load_model_inner = midas.api.load_model def load_model_wrapper(model_type): path = midas.api.ISL_PATHS[model_type] if not os.path.exists(path): if not os.path.exists(midas_path): mkdir(midas_path) print(f"Downloading midas model weights for {model_type} to {path}") request.urlretrieve(midas_urls[model_type], path) print(f"{model_type} downloaded") return midas.api.load_model_inner(model_type) midas.api.load_model = load_model_wrapper def repair_config(sd_config): if not "use_ema" in sd_config.model.params: sd_config.model.params.use_ema = False if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False elif shared.opts.upcast_sampling: sd_config.model.params.unet_config.params.use_fp16 = True if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" # For UnCLIP-L, override the hardcoded karlo directory if "noise_aug_config" in sd_config.model.params and "clip_stats_path" in sd_config.model.params.noise_aug_config.params: karlo_path = os.path.join(paths.models_path, 'karlo') sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path) sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() do_inpainting_hijack() timer = Timer() current_checkpoint_info = None if shared.sd_model: current_checkpoint_info = shared.sd_model.sd_checkpoint_info sd_hijack.model_hijack.undo_hijack(shared.sd_model) shared.sd_model = None gc.collect() devices.torch_gc() if already_loaded_state_dict is not None: state_dict = already_loaded_state_dict else: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) if state_dict is None or checkpoint_config is None: print(f"Failed to load checkpooint: {checkpoint_info.filename}") if current_checkpoint_info is not None: print(f"Restoring previous checkpoint: {current_checkpoint_info.filename}") load_model(current_checkpoint_info, None) return clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict sd_config = OmegaConf.load(checkpoint_config) repair_config(sd_config) timer.record("config") print(f"Creating model from config: {checkpoint_config}") sd_model = None try: with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd): sd_model = instantiate_from_config(sd_config.model) except Exception: sd_model = instantiate_from_config(sd_config.model) sd_model.used_config = checkpoint_config timer.record("create") load_model_weights(sd_model, checkpoint_info, state_dict, timer) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) else: sd_model.to(shared.device) timer.record("move") sd_hijack.model_hijack.hijack(sd_model) timer.record("hijack") sd_model.eval() shared.sd_model = sd_model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model timer.record("embeddings") script_callbacks.model_loaded_callback(sd_model) timer.record("callbacks") print(f"Model loaded in {timer.summary()}") return sd_model def reload_model_weights(sd_model=None, info=None): from modules import lowvram, sd_hijack checkpoint_info = info or select_checkpoint() if not sd_model: sd_model = shared.sd_model if sd_model is None: # previous model load failed current_checkpoint_info = None else: current_checkpoint_info = sd_model.sd_checkpoint_info if sd_model.sd_model_checkpoint == checkpoint_info.filename: return if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: sd_model.to(devices.cpu) sd_hijack.model_hijack.undo_hijack(sd_model) timer = Timer() state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) timer.record("find config") if sd_model is None or checkpoint_config != sd_model.used_config: del sd_model checkpoints_loaded.clear() load_model(checkpoint_info, already_loaded_state_dict=state_dict) return shared.sd_model try: load_model_weights(sd_model, checkpoint_info, state_dict, timer) except Exception: print("Failed to load checkpoint, restoring previous") load_model_weights(sd_model, current_checkpoint_info, None, timer) raise finally: sd_hijack.model_hijack.hijack(sd_model) timer.record("hijack") script_callbacks.model_loaded_callback(sd_model) timer.record("callbacks") if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) timer.record("device") print(f"Weights loaded in {timer.summary()}") def unload_model_weights(sd_model=None, _info=None): from modules import sd_hijack timer = Timer() if shared.sd_model: # shared.sd_model.cond_stage_model.to(devices.cpu) # shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.to(devices.cpu) sd_hijack.model_hijack.undo_hijack(shared.sd_model) shared.sd_model = None sd_model = None gc.collect() devices.torch_gc() torch.cuda.empty_cache() print(f"Unloaded weights {timer.summary()}") return sd_model def apply_token_merging(sd_model, hr: bool): """ Applies speed and memory optimizations from tomesd. Args: hr (bool): True if called in the context of a high-res pass """ ratio = shared.opts.token_merging_ratio if hr: ratio = shared.opts.token_merging_ratio_hr tomesd.apply_patch( sd_model, ratio=ratio, max_downsample=shared.opts.token_merging_maximum_down_sampling, sx=shared.opts.token_merging_stride_x, sy=shared.opts.token_merging_stride_y, use_rand=shared.opts.token_merging_random, merge_attn=shared.opts.token_merging_merge_attention, merge_crossattn=shared.opts.token_merging_merge_cross_attention, merge_mlp=shared.opts.token_merging_merge_mlp )