stable-diffusion-webui-wd14.../tagger/uiset.py

499 lines
17 KiB
Python

""" for handling ui settings """
from typing import List, Dict, Tuple, Callable, Set
import os
from pathlib import Path
from glob import glob
from math import ceil
from hashlib import sha256
from re import compile as re_comp, sub as re_sub, match as re_match, IGNORECASE
from json import dumps, loads, JSONDecodeError
from functools import partial
from html import escape as html_escape
from collections import defaultdict
from PIL import Image
from modules import shared
from modules.deepbooru import re_special as tag_escape_pattern
from tagger import format as tags_format
from tagger import settings
Its = settings.InterrogatorSettings
# PIL.Image.registered_extensions() returns only PNG if you call early
supported_extensions = {
e
for e, f in Image.registered_extensions().items()
if f in Image.OPEN
}
# interrogator return type
ItRetTP = Tuple[
str, # tags as string
Dict[str, float], # rating confidences
Dict[str, float], # tag confidences
str, # error message
]
class IOData:
""" data class for input and output paths """
last_input_glob = None
base_dir = None
output_root = None
paths = []
save_tags = True
@classmethod
def flip_save_tags(cls) -> callable:
def toggle():
cls.save_tags = not cls.save_tags
return toggle
@classmethod
def toggle_save_tags(cls) -> None:
cls.save_tags = not cls.save_tags
@classmethod
def update_output_dir(cls, output_dir: str) -> str:
""" update output directory, and set input and output paths """
pout = Path(output_dir)
if pout != cls.output_root:
paths = [x[0] for x in cls.paths]
cls.paths = []
cls.output_root = pout
err = cls.set_batch_io(paths)
return err
return ''
@classmethod
def get_bytes_hash(cls, data) -> str:
""" get sha256 checksum of file """
# Note: the checksum from an image is not the same as from file
return sha256(data).hexdigest()
@classmethod
def get_hashes(cls):
""" get hashes of all files """
ret = set()
for entries in cls.paths:
if len(entries) == 4:
ret.add(entries[3])
else:
# if there is no checksum, calculate it
image = Image.open(entries[0])
checksum = cls.get_bytes_hash(image.tobytes())
entries.append(checksum)
ret.add(checksum)
return ret
@classmethod
def update_input_glob(cls, input_glob: str) -> str:
""" update input glob pattern, and set input and output paths """
input_glob = input_glob.strip()
if input_glob == cls.last_input_glob:
print('input glob did not change')
return ''
last_input_glob = input_glob
cls.paths = []
# if there is no glob pattern, insert it automatically
if not input_glob.endswith('*'):
if not input_glob.endswith(os.sep):
input_glob += os.sep
input_glob += '*'
# get root directory of input glob pattern
base_dir = input_glob.replace('?', '*')
base_dir = base_dir.split(os.sep + '*').pop(0)
if not os.path.isdir(base_dir):
return 'Invalid input directory'
if cls.output_root is None:
output_dir = base_dir
cls.output_root = Path(output_dir)
elif not cls.output_root or cls.output_root == Path(cls.base_dir):
cls.output_root = Path(base_dir)
cls.base_dir_last = Path(base_dir).parts[-1]
cls.base_dir = base_dir
err = QData.read_json(cls.output_root)
if err != '':
return err
recursive = getattr(shared.opts, 'tagger_batch_recursive', '')
paths = glob(input_glob, recursive=recursive)
print(f'found {len(paths)} image(s)')
err = cls.set_batch_io(paths)
if err == '':
cls.last_input_glob = last_input_glob
return err
@classmethod
def set_batch_io(cls, paths: List[Path]) -> str:
""" set input and output paths for batch mode """
checked_dirs = set()
for path in paths:
ext = os.path.splitext(path)[1].lower()
if ext in supported_extensions:
path = Path(path)
if not cls.save_tags:
cls.paths.append([path, '', ''])
continue
# guess the output path
base_dir_last_idx = path.parts.index(cls.base_dir_last)
# format output filename
info = tags_format.Info(path, 'txt')
fmt = partial(lambda info, m: tags_format.parse(m, info), info)
try:
formatted_output_filename = tags_format.pattern.sub(
fmt,
Its.output_filename_format
)
except (TypeError, ValueError) as error:
return f"{path}: output format: {str(error)}"
output_dir = cls.output_root.joinpath(
*path.parts[base_dir_last_idx + 1:]).parent
tags_out = output_dir.joinpath(formatted_output_filename)
if output_dir in checked_dirs:
cls.paths.append([path, tags_out, ''])
else:
checked_dirs.add(output_dir)
if os.path.exists(output_dir):
if os.path.isdir(output_dir):
cls.paths.append([path, tags_out, ''])
else:
return f"{output_dir}: not a directory."
else:
cls.paths.append([path, tags_out, output_dir])
elif ext != '.txt' and 'db.json' not in path:
print(f'{path}: not an image extension: "{ext}"')
return ''
def get_i_wt(stored: float) -> Tuple[int, float]:
"""
in db.json or InterrogationDB.weighed, with weights + increment in the list
similar for the "query" dict. Same increment per filestamp-interrogation.
"""
i = ceil(stored) - 1
return i, stored - i
class QData:
""" Query data: contains parameters for the query """
add_tags = []
keep_tags = set()
exclude_tags = set()
rexcl = None
search_tags = {}
replace_tags = []
threshold = 0.35
tag_frac_threshold = 0.05
# read from db.json, update with what should be written to db.json:
json_db = None
weighed = (defaultdict(lambda: []), defaultdict(lambda: []))
query = {}
# representing the (cumulative) current interrogations
ratings = defaultdict(lambda: 0.0)
tags = defaultdict(lambda: [])
in_db = {}
for_tags_file = defaultdict(set)
inverse = False
had_new = False
@classmethod
def set(cls, key: str) -> Callable[[str], Tuple[str]]:
def setter(val) -> Tuple[str]:
setattr(cls, key, val)
return setter
@classmethod
def update_keep(cls, keep: str) -> str:
cls.keep_tags = {x for x in map(str.strip, keep.split(',')) if x != ''}
@classmethod
def update_add(cls, add: str) -> str:
cls.add_tags = [x for x in map(str.strip, add.split(',')) if x != '']
count_threshold = getattr(shared.opts, 'tagger_count_threshold', 100)
if len(cls.add_tags) > count_threshold:
# secretly raise count threshold to avoid issue in apply_filters
shared.opts.tagger_count_threshold = len(cls.add_tags)
@classmethod
def update_exclude(cls, exclude: str) -> str:
excl = exclude.strip()
# first filter empty strings
if ',' in excl:
filtered = [x for x in map(str.strip, excl.split(',')) if x != '']
cls.exclude_tags = set(filtered)
cls.rexcl = None
elif excl != '':
cls.rexcl = re_comp('^'+excl+'$', flags=IGNORECASE)
@classmethod
def update_search(cls, search: str) -> str:
search = []
for x in map(str.strip, search.split(',')):
if x != '':
if x[0] == '^' and x[-1] == '$':
search.append(re_comp(x, flags=IGNORECASE))
else:
search.append(re_comp('^'+x+'$', flags=IGNORECASE))
cls.search_tags = dict(enumerate(search))
slen = len(cls.search_tags)
if slen != len(cls.replace_tags):
return 'search, replace: unequal len, replacements > 1.'
return ''
@classmethod
def update_replace(cls, replace: str) -> str:
repl = [x for x in map(str.strip, replace.split(',')) if x != '']
cls.replace_tags = list(repl)
if len(cls.search_tags) != len(cls.replace_tags):
return 'search, replace: unequal len, replacements > 1.'
return ''
@classmethod
def read_json(cls, outdir) -> str:
""" read db.json if it exists """
cls.json_db = None
if getattr(shared.opts, 'tagger_auto_serde_json', True):
cls.json_db = outdir.joinpath('db.json')
if cls.json_db.is_file():
cls.had_new = False
try:
data = loads(cls.json_db.read_text())
except JSONDecodeError as err:
return f'Error reading {cls.json_db}: {repr(err)}'
for key in ["tag", "rating", "query"]:
if key not in data:
return f'{cls.json_db}: missing {key} key.'
for key in ["add", "keep", "exclude", "search", "replace"]:
if key in data:
err = getattr(cls, f"update_{key}")(data[key])
if err:
return err
cls.weighed = (
defaultdict(lambda: [], data["rating"]),
defaultdict(lambda: [], data["tag"])
)
cls.query = data["query"]
print(f'Read {cls.json_db}: {len(cls.query)} interrogations, '
f'{len(cls.tags)} tags.')
return ''
@classmethod
def write_json(cls) -> None:
""" write db.json """
if cls.json_db is not None:
search = sorted(cls.search_tags.items(), key=lambda x: x[0])
data = {
"rating": cls.weighed[0],
"tag": cls.weighed[1],
"query": cls.query,
"add": ','.join(cls.add_tags),
"keep": ','.join(cls.keep_tags),
"exclude": ','.join(cls.exclude_tags),
"search": ','.join([x[1] for x in search]),
"repl": ','.join(cls.replace_tags)
}
cls.json_db.write_text(dumps(data, indent=2))
print(f'Wrote {cls.json_db}: {len(cls.query)} interrogations, '
f'{len(cls.tags)} tags.')
@classmethod
def get_index(cls, fi_key: str, path='') -> int:
""" get index for filestamp-interrogator """
if path and path != cls.query[fi_key][0]:
if cls.query[fi_key][0] != '':
print(f'Dup or rename: Identical checksums for {path}\n'
f'and: {cls.query[fi_key][0]} (path updated)')
cls.had_new = True
cls.query[fi_key] = (path, cls.query[fi_key][1])
return cls.query[fi_key][1]
@classmethod
def single_data(
cls, fi_key: str
) -> Tuple[Dict[str, float], Dict[str, float]]:
""" get tags and ratings for filestamp-interrogator """
index = cls.query.get(fi_key)[1]
data = ({}, {})
for j in range(2):
for ent, lst in cls.weighed[j].items():
for i, val in map(get_i_wt, lst):
if i == index:
data[j][ent] = val
QData.in_db[index] = ('', '', '') + data
@classmethod
def is_excluded(cls, ent: str) -> bool:
""" check if tag is excluded """
if cls.rexcl:
return re_match(cls.rexcl, ent)
return ent in cls.exclude_tags
@classmethod
def correct_tag(cls, tag: str) -> str:
""" correct tag for display """
replace_underscore = getattr(shared.opts, 'tagger_repl_us', True)
if replace_underscore and tag not in Its.kamojis:
tag = tag.replace('_', ' ')
if getattr(shared.opts, 'tagger_escape', False):
tag = tag_escape_pattern.sub(r'\\\1', tag)
for i, regex in cls.search_tags.items():
if re_match(regex, tag):
tag = re_sub(regex, cls.replace_tags[i], tag,
count=1, flags=IGNORECASE)
break
return tag
@classmethod
def inverse_apply_filters(cls, tags: List[Tuple[str, float]]) -> None:
""" inverse: List all tags marked for exclusion """
for tag, val in tags:
tag = cls.correct_tag(tag)
if tag in cls.keep_tags or tag in cls.add_tags:
continue
if cls.is_excluded(tag) or val < cls.threshold:
cls.tags[tag].append(val)
@classmethod
def apply_filters(cls, data) -> Set[str]:
""" apply filters to query data, store in db.json if required """
# fi_key == '' means this is a new file or interrogation for that file
tags = sorted(data[4].items(), key=lambda x: x[1], reverse=True)
if cls.inverse:
cls.inverse_apply_filters(tags)
return
# not inverse: display all tags marked for inclusion
fi_key = data[2]
index = len(cls.query)
ratings = sorted(data[3].items(), key=lambda x: x[1], reverse=True)
# loop over ratings
for rating, val in ratings:
if fi_key != '':
cls.weighed[0][rating].append(val + index)
cls.ratings[rating] += val
count_threshold = getattr(shared.opts, 'tagger_count_threshold', 100)
max_ct = count_threshold - len(cls.add_tags)
count = 0
# loop over tags with db update
for tag, val in tags:
if isinstance(tag, float):
print(f'bad return from interrogator, float: {tag} {val}')
# FIXME: why does this happen? what does it mean?
continue
if fi_key != '' and val >= 0.005:
cls.weighed[1][tag].append(val + index)
if count < max_ct:
tag = cls.correct_tag(tag)
if tag not in cls.keep_tags:
if cls.is_excluded(tag) or val < cls.threshold:
continue
if data[1] != '':
cls.for_tags_file[data[1]].add(tag)
count += 1
elif fi_key == '':
break
if tag not in cls.add_tags:
# those are already added
cls.tags[tag].append(val)
for tag in cls.add_tags:
if data[1] != '':
cls.for_tags_file[data[1]].add(tag)
cls.tags[tag] = [1.0 for _ in range(len(cls.query))]
if getattr(shared.opts, 'tagger_verbose', True):
print(f'{data[0]}: {count}/{len(tags)} tags kept')
if fi_key != '':
cls.query[fi_key] = (data[0], index)
@classmethod
def finalize_batch(cls, count: int) -> ItRetTP:
""" finalize the batch query """
if cls.json_db and cls.had_new:
cls.write_json()
cls.had_new = False
# collect the weights per file/interrogation of the prior in db stored.
for index in range(2):
for ent, lst in cls.weighed[index].items():
for i, val in map(get_i_wt, lst):
if i not in cls.in_db:
continue
cls.in_db[i][3+index][ent] = val
# process the retrieved from db and add them to the stats
for got in cls.in_db.values():
cls.apply_filters(got)
# average
return cls.finalize(count)
@classmethod
def finalize(cls, count: int) -> ItRetTP:
""" finalize the query, return the results """
count += len(cls.in_db)
if count == 0:
return [None, None, None, 'no results for query']
tags_str, ratings, tags = '', {}, {}
js_bool = 'true' if cls.inverse else 'false'
for k, lst in cls.tags.items():
# len(!) fraction of the all interrogations was above the threshold
fraction_of_queries = len(lst) / count
if fraction_of_queries >= cls.tag_frac_threshold:
# store the average of those interrogations sum(!) / count
tags[k] = sum(lst) / count
# trigger an event to place the tag in the active tags list
# replace if k interferes with html code
escaped = html_escape(k)
tags_str += f""", <a href='javascript:tag_clicked("{k}", """\
f"""{js_bool})'>{escaped}</a>"""
else:
for remaining_tags in cls.for_tags_file.values():
remaining_tags.discard(k)
for ent, val in cls.ratings.items():
ratings[ent] = val / count
for file, remaining_tags in cls.for_tags_file.items():
file.write_text(', '.join(remaining_tags), encoding='utf-8')
print('all done :)')
return (tags_str[2:], ratings, tags, '')