472 lines
16 KiB
Python
472 lines
16 KiB
Python
""" for handling ui settings """
|
|
|
|
from typing import List, Dict, Tuple, Callable
|
|
import os
|
|
from pathlib import Path
|
|
from glob import glob
|
|
from math import ceil
|
|
from re import compile as re_comp, sub as re_sub, match as re_match, IGNORECASE
|
|
from json import dumps, loads
|
|
from PIL import Image
|
|
from modules import shared
|
|
from modules.deepbooru import re_special as tag_escape_pattern
|
|
from functools import partial
|
|
|
|
from tagger import format as tags_format
|
|
|
|
from tagger import settings
|
|
Its = settings.InterrogatorSettings
|
|
|
|
# PIL.Image.registered_extensions() returns only PNG if you call early
|
|
supported_extensions = {
|
|
e
|
|
for e, f in Image.registered_extensions().items()
|
|
if f in Image.OPEN
|
|
}
|
|
|
|
# interrogator return type
|
|
it_ret_tp = Tuple[
|
|
str, # tags as string
|
|
Dict[str, float], # rating confidences
|
|
Dict[str, float], # tag confidences
|
|
str, # error message
|
|
]
|
|
|
|
|
|
class IOData:
|
|
""" data class for input and output paths """
|
|
last_input_glob = None
|
|
base_dir = None
|
|
output_root = None
|
|
paths = []
|
|
save_tags = True
|
|
|
|
@classmethod
|
|
def flip_save_tags(cls) -> callable:
|
|
def toggle():
|
|
cls.save_tags = not cls.save_tags
|
|
return toggle
|
|
|
|
@classmethod
|
|
def toggle_save_tags(cls) -> None:
|
|
cls.save_tags = not cls.save_tags
|
|
|
|
@classmethod
|
|
def update_output_dir(cls, output_dir: str) -> str:
|
|
""" update output directory, and set input and output paths """
|
|
pout = Path(output_dir)
|
|
if pout != cls.output_root:
|
|
paths = [x[0] for x in cls.paths]
|
|
cls.paths = []
|
|
cls.output_root = pout
|
|
err = cls.set_batch_io(paths)
|
|
return err
|
|
return ''
|
|
|
|
@classmethod
|
|
def update_input_glob(cls, input_glob: str) -> str:
|
|
""" update input glob pattern, and set input and output paths """
|
|
input_glob = input_glob.strip()
|
|
if input_glob == cls.last_input_glob:
|
|
print('input glob did not change')
|
|
return ''
|
|
last_input_glob = input_glob
|
|
|
|
cls.paths = []
|
|
|
|
# if there is no glob pattern, insert it automatically
|
|
if not input_glob.endswith('*'):
|
|
if not input_glob.endswith(os.sep):
|
|
input_glob += os.sep
|
|
input_glob += '*'
|
|
# get root directory of input glob pattern
|
|
base_dir = input_glob.replace('?', '*')
|
|
base_dir = base_dir.split(os.sep + '*').pop(0)
|
|
if not os.path.isdir(base_dir):
|
|
return 'Invalid input directory'
|
|
|
|
if cls.output_root is None:
|
|
output_dir = base_dir
|
|
|
|
cls.output_root = Path(output_dir)
|
|
elif not cls.output_root or cls.output_root == Path(cls.base_dir):
|
|
cls.output_root = Path(base_dir)
|
|
|
|
cls.base_dir_last = Path(base_dir).parts[-1]
|
|
cls.base_dir = base_dir
|
|
err = QData.read_json(cls.output_root)
|
|
if err != '':
|
|
return err
|
|
|
|
recursive = getattr(shared.opts, 'tagger_batch_recursive', '')
|
|
paths = glob(input_glob, recursive=recursive)
|
|
print(f'found {len(paths)} image(s)')
|
|
err = cls.set_batch_io(paths)
|
|
if err == '':
|
|
cls.last_input_glob = last_input_glob
|
|
|
|
return err
|
|
|
|
@classmethod
|
|
def set_batch_io(cls, paths: List[Path]) -> str:
|
|
""" set input and output paths for batch mode """
|
|
checked_dirs = set()
|
|
for path in paths:
|
|
ext = os.path.splitext(path)[1].lower()
|
|
if ext in supported_extensions:
|
|
path = Path(path)
|
|
if not cls.save_tags:
|
|
cls.paths.append([path, '', ''])
|
|
continue
|
|
|
|
# guess the output path
|
|
base_dir_last_idx = path.parts.index(cls.base_dir_last)
|
|
# format output filename
|
|
|
|
info = tags_format.Info(path, 'txt')
|
|
fm = partial(lambda info, m: tags_format.parse(m, info), info)
|
|
|
|
try:
|
|
formatted_output_filename = tags_format.pattern.sub(
|
|
fm,
|
|
Its.output_filename_format
|
|
)
|
|
except (TypeError, ValueError) as error:
|
|
return f"{path}: output format: {str(error)}"
|
|
|
|
output_dir = cls.output_root.joinpath(
|
|
*path.parts[base_dir_last_idx + 1:]).parent
|
|
|
|
tags_out = output_dir.joinpath(formatted_output_filename)
|
|
|
|
if output_dir in checked_dirs:
|
|
cls.paths.append([path, tags_out, ''])
|
|
else:
|
|
checked_dirs.add(output_dir)
|
|
if os.path.exists(output_dir):
|
|
if os.path.isdir(output_dir):
|
|
cls.paths.append([path, tags_out, ''])
|
|
else:
|
|
return f"{output_dir}: not a directory."
|
|
else:
|
|
cls.paths.append([path, tags_out, output_dir])
|
|
elif ext != '.txt' and 'db.json' not in path:
|
|
print(f'{path}: not an image extension: "{ext}"')
|
|
return ''
|
|
|
|
|
|
def get_i_wt(stored: float) -> Tuple[int, float]:
|
|
"""
|
|
in db.json or InterrogationDB.weighed, with weights + increment in the list
|
|
similar for the "query" dict. Same increment per filestamp-interrogation.
|
|
"""
|
|
i = ceil(stored) - 1
|
|
return i, stored - i
|
|
|
|
|
|
class QData:
|
|
""" Query data: contains parameters for the query """
|
|
add_tags = []
|
|
keep_tags = set()
|
|
exclude_tags = set()
|
|
rexcl = None
|
|
search_tags = {}
|
|
replace_tags = []
|
|
re_search = None
|
|
threshold = 0.35
|
|
count_threshold = 100
|
|
|
|
json_db = None
|
|
weighed = ({}, {})
|
|
query = {}
|
|
data = None
|
|
ratings = {}
|
|
tags = {}
|
|
inverse = False
|
|
|
|
@classmethod
|
|
def set(cls, key: str) -> Callable[[str], Tuple[str]]:
|
|
def setter(val) -> Tuple[str]:
|
|
setattr(cls, key, val)
|
|
return ('',)
|
|
return setter
|
|
|
|
@classmethod
|
|
def update_keep(cls, keep: str) -> str:
|
|
cls.keep_tags = {x for x in map(str.strip, keep.split(',')) if x != ''}
|
|
return ''
|
|
|
|
@classmethod
|
|
def update_add(cls, add: str) -> str:
|
|
cls.add_tags = [x for x in map(str.strip, add.split(',')) if x != '']
|
|
return ''
|
|
|
|
@classmethod
|
|
def update_exclude(cls, exclude: str) -> str:
|
|
exclude = exclude.strip()
|
|
# first filter empty strings
|
|
if ',' in exclude:
|
|
filtered = [x for x in map(str.strip, exclude.split(',')) if x != '']
|
|
cls.exclude_tags = set(filtered)
|
|
cls.rexcl = None
|
|
elif exclude != '':
|
|
cls.rexcl = re_comp('^'+exclude+'$', flags=IGNORECASE)
|
|
return ''
|
|
|
|
@classmethod
|
|
def update_search(cls, search: str) -> str:
|
|
search = [x for x in map(str.strip, search.split(',')) if x != '']
|
|
cls.search_tags = dict(enumerate(search))
|
|
slen = len(cls.search_tags)
|
|
if len(cls.search_tags) == 1:
|
|
cls.re_search = re_comp('^'+search[0]+'$', flags=IGNORECASE)
|
|
elif slen != len(cls.replace_tags):
|
|
return 'search, replace: unequal len, replacements > 1.'
|
|
return ''
|
|
|
|
@classmethod
|
|
def update_replace(cls, replace: str) -> str:
|
|
repl_tag_map = [x for x in map(str.strip, replace.split(',')) if x != '']
|
|
cls.replace_tags = list(repl_tag_map)
|
|
if cls.re_search is None and len(cls.search_tags) != len(cls.replace_tags):
|
|
return 'search, replace: unequal len, replacements > 1.'
|
|
|
|
@classmethod
|
|
def read_json(cls, outdir) -> str:
|
|
""" read db.json if it exists """
|
|
cls.json_db = None
|
|
if getattr(shared.opts, 'tagger_auto_serde_json', True):
|
|
cls.json_db = outdir.joinpath('db.json')
|
|
if cls.json_db.is_file():
|
|
try:
|
|
data = loads(cls.json_db.read_text())
|
|
if any(x not in data for x in ["tag", "rating", "query"]):
|
|
raise TypeError
|
|
except Exception as err:
|
|
return f'Error reading {cls.json_db}: {repr(err)}'
|
|
for key in ["add", "keep", "exclude", "search", "replace"]:
|
|
if key in data:
|
|
err = getattr(cls, f"update_{key}")(data[key])
|
|
if err:
|
|
return err
|
|
cls.weighed = (data["tag"], data["rating"])
|
|
cls.query = data["query"]
|
|
return ''
|
|
|
|
@classmethod
|
|
def write_json(cls) -> None:
|
|
""" write db.json """
|
|
if cls.json_db is not None:
|
|
search = sorted(cls.search_tags.items(), key=lambda x: x[0])
|
|
data = {
|
|
"tag": cls.weighed[0],
|
|
"rating": cls.weighed[1],
|
|
"query": cls.query,
|
|
"add": ','.join(cls.add_tags),
|
|
"keep": ','.join(cls.keep_tags),
|
|
"exclude": ','.join(cls.exclude_tags),
|
|
"search": ','.join([x[1] for x in search]),
|
|
"repl": ','.join(cls.replace_tags)
|
|
}
|
|
cls.json_db.write_text(dumps(data, indent=2))
|
|
|
|
@classmethod
|
|
def move_filter_to_exclude(cls) -> None:
|
|
""" move filter tags to exclude tags """
|
|
|
|
cls.exclude_tags.update()
|
|
|
|
@classmethod
|
|
def get_index(cls, fi_key: str, path='') -> int:
|
|
""" get index for filestamp-interrogator """
|
|
if path and path != cls.query[fi_key][0]:
|
|
if cls.query[fi_key][0] != '':
|
|
print(f'Dup or rename: Identical checksums for {path}\n'
|
|
'and: {cls.query[fi_key][0]} (path updated)')
|
|
cls.query[fi_key] = (path, cls.query[fi_key][1])
|
|
|
|
# this file was already queried for this interrogator.
|
|
return cls.query[fi_key][1]
|
|
|
|
@classmethod
|
|
def get_single_data(cls, fi_key: str) -> Tuple[Dict[str, float], Dict[str, float]]:
|
|
""" get tags and ratings for filestamp-interrogator """
|
|
index = QData.query.get(fi_key)[1]
|
|
data = [{}, {}]
|
|
for j in range(2):
|
|
for ent, lst in cls.weighed[j].items():
|
|
for i, val in map(get_i_wt, lst):
|
|
if i == index:
|
|
data[j][ent] = val
|
|
|
|
return tuple(data)
|
|
|
|
|
|
@classmethod
|
|
def init_query(cls) -> None:
|
|
cls.tags.clear()
|
|
cls.ratings.clear()
|
|
|
|
@classmethod
|
|
def is_excluded(cls, ent: str) -> bool:
|
|
""" check if tag is excluded """
|
|
return re_match(cls.rexcl, ent) if cls.rexcl else ent in cls.exclude_tags
|
|
|
|
@classmethod
|
|
def apply_filters(
|
|
cls,
|
|
data,
|
|
fi_key: str,
|
|
on_avg: bool,
|
|
):
|
|
""" apply filters to query data, store in db.json if required """
|
|
|
|
replace_underscore = getattr(shared.opts, 'tagger_repl_us', True)
|
|
|
|
tags = sorted(data[3].items(), key=lambda x: x[1], reverse=True)
|
|
if cls.inverse:
|
|
# inverse: display all tags marked for exclusion
|
|
for ent, val in tags:
|
|
if replace_underscore and ent not in Its.kamojis:
|
|
ent = ent.replace('_', ' ')
|
|
|
|
if getattr(shared.opts, 'tagger_escape', False):
|
|
ent = tag_escape_pattern.sub(r'\\\1', ent)
|
|
|
|
if cls.re_search:
|
|
ent = re_sub(cls.re_search, cls.replace_tags[0], ent, 1)
|
|
elif ent in cls.search_tags:
|
|
ent = cls.replace_tags[cls.search_tags[ent]]
|
|
|
|
if ent in cls.keep_tags or ent in cls.add_tags:
|
|
continue
|
|
if on_avg or cls.is_excluded(ent) or val < cls.threshold:
|
|
if ent not in cls.tags:
|
|
cls.tags[ent] = 0.0
|
|
cls.tags[ent] += val
|
|
return
|
|
|
|
# not inverse: display all tags marked for inclusion
|
|
for_tags_file = ""
|
|
do_store = fi_key != ''
|
|
count = 0
|
|
max_ct = QData.count_threshold - len(cls.add_tags)
|
|
ratings = sorted(data[2].items(), key=lambda x: x[1], reverse=True)
|
|
# loop over ratings
|
|
for ent, val in ratings:
|
|
if do_store:
|
|
if ent not in cls.weighed[0]:
|
|
cls.weighed[0][ent] = []
|
|
|
|
cls.weighed[0][ent].append(val + len(cls.query))
|
|
if ent not in cls.ratings:
|
|
cls.ratings[ent] = 0.0
|
|
|
|
cls.ratings[ent] += val
|
|
|
|
# loop over tags with db update
|
|
for ent, val in tags:
|
|
if isinstance(ent, float):
|
|
print(f'float: {ent} {val}')
|
|
continue
|
|
if do_store:
|
|
if val > 0.005:
|
|
if ent not in cls.weighed[1]:
|
|
cls.weighed[1][ent] = []
|
|
|
|
cls.weighed[1][ent].append(val + len(cls.query))
|
|
|
|
if count < max_ct:
|
|
if replace_underscore and ent not in Its.kamojis:
|
|
ent = ent.replace('_', ' ')
|
|
|
|
if getattr(shared.opts, 'tagger_escape', False):
|
|
ent = tag_escape_pattern.sub(r'\\\1', ent)
|
|
|
|
if cls.re_search:
|
|
ent = re_sub(cls.re_search, cls.replace_tags[0], ent, 1)
|
|
elif ent in cls.search_tags:
|
|
ent = cls.replace_tags[cls.search_tags[ent]]
|
|
if ent not in cls.keep_tags:
|
|
if cls.is_excluded(ent):
|
|
continue
|
|
if not on_avg and val < cls.threshold:
|
|
continue
|
|
for_tags_file += ", " + ent
|
|
count += 1
|
|
elif not do_store:
|
|
break
|
|
cls.tags[ent] = cls.tags[ent] + val if ent in cls.tags else val
|
|
|
|
for tag in cls.add_tags:
|
|
cls.tags[tag] = 1.0
|
|
|
|
if getattr(shared.opts, 'tagger_verbose', True):
|
|
print(f'{data[0]}: {count}/{len(tags)} tags kept')
|
|
|
|
if do_store:
|
|
cls.query[fi_key] = (data[0], len(cls.query))
|
|
|
|
if data[1]:
|
|
data[1].write_text(for_tags_file[2:], encoding='utf-8')
|
|
|
|
@classmethod
|
|
def finalize_batch(
|
|
cls,
|
|
in_db,
|
|
ct: int,
|
|
on_avg: bool
|
|
) -> it_ret_tp:
|
|
""" finalize the batch query """
|
|
if cls.json_db and ct > 0:
|
|
cls.write_json()
|
|
|
|
# collect the weights per file/interrogation of the prior in db stored.
|
|
for index in range(2):
|
|
for ent, lst in cls.weighed[index].items():
|
|
for i, val in map(get_i_wt, lst):
|
|
if i in in_db:
|
|
in_db[i][2+index][ent] = val
|
|
|
|
# process the retrieved from db and add them to the stats
|
|
for got in in_db.values():
|
|
cls.apply_filters(got, '', on_avg)
|
|
|
|
# average
|
|
return cls.finalize(ct + len(in_db), on_avg)
|
|
|
|
@classmethod
|
|
def finalize(cls, count: int, on_avg: bool) -> it_ret_tp:
|
|
""" finalize the query, return the results """
|
|
tags_str, ratings, tags = '', {}, {}
|
|
|
|
def averager(x):
|
|
return x[0], x[1] / count
|
|
|
|
js_bool = 'true' if cls.inverse else 'false'
|
|
|
|
if on_avg:
|
|
if cls.inverse:
|
|
def inverse_filt(x):
|
|
return cls.is_excluded(x[0]) or x[1] < cls.threshold and \
|
|
x[0] not in cls.keep_tags
|
|
iter = filter(inverse_filt, map(averager, cls.tags.items()))
|
|
else:
|
|
def filt(x):
|
|
return not cls.is_excluded(x[0]) and \
|
|
(x[1] >= cls.threshold or x[0] in cls.keep_tags)
|
|
iter = filter(filt, map(averager, cls.tags.items()))
|
|
else:
|
|
iter = map(averager, cls.tags.items())
|
|
|
|
for k, already_averaged_val in iter:
|
|
tags[k] = already_averaged_val
|
|
# trigger an event to place the tag in the active tags list
|
|
tags_str += f""", <a href='javascript:tag_clicked("{k}", {js_bool})'>{k}</a>"""
|
|
|
|
for ent, val in cls.ratings.items():
|
|
ratings[ent] = val / count
|
|
|
|
print('all done :)')
|
|
return (tags_str[2:], ratings, tags, '')
|