automatic/modules/hashes.py

129 lines
4.5 KiB
Python

import hashlib
import os.path
from collections import defaultdict
from typing import Literal, TypeAlias, TypedDict
from rich import progress, errors
from modules.logger import console
from modules.logger import log
from modules.json_helpers import readfile, writefile
from modules.paths import data_path
class HashEntry(TypedDict):
mtime: float
sha256: str
class HashStore(dict[str, HashEntry]):
def __init__(self, *args):
super().__init__(*args)
def add_hash(self, key: str, mtime: float = 0, sha256: str | None = None): # pylint: disable=redefined-outer-name
self.__setitem__(key, {"mtime": mtime, "sha256": sha256 or ""})
default_hash_store = "hashes"
KnownHashStores: TypeAlias = Literal["hashes", "hashes-addnet"] # For autocomplete in IDE
cache_filename = os.path.join(data_path, "data", "cache.json")
progress_ok = True
# defaultdict allows for easily using new stores without needing to define them ahead of time
_data: defaultdict[str, HashStore] = defaultdict(HashStore)
def load_cache():
if os.path.isfile(cache_filename):
for store, data in readfile(cache_filename, lock=True, as_type="dict").items():
_data[store] = HashStore(data)
def save_cache():
# Don't include empty hash stores
filtered = filter(lambda item: len(item[1]) > 0, _data.items())
writefile(dict(filtered), cache_filename)
def cache(store: KnownHashStores | str | None = None) -> HashStore:
if store is None:
return _data[default_hash_store]
return _data[store]
def calculate_sha256(filename, quiet=False):
global progress_ok # pylint: disable=global-statement
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
if not quiet:
if progress_ok:
try:
with progress.open(filename, 'rb', description=f'[cyan]Calculating hash: [yellow]{filename}', auto_refresh=True, console=console) as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
except errors.LiveError:
log.warning('Hash: attempting to use function in a thread')
progress_ok = False
if not progress_ok:
with open(filename, 'rb') as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
else:
with open(filename, 'rb') as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def sha256_from_cache(filename: str, title: str, *, store: KnownHashStores | str | None = None):
hashes = cache(store)
if title not in hashes:
return None
cached = hashes[title]
ondisk_mtime = os.path.getmtime(filename) if os.path.isfile(filename) else 0
if ondisk_mtime > cached["mtime"] or not cached["sha256"]:
return None
return cached["sha256"]
def sha256(filename: str, title: str, *, store: KnownHashStores | str | None = None):
from modules import shared
global progress_ok # pylint: disable=global-statement
sha256_value = sha256_from_cache(filename, title, store=store)
if sha256_value is not None:
return sha256_value
if shared.cmd_opts.no_hashing:
return None
if not os.path.isfile(filename):
return None
jobid = shared.state.begin("Hash")
if store == "hashes-addnet":
if progress_ok:
try:
with progress.open(filename, 'rb', description=f'[cyan]Calculating hash: [yellow]{filename}', auto_refresh=True, console=console) as f:
sha256_value = addnet_hash_safetensors(f)
except errors.LiveError:
log.warning('Hash: attempting to use function in a thread')
progress_ok = False
if not progress_ok:
with open(filename, 'rb') as f:
sha256_value = addnet_hash_safetensors(f)
else:
sha256_value = calculate_sha256(filename)
cache(store).add_hash(title, os.path.getmtime(filename), sha256_value)
shared.state.end(jobid)
save_cache()
return sha256_value
def addnet_hash_safetensors(b):
"""kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")
offset = n + 8
b.seek(offset)
for chunk in iter(lambda: b.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()