""" for handling ui settings """ from typing import List, Dict, Tuple, Callable import os from pathlib import Path from glob import glob from math import ceil from re import compile as re_comp, sub as re_sub, match as re_match, IGNORECASE from json import dumps, loads from PIL import Image from modules import shared from modules.deepbooru import re_special as tag_escape_pattern from functools import partial from tagger import format as tags_format from tagger import settings Its = settings.InterrogatorSettings # PIL.Image.registered_extensions() returns only PNG if you call early supported_extensions = { e for e, f in Image.registered_extensions().items() if f in Image.OPEN } # interrogator return type it_ret_tp = Tuple[ str, # tags as string Dict[str, float], # rating confidences Dict[str, float], # tag confidences str, # error message ] class IOData: """ data class for input and output paths """ last_input_glob = None base_dir = None output_root = None paths = [] save_tags = True @classmethod def flip_save_tags(cls) -> callable: def toggle(): cls.save_tags = not cls.save_tags return toggle @classmethod def toggle_save_tags(cls) -> None: cls.save_tags = not cls.save_tags @classmethod def update_output_dir(cls, output_dir: str) -> str: """ update output directory, and set input and output paths """ pout = Path(output_dir) if pout != cls.output_root: paths = [x[0] for x in cls.paths] cls.paths = [] cls.output_root = pout err = cls.set_batch_io(paths) return err return '' @classmethod def update_input_glob(cls, input_glob: str) -> str: """ update input glob pattern, and set input and output paths """ input_glob = input_glob.strip() if input_glob == cls.last_input_glob: print('input glob did not change') return '' last_input_glob = input_glob cls.paths = [] # if there is no glob pattern, insert it automatically if not input_glob.endswith('*'): if not input_glob.endswith(os.sep): input_glob += os.sep input_glob += '*' # get root directory of input glob pattern base_dir = input_glob.replace('?', '*') base_dir = base_dir.split(os.sep + '*').pop(0) if not os.path.isdir(base_dir): return 'Invalid input directory' if cls.output_root is None: output_dir = base_dir cls.output_root = Path(output_dir) elif not cls.output_root or cls.output_root == Path(cls.base_dir): cls.output_root = Path(base_dir) cls.base_dir_last = Path(base_dir).parts[-1] cls.base_dir = base_dir err = QData.read_json(cls.output_root) if err != '': return err recursive = getattr(shared.opts, 'tagger_batch_recursive', '') paths = glob(input_glob, recursive=recursive) print(f'found {len(paths)} image(s)') err = cls.set_batch_io(paths) if err == '': cls.last_input_glob = last_input_glob return err @classmethod def set_batch_io(cls, paths: List[Path]) -> str: """ set input and output paths for batch mode """ checked_dirs = set() for path in paths: ext = os.path.splitext(path)[1].lower() if ext in supported_extensions: path = Path(path) if not cls.save_tags: cls.paths.append([path, '', '']) continue # guess the output path base_dir_last_idx = path.parts.index(cls.base_dir_last) # format output filename info = tags_format.Info(path, 'txt') fm = partial(lambda info, m: tags_format.parse(m, info), info) try: formatted_output_filename = tags_format.pattern.sub( fm, Its.output_filename_format ) except (TypeError, ValueError) as error: return f"{path}: output format: {str(error)}" output_dir = cls.output_root.joinpath( *path.parts[base_dir_last_idx + 1:]).parent tags_out = output_dir.joinpath(formatted_output_filename) if output_dir in checked_dirs: cls.paths.append([path, tags_out, '']) else: checked_dirs.add(output_dir) if os.path.exists(output_dir): if os.path.isdir(output_dir): cls.paths.append([path, tags_out, '']) else: return f"{output_dir}: not a directory." else: cls.paths.append([path, tags_out, output_dir]) elif ext != '.txt' and 'db.json' not in path: print(f'{path}: not an image extension: "{ext}"') return '' def get_i_wt(stored: float) -> Tuple[int, float]: """ in db.json or InterrogationDB.weighed, with weights + increment in the list similar for the "query" dict. Same increment per filestamp-interrogation. """ i = ceil(stored) - 1 return i, stored - i class QData: """ Query data: contains parameters for the query """ add_tags = [] keep_tags = set() exclude_tags = set() rexcl = None search_tags = {} replace_tags = [] re_search = None threshold = 0.35 count_threshold = 100 json_db = None weighed = ({}, {}) query = {} data = None ratings = {} tags = {} inverse = False @classmethod def set(cls, key: str) -> Callable[[str], Tuple[str]]: def setter(val) -> Tuple[str]: setattr(cls, key, val) return ('',) return setter @classmethod def update_keep(cls, keep: str) -> str: cls.keep_tags = {x for x in map(str.strip, keep.split(',')) if x != ''} return '' @classmethod def update_add(cls, add: str) -> str: cls.add_tags = [x for x in map(str.strip, add.split(',')) if x != ''] return '' @classmethod def update_exclude(cls, exclude: str) -> str: exclude = exclude.strip() # first filter empty strings if ',' in exclude: filtered = [x for x in map(str.strip, exclude.split(',')) if x != ''] cls.exclude_tags = set(filtered) cls.rexcl = None elif exclude != '': cls.rexcl = re_comp('^'+exclude+'$', flags=IGNORECASE) return '' @classmethod def update_search(cls, search: str) -> str: search = [x for x in map(str.strip, search.split(',')) if x != ''] cls.search_tags = dict(enumerate(search)) slen = len(cls.search_tags) if len(cls.search_tags) == 1: cls.re_search = re_comp('^'+search[0]+'$', flags=IGNORECASE) elif slen != len(cls.replace_tags): return 'search, replace: unequal len, replacements > 1.' return '' @classmethod def update_replace(cls, replace: str) -> str: repl_tag_map = [x for x in map(str.strip, replace.split(',')) if x != ''] cls.replace_tags = list(repl_tag_map) if cls.re_search is None and len(cls.search_tags) != len(cls.replace_tags): return 'search, replace: unequal len, replacements > 1.' @classmethod def read_json(cls, outdir) -> str: """ read db.json if it exists """ cls.json_db = None if getattr(shared.opts, 'tagger_auto_serde_json', True): cls.json_db = outdir.joinpath('db.json') if cls.json_db.is_file(): try: data = loads(cls.json_db.read_text()) if any(x not in data for x in ["tag", "rating", "query"]): raise TypeError except Exception as err: return f'Error reading {cls.json_db}: {repr(err)}' for key in ["add", "keep", "exclude", "search", "replace"]: if key in data: err = getattr(cls, f"update_{key}")(data[key]) if err: return err cls.weighed = (data["tag"], data["rating"]) cls.query = data["query"] return '' @classmethod def write_json(cls) -> None: """ write db.json """ if cls.json_db is not None: search = sorted(cls.search_tags.items(), key=lambda x: x[0]) data = { "tag": cls.weighed[0], "rating": cls.weighed[1], "query": cls.query, "add": ','.join(cls.add_tags), "keep": ','.join(cls.keep_tags), "exclude": ','.join(cls.exclude_tags), "search": ','.join([x[1] for x in search]), "repl": ','.join(cls.replace_tags) } cls.json_db.write_text(dumps(data, indent=2)) @classmethod def move_filter_to_exclude(cls) -> None: """ move filter tags to exclude tags """ cls.exclude_tags.update() @classmethod def get_index(cls, fi_key: str, path='') -> int: """ get index for filestamp-interrogator """ if path and path != cls.query[fi_key][0]: if cls.query[fi_key][0] != '': print(f'Dup or rename: Identical checksums for {path}\n' 'and: {cls.query[fi_key][0]} (path updated)') cls.query[fi_key] = (path, cls.query[fi_key][1]) # this file was already queried for this interrogator. return cls.query[fi_key][1] @classmethod def get_single_data(cls, fi_key: str) -> Tuple[Dict[str, float], Dict[str, float]]: """ get tags and ratings for filestamp-interrogator """ index = QData.query.get(fi_key)[1] data = [{}, {}] for j in range(2): for ent, lst in cls.weighed[j].items(): for i, val in map(get_i_wt, lst): if i == index: data[j][ent] = val return tuple(data) @classmethod def init_query(cls) -> None: cls.tags.clear() cls.ratings.clear() @classmethod def is_excluded(cls, ent: str) -> bool: """ check if tag is excluded """ return re_match(cls.rexcl, ent) if cls.rexcl else ent in cls.exclude_tags @classmethod def apply_filters( cls, data, fi_key: str, on_avg: bool, ): """ apply filters to query data, store in db.json if required """ replace_underscore = getattr(shared.opts, 'tagger_repl_us', True) tags = sorted(data[3].items(), key=lambda x: x[1], reverse=True) if cls.inverse: # inverse: display all tags marked for exclusion for ent, val in tags: if replace_underscore and ent not in Its.kamojis: ent = ent.replace('_', ' ') if getattr(shared.opts, 'tagger_escape', False): ent = tag_escape_pattern.sub(r'\\\1', ent) if cls.re_search: ent = re_sub(cls.re_search, cls.replace_tags[0], ent, 1) elif ent in cls.search_tags: ent = cls.replace_tags[cls.search_tags[ent]] if ent in cls.keep_tags or ent in cls.add_tags: continue if on_avg or cls.is_excluded(ent) or val < cls.threshold: if ent not in cls.tags: cls.tags[ent] = 0.0 cls.tags[ent] += val return # not inverse: display all tags marked for inclusion for_tags_file = "" do_store = fi_key != '' count = 0 max_ct = QData.count_threshold - len(cls.add_tags) ratings = sorted(data[2].items(), key=lambda x: x[1], reverse=True) # loop over ratings for ent, val in ratings: if do_store: if ent not in cls.weighed[0]: cls.weighed[0][ent] = [] cls.weighed[0][ent].append(val + len(cls.query)) if ent not in cls.ratings: cls.ratings[ent] = 0.0 cls.ratings[ent] += val # loop over tags with db update for ent, val in tags: if isinstance(ent, float): print(f'float: {ent} {val}') continue if do_store: if val > 0.005: if ent not in cls.weighed[1]: cls.weighed[1][ent] = [] cls.weighed[1][ent].append(val + len(cls.query)) if count < max_ct: if replace_underscore and ent not in Its.kamojis: ent = ent.replace('_', ' ') if getattr(shared.opts, 'tagger_escape', False): ent = tag_escape_pattern.sub(r'\\\1', ent) if cls.re_search: ent = re_sub(cls.re_search, cls.replace_tags[0], ent, 1) elif ent in cls.search_tags: ent = cls.replace_tags[cls.search_tags[ent]] if ent not in cls.keep_tags: if cls.is_excluded(ent): continue if not on_avg and val < cls.threshold: continue for_tags_file += ", " + ent count += 1 elif not do_store: break cls.tags[ent] = cls.tags[ent] + val if ent in cls.tags else val for tag in cls.add_tags: cls.tags[tag] = 1.0 if getattr(shared.opts, 'tagger_verbose', True): print(f'{data[0]}: {count}/{len(tags)} tags kept') if do_store: cls.query[fi_key] = (data[0], len(cls.query)) if data[1]: data[1].write_text(for_tags_file[2:], encoding='utf-8') @classmethod def finalize_batch( cls, in_db, ct: int, on_avg: bool ) -> it_ret_tp: """ finalize the batch query """ if cls.json_db and ct > 0: cls.write_json() # collect the weights per file/interrogation of the prior in db stored. for index in range(2): for ent, lst in cls.weighed[index].items(): for i, val in map(get_i_wt, lst): if i in in_db: in_db[i][2+index][ent] = val # process the retrieved from db and add them to the stats for got in in_db.values(): cls.apply_filters(got, '', on_avg) # average return cls.finalize(ct + len(in_db), on_avg) @classmethod def finalize(cls, count: int, on_avg: bool) -> it_ret_tp: """ finalize the query, return the results """ tags_str, ratings, tags = '', {}, {} def averager(x): return x[0], x[1] / count js_bool = 'true' if cls.inverse else 'false' if on_avg: if cls.inverse: def inverse_filt(x): return cls.is_excluded(x[0]) or x[1] < cls.threshold and \ x[0] not in cls.keep_tags iter = filter(inverse_filt, map(averager, cls.tags.items())) else: def filt(x): return not cls.is_excluded(x[0]) and \ (x[1] >= cls.threshold or x[0] in cls.keep_tags) iter = filter(filt, map(averager, cls.tags.items())) else: iter = map(averager, cls.tags.items()) for k, already_averaged_val in iter: tags[k] = already_averaged_val # trigger an event to place the tag in the active tags list tags_str += f""", {k}""" for ent, val in cls.ratings.items(): ratings[ent] = val / count print('all done :)') return (tags_str[2:], ratings, tags, '')