import io import base64 import os import re import time import json import collections from PIL import Image from modules import shared, paths, modelloader, hashes from modules.logger import log from modules.json_helpers import writefile checkpoints_list = {} checkpoint_aliases = {} checkpoints_loaded = collections.OrderedDict() model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) sd_metadata_file = os.path.join(paths.data_path, "data", "metadata.json") sd_metadata = None sd_metadata_pending = 0 sd_metadata_timer = 0 warn_once = False class CheckpointInfo: def __init__(self, filename, sha=None, subfolder=None): self.name = None self.hash = sha self.filename = filename self.type = '' self.subfolder = subfolder relname = filename app_path = os.path.abspath(paths.script_path) def rel(fn, path): try: return os.path.relpath(fn, path) except Exception: return fn if relname.startswith('..'): relname = os.path.abspath(relname) if relname.startswith(shared.opts.ckpt_dir): relname = rel(filename, shared.opts.ckpt_dir) elif relname.startswith(shared.opts.diffusers_dir): relname = rel(filename, shared.opts.diffusers_dir) elif relname.startswith(model_path): relname = rel(filename, model_path) elif relname.startswith(paths.script_path): relname = rel(filename, paths.script_path) elif relname.startswith(app_path): relname = rel(filename, app_path) else: relname = os.path.abspath(relname) relname, ext = os.path.splitext(relname) ext = ext.lower()[1:] if filename.lower() == 'none': self.name = 'none' self.relname = 'none' self.sha256 = None self.type = 'unknown' elif os.path.isfile(filename): # ckpt or safetensor self.name = relname self.filename = filename self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{relname}") self.type = ext if 'nf4' in filename: self.type = 'transformer' else: # maybe a diffuser if self.hash is None: repo = [r for r in modelloader.diffuser_repos if self.filename == r['name']] else: repo = [r for r in modelloader.diffuser_repos if self.hash == r['hash']] if len(repo) == 0: self.name = filename self.filename = filename self.sha256 = None self.type = 'unknown' else: self.name = os.path.join(os.path.basename(shared.opts.diffusers_dir), repo[0]['name']) self.filename = repo[0]['path'] self.sha256 = repo[0]['hash'] self.type = 'diffusers' self.shorthash = self.sha256[0:10] if self.sha256 else None self.title = self.name if self.shorthash is None else f'{self.name} [{self.shorthash}]' self.path = self.filename self.model_name = os.path.basename(self.name) self.metadata = read_metadata_from_safetensors(filename) # log.debug(f'Checkpoint: type={self.type} name={self.name} filename={self.filename} hash={self.shorthash} title={self.title}') def register(self): checkpoints_list[self.title] = self for i in [self.name, self.filename, self.shorthash, self.title]: if i is not None: checkpoint_aliases[i] = self def calculate_shorthash(self): self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}") if self.sha256 is None: return None self.shorthash = self.sha256[0:10] if self.title in checkpoints_list: checkpoints_list.pop(self.title) self.title = f'{self.name} [{self.shorthash}]' self.register() return self.shorthash def __str__(self): return f'CheckpointInfo(name="{self.name}" filename="{self.filename}" hash={self.shorthash} type={self.type} title="{self.title}" path="{self.path}" subfolder="{self.subfolder}")' def setup_model(): list_models() # sd_hijack_accelerate.hijack_hfhub() # sd_hijack_accelerate.hijack_torch_conv() def checkpoint_titles(use_short=False): def convert(name): return int(name) if name.isdigit() else name.lower() def alphanumeric_key(key): return [convert(c) for c in re.split("([0-9]+)", key)] if use_short: return sorted([x.title.rsplit("\\", 1)[-1].rsplit("/", 1)[-1] for x in checkpoints_list.values()], key=alphanumeric_key) return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) def list_models(): t0 = time.time() global checkpoints_list # pylint: disable=global-statement checkpoints_list.clear() checkpoint_aliases.clear() ext_filter = [".safetensors"] model_list = list(modelloader.load_models(model_path=model_path, model_url=None, command_path=shared.opts.ckpt_dir, ext_filter=ext_filter, download_name=None, ext_blacklist=[".vae.ckpt", ".vae.safetensors"])) safetensors_list = [] for filename in sorted(model_list, key=str.lower): checkpoint_info = CheckpointInfo(filename) safetensors_list.append(checkpoint_info) if checkpoint_info.name is not None: checkpoint_info.register() diffusers_list = [] for repo in modelloader.load_diffusers_models(clear=True): checkpoint_info = CheckpointInfo(repo['name'], sha=repo['hash']) diffusers_list.append(checkpoint_info) if checkpoint_info.name is not None: checkpoint_info.register() if shared.cmd_opts.ckpt is not None: checkpoint_info = CheckpointInfo(shared.cmd_opts.ckpt) if checkpoint_info.name is not None and os.path.exists(checkpoint_info.filename): checkpoint_info.register() shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title elif shared.cmd_opts.ckpt != shared.default_sd_model_file: log.warning(f'Load model: path="{shared.cmd_opts.ckpt}" not found') log.info(f'Available Models: safetensors="{shared.opts.ckpt_dir}":{len(safetensors_list)} diffusers="{shared.opts.diffusers_dir}":{len(diffusers_list)} reference={len(list(shared.reference_models))} items={len(checkpoints_list)} time={time.time()-t0:.2f}') checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename)) def update_model_hashes(): def update_model_hashes_table(rows): html = """ {tbody}
NameTypeHash
""" tbody = '' for row in rows: try: tbody += f""" {row.name} {row.type} {row.shorthash} """ except Exception as e: log.error(f'Model list: row={row} {e}') return html.format(tbody=tbody) lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.hash is None] for ckpt in lst: ckpt.hash = model_hash(ckpt.filename) lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.sha256 is None or ckpt.shorthash is None] log.info(f'Models list: hash missing={len(lst)} total={len(checkpoints_list)}') updated = [] for ckpt in lst: ckpt.sha256 = hashes.sha256(ckpt.filename, f"checkpoint/{ckpt.name}") ckpt.shorthash = ckpt.sha256[0:10] if ckpt.sha256 is not None else None updated.append(ckpt) yield update_model_hashes_table(updated) def remove_hash(s): return re.sub(r'\s*\[.*?\]', '', s) def get_closest_checkpoint_match(s: str) -> CheckpointInfo: # direct hf url if s.startswith('https://huggingface.co/'): model_name = s.replace('https://huggingface.co/', '') checkpoint_info = CheckpointInfo(model_name) # create a virutal model info checkpoint_info.type = 'huggingface' log.debug(f'Seach model: name="{s}" matched="{checkpoint_info.path}" type=huggingface') return checkpoint_info if s.startswith('huggingface/'): model_name = s.replace('huggingface/', '') checkpoint_info = CheckpointInfo(model_name) # create a virutal model info checkpoint_info.type = 'huggingface' return checkpoint_info # alias search checkpoint_info = checkpoint_aliases.get(s, None) if checkpoint_info is not None: log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=alias') return checkpoint_info # models search found = sorted([info for info in checkpoints_list.values() if os.path.basename(info.title).lower() == s.lower()], key=lambda x: len(x.title)) if found and len(found) == 1: checkpoint_info = found[0] log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=hash') return checkpoint_info # nohash search found = sorted([info for info in checkpoints_list.values() if remove_hash(info.title).lower() == remove_hash(s).lower()], key=lambda x: len(x.title)) if found and len(found) == 1: checkpoint_info = found[0] log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=model') return checkpoint_info # absolute path if s.endswith('.safetensors') and os.path.isfile(s): checkpoint_info = CheckpointInfo(s) checkpoint_info.type = 'safetensors' log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=safetensors') return checkpoint_info # reference search ref = [(k, v) for k, v in shared.reference_models.items() if f"{v.get('path', '')}+{v.get('subfolder', '')}" == s] if len(ref) == 0: ref = [(k, v) for k, v in shared.reference_models.items() if v.get('path', '') == s] if ref and len(ref) > 0: _name, info = ref[0] checkpoint_info = CheckpointInfo(s) checkpoint_info.subfolder = info.get('subfolder', None) checkpoint_info.type = 'reference' log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=reference') return checkpoint_info # huggingface search if shared.opts.sd_checkpoint_autodownload and (s.count('/') == 1 or s.count('/') == 2): if s.count('/') == 2: subfolder = '/'.join(s.split('/')[2:]) # subfolder s = '/'.join(s.split('/')[:2]) # only user/model else: subfolder = None modelloader.hf_login() found = modelloader.find_diffuser(s, full=True) if found is None: return None found = [f for f in found if f == s] log.info(f'HF search: model="{s}" results={found}') if found is not None and len(found) == 1: checkpoint_info = CheckpointInfo(s) checkpoint_info.type = 'huggingface' if subfolder is not None and len(subfolder) > 0: checkpoint_info.subfolder = subfolder log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=huggingface') return checkpoint_info # civitai search if shared.opts.sd_checkpoint_autodownload and s.startswith("https://civitai.com/api/download/models"): from modules.civitai.download_civitai import download_civit_model fn = download_civit_model(model_url=s, model_name='', model_path='', model_type='Model', token=shared.opts.civitai_token) if fn is not None: checkpoint_info = CheckpointInfo(fn) log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=civitai') return checkpoint_info return None def model_hash(filename): """old hash that only looks at a small part of the file and is prone to collisions""" try: with open(filename, "rb") as file: import hashlib m = hashlib.sha256() file.seek(0x100000) m.update(file.read(0x10000)) shorthash = m.hexdigest()[0:8] return shorthash except FileNotFoundError: return 'NOFILE' except Exception: return 'NOHASH' def select_checkpoint(op='model', sd_model_checkpoint=None): model_checkpoint = sd_model_checkpoint or (shared.opts.data.get('sd_model_refiner', None) if op == 'refiner' else shared.opts.data.get('sd_model_checkpoint', None)) if model_checkpoint is None or model_checkpoint == 'None' or len(model_checkpoint) < 3: return None checkpoint_info = get_closest_checkpoint_match(model_checkpoint) if checkpoint_info is not None: log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"') return checkpoint_info if len(checkpoints_list) == 0: log.error("No models found") global warn_once # pylint: disable=global-statement if not warn_once: warn_once = True log.info("Set system paths to use existing folders") log.info(" or use --models-dir to specify base folder with all models") log.info(" or use --ckpt to force using specific model") return None if model_checkpoint is not None: if model_checkpoint != 'model.safetensors' and model_checkpoint != 'stabilityai/stable-diffusion-xl-base-1.0': log.error(f'Load {op}: search="{model_checkpoint}" not found') else: log.info("Selecting first available checkpoint") else: log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"') return checkpoint_info def init_metadata(): global sd_metadata # pylint: disable=global-statement if sd_metadata is None: sd_metadata = shared.readfile(sd_metadata_file, lock=True, as_type="dict") if os.path.isfile(sd_metadata_file) else {} def extract_thumbnail(filename, data): try: thumbnail = data.split(",")[1] thumbnail = base64.b64decode(thumbnail) thumbnail = io.BytesIO(thumbnail) thumbnail = Image.open(thumbnail) thumbnail = thumbnail.convert("RGB") thumbnail = thumbnail.resize((512, 512), Image.Resampling.HAMMING) fn = os.path.splitext(filename)[0] thumbnail = thumbnail.save(f"{fn}.thumb.jpg", quality=50) except Exception as e: log.error(f"Error extracting thumbnail: {filename} {e}") def read_metadata_from_safetensors(filename): global sd_metadata # pylint: disable=global-statement if sd_metadata is None: sd_metadata = shared.readfile(sd_metadata_file, lock=True, as_type="dict") if os.path.isfile(sd_metadata_file) else {} res = sd_metadata.get(filename, None) if res is not None: return res if not filename.endswith(".safetensors"): return {} if shared.cmd_opts.no_metadata: return {} res = {} # try: t0 = time.time() try: with open(filename, mode="rb") as file: metadata_len = file.read(8) metadata_len = int.from_bytes(metadata_len, "little") json_start = file.read(2) if metadata_len <= 2 or json_start not in (b'{"', b"{'"): log.error(f'Model metadata invalid: file="{filename}" len={metadata_len} start={json_start}') return res json_data = json_start + file.read(metadata_len-2) json_obj = json.loads(json_data) for k, v in json_obj.get("__metadata__", {}).items(): if k == 'modelspec.thumbnail' and v.startswith("data:"): extract_thumbnail(filename, v) if v.startswith("data:"): v = 'data' if k == 'format' and v == 'pt': continue large = True if len(v) > 2048 else False if large and k in ['ss_datasets', 'workflow', 'prompt', 'ss_bucket_info', 'sd_metadata_file']: continue if v[0:1] == '{': try: v = json.loads(v) if large and k == 'ss_tag_frequency': v = { i: len(j) for i, j in v.items() } if large and k == 'sd_merge_models': scrub_dict(v, ['sd_merge_recipe']) except Exception: pass res[k] = v except Exception as e: log.error(f'Model metadata: file="{filename}" {e}') from modules import errors errors.display(e, 'Model metadata') sd_metadata[filename] = res global sd_metadata_pending # pylint: disable=global-statement sd_metadata_pending += 1 t1 = time.time() global sd_metadata_timer # pylint: disable=global-statement sd_metadata_timer += (t1 - t0) return res def scrub_dict(dict_obj, keys): for key in list(dict_obj.keys()): if not isinstance(dict_obj, dict): continue if key in keys: dict_obj.pop(key, None) elif isinstance(dict_obj[key], dict): scrub_dict(dict_obj[key], keys) elif isinstance(dict_obj[key], list): for item in dict_obj[key]: scrub_dict(item, keys) def write_metadata(): global sd_metadata_pending # pylint: disable=global-statement if sd_metadata_pending == 0: log.debug(f'Model metadata: file="{sd_metadata_file}" no changes') return writefile(sd_metadata, sd_metadata_file) log.info(f'Model metadata saved: file="{sd_metadata_file}" items={sd_metadata_pending} time={sd_metadata_timer:.2f}') sd_metadata_pending = 0