From c3252d832556e050741baf5f68bb85208ff1db08 Mon Sep 17 00:00:00 2001 From: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com> Date: Wed, 8 May 2024 03:58:58 +0900 Subject: [PATCH] Merge changes in standalone version (#93) * Merge changes in standalone version - New Taggers and Custom Tagger - a little bit stable UI --- scripts/config.py | 143 ++++ scripts/dataset_tag_editor/__init__.py | 8 +- scripts/dataset_tag_editor/captioning.py | 47 -- scripts/dataset_tag_editor/custom_scripts.py | 50 ++ scripts/dataset_tag_editor/dte_logic.py | 749 ++++++++++++------ scripts/dataset_tag_editor/interrogator.py | 15 - .../interrogators/__init__.py | 3 +- .../interrogators/blip2_captioning.py | 33 + .../interrogators/git_large_captioning.py | 27 +- .../interrogators/waifu_diffusion_tagger.py | 73 +- scripts/dataset_tag_editor/tagger.py | 106 --- scripts/dataset_tag_editor/taggers_builtin.py | 171 ++++ scripts/dte_instance.py | 2 +- scripts/logger.py | 8 + scripts/main.py | 479 +++++++---- scripts/model_loader.py | 16 + scripts/paths.py | 17 + scripts/singleton.py | 10 +- .../tag_editor_ui/block_dataset_gallery.py | 2 +- scripts/tag_editor_ui/block_gallery_state.py | 2 +- scripts/tag_editor_ui/block_load_dataset.py | 245 ++++-- scripts/tag_editor_ui/block_tag_filter.py | 4 +- .../tag_editor_ui/tab_batch_edit_captions.py | 6 +- .../tab_edit_caption_of_selected_image.py | 6 +- .../tag_editor_ui/tab_filter_by_selection.py | 7 +- .../tag_editor_ui/tab_move_or_delete_files.py | 19 +- scripts/tag_editor_ui/ui_instance.py | 18 +- scripts/tagger.py | 52 ++ userscripts/taggers/aesthetic_shadow.py | 54 ++ .../taggers/cafeai_aesthetic_classifier.py | 54 ++ .../taggers/improved_aesthetic_predictor.py | 75 ++ .../taggers/waifu_aesthetic_classifier.py | 73 ++ 32 files changed, 1839 insertions(+), 735 deletions(-) create mode 100644 scripts/config.py delete mode 100644 scripts/dataset_tag_editor/captioning.py create mode 100644 scripts/dataset_tag_editor/custom_scripts.py delete mode 100644 scripts/dataset_tag_editor/interrogator.py create mode 100644 scripts/dataset_tag_editor/interrogators/blip2_captioning.py delete mode 100644 scripts/dataset_tag_editor/tagger.py create mode 100644 scripts/dataset_tag_editor/taggers_builtin.py create mode 100644 scripts/logger.py create mode 100644 scripts/model_loader.py create mode 100644 scripts/paths.py create mode 100644 scripts/tagger.py create mode 100644 userscripts/taggers/aesthetic_shadow.py create mode 100644 userscripts/taggers/cafeai_aesthetic_classifier.py create mode 100644 userscripts/taggers/improved_aesthetic_predictor.py create mode 100644 userscripts/taggers/waifu_aesthetic_classifier.py diff --git a/scripts/config.py b/scripts/config.py new file mode 100644 index 0000000..3ed7120 --- /dev/null +++ b/scripts/config.py @@ -0,0 +1,143 @@ +from collections import namedtuple +import json + +from scripts import logger +from scripts.paths import paths +from scripts.dte_instance import dte_instance + +SortBy = dte_instance.SortBy +SortOrder = dte_instance.SortOrder + +CONFIG_PATH = paths.base_path / "config.json" + +GeneralConfig = namedtuple( + "GeneralConfig", + [ + "backup", + "dataset_dir", + "caption_ext", + "load_recursive", + "load_caption_from_filename", + "replace_new_line", + "use_interrogator", + "use_interrogator_names", + "use_custom_threshold_booru", + "custom_threshold_booru", + "use_custom_threshold_waifu", + "custom_threshold_waifu", + "custom_threshold_z3d", + "save_kohya_metadata", + "meta_output_path", + "meta_input_path", + "meta_overwrite", + "meta_save_as_caption", + "meta_use_full_path", + ], +) +FilterConfig = namedtuple( + "FilterConfig", + ["sw_prefix", "sw_suffix", "sw_regex", "sort_by", "sort_order", "logic"], +) +BatchEditConfig = namedtuple( + "BatchEditConfig", + [ + "show_only_selected", + "prepend", + "use_regex", + "target", + "sw_prefix", + "sw_suffix", + "sw_regex", + "sory_by", + "sort_order", + "batch_sort_by", + "batch_sort_order", + "token_count", + ], +) +EditSelectedConfig = namedtuple( + "EditSelectedConfig", + [ + "auto_copy", + "sort_on_save", + "warn_change_not_saved", + "use_interrogator_name", + "sort_by", + "sort_order", + ], +) +MoveDeleteConfig = namedtuple( + "MoveDeleteConfig", ["range", "target", "caption_ext", "destination"] +) + +CFG_GENERAL_DEFAULT = GeneralConfig( + True, + "", + ".txt", + False, + True, + False, + "No", + [], + False, + 0.7, + False, + 0.35, + 0.35, + False, + "", + "", + True, + False, + False, +) +CFG_FILTER_P_DEFAULT = FilterConfig( + False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, "AND" +) +CFG_FILTER_N_DEFAULT = FilterConfig( + False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, "OR" +) +CFG_BATCH_EDIT_DEFAULT = BatchEditConfig( + True, + False, + False, + "Only Selected Tags", + False, + False, + False, + SortBy.ALPHA.value, + SortOrder.ASC.value, + SortBy.ALPHA.value, + SortOrder.ASC.value, + 75, +) +CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig( + False, False, False, "", SortBy.ALPHA.value, SortOrder.ASC.value +) +CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig("Selected One", [], ".txt", "") + + +class Config: + def __init__(self): + self.config = dict() + + def load(self): + if not CONFIG_PATH.is_file(): + self.config = dict() + return + try: + self.config = json.loads(CONFIG_PATH.read_text("utf8")) + except: + logger.warn("Error on loading config.json. Default settings will be loaded.") + self.config = dict() + else: + logger.write("Settings has been read from config.json") + + def save(self): + CONFIG_PATH.write_text(json.dumps(self.config, indent=4), "utf8") + + def read(self, name: str): + return self.config.get(name) + + def write(self, cfg: dict, name: str): + self.config[name] = cfg diff --git a/scripts/dataset_tag_editor/__init__.py b/scripts/dataset_tag_editor/__init__.py index f601d0f..a20e95a 100644 --- a/scripts/dataset_tag_editor/__init__.py +++ b/scripts/dataset_tag_editor/__init__.py @@ -1,8 +1,8 @@ -from . import tagger -from . import captioning +from . import taggers_builtin from . import filters from . import dataset as ds +from . import kohya_finetune_metadata -from .dte_logic import DatasetTagEditor, INTERROGATOR_NAMES, interrogate_image +from .dte_logic import DatasetTagEditor -__all__ = ["ds", "tagger", "captioning", "filters", "kohya_metadata", "INTERROGATOR_NAMES", "interrogate_image", "DatasetTagEditor"] +__all__ = ["ds", "taggers_builtin", "filters", "kohya_finetune_metadata", "DatasetTagEditor"] diff --git a/scripts/dataset_tag_editor/captioning.py b/scripts/dataset_tag_editor/captioning.py deleted file mode 100644 index 044a843..0000000 --- a/scripts/dataset_tag_editor/captioning.py +++ /dev/null @@ -1,47 +0,0 @@ -import modules.shared as shared - -from .interrogator import Interrogator -from .interrogators import GITLargeCaptioning - - -class Captioning(Interrogator): - def start(self): - pass - def stop(self): - pass - def predict(self, image): - raise NotImplementedError() - def name(self): - raise NotImplementedError() - - -class BLIP(Captioning): - def start(self): - shared.interrogator.load() - - def stop(self): - shared.interrogator.unload() - - def predict(self, image): - tags = shared.interrogator.generate_caption(image).split(',') - return [t for t in tags if t] - - def name(self): - return 'BLIP' - - -class GITLarge(Captioning): - def __init__(self): - self.interrogator = GITLargeCaptioning() - def start(self): - self.interrogator.load() - - def stop(self): - self.interrogator.unload() - - def predict(self, image): - tags = self.interrogator.apply(image).split(',') - return [t for t in tags if t] - - def name(self): - return 'GIT-large-COCO' \ No newline at end of file diff --git a/scripts/dataset_tag_editor/custom_scripts.py b/scripts/dataset_tag_editor/custom_scripts.py new file mode 100644 index 0000000..6fe8be2 --- /dev/null +++ b/scripts/dataset_tag_editor/custom_scripts.py @@ -0,0 +1,50 @@ +import sys +from pathlib import Path +import importlib.util +from types import ModuleType + +from scripts import logger +from scripts.paths import paths + + +class CustomScripts: + def _load_module_from(self, path:Path): + module_spec = importlib.util.spec_from_file_location(path.stem, path) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + return module + + def _load_derived_classes(self, module:ModuleType, base_class:type): + derived_classes = [] + for name in dir(module): + obj = getattr(module, name) + if isinstance(obj, type) and issubclass(obj, base_class) and obj is not base_class: + derived_classes.append(obj) + + return derived_classes + + def __init__(self, scripts_dir:Path) -> None: + self.scripts = dict() + self.scripts_dir = scripts_dir.absolute() + + def load_derived_classes(self, baseclass:type): + back_syspath = sys.path + if not self.scripts_dir.is_dir(): + logger.warn(f"NOT A DIRECTORY: {self.scripts_dir}") + return [] + + classes = [] + try: + sys.path = [str(paths.base_path)] + sys.path + for path in self.scripts_dir.glob("*.py"): + self.scripts[path.stem] = self._load_module_from(path) + for module in self.scripts.values(): + classes.extend(self._load_derived_classes(module, baseclass)) + except Exception as e: + tb = sys.exc_info()[2] + logger.error(f"Error on loading {path}") + logger.error(e.with_traceback(tb)) + finally: + sys.path = back_syspath + + return classes \ No newline at end of file diff --git a/scripts/dataset_tag_editor/dte_logic.py b/scripts/dataset_tag_editor/dte_logic.py index b2d862c..5e0d4eb 100644 --- a/scripts/dataset_tag_editor/dte_logic.py +++ b/scripts/dataset_tag_editor/dte_logic.py @@ -1,57 +1,45 @@ from pathlib import Path -import re +import re, sys from typing import List, Set, Optional from enum import Enum + from PIL import Image +from tqdm import tqdm from modules import shared from modules.textual_inversion.dataset import re_numbers_at_start from scripts.singleton import Singleton -from . import tagger, captioning, filters, dataset as ds, kohya_finetune_metadata as kohya_metadata +from scripts import logger +from scripts.paths import paths + +from . import ( + filters, + dataset as ds, + kohya_finetune_metadata as kohya_metadata, + taggers_builtin +) +from .custom_scripts import CustomScripts from scripts.tokenizer import clip_tokenizer +from scripts.tagger import Tagger -WD_TAGGER_NAMES = ["wd-v1-4-vit-tagger", "wd-v1-4-convnext-tagger", "wd-v1-4-vit-tagger-v2", "wd-v1-4-convnext-tagger-v2", "wd-v1-4-swinv2-tagger-v2"] -WD_TAGGER_THRESHOLDS = [0.35, 0.35, 0.3537, 0.3685, 0.3771] # v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf - -INTERROGATORS = [captioning.BLIP(), tagger.DeepDanbooru()] + [tagger.WaifuDiffusion(name, WD_TAGGER_THRESHOLDS[i]) for i, name in enumerate(WD_TAGGER_NAMES)] -INTERROGATOR_NAMES = [it.name() for it in INTERROGATORS] - -re_tags = re.compile(r'^([\s\S]+?)( \[\d+\])?$') -re_newlines = re.compile(r'[\r\n]+') - - -def interrogate_image(path:str, interrogator_name:str, threshold_booru, threshold_wd): - try: - img = Image.open(path).convert('RGB') - except: - return '' - else: - for it in INTERROGATORS: - if it.name() == interrogator_name: - if isinstance(it, tagger.DeepDanbooru): - with it as tg: - res = tg.predict(img, threshold_booru) - elif isinstance(it, tagger.WaifuDiffusion): - with it as tg: - res = tg.predict(img, threshold_wd) - else: - with it as cap: - res = cap.predict(img) - return ', '.join(res) +re_tags = re.compile(r"^([\s\S]+?)( \[\d+\])?$") +re_newlines = re.compile(r"[\r\n]+") +def convert_rgb(data:Image.Image): + return data.convert("RGB") class DatasetTagEditor(Singleton): class SortBy(Enum): - ALPHA = 'Alphabetical Order' - FREQ = 'Frequency' - LEN = 'Length' - TOKEN = 'Token Length' + ALPHA = "Alphabetical Order" + FREQ = "Frequency" + LEN = "Length" + TOKEN = "Token Length" class SortOrder(Enum): - ASC = 'Ascending' - DESC = 'Descending' + ASC = "Ascending" + DESC = "Descending" class InterrogateMethod(Enum): NONE = 0 @@ -59,63 +47,138 @@ class DatasetTagEditor(Singleton): OVERWRITE = 2 PREPEND = 3 APPEND = 4 - + def __init__(self): # from modules.textual_inversion.dataset - self.re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None + self.re_word = ( + re.compile(shared.opts.dataset_filename_word_regex) + if len(shared.opts.dataset_filename_word_regex) > 0 + else None + ) self.dataset = ds.Dataset() self.img_idx = dict() self.tag_counts = {} - self.dataset_dir = '' + self.dataset_dir = "" self.images = {} self.tag_tokens = {} self.raw_clip_token_used = None + + def load_interrogators(self): + custom_tagger_scripts = CustomScripts(paths.userscript_path / "taggers") + custom_taggers:list[Tagger] = custom_tagger_scripts.load_derived_classes(Tagger) + logger.write(f"Custom taggers loaded: {[tagger().name() for tagger in custom_taggers]}") + + self.BLIP2_CAPTIONING_NAMES = [ + "blip2-opt-2.7b", + "blip2-opt-2.7b-coco", + "blip2-opt-6.7b", + "blip2-opt-6.7b-coco", + "blip2-flan-t5-xl", + "blip2-flan-t5-xl-coco", + "blip2-flan-t5-xxl", + ] + + self.WD_TAGGERS = { + "wd-v1-4-vit-tagger" : 0.35, + "wd-v1-4-convnext-tagger" : 0.35, + "wd-v1-4-vit-tagger-v2" : 0.3537, + "wd-v1-4-convnext-tagger-v2" : 0.3685, + "wd-v1-4-convnextv2-tagger-v2" : 0.371, + "wd-v1-4-swinv2-tagger-v2" : 0.3771, + "wd-v1-4-moat-tagger-v2" : 0.3771, + "wd-vit-tagger-v3" : 0.2614, + "wd-convnext-tagger-v3" : 0.2682, + "wd-swinv2-tagger-v3" : 0.2653, + } + # {tagger name : default tagger threshold} + # v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf + + self.INTERROGATORS = ( + [taggers_builtin.BLIP()] + + [taggers_builtin.BLIP2(name) for name in self.BLIP2_CAPTIONING_NAMES] + + [taggers_builtin.GITLarge()] + + [taggers_builtin.DeepDanbooru()] + + [ + taggers_builtin.WaifuDiffusion(name, threshold) + for name, threshold in self.WD_TAGGERS.items() + ] + + [taggers_builtin.Z3D_E621()] + + [cls_tagger() for cls_tagger in custom_taggers] + ) + self.INTERROGATOR_NAMES = [it.name() for it in self.INTERROGATORS] + + def interrogate_image(self, path: str, interrogator_name: str, threshold_booru, threshold_wd, threshold_z3d): + try: + img = Image.open(path).convert("RGB") + except: + return "" + else: + for it in self.INTERROGATORS: + if it.name() == interrogator_name: + if isinstance(it, taggers_builtin.DeepDanbooru): + with it as tg: + res = tg.predict(img, threshold_booru) + elif isinstance(it, taggers_builtin.WaifuDiffusion): + with it as tg: + res = tg.predict(img, threshold_wd) + elif isinstance(it, taggers_builtin.Z3D_E621): + with it as tg: + res = tg.predict(img, threshold_z3d) + else: + with it as cap: + res = cap.predict(img) + return ", ".join(res) def get_tag_list(self): if len(self.tag_counts) == 0: self.construct_tag_infos() return [key for key in self.tag_counts.keys()] - def get_tag_set(self): if len(self.tag_counts) == 0: self.construct_tag_infos() return {key for key in self.tag_counts.keys()} - - def get_tags_by_image_path(self, imgpath:str): + def get_tags_by_image_path(self, imgpath: str): return self.dataset.get_data_tags(imgpath) - - def set_tags_by_image_path(self, imgpath:str, tags:List[str]): - self.dataset.append_data(ds.Data(imgpath, ','.join(tags))) + def set_tags_by_image_path(self, imgpath: str, tags: list[str]): + self.dataset.append_data(ds.Data(imgpath, ",".join(tags))) self.construct_tag_infos() - - def write_tags(self, tags:List[str], sort_by:SortBy=SortBy.FREQ): + def write_tags(self, tags: list[str], sort_by: SortBy = SortBy.FREQ): sort_by = self.SortBy(sort_by) if tags: if sort_by == self.SortBy.FREQ: - return [f'{tag} [{self.tag_counts.get(tag) or 0}]' for tag in tags if tag] + return [ + f"{tag} [{self.tag_counts.get(tag) or 0}]" for tag in tags if tag + ] elif sort_by == self.SortBy.LEN: - return [f'{tag} [{len(tag)}]' for tag in tags if tag] + return [f"{tag} [{len(tag)}]" for tag in tags if tag] elif sort_by == self.SortBy.TOKEN: - return [f'{tag} [{self.tag_tokens.get(tag, (0, 0))[1]}]' for tag in tags if tag] + return [ + f"{tag} [{self.tag_tokens.get(tag, (0, 0))[1]}]" + for tag in tags + if tag + ] else: - return [f'{tag}' for tag in tags if tag] + return [f"{tag}" for tag in tags if tag] else: return [] - - def read_tags(self, tags:List[str]): + def read_tags(self, tags: list[str]): if tags: tags = [re_tags.match(tag).group(1) for tag in tags if tag] return [t for t in tags if t] else: return [] - - def sort_tags(self, tags:List[str], sort_by:SortBy=SortBy.ALPHA, sort_order:SortOrder=SortOrder.ASC): + def sort_tags( + self, + tags: list[str], + sort_by: SortBy = SortBy.ALPHA, + sort_order: SortOrder = SortOrder.ASC, + ): sort_by = self.SortBy(sort_by) sort_order = self.SortOrder(sort_order) if sort_by == self.SortBy.ALPHA: @@ -125,61 +188,77 @@ class DatasetTagEditor(Singleton): return sorted(tags, reverse=True) elif sort_by == self.SortBy.FREQ: if sort_order == self.SortOrder.ASC: - return sorted(tags, key=lambda t:(self.tag_counts.get(t, 0), t), reverse=False) + return sorted( + tags, key=lambda t: (self.tag_counts.get(t, 0), t), reverse=False + ) elif sort_order == self.SortOrder.DESC: - return sorted(tags, key=lambda t:(-self.tag_counts.get(t, 0), t), reverse=False) + return sorted( + tags, key=lambda t: (-self.tag_counts.get(t, 0), t), reverse=False + ) elif sort_by == self.SortBy.LEN: if sort_order == self.SortOrder.ASC: - return sorted(tags, key=lambda t:(len(t), t), reverse=False) + return sorted(tags, key=lambda t: (len(t), t), reverse=False) elif sort_order == self.SortOrder.DESC: - return sorted(tags, key=lambda t:(-len(t), t), reverse=False) + return sorted(tags, key=lambda t: (-len(t), t), reverse=False) elif sort_by == self.SortBy.TOKEN: if sort_order == self.SortOrder.ASC: - return sorted(tags, key=lambda t:(self.tag_tokens.get(t, (0, 0))[1], t), reverse=False) + return sorted( + tags, + key=lambda t: (self.tag_tokens.get(t, (0, 0))[1], t), + reverse=False, + ) elif sort_order == self.SortOrder.DESC: - return sorted(tags, key=lambda t:(-self.tag_tokens.get(t, (0, 0))[1], t), reverse=False) + return sorted( + tags, + key=lambda t: (-self.tag_tokens.get(t, (0, 0))[1], t), + reverse=False, + ) return list(tags) - - def get_filtered_imgpaths(self, filters:List[filters.Filter] = []): + def get_filtered_imgpaths(self, filters: list[filters.Filter] = []): filtered_set = self.dataset.copy() for filter in filters: filtered_set.filter(filter) - + img_paths = sorted(filtered_set.datas.keys()) - + return img_paths - - def get_filtered_imgs(self, filters:List[filters.Filter] = []): + def get_filtered_imgs(self, filters: list[filters.Filter] = []): filtered_set = self.dataset.copy() for filter in filters: filtered_set.filter(filter) - + img_paths = sorted(filtered_set.datas.keys()) - + return [self.images.get(path) for path in img_paths] - - def get_filtered_imgindices(self, filters:List[filters.Filter] = []): + def get_filtered_imgindices(self, filters: list[filters.Filter] = []): filtered_set = self.dataset.copy() for filter in filters: filtered_set.filter(filter) - + img_paths = sorted(filtered_set.datas.keys()) - + return [self.img_idx.get(p) for p in img_paths] - - def get_filtered_tags(self, filters:List[filters.Filter] = [], filter_word:str = '', filter_tags = True, prefix=False, suffix=False, regex=False): + def get_filtered_tags( + self, + filters: list[filters.Filter] = [], + filter_word: str = "", + filter_tags=True, + prefix=False, + suffix=False, + regex=False, + ): if filter_tags: filtered_set = self.dataset.copy() for filter in filters: filtered_set.filter(filter) - tags:Set[str] = filtered_set.get_tagset() + tags: set[str] = filtered_set.get_tagset() else: - tags:Set[str] = self.dataset.get_tagset() - + tags: set[str] = self.dataset.get_tagset() + result = set() try: for tag in tags: @@ -215,115 +294,177 @@ class DatasetTagEditor(Singleton): else: return result - - def cleanup_tags(self, tags:List[str]): + def cleanup_tags(self, tags: list[str]): current_dataset_tags = self.dataset.get_tagset() return [t for t in tags if t in current_dataset_tags] - - def cleanup_tagset(self, tags:Set[str]): + + def cleanup_tagset(self, tags: set[str]): current_dataset_tagset = self.dataset.get_tagset() return tags & current_dataset_tagset - - def get_common_tags(self, filters:List[filters.Filter] = []): + def get_common_tags(self, filters: list[filters.Filter] = []): filtered_set = self.dataset.copy() for filter in filters: filtered_set.filter(filter) - + result = filtered_set.get_tagset() for d in filtered_set.datas.values(): result &= d.tagset return sorted(result) - - def replace_tags(self, search_tags:List[str], replace_tags:List[str], filters:List[filters.Filter] = [], prepend:bool = False): + def replace_tags( + self, + search_tags: list[str], + replace_tags: list[str], + filters: list[filters.Filter] = [], + prepend: bool = False, + ): img_paths = self.get_filtered_imgpaths(filters=filters) - tags_to_append = replace_tags[len(search_tags):] - tags_to_remove = search_tags[len(replace_tags):] + tags_to_append = replace_tags[len(search_tags) :] + tags_to_remove = search_tags[len(replace_tags) :] tags_to_replace = {} for i in range(min(len(search_tags), len(replace_tags))): - if replace_tags[i] is None or replace_tags[i] == '': + if replace_tags[i] is None or replace_tags[i] == "": tags_to_remove.append(search_tags[i]) else: tags_to_replace[search_tags[i]] = replace_tags[i] for img_path in img_paths: - tags_removed = [t for t in self.dataset.get_data_tags(img_path) if t not in tags_to_remove] - tags_replaced = [tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed] - self.set_tags_by_image_path(img_path, tags_to_append + tags_replaced if prepend else tags_replaced + tags_to_append) - + tags_removed = [ + t + for t in self.dataset.get_data_tags(img_path) + if t not in tags_to_remove + ] + tags_replaced = [ + tags_to_replace.get(t) if t in tags_to_replace.keys() else t + for t in tags_removed + ] + self.set_tags_by_image_path( + img_path, + tags_to_append + tags_replaced + if prepend + else tags_replaced + tags_to_append, + ) + self.construct_tag_infos() - def get_replaced_tagset(self, tags:Set[str], search_tags:List[str], replace_tags:List[str]): - tags_to_remove = search_tags[len(replace_tags):] + def get_replaced_tagset( + self, tags: set[str], search_tags: list[str], replace_tags: list[str] + ): + tags_to_remove = search_tags[len(replace_tags) :] tags_to_replace = {} for i in range(min(len(search_tags), len(replace_tags))): - if replace_tags[i] is None or replace_tags[i] == '': + if replace_tags[i] is None or replace_tags[i] == "": tags_to_remove.append(search_tags[i]) else: tags_to_replace[search_tags[i]] = replace_tags[i] tags_removed = {t for t in tags if t not in tags_to_remove} - tags_replaced = {tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed} + tags_replaced = { + tags_to_replace.get(t) if t in tags_to_replace.keys() else t + for t in tags_removed + } return {t for t in tags_replaced if t} - - def search_and_replace_caption(self, search_text:str, replace_text:str, filters:List[filters.Filter] = [], use_regex:bool = False): + def search_and_replace_caption( + self, + search_text: str, + replace_text: str, + filters: list[filters.Filter] = [], + use_regex: bool = False, + ): img_paths = self.get_filtered_imgpaths(filters=filters) - + for img_path in img_paths: - caption = ', '.join(self.dataset.get_data_tags(img_path)) + caption = ", ".join(self.dataset.get_data_tags(img_path)) if use_regex: - caption = [t.strip() for t in re.sub(search_text, replace_text, caption).split(',')] + caption = [ + t.strip() + for t in re.sub(search_text, replace_text, caption).split(",") + ] else: - caption = [t.strip() for t in caption.replace(search_text, replace_text).split(',')] + caption = [ + t.strip() + for t in caption.replace(search_text, replace_text).split(",") + ] caption = [t for t in caption if t] self.set_tags_by_image_path(img_path, caption) - + self.construct_tag_infos() - - def search_and_replace_selected_tags(self, search_text:str, replace_text:str, selected_tags:Optional[Set[str]], filters:List[filters.Filter] = [], use_regex:bool = False): + def search_and_replace_selected_tags( + self, + search_text: str, + replace_text: str, + selected_tags: Optional[set[str]], + filters: list[filters.Filter] = [], + use_regex: bool = False, + ): img_paths = self.get_filtered_imgpaths(filters=filters) for img_path in img_paths: tags = self.dataset.get_data_tags(img_path) - tags = self.search_and_replace_tag_list(search_text, replace_text, tags, selected_tags, use_regex) + tags = self.search_and_replace_tag_list( + search_text, replace_text, tags, selected_tags, use_regex + ) self.set_tags_by_image_path(img_path, tags) - + self.construct_tag_infos() - - def search_and_replace_tag_list(self, search_text:str, replace_text:str, tags:List[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False): + def search_and_replace_tag_list( + self, + search_text: str, + replace_text: str, + tags: list[str], + selected_tags: Optional[set[str]] = None, + use_regex: bool = False, + ): if use_regex: if selected_tags is None: tags = [re.sub(search_text, replace_text, t) for t in tags] else: - tags = [re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags] + tags = [ + re.sub(search_text, replace_text, t) if t in selected_tags else t + for t in tags + ] else: if selected_tags is None: tags = [t.replace(search_text, replace_text) for t in tags] else: - tags = [t.replace(search_text, replace_text) if t in selected_tags else t for t in tags] - tags = [t2 for t1 in tags for t2 in t1.split(',') if t2] + tags = [ + t.replace(search_text, replace_text) if t in selected_tags else t + for t in tags + ] + tags = [t2 for t1 in tags for t2 in t1.split(",") if t2] return [t for t in tags if t] - - def search_and_replace_tag_set(self, search_text:str, replace_text:str, tags:Set[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False): + def search_and_replace_tag_set( + self, + search_text: str, + replace_text: str, + tags: set[str], + selected_tags: Optional[set[str]] = None, + use_regex: bool = False, + ): if use_regex: if selected_tags is None: tags = {re.sub(search_text, replace_text, t) for t in tags} else: - tags = {re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags} + tags = { + re.sub(search_text, replace_text, t) if t in selected_tags else t + for t in tags + } else: if selected_tags is None: tags = {t.replace(search_text, replace_text) for t in tags} else: - tags = {t.replace(search_text, replace_text) if t in selected_tags else t for t in tags} - tags = {t2 for t1 in tags for t2 in t1.split(',') if t2} + tags = { + t.replace(search_text, replace_text) if t in selected_tags else t + for t in tags + } + tags = {t2 for t1 in tags for t2 in t1.split(",") if t2} return {t for t in tags if t} - - def remove_duplicated_tags(self, filters:List[filters.Filter] = []): + def remove_duplicated_tags(self, filters: list[filters.Filter] = []): img_paths = self.get_filtered_imgpaths(filters) for path in img_paths: tags = self.dataset.get_data_tags(path) @@ -332,32 +473,33 @@ class DatasetTagEditor(Singleton): if t not in res: res.append(t) self.set_tags_by_image_path(path, res) - - def remove_tags(self, tags:Set[str], filters:List[filters.Filter] = []): + def remove_tags(self, tags: set[str], filters: list[filters.Filter] = []): img_paths = self.get_filtered_imgpaths(filters) for path in img_paths: res = self.dataset.get_data_tags(path) res = [t for t in res if t not in tags] self.set_tags_by_image_path(path, res) - - def sort_filtered_tags(self, filters:List[filters.Filter] = [], **sort_args): + def sort_filtered_tags(self, filters: list[filters.Filter] = [], **sort_args): img_paths = self.get_filtered_imgpaths(filters) for path in img_paths: tags = self.dataset.get_data_tags(path) res = self.sort_tags(tags, **sort_args) self.set_tags_by_image_path(path, res) - print(f'[tag-editor] Tags are sorted by {sort_args.get("sort_by").value} ({sort_args.get("sort_order").value})') + logger.write( + f'Tags are sorted by {sort_args.get("sort_by").value} ({sort_args.get("sort_order").value})' + ) - - def truncate_filtered_tags_by_token_count(self, filters:List[filters.Filter] = [], max_token_count:int = 75): + def truncate_filtered_tags_by_token_count( + self, filters: list[filters.Filter] = [], max_token_count: int = 75 + ): img_paths = self.get_filtered_imgpaths(filters) for path in img_paths: tags = self.dataset.get_data_tags(path) res = [] for tag in tags: - _, token_count = clip_tokenizer.tokenize(', '.join(res + [tag]), shared.opts.dataset_editor_use_raw_clip_token) + _, token_count = clip_tokenizer.tokenize(", ".join(res + [tag])) if token_count <= max_token_count: res.append(tag) else: @@ -365,46 +507,67 @@ class DatasetTagEditor(Singleton): self.set_tags_by_image_path(path, res) self.construct_tag_infos() - print(f'[tag-editor] Tags are truncated into token count <= {max_token_count}') - + logger.write(f"Tags are truncated into token count <= {max_token_count}") def get_img_path_list(self): return [k for k in self.dataset.datas.keys() if k] - def get_img_path_set(self): return {k for k in self.dataset.datas.keys() if k} - - def delete_dataset(self, caption_ext:str, filters:List[filters.Filter], delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False): + def delete_dataset( + self, + caption_ext: str, + filters: list[filters.Filter], + delete_image: bool = False, + delete_caption: bool = False, + delete_backup: bool = False, + ): filtered_set = self.dataset.copy() for filter in filters: filtered_set.filter(filter) for path in filtered_set.datas.keys(): - self.delete_dataset_file(path, caption_ext, delete_image, delete_caption, delete_backup) - + self.delete_dataset_file( + path, caption_ext, delete_image, delete_caption, delete_backup + ) + if delete_image: self.dataset.remove(filtered_set) self.construct_tag_infos() - - def move_dataset(self, dest_dir:str, caption_ext:str, filters:List[filters.Filter], move_image:bool = False, move_caption:bool = False, move_backup:bool = False): + def move_dataset( + self, + dest_dir: str, + caption_ext: str, + filters: list[filters.Filter], + move_image: bool = False, + move_caption: bool = False, + move_backup: bool = False, + ): filtered_set = self.dataset.copy() for filter in filters: filtered_set.filter(filter) for path in filtered_set.datas.keys(): - self.move_dataset_file(path, caption_ext, dest_dir, move_image, move_caption, move_backup) - + self.move_dataset_file( + path, caption_ext, dest_dir, move_image, move_caption, move_backup + ) + if move_image: self.construct_tag_infos() - - def delete_dataset_file(self, img_path:str, caption_ext:str, delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False): + def delete_dataset_file( + self, + img_path: str, + caption_ext: str, + delete_image: bool = False, + delete_caption: bool = False, + delete_backup: bool = False, + ): if img_path not in self.dataset.datas.keys(): return - + img_path_obj = Path(img_path) - + if delete_image: try: if img_path_obj.is_file(): @@ -413,34 +576,41 @@ class DatasetTagEditor(Singleton): del self.images[img_path] img_path_obj.unlink() self.dataset.remove_by_path(img_path) - print(f'[tag-editor] Deleted {img_path_obj.absolute()}') + logger.write(f"Deleted {img_path_obj.absolute()}") except Exception as e: - print(e) - + logger.error(e) + if delete_caption: try: txt_path_obj = img_path_obj.with_suffix(caption_ext) if txt_path_obj.is_file(): txt_path_obj.unlink() - print(f'[tag-editor] Deleted {txt_path_obj.absolute()}') + logger.write(f"Deleted {txt_path_obj.absolute()}") except Exception as e: - print(e) - + logger.error(e) + if delete_backup: try: for extnum in range(1000): - bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}') + bak_path_obj = img_path_obj.with_suffix(f".{extnum:0>3d}") if bak_path_obj.is_file(): bak_path_obj.unlink() - print(f'[tag-editor] Deleted {bak_path_obj.absolute()}') + logger.write(f"Deleted {bak_path_obj.absolute()}") except Exception as e: - print(e) - + logger.error(e) - def move_dataset_file(self, img_path:str, caption_ext:str, dest_dir:str, move_image:bool = False, move_caption:bool = False, move_backup:bool = False): + def move_dataset_file( + self, + img_path: str, + caption_ext: str, + dest_dir: str, + move_image: bool = False, + move_caption: bool = False, + move_backup: bool = False, + ): if img_path not in self.dataset.datas.keys(): return - + img_path_obj = Path(img_path) dest_dir_obj = Path(dest_dir) @@ -456,54 +626,77 @@ class DatasetTagEditor(Singleton): del self.images[img_path] img_path_obj.replace(dst_path_obj) self.dataset.remove_by_path(img_path) - print(f'[tag-editor] Moved {img_path_obj.absolute()} -> {dst_path_obj.absolute()}') + logger.write( + f"Moved {img_path_obj.absolute()} -> {dst_path_obj.absolute()}" + ) except Exception as e: - print(e) - + logger.error(e) + if move_caption: try: txt_path_obj = img_path_obj.with_suffix(caption_ext) dst_path_obj = dest_dir_obj / txt_path_obj.name if txt_path_obj.is_file(): txt_path_obj.replace(dst_path_obj) - print(f'[tag-editor] Moved {txt_path_obj.absolute()} -> {dst_path_obj.absolute()}') + logger.write( + f"Moved {txt_path_obj.absolute()} -> {dst_path_obj.absolute()}" + ) except Exception as e: - print(e) - + logger.error(e) + if move_backup: try: for extnum in range(1000): - bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}') + bak_path_obj = img_path_obj.with_suffix(f".{extnum:0>3d}") dst_path_obj = dest_dir_obj / bak_path_obj.name if bak_path_obj.is_file(): bak_path_obj.replace(dst_path_obj) - print(f'[tag-editor] Moved {bak_path_obj.absolute()} -> {dst_path_obj.absolute()}') + logger.write( + f"Moved {bak_path_obj.absolute()} -> {dst_path_obj.absolute()}" + ) except Exception as e: - print(e) + logger.error(e) - def load_dataset(self, img_dir:str, caption_ext:str, recursive:bool, load_caption_from_filename:bool, replace_new_line:bool, interrogate_method:InterrogateMethod, interrogator_names:List[str], threshold_booru:float, threshold_waifu:float, use_temp_dir:bool, kohya_json_path:Optional[str], max_res:float): + def load_dataset( + self, + img_dir: str, + caption_ext: str, + recursive: bool, + load_caption_from_filename: bool, + replace_new_line: bool, + interrogate_method: InterrogateMethod, + interrogator_names: list[str], + threshold_booru: float, + threshold_waifu: float, + threshold_z3d: float, + use_temp_dir: bool, + kohya_json_path: Optional[str], + max_res:float + ): self.clear() img_dir_obj = Path(img_dir) - print(f'[tag-editor] Loading dataset from {img_dir_obj.absolute()}') + logger.write(f"Loading dataset from {img_dir_obj.absolute()}") if recursive: - print(f'[tag-editor] Also loading from subdirectories.') - + logger.write(f"Also loading from subdirectories.") + try: - filepaths = img_dir_obj.glob('**/*') if recursive else img_dir_obj.glob('*') + filepaths = img_dir_obj.glob("**/*") if recursive else img_dir_obj.glob("*") filepaths = [p for p in filepaths if p.is_file()] except Exception as e: - print(e) - print('[tag-editor] Loading Aborted.') + logger.error(e) + logger.write("Loading Aborted.") return self.dataset_dir = img_dir - print(f'[tag-editor] Total {len(filepaths)} files under the directory including not image files.') + logger.write( + f"Total {len(filepaths)} files under the directory including not image files." + ) - def load_images(filepaths:List[Path]): + def load_images(filepaths: list[Path]): imgpaths = [] images = {} for img_path in filepaths: @@ -521,97 +714,130 @@ class DatasetTagEditor(Singleton): if not use_temp_dir and max_res <= 0: img.already_saved_as = abs_path images[abs_path] = img - + imgpaths.append(abs_path) return imgpaths, images - - def load_captions(imgpaths:List[str]): + + def load_captions(imgpaths: list[str]): taglists = [] for abs_path in imgpaths: img_path = Path(abs_path) text_path = img_path.with_suffix(caption_ext) - caption_text = '' + caption_text = "" if interrogate_method != self.InterrogateMethod.OVERWRITE: # from modules/textual_inversion/dataset.py, modified if text_path.is_file(): - caption_text = text_path.read_text('utf8') + caption_text = text_path.read_text("utf8") elif load_caption_from_filename: caption_text = img_path.stem - caption_text = re.sub(re_numbers_at_start, '', caption_text) + caption_text = re.sub(re_numbers_at_start, "", caption_text) if self.re_word: tokens = self.re_word.findall(caption_text) - caption_text = (shared.opts.dataset_filename_join_string or "").join(tokens) - + caption_text = ( + shared.opts.dataset_filename_join_string or "" + ).join(tokens) + if replace_new_line: - caption_text = re_newlines.sub(',', caption_text) - - caption_tags = [t.strip() for t in caption_text.split(',')] + caption_text = re_newlines.sub(",", caption_text) + + caption_tags = [t.strip() for t in caption_text.split(",")] caption_tags = [t for t in caption_tags if t] taglists.append(caption_tags) return taglists - try: - captionings = [] - taggers = [] - if interrogate_method != self.InterrogateMethod.NONE: - for it in INTERROGATORS: - if it.name() in interrogator_names: - it.start() - if isinstance(it, tagger.Tagger): - if isinstance(it, tagger.DeepDanbooru): - taggers.append((it, threshold_booru)) - if isinstance(it, tagger.WaifuDiffusion): - taggers.append((it, threshold_waifu)) - elif isinstance(it, captioning.Captioning): - captionings.append(it) - - if kohya_json_path: - imgpaths, self.images, taglists = kohya_metadata.read(img_dir, kohya_json_path, use_temp_dir) - else: - imgpaths, self.images = load_images(filepaths) - taglists = load_captions(imgpaths) - - for img_path, tags in zip(imgpaths, taglists): - interrogate_tags = [] - img = self.images.get(img_path) - if interrogate_method != self.InterrogateMethod.NONE and ((interrogate_method != self.InterrogateMethod.PREFILL) or (interrogate_method == self.InterrogateMethod.PREFILL and not tags)): - if img is None: - print(f'Failed to load image {img_path}. Interrogating is aborted.') + tagger_thresholds:list[tuple[Tagger, float]] = [] + if interrogate_method != self.InterrogateMethod.NONE: + for it in self.INTERROGATORS: + if it.name() in interrogator_names: + if isinstance(it, taggers_builtin.DeepDanbooru): + tagger_thresholds.append((it, threshold_booru)) + elif isinstance(it, taggers_builtin.WaifuDiffusion): + tagger_thresholds.append((it, threshold_waifu)) + elif isinstance(it, taggers_builtin.Z3D_E621): + tagger_thresholds.append((it, threshold_z3d)) else: - img = img.convert('RGB') - for cap in captionings: - interrogate_tags += cap.predict(img) - - for tg, threshold in taggers: - interrogate_tags += [t for t in tg.predict(img, threshold).keys()] - - if interrogate_method == self.InterrogateMethod.OVERWRITE: - tags = interrogate_tags - elif interrogate_method == self.InterrogateMethod.PREPEND: - tags = interrogate_tags + tags - else: - tags = tags + interrogate_tags - - self.set_tags_by_image_path(img_path, tags) + tagger_thresholds.append((it, None)) + + if kohya_json_path: + imgpaths, self.images, taglists = kohya_metadata.read( + img_dir, kohya_json_path, use_temp_dir + ) + else: + imgpaths, self.images = load_images(filepaths) + taglists = load_captions(imgpaths) + + interrogate_tags = {img_path : [] for img_path in imgpaths} + if interrogate_method != self.InterrogateMethod.NONE: + logger.write("Preprocess images...") + max_workers = shared.opts.dataset_editor_num_cpu_workers + if max_workers < 0: + import os + max_workers = os.cpu_count() + 1 - finally: - if interrogate_method != self.InterrogateMethod.NONE: - for cap in captionings: - cap.stop() - for tg, _ in taggers: + def gen_data(paths:list[str], images:dict[str, Image.Image]): + for img_path in paths: + yield images.get(img_path) + + from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=max_workers) as executor: + result = list(executor.map(convert_rgb, gen_data(imgpaths, self.images))) + logger.write("Preprocess completed") + + for tg, th in tqdm(tagger_thresholds): + use_pipe = True + tg.start() + + try: + tg.predict_pipe(None) + except NotImplementedError: + use_pipe = False + except Exception as e: + tb = sys.exc_info()[2] + logger.error(e.with_traceback(tb)) + continue + try: + if use_pipe: + for img_path, tags in tqdm(zip(imgpaths, tg.predict_pipe(result, th)), desc=tg.name(), total=len(imgpaths)): + interrogate_tags[img_path] += tags + else: + for img_path, data in tqdm(zip(imgpaths, result), desc=tg.name(), total=len(imgpaths)): + interrogate_tags[img_path] += tg.predict(data, th) + except Exception as e: + tb = sys.exc_info()[2] + logger.error(e.with_traceback(tb)) + finally: tg.stop() + for img_path, tags in zip(imgpaths, taglists): + if (interrogate_method == self.InterrogateMethod.PREFILL and not tags) or (interrogate_method == self.InterrogateMethod.OVERWRITE): + tags = interrogate_tags[img_path] + elif interrogate_method == self.InterrogateMethod.PREPEND: + tags = interrogate_tags[img_path] + tags + else: + tags = tags + interrogate_tags[img_path] + + self.set_tags_by_image_path(img_path, tags) + for i, p in enumerate(sorted(self.dataset.datas.keys())): self.img_idx[p] = i self.construct_tag_infos() - print(f'[tag-editor] Loading Completed: {len(self.dataset)} images found') - + logger.write(f"Loading Completed: {len(self.dataset)} images found") - def save_dataset(self, backup:bool, caption_ext:str, write_kohya_metadata:bool, meta_out_path:str, meta_in_path:Optional[str], meta_overwrite:bool, meta_as_caption:bool, meta_full_path:bool): + def save_dataset( + self, + backup: bool, + caption_ext: str, + write_kohya_metadata: bool, + meta_out_path: str, + meta_in_path: Optional[str], + meta_overwrite: bool, + meta_as_caption: bool, + meta_full_path: bool, + ): if len(self.dataset) == 0: - return (0, 0, '') + return (0, 0, "") saved_num = 0 backup_num = 0 @@ -621,53 +847,70 @@ class DatasetTagEditor(Singleton): # make backup if backup and txt_path.is_file(): for extnum in range(1000): - bak_path = img_path.with_suffix(f'.{extnum:0>3d}') + bak_path = img_path.with_suffix(f".{extnum:0>3d}") if not bak_path.is_file(): break else: bak_path = None if bak_path is None: - print(f"[tag-editor] There are too many backup files with same filename. A backup file of {txt_path} cannot be created.") + logger.write( + f"There are too many backup files with same filename. A backup file of {txt_path} cannot be created." + ) else: try: txt_path.rename(bak_path) except Exception as e: print(e) - print(f"[tag-editor] A backup file of {txt_path} cannot be created.") + logger.write( + f"A backup file of {txt_path} cannot be created." + ) else: backup_num += 1 # save try: - txt_path.write_text(', '.join(tags), 'utf8') + txt_path.write_text(", ".join(tags), "utf8") except Exception as e: print(e) - print(f"[tag-editor] Warning: {txt_path} cannot be saved.") + logger.warn(f"{txt_path} cannot be saved.") else: saved_num += 1 - print(f'[tag-editor] Backup text files: {backup_num}/{len(self.dataset)} under {self.dataset_dir}') - print(f'[tag-editor] Saved text files: {saved_num}/{len(self.dataset)} under {self.dataset_dir}') - - if(write_kohya_metadata): - kohya_metadata.write(dataset=self.dataset, dataset_dir=self.dataset_dir, out_path=meta_out_path, in_path=meta_in_path, overwrite=meta_overwrite, save_as_caption=meta_as_caption, use_full_path=meta_full_path) - print(f'[tag-editor] Saved json metadata file in {meta_out_path}') - return (saved_num, len(self.dataset), self.dataset_dir) + logger.write( + f"Backup text files: {backup_num}/{len(self.dataset)} under {self.dataset_dir}" + ) + logger.write( + f"Saved text files: {saved_num}/{len(self.dataset)} under {self.dataset_dir}" + ) + if write_kohya_metadata: + kohya_metadata.write( + dataset=self.dataset, + dataset_dir=self.dataset_dir, + out_path=meta_out_path, + in_path=meta_in_path, + overwrite=meta_overwrite, + save_as_caption=meta_as_caption, + use_full_path=meta_full_path, + ) + logger.write(f"Saved json metadata file in {meta_out_path}") + return (saved_num, len(self.dataset), self.dataset_dir) def clear(self): self.dataset.clear() self.tag_counts.clear() self.tag_tokens.clear() self.img_idx.clear() - self.dataset_dir = '' + self.dataset_dir = "" for img in self.images: if isinstance(img, Image.Image): img.close() self.images.clear() - def construct_tag_infos(self): self.tag_counts = {} - update_token_count = self.raw_clip_token_used is None or self.raw_clip_token_used != shared.opts.dataset_editor_use_raw_clip_token + update_token_count = ( + self.raw_clip_token_used is None + or self.raw_clip_token_used != shared.opts.dataset_editor_use_raw_clip_token + ) if update_token_count: self.tag_tokens.clear() @@ -679,5 +922,7 @@ class DatasetTagEditor(Singleton): else: self.tag_counts[tag] = 1 if tag not in self.tag_tokens: - self.tag_tokens[tag] = clip_tokenizer.tokenize(tag, shared.opts.dataset_editor_use_raw_clip_token) + self.tag_tokens[tag] = clip_tokenizer.tokenize( + tag, shared.opts.dataset_editor_use_raw_clip_token + ) self.raw_clip_token_used = shared.opts.dataset_editor_use_raw_clip_token diff --git a/scripts/dataset_tag_editor/interrogator.py b/scripts/dataset_tag_editor/interrogator.py deleted file mode 100644 index 38408d4..0000000 --- a/scripts/dataset_tag_editor/interrogator.py +++ /dev/null @@ -1,15 +0,0 @@ -class Interrogator: - def __enter__(self): - self.start() - return self - def __exit__(self, exception_type, exception_value, traceback): - self.stop() - pass - def start(self): - pass - def stop(self): - pass - def predict(self, image, **kwargs): - raise NotImplementedError() - def name(self): - raise NotImplementedError() diff --git a/scripts/dataset_tag_editor/interrogators/__init__.py b/scripts/dataset_tag_editor/interrogators/__init__.py index 726c896..3ab3dbf 100644 --- a/scripts/dataset_tag_editor/interrogators/__init__.py +++ b/scripts/dataset_tag_editor/interrogators/__init__.py @@ -1,6 +1,7 @@ +from .blip2_captioning import BLIP2Captioning from .git_large_captioning import GITLargeCaptioning from .waifu_diffusion_tagger import WaifuDiffusionTagger __all__ = [ - 'GITLargeCaptioning', 'WaifuDiffusionTagger' + "BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger' ] \ No newline at end of file diff --git a/scripts/dataset_tag_editor/interrogators/blip2_captioning.py b/scripts/dataset_tag_editor/interrogators/blip2_captioning.py new file mode 100644 index 0000000..179c3f4 --- /dev/null +++ b/scripts/dataset_tag_editor/interrogators/blip2_captioning.py @@ -0,0 +1,33 @@ +from transformers import Blip2Processor, Blip2ForConditionalGeneration + +from modules import devices, shared +from scripts.paths import paths + + +class BLIP2Captioning: + def __init__(self, model_repo: str): + self.MODEL_REPO = model_repo + self.processor: Blip2Processor = None + self.model: Blip2ForConditionalGeneration = None + + def load(self): + if self.model is None or self.processor is None: + self.processor = Blip2Processor.from_pretrained( + self.MODEL_REPO, cache_dir=paths.setting_model_path + ) + self.model = Blip2ForConditionalGeneration.from_pretrained( + self.MODEL_REPO, cache_dir=paths.setting_model_path + ).to(devices.device) + + def unload(self): + if not shared.opts.interrogate_keep_models_in_memory: + self.model = None + self.processor = None + devices.torch_gc() + + def apply(self, image): + if self.model is None or self.processor is None: + return "" + inputs = self.processor(images=image, return_tensors="pt").to(devices.device) + ids = self.model.generate(**inputs) + return self.processor.batch_decode(ids, skip_special_tokens=True) diff --git a/scripts/dataset_tag_editor/interrogators/git_large_captioning.py b/scripts/dataset_tag_editor/interrogators/git_large_captioning.py index 2a46767..4c8680a 100644 --- a/scripts/dataset_tag_editor/interrogators/git_large_captioning.py +++ b/scripts/dataset_tag_editor/interrogators/git_large_captioning.py @@ -1,26 +1,35 @@ from transformers import AutoProcessor, AutoModelForCausalLM -from modules import shared +from modules import shared, devices, lowvram + # brought from https://huggingface.co/docs/transformers/main/en/model_doc/git and modified -class GITLargeCaptioning(): +class GITLargeCaptioning: MODEL_REPO = "microsoft/git-large-coco" + def __init__(self): - self.processor:AutoProcessor = None - self.model:AutoModelForCausalLM = None + self.processor: AutoProcessor = None + self.model: AutoModelForCausalLM = None def load(self): if self.model is None or self.processor is None: self.processor = AutoProcessor.from_pretrained(self.MODEL_REPO) - self.model = AutoModelForCausalLM.from_pretrained(self.MODEL_REPO).to(shared.device) + self.model = AutoModelForCausalLM.from_pretrained(self.MODEL_REPO).to( + shared.device + ) + lowvram.send_everything_to_cpu() def unload(self): if not shared.opts.interrogate_keep_models_in_memory: self.model = None self.processor = None + devices.torch_gc() def apply(self, image): if self.model is None or self.processor is None: - return '' - inputs = self.processor(images=image, return_tensors='pt').to(shared.device) - ids = self.model.generate(pixel_values=inputs.pixel_values, max_length=shared.opts.interrogate_clip_max_length) - return self.processor.batch_decode(ids, skip_special_tokens=True)[0] \ No newline at end of file + return "" + inputs = self.processor(images=image, return_tensors="pt").to(shared.device) + ids = self.model.generate( + pixel_values=inputs.pixel_values, + max_length=shared.opts.interrogate_clip_max_length, + ) + return self.processor.batch_decode(ids, skip_special_tokens=True)[0] diff --git a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py index 9ba6dd3..75156d5 100644 --- a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py +++ b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py @@ -1,74 +1,103 @@ from PIL import Image import numpy as np from typing import List, Tuple -from modules import shared +from modules import shared, devices import launch -class WaifuDiffusionTagger(): +class WaifuDiffusionTagger: # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified - MODEL_FILENAME = "model.onnx" - LABEL_FILENAME = "selected_tags.csv" - def __init__(self, model_name): + def __init__( + self, + model_name, + model_filename="model.onnx", + label_filename="selected_tags.csv", + ): + self.MODEL_FILENAME = model_filename + self.LABEL_FILENAME = label_filename self.MODEL_REPO = model_name self.model = None self.labels = [] def load(self): import huggingface_hub + if not self.model: path_model = huggingface_hub.hf_hub_download( self.MODEL_REPO, self.MODEL_FILENAME ) - if 'all' in shared.cmd_opts.use_cpu or 'interrogate' in shared.cmd_opts.use_cpu: - providers = ['CPUExecutionProvider'] + if ( + "all" in shared.cmd_opts.use_cpu + or "interrogate" in shared.cmd_opts.use_cpu + ): + providers = ["CPUExecutionProvider"] else: - providers = ['CUDAExecutionProvider', 'DmlExecutionProvider', 'CPUExecutionProvider'] - + providers = [ + "CUDAExecutionProvider", + "DmlExecutionProvider", + "CPUExecutionProvider", + ] + def check_available_device(): import torch + if torch.cuda.is_available(): - return 'cuda' + return "cuda" elif launch.is_installed("torch-directml"): # This code cannot detect DirectML available device without pytorch-directml try: import torch_directml + torch_directml.device() except: pass else: - return 'directml' - return 'cpu' + return "directml" + return "cpu" if not launch.is_installed("onnxruntime"): dev = check_available_device() - if dev == 'cuda': - launch.run_pip("install -U onnxruntime-gpu", "requirements for dataset-tag-editor [onnxruntime-gpu]") - elif dev == 'directml': - launch.run_pip("install -U onnxruntime-directml", "requirements for dataset-tag-editor [onnxruntime-directml]") + if dev == "cuda": + launch.run_pip( + "install -U onnxruntime-gpu", + "requirements for dataset-tag-editor [onnxruntime-gpu]", + ) + elif dev == "directml": + launch.run_pip( + "install -U onnxruntime-directml", + "requirements for dataset-tag-editor [onnxruntime-directml]", + ) else: - print('Your device is not compatible with onnx hardware acceleration. CPU only version will be installed and it may be very slow.') - launch.run_pip("install -U onnxruntime", "requirements for dataset-tag-editor [onnxruntime for CPU]") + print( + "Your device is not compatible with onnx hardware acceleration. CPU only version will be installed and it may be very slow." + ) + launch.run_pip( + "install -U onnxruntime", + "requirements for dataset-tag-editor [onnxruntime for CPU]", + ) import onnxruntime as ort + self.model = ort.InferenceSession(path_model, providers=providers) - + path_label = huggingface_hub.hf_hub_download( self.MODEL_REPO, self.LABEL_FILENAME ) import pandas as pd + self.labels = pd.read_csv(path_label)["name"].tolist() def unload(self): if not shared.opts.interrogate_keep_models_in_memory: self.model = None + devices.torch_gc() # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified def apply(self, image: Image.Image): if not self.model: return dict() - + from modules import images - + _, height, width, _ = self.model.get_inputs()[0].shape # the way to fill empty pixels is quite different from original one; @@ -85,4 +114,4 @@ class WaifuDiffusionTagger(): probs = self.model.run([label_name], {input_name: image_np})[0] labels: List[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float))) - return labels \ No newline at end of file + return labels diff --git a/scripts/dataset_tag_editor/tagger.py b/scripts/dataset_tag_editor/tagger.py deleted file mode 100644 index 5ee520b..0000000 --- a/scripts/dataset_tag_editor/tagger.py +++ /dev/null @@ -1,106 +0,0 @@ -from PIL import Image -import re -import torch -import numpy as np -from typing import Optional, Dict -from modules import devices, shared -from modules import deepbooru as db - -from .interrogator import Interrogator -from .interrogators import WaifuDiffusionTagger - - -class Tagger(Interrogator): - def start(self): - pass - def stop(self): - pass - def predict(self, image: Image.Image, threshold: Optional[float]): - raise NotImplementedError() - def name(self): - raise NotImplementedError() - - -def get_replaced_tag(tag: str): - use_spaces = shared.opts.deepbooru_use_spaces - use_escape = shared.opts.deepbooru_escape - if use_spaces: - tag = tag.replace('_', ' ') - if use_escape: - tag = re.sub(db.re_special, r'\\\1', tag) - return tag - - -def get_arranged_tags(probs: Dict[str, float]): - alpha_sort = shared.opts.deepbooru_sort_alpha - if alpha_sort: - return sorted(probs) - else: - return [tag for tag, _ in sorted(probs.items(), key=lambda x: -x[1])] - - -class DeepDanbooru(Tagger): - def start(self): - db.model.start() - - def stop(self): - db.model.stop() - - # brought from webUI modules/deepbooru.py and modified - def predict(self, image: Image.Image, threshold: Optional[float] = None): - from modules import images - - pic = images.resize_image(2, image.convert("RGB"), 512, 512) - a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 - - with torch.no_grad(), devices.autocast(): - x = torch.from_numpy(a).to(devices.device) - y = db.model.model(x)[0].detach().cpu().numpy() - - probability_dict = dict() - - for tag, probability in zip(db.model.model.tags, y): - if threshold and probability < threshold: - continue - if tag.startswith("rating:"): - continue - probability_dict[get_replaced_tag(tag)] = probability - - return probability_dict - - def name(self): - return 'DeepDanbooru' - - -class WaifuDiffusion(Tagger): - def __init__(self, repo_name, threshold): - self.repo_name = repo_name - self.tagger_inst = WaifuDiffusionTagger("SmilingWolf/" + repo_name) - self.threshold = threshold - - def start(self): - self.tagger_inst.load() - return self - - def stop(self): - self.tagger_inst.unload() - - # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified - # set threshold<0 to use default value for now... - def predict(self, image: Image.Image, threshold: Optional[float] = None): - # may not use ratings - # rating = dict(labels[:4]) - - labels = self.tagger_inst.apply(image) - - if threshold is not None: - if threshold < 0: - threshold = self.threshold - probability_dict = dict([(get_replaced_tag(x[0]), x[1]) for x in labels[4:] if x[1] > threshold]) - else: - probability_dict = dict([(get_replaced_tag(x[0]), x[1]) for x in labels[4:]]) - - return probability_dict - - def name(self): - return self.repo_name \ No newline at end of file diff --git a/scripts/dataset_tag_editor/taggers_builtin.py b/scripts/dataset_tag_editor/taggers_builtin.py new file mode 100644 index 0000000..bd3a1e3 --- /dev/null +++ b/scripts/dataset_tag_editor/taggers_builtin.py @@ -0,0 +1,171 @@ +from typing import Optional + +from PIL import Image +import numpy as np +import torch + +from modules import devices, shared +from modules import deepbooru as db + +from scripts.tagger import Tagger, get_replaced_tag +from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger + + +class BLIP(Tagger): + def start(self): + shared.interrogator.load() + + def stop(self): + shared.interrogator.unload() + + def predict(self, image:Image.Image, threshold=None): + tags = shared.interrogator.generate_caption(image).split(',') + return [t for t in tags if t] + + def name(self): + return 'BLIP' + + + +class BLIP2(Tagger): + def __init__(self, repo_name): + self.interrogator = BLIP2Captioning("Salesforce/" + repo_name) + self.repo_name = repo_name + + def start(self): + self.interrogator.load() + + def stop(self): + self.interrogator.unload() + + def predict(self, image:Image, threshold=None): + tags = self.interrogator.apply(image)[0].split(",") + return [t for t in tags if t] + + # def predict_multi(self, images:list): + # captions = self.interrogator.apply(images) + # return [[t for t in caption.split(',') if t] for caption in captions] + + def name(self): + return self.repo_name + + +class GITLarge(Tagger): + def __init__(self): + self.interrogator = GITLargeCaptioning() + + def start(self): + self.interrogator.load() + + def stop(self): + self.interrogator.unload() + + def predict(self, image:Image, threshold=None): + tags = self.interrogator.apply(image)[0].split(",") + return [t for t in tags if t] + + # def predict_multi(self, images:list): + # captions = self.interrogator.apply(images) + # return [[t for t in caption.split(',') if t] for caption in captions] + + def name(self): + return "GIT-large-COCO" + + +class DeepDanbooru(Tagger): + def start(self): + db.model.start() + + def stop(self): + db.model.stop() + + # brought from webUI modules/deepbooru.py and modified + def predict(self, image: Image.Image, threshold: Optional[float] = None): + from modules import images + + pic = images.resize_image(2, image.convert("RGB"), 512, 512) + a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 + + with torch.no_grad(), devices.autocast(): + x = torch.from_numpy(a).to(devices.device) + y = db.model.model(x)[0].detach().cpu().numpy() + + tags = [] + + for tag, probability in zip(db.model.model.tags, y): + if threshold and probability < threshold: + continue + if not shared.opts.dataset_editor_use_rating and tag.startswith("rating:"): + continue + tags.append(get_replaced_tag(tag)) + + return tags + + def name(self): + return 'DeepDanbooru' + + +class WaifuDiffusion(Tagger): + def __init__(self, repo_name, threshold): + self.repo_name = repo_name + self.tagger_inst = WaifuDiffusionTagger("SmilingWolf/" + repo_name) + self.threshold = threshold + + def start(self): + self.tagger_inst.load() + return self + + def stop(self): + self.tagger_inst.unload() + + # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified + # set threshold<0 to use default value for now... + def predict(self, image: Image.Image, threshold: Optional[float] = None): + # may not use ratings + # rating = dict(labels[:4]) + + labels = self.tagger_inst.apply(image) + + if not shared.opts.dataset_editor_use_rating: + labels = labels[4:] + + if threshold is not None: + if threshold < 0: + threshold = self.threshold + tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold] + else: + tags = [get_replaced_tag(tag) for tag, _ in labels] + + return tags + + def name(self): + return self.repo_name + + +class Z3D_E621(Tagger): + def __init__(self): + self.tagger_inst = WaifuDiffusionTagger("toynya/Z3D-E621-Convnext", label_filename="tags-selected.csv") + + def start(self): + self.tagger_inst.load() + return self + + def stop(self): + self.tagger_inst.unload() + + # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified + # set threshold<0 to use default value for now... + def predict(self, image: Image.Image, threshold: Optional[float] = None): + # may not use ratings + # rating = dict(labels[:4]) + + labels = self.tagger_inst.apply(image) + if threshold is not None: + tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold] + else: + tags = [get_replaced_tag(tag) for tag, _ in labels] + + return tags + + def name(self): + return "Z3D-E621-Convnext" \ No newline at end of file diff --git a/scripts/dte_instance.py b/scripts/dte_instance.py index ed06b26..a174e79 100644 --- a/scripts/dte_instance.py +++ b/scripts/dte_instance.py @@ -1,2 +1,2 @@ import scripts.dataset_tag_editor as dte_module -dte_instance = dte_module.DatasetTagEditor.get_instance() \ No newline at end of file +dte_instance = dte_module.DatasetTagEditor() \ No newline at end of file diff --git a/scripts/logger.py b/scripts/logger.py new file mode 100644 index 0000000..a9f539c --- /dev/null +++ b/scripts/logger.py @@ -0,0 +1,8 @@ +def write(content): + print("[tag-editor] " + content) + +def warn(content): + write(f"[tag-editor:WARNING] {content}") + +def error(content): + write(f"[tag-editor:ERROR] {content}") \ No newline at end of file diff --git a/scripts/main.py b/scripts/main.py index 33c613b..297018f 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -1,197 +1,165 @@ from typing import NamedTuple, Type, Dict, Any -from modules import shared, script_callbacks, scripts +from modules import shared, script_callbacks from modules.shared import opts import gradio as gr -import json -from pathlib import Path -from collections import namedtuple +from scripts.config import * import scripts.tag_editor_ui as ui -from scripts.dte_instance import dte_instance # ================================================================ # General Callbacks # ================================================================ -CONFIG_PATH = Path(scripts.basedir(), 'config.json') - - -SortBy = dte_instance.SortBy -SortOrder = dte_instance.SortOrder - -GeneralConfig = namedtuple('GeneralConfig', [ - 'backup', - 'dataset_dir', - 'caption_ext', - 'load_recursive', - 'load_caption_from_filename', - 'replace_new_line', - 'use_interrogator', - 'use_interrogator_names', - 'use_custom_threshold_booru', - 'custom_threshold_booru', - 'use_custom_threshold_waifu', - 'custom_threshold_waifu', - 'save_kohya_metadata', - 'meta_output_path', - 'meta_input_path', - 'meta_overwrite', - 'meta_save_as_caption', - 'meta_use_full_path' - ]) -FilterConfig = namedtuple('FilterConfig', ['sw_prefix', 'sw_suffix', 'sw_regex','sort_by', 'sort_order', 'logic']) -BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sw_prefix', 'sw_suffix', 'sw_regex', 'sory_by', 'sort_order', 'batch_sort_by', 'batch_sort_order', 'token_count']) -EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order']) -MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination']) - -CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, False, 'No', [], False, 0.7, False, 0.35, False, '', '', True, False, False) -CFG_FILTER_P_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'AND') -CFG_FILTER_N_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'OR') -CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value, 75) -CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value) -CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig('Selected One', [], '.txt', '') - -class Config: - def __init__(self): - self.config = dict() - - def load(self): - if not CONFIG_PATH.is_file(): - self.config = dict() - return - try: - self.config = json.loads(CONFIG_PATH.read_text('utf8')) - except: - print('[tag-editor] Error on loading config.json. Default settings will be loaded.') - self.config = dict() - else: - print('[tag-editor] Settings has been read from config.json') - - def save(self): - CONFIG_PATH.write_text(json.dumps(self.config, indent=4), 'utf8') - - def read(self, name: str): - return self.config.get(name) - - def write(self, cfg: dict, name: str): - self.config[name] = cfg - config = Config() + def write_general_config(*args): cfg = GeneralConfig(*args) - config.write(cfg._asdict(), 'general') + config.write(cfg._asdict(), "general") + def write_filter_config(*args): hlen = len(args) // 2 cfg_p = FilterConfig(*args[:hlen]) cfg_n = FilterConfig(*args[hlen:]) - config.write({'positive':cfg_p._asdict(), 'negative':cfg_n._asdict()}, 'filter') + config.write({"positive": cfg_p._asdict(), "negative": cfg_n._asdict()}, "filter") + def write_batch_edit_config(*args): cfg = BatchEditConfig(*args) - config.write(cfg._asdict(), 'batch_edit') + config.write(cfg._asdict(), "batch_edit") + def write_edit_selected_config(*args): cfg = EditSelectedConfig(*args) - config.write(cfg._asdict(), 'edit_selected') + config.write(cfg._asdict(), "edit_selected") + def write_move_delete_config(*args): cfg = MoveDeleteConfig(*args) - config.write(cfg._asdict(), 'file_move_delete') + config.write(cfg._asdict(), "file_move_delete") -def read_config(name: str, config_type: Type, default: NamedTuple, compat_func = None): + +def read_config(name: str, config_type: Type, default: NamedTuple, compat_func=None): d = config.read(name) cfg = default if d: - if compat_func: d = compat_func(d) + if compat_func: + d = compat_func(d) d = cfg._asdict() | d - d = {k:v for k,v in d.items() if k in cfg._asdict().keys()} + d = {k: v for k, v in d.items() if k in cfg._asdict().keys()} cfg = config_type(**d) return cfg + def read_general_config(): # for compatibility generalcfg_intterogator_names = [ - ('use_blip_to_prefill', 'BLIP'), - ('use_git_to_prefill', 'GIT-large-COCO'), - ('use_booru_to_prefill', 'DeepDanbooru'), - ('use_waifu_to_prefill', 'wd-v1-4-vit-tagger') + ("use_blip_to_prefill", "BLIP"), + ("use_git_to_prefill", "GIT-large-COCO"), + ("use_booru_to_prefill", "DeepDanbooru"), + ("use_waifu_to_prefill", "wd-v1-4-vit-tagger"), ] use_interrogator_names = [] + def compat_func(d: Dict[str, Any]): - if 'use_interrogator_names' in d.keys(): + if "use_interrogator_names" in d.keys(): return d for cfg in generalcfg_intterogator_names: if d.get(cfg[0]): use_interrogator_names.append(cfg[1]) - d['use_interrogator_names'] = use_interrogator_names + d["use_interrogator_names"] = use_interrogator_names return d - return read_config('general', GeneralConfig, CFG_GENERAL_DEFAULT, compat_func) + + return read_config("general", GeneralConfig, CFG_GENERAL_DEFAULT, compat_func) + def read_filter_config(): - d = config.read('filter') - d_p = d.get('positive') if d else None - d_n = d.get('negative') if d else None + d = config.read("filter") + d_p = d.get("positive") if d else None + d_n = d.get("negative") if d else None cfg_p = CFG_FILTER_P_DEFAULT cfg_n = CFG_FILTER_N_DEFAULT if d_p: d_p = cfg_p._asdict() | d_p - d_p = {k:v for k,v in d_p.items() if k in cfg_p._asdict().keys()} + d_p = {k: v for k, v in d_p.items() if k in cfg_p._asdict().keys()} cfg_p = FilterConfig(**d_p) if d_n: d_n = cfg_n._asdict() | d_n - d_n = {k:v for k,v in d_n.items() if k in cfg_n._asdict().keys()} + d_n = {k: v for k, v in d_n.items() if k in cfg_n._asdict().keys()} cfg_n = FilterConfig(**d_n) return cfg_p, cfg_n + def read_batch_edit_config(): - return read_config('batch_edit', BatchEditConfig, CFG_BATCH_EDIT_DEFAULT) + return read_config("batch_edit", BatchEditConfig, CFG_BATCH_EDIT_DEFAULT) + def read_edit_selected_config(): - return read_config('edit_selected', EditSelectedConfig, CFG_EDIT_SELECTED_DEFAULT) + return read_config("edit_selected", EditSelectedConfig, CFG_EDIT_SELECTED_DEFAULT) + def read_move_delete_config(): - return read_config('file_move_delete', MoveDeleteConfig, CFG_MOVE_DELETE_DEFAULT) + return read_config("file_move_delete", MoveDeleteConfig, CFG_MOVE_DELETE_DEFAULT) + # ================================================================ # General Callbacks for Updating UIs # ================================================================ + def get_filters(): - filters = [ui.filter_by_tags.tag_filter_ui.get_filter(), ui.filter_by_tags.tag_filter_ui_neg.get_filter()] + [ui.filter_by_selection.path_filter] + filters = [ + ui.filter_by_tags.tag_filter_ui.get_filter(), + ui.filter_by_tags.tag_filter_ui_neg.get_filter(), + ] + [ui.filter_by_selection.path_filter] return filters + def update_gallery(): img_indices = ui.dte_instance.get_filtered_imgindices(filters=get_filters()) total_image_num = len(ui.dte_instance.dataset) displayed_image_num = len(img_indices) - ui.gallery_state.register_value('Displayed Images', f'{displayed_image_num} / {total_image_num} total') - ui.gallery_state.register_value('Current Tag Filter', f"{ui.filter_by_tags.tag_filter_ui.get_filter()} {' AND ' if ui.filter_by_tags.tag_filter_ui.get_filter().tags and ui.filter_by_tags.tag_filter_ui_neg.get_filter().tags else ''} {ui.filter_by_tags.tag_filter_ui_neg.get_filter()}") - ui.gallery_state.register_value('Current Selection Filter', f'{len(ui.filter_by_selection.path_filter.paths)} images') + ui.gallery_state.register_value( + "Displayed Images", f"{displayed_image_num} / {total_image_num} total" + ) + ui.gallery_state.register_value( + "Current Tag Filter", + f"{ui.filter_by_tags.tag_filter_ui.get_filter()} {' AND ' if ui.filter_by_tags.tag_filter_ui.get_filter().tags and ui.filter_by_tags.tag_filter_ui_neg.get_filter().tags else ''} {ui.filter_by_tags.tag_filter_ui_neg.get_filter()}", + ) + ui.gallery_state.register_value( + "Current Selection Filter", + f"{len(ui.filter_by_selection.path_filter.paths)} images", + ) return [ [str(i) for i in img_indices], 1, -1, -1, -1, - ui.gallery_state.get_current_gallery_txt() - ] + ui.gallery_state.get_current_gallery_txt(), + ] + def update_filter_and_gallery(): - return \ - [ui.filter_by_tags.tag_filter_ui.cbg_tags_update(), ui.filter_by_tags.tag_filter_ui_neg.cbg_tags_update()] +\ - update_gallery() +\ - ui.batch_edit_captions.get_common_tags(get_filters, ui.filter_by_tags) +\ - [', '.join(ui.filter_by_tags.tag_filter_ui.filter.tags)] +\ - [ui.batch_edit_captions.tag_select_ui_remove.cbg_tags_update()] +\ - ['', ''] + return ( + [ + ui.filter_by_tags.tag_filter_ui.cbg_tags_update(), + ui.filter_by_tags.tag_filter_ui_neg.cbg_tags_update(), + ] + + update_gallery() + + ui.batch_edit_captions.get_common_tags(get_filters, ui.filter_by_tags) + + [", ".join(ui.filter_by_tags.tag_filter_ui.filter.tags)] + + [ui.batch_edit_captions.tag_select_ui_remove.cbg_tags_update()] + + ["", ""] + ) # ================================================================ # Script Callbacks # ================================================================ + def on_ui_tabs(): config.load() @@ -201,101 +169,156 @@ def on_ui_tabs(): cfg_edit_selected = read_edit_selected_config() cfg_file_move_delete = read_move_delete_config() + ui.dte_instance.load_interrogators() + with gr.Blocks(analytics_enabled=False) as dataset_tag_editor_interface: - gr.HTML(value=""" + gr.HTML( + value=""" This extension works well with text captions in comma-separated style (such as the tags generated by DeepBooru interrogator). - """) + """ + ) ui.toprow.create_ui(cfg_general) - with gr.Accordion(label='Reload/Save Settings (config.json)', open=False): + with gr.Accordion(label="Reload/Save Settings (config.json)", open=False): with gr.Row(): - btn_reload_config_file = gr.Button(value='Reload settings') - btn_save_setting_as_default = gr.Button(value='Save current settings') - btn_restore_default = gr.Button(value='Restore settings to default') + btn_reload_config_file = gr.Button(value="Reload settings") + btn_save_setting_as_default = gr.Button(value="Save current settings") + btn_restore_default = gr.Button(value="Restore settings to default") - with gr.Row().style(equal_height=False): + with gr.Row(equal_height=False): with gr.Column(): ui.load_dataset.create_ui(cfg_general) ui.dataset_gallery.create_ui(opts.dataset_editor_image_columns) ui.gallery_state.create_ui() - with gr.Tab(label='Filter by Tags'): + with gr.Tab(label="Filter by Tags"): ui.filter_by_tags.create_ui(cfg_filter_p, cfg_filter_n, get_filters) - - with gr.Tab(label='Filter by Selection'): + + with gr.Tab(label="Filter by Selection"): ui.filter_by_selection.create_ui(opts.dataset_editor_image_columns) - with gr.Tab(label='Batch Edit Captions'): + with gr.Tab(label="Batch Edit Captions"): ui.batch_edit_captions.create_ui(cfg_batch_edit, get_filters) - with gr.Tab(label='Edit Caption of Selected Image'): + with gr.Tab(label="Edit Caption of Selected Image"): ui.edit_caption_of_selected_image.create_ui(cfg_edit_selected) - with gr.Tab(label='Move or Delete Files'): + with gr.Tab(label="Move or Delete Files"): ui.move_or_delete_files.create_ui(cfg_file_move_delete) - #---------------------------------------------------------------- + # ---------------------------------------------------------------- # General components_general = [ - ui.toprow.cb_backup, ui.load_dataset.tb_img_directory, ui.load_dataset.tb_caption_file_ext, ui.load_dataset.cb_load_recursive, - ui.load_dataset.cb_load_caption_from_filename, ui.load_dataset.cb_replace_new_line_with_comma, ui.load_dataset.rb_use_interrogator, ui.load_dataset.dd_intterogator_names, - ui.load_dataset.cb_use_custom_threshold_booru, ui.load_dataset.sl_custom_threshold_booru, ui.load_dataset.cb_use_custom_threshold_waifu, ui.load_dataset.sl_custom_threshold_waifu, - ui.toprow.cb_save_kohya_metadata, ui.toprow.tb_metadata_output, ui.toprow.tb_metadata_input, ui.toprow.cb_metadata_overwrite, ui.toprow.cb_metadata_as_caption, ui.toprow.cb_metadata_use_fullpath + ui.toprow.cb_backup, + ui.load_dataset.tb_img_directory, + ui.load_dataset.tb_caption_file_ext, + ui.load_dataset.cb_load_recursive, + ui.load_dataset.cb_load_caption_from_filename, + ui.load_dataset.cb_replace_new_line_with_comma, + ui.load_dataset.rb_use_interrogator, + ui.load_dataset.dd_intterogator_names, + ui.load_dataset.cb_use_custom_threshold_booru, + ui.load_dataset.sl_custom_threshold_booru, + ui.load_dataset.cb_use_custom_threshold_waifu, + ui.load_dataset.sl_custom_threshold_waifu, + ui.load_dataset.sl_custom_threshold_z3d, + ui.toprow.cb_save_kohya_metadata, + ui.toprow.tb_metadata_output, + ui.toprow.tb_metadata_input, + ui.toprow.cb_metadata_overwrite, + ui.toprow.cb_metadata_as_caption, + ui.toprow.cb_metadata_use_fullpath, + ] + components_filter = [ + ui.filter_by_tags.tag_filter_ui.cb_prefix, + ui.filter_by_tags.tag_filter_ui.cb_suffix, + ui.filter_by_tags.tag_filter_ui.cb_regex, + ui.filter_by_tags.tag_filter_ui.rb_sort_by, + ui.filter_by_tags.tag_filter_ui.rb_sort_order, + ui.filter_by_tags.tag_filter_ui.rb_logic, + ] + [ + ui.filter_by_tags.tag_filter_ui_neg.cb_prefix, + ui.filter_by_tags.tag_filter_ui_neg.cb_suffix, + ui.filter_by_tags.tag_filter_ui_neg.cb_regex, + ui.filter_by_tags.tag_filter_ui_neg.rb_sort_by, + ui.filter_by_tags.tag_filter_ui_neg.rb_sort_order, + ui.filter_by_tags.tag_filter_ui_neg.rb_logic, ] - components_filter = \ - [ui.filter_by_tags.tag_filter_ui.cb_prefix, ui.filter_by_tags.tag_filter_ui.cb_suffix, ui.filter_by_tags.tag_filter_ui.cb_regex, ui.filter_by_tags.tag_filter_ui.rb_sort_by, ui.filter_by_tags.tag_filter_ui.rb_sort_order, ui.filter_by_tags.tag_filter_ui.rb_logic] +\ - [ui.filter_by_tags.tag_filter_ui_neg.cb_prefix, ui.filter_by_tags.tag_filter_ui_neg.cb_suffix, ui.filter_by_tags.tag_filter_ui_neg.cb_regex, ui.filter_by_tags.tag_filter_ui_neg.rb_sort_by, ui.filter_by_tags.tag_filter_ui_neg.rb_sort_order, ui.filter_by_tags.tag_filter_ui_neg.rb_logic] components_batch_edit = [ - ui.batch_edit_captions.cb_show_only_tags_selected, ui.batch_edit_captions.cb_prepend_tags, ui.batch_edit_captions.cb_use_regex, + ui.batch_edit_captions.cb_show_only_tags_selected, + ui.batch_edit_captions.cb_prepend_tags, + ui.batch_edit_captions.cb_use_regex, ui.batch_edit_captions.rb_sr_replace_target, - ui.batch_edit_captions.tag_select_ui_remove.cb_prefix, ui.batch_edit_captions.tag_select_ui_remove.cb_suffix, ui.batch_edit_captions.tag_select_ui_remove.cb_regex, - ui.batch_edit_captions.tag_select_ui_remove.rb_sort_by, ui.batch_edit_captions.tag_select_ui_remove.rb_sort_order, - ui.batch_edit_captions.rb_sort_by, ui.batch_edit_captions.rb_sort_order, - ui.batch_edit_captions.nb_token_count + ui.batch_edit_captions.tag_select_ui_remove.cb_prefix, + ui.batch_edit_captions.tag_select_ui_remove.cb_suffix, + ui.batch_edit_captions.tag_select_ui_remove.cb_regex, + ui.batch_edit_captions.tag_select_ui_remove.rb_sort_by, + ui.batch_edit_captions.tag_select_ui_remove.rb_sort_order, + ui.batch_edit_captions.rb_sort_by, + ui.batch_edit_captions.rb_sort_order, + ui.batch_edit_captions.nb_token_count, ] components_edit_selected = [ - ui.edit_caption_of_selected_image.cb_copy_caption_automatically, ui.edit_caption_of_selected_image.cb_sort_caption_on_save, - ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed, ui.edit_caption_of_selected_image.dd_intterogator_names_si, - ui.edit_caption_of_selected_image.rb_sort_by, ui.edit_caption_of_selected_image.rb_sort_order + ui.edit_caption_of_selected_image.cb_copy_caption_automatically, + ui.edit_caption_of_selected_image.cb_sort_caption_on_save, + ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed, + ui.edit_caption_of_selected_image.dd_intterogator_names_si, + ui.edit_caption_of_selected_image.rb_sort_by, + ui.edit_caption_of_selected_image.rb_sort_order, ] components_move_delete = [ - ui.move_or_delete_files.rb_move_or_delete_target_data, ui.move_or_delete_files.cbg_move_or_delete_target_file, - ui.move_or_delete_files.tb_move_or_delete_caption_ext, ui.move_or_delete_files.tb_move_or_delete_destination_dir + ui.move_or_delete_files.rb_move_or_delete_target_data, + ui.move_or_delete_files.cbg_move_or_delete_target_file, + ui.move_or_delete_files.tb_move_or_delete_caption_ext, + ui.move_or_delete_files.tb_move_or_delete_destination_dir, ] - - configurable_components = components_general + components_filter + components_batch_edit + components_edit_selected + components_move_delete + + configurable_components = ( + components_general + + components_filter + + components_batch_edit + + components_edit_selected + + components_move_delete + ) def reload_config_file(): config.load() p, n = read_filter_config() - print('[tag-editor] Reload config.json') - return read_general_config() + p + n + read_batch_edit_config() + read_edit_selected_config() + read_move_delete_config() + logger.write("Reload config.json") + return ( + read_general_config() + + p + + n + + read_batch_edit_config() + + read_edit_selected_config() + + read_move_delete_config() + ) btn_reload_config_file.click( - fn=reload_config_file, - outputs=configurable_components + fn=reload_config_file, outputs=configurable_components ) def save_settings_callback(*a): p = 0 + def inc(v): nonlocal p p += v return p - write_general_config(*a[p:inc(len(components_general))]) - write_filter_config(*a[p:inc(len(components_filter))]) - write_batch_edit_config(*a[p:inc(len(components_batch_edit))]) - write_edit_selected_config(*a[p:inc(len(components_edit_selected))]) + + write_general_config(*a[p : inc(len(components_general))]) + write_filter_config(*a[p : inc(len(components_filter))]) + write_batch_edit_config(*a[p : inc(len(components_batch_edit))]) + write_edit_selected_config(*a[p : inc(len(components_edit_selected))]) write_move_delete_config(*a[p:]) config.save() - print('[tag-editor] Current settings have been saved into config.json') + logger.write("Current settings have been saved into config.json") btn_save_setting_as_default.click( - fn=save_settings_callback, - inputs=configurable_components + fn=save_settings_callback, inputs=configurable_components ) def restore_default_settings(): @@ -304,44 +327,150 @@ def on_ui_tabs(): write_batch_edit_config(*CFG_BATCH_EDIT_DEFAULT) write_edit_selected_config(*CFG_EDIT_SELECTED_DEFAULT) write_move_delete_config(*CFG_MOVE_DELETE_DEFAULT) - print('[tag-editor] Restore default settings') - return CFG_GENERAL_DEFAULT + CFG_FILTER_P_DEFAULT + CFG_FILTER_N_DEFAULT + CFG_BATCH_EDIT_DEFAULT + CFG_EDIT_SELECTED_DEFAULT + CFG_MOVE_DELETE_DEFAULT - + logger.write("Restore default settings") + return ( + CFG_GENERAL_DEFAULT + + CFG_FILTER_P_DEFAULT + + CFG_FILTER_N_DEFAULT + + CFG_BATCH_EDIT_DEFAULT + + CFG_EDIT_SELECTED_DEFAULT + + CFG_MOVE_DELETE_DEFAULT + ) btn_restore_default.click( - fn=restore_default_settings, - outputs=configurable_components + fn=restore_default_settings, outputs=configurable_components ) - o_update_gallery = [ui.dataset_gallery.cbg_hidden_dataset_filter, ui.dataset_gallery.nb_hidden_dataset_filter_apply, ui.dataset_gallery.nb_hidden_image_index, ui.dataset_gallery.nb_hidden_image_index_prev, ui.edit_caption_of_selected_image.nb_hidden_image_index_save_or_not, ui.gallery_state.txt_gallery] + o_update_gallery = [ + ui.dataset_gallery.cbg_hidden_dataset_filter, + ui.dataset_gallery.nb_hidden_dataset_filter_apply, + ui.dataset_gallery.nb_hidden_image_index, + ui.dataset_gallery.nb_hidden_image_index_prev, + ui.edit_caption_of_selected_image.nb_hidden_image_index_save_or_not, + ui.gallery_state.txt_gallery, + ] + + o_update_filter_and_gallery = ( + [ + ui.filter_by_tags.tag_filter_ui.cbg_tags, + ui.filter_by_tags.tag_filter_ui_neg.cbg_tags, + ] + + o_update_gallery + + [ + ui.batch_edit_captions.tb_common_tags, + ui.batch_edit_captions.tb_edit_tags, + ] + + [ui.batch_edit_captions.tb_sr_selected_tags] + + [ui.batch_edit_captions.tag_select_ui_remove.cbg_tags] + + [ + ui.edit_caption_of_selected_image.tb_caption, + ui.edit_caption_of_selected_image.tb_edit_caption, + ] + ) - o_update_filter_and_gallery = \ - [ui.filter_by_tags.tag_filter_ui.cbg_tags, ui.filter_by_tags.tag_filter_ui_neg.cbg_tags] + \ - o_update_gallery + \ - [ui.batch_edit_captions.tb_common_tags, ui.batch_edit_captions.tb_edit_tags] + \ - [ui.batch_edit_captions.tb_sr_selected_tags] +\ - [ui.batch_edit_captions.tag_select_ui_remove.cbg_tags] +\ - [ui.edit_caption_of_selected_image.tb_caption, ui.edit_caption_of_selected_image.tb_edit_caption] - ui.toprow.set_callbacks(ui.load_dataset) - ui.load_dataset.set_callbacks(o_update_filter_and_gallery,ui.toprow, ui.dataset_gallery, ui.filter_by_tags, ui.filter_by_selection, ui.batch_edit_captions, update_filter_and_gallery) + ui.load_dataset.set_callbacks( + o_update_filter_and_gallery, + ui.toprow, + ui.dataset_gallery, + ui.filter_by_tags, + ui.filter_by_selection, + ui.batch_edit_captions, + update_filter_and_gallery, + ) ui.dataset_gallery.set_callbacks(ui.load_dataset, ui.gallery_state, get_filters) ui.gallery_state.set_callbacks(ui.dataset_gallery) - ui.filter_by_tags.set_callbacks(o_update_gallery, o_update_filter_and_gallery, ui.batch_edit_captions, ui.move_or_delete_files, update_gallery, update_filter_and_gallery, get_filters) - ui.filter_by_selection.set_callbacks(o_update_filter_and_gallery, ui.dataset_gallery, ui.filter_by_tags, get_filters, update_filter_and_gallery) - ui.batch_edit_captions.set_callbacks(o_update_filter_and_gallery, ui.load_dataset, ui.filter_by_tags, get_filters, update_filter_and_gallery) - ui.edit_caption_of_selected_image.set_callbacks(o_update_filter_and_gallery, ui.dataset_gallery, ui.load_dataset, get_filters, update_filter_and_gallery) - ui.move_or_delete_files.set_callbacks(o_update_filter_and_gallery, ui.dataset_gallery, ui.filter_by_tags, ui.batch_edit_captions, ui.filter_by_selection, ui.edit_caption_of_selected_image, get_filters, update_filter_and_gallery) - - return [(dataset_tag_editor_interface, "Dataset Tag Editor", "dataset_tag_editor_interface")] + ui.filter_by_tags.set_callbacks( + o_update_gallery, + o_update_filter_and_gallery, + ui.batch_edit_captions, + ui.move_or_delete_files, + update_gallery, + update_filter_and_gallery, + get_filters, + ) + ui.filter_by_selection.set_callbacks( + o_update_filter_and_gallery, + ui.dataset_gallery, + ui.filter_by_tags, + get_filters, + update_filter_and_gallery, + ) + ui.batch_edit_captions.set_callbacks( + o_update_filter_and_gallery, + ui.load_dataset, + ui.filter_by_tags, + get_filters, + update_filter_and_gallery, + ) + ui.edit_caption_of_selected_image.set_callbacks( + o_update_filter_and_gallery, + ui.dataset_gallery, + ui.load_dataset, + get_filters, + update_filter_and_gallery, + ) + ui.move_or_delete_files.set_callbacks( + o_update_filter_and_gallery, + ui.dataset_gallery, + ui.filter_by_tags, + ui.batch_edit_captions, + ui.filter_by_selection, + ui.edit_caption_of_selected_image, + get_filters, + update_filter_and_gallery, + ) + + return [ + ( + dataset_tag_editor_interface, + "Dataset Tag Editor", + "dataset_tag_editor_interface", + ) + ] def on_ui_settings(): - section = ('dataset-tag-editor', "Dataset Tag Editor") - shared.opts.add_option("dataset_editor_image_columns", shared.OptionInfo(6, "Number of columns on image gallery", section=section)) - shared.opts.add_option("dataset_editor_max_res", shared.OptionInfo(0, "Max resolution of temporary files", section=section)) - shared.opts.add_option("dataset_editor_use_temp_files", shared.OptionInfo(False, "Force image gallery to use temporary files", section=section)) - shared.opts.add_option("dataset_editor_use_raw_clip_token", shared.OptionInfo(True, "Use raw CLIP token to calculate token count (without emphasis or embeddings)", section=section)) + section = ("dataset-tag-editor", "Dataset Tag Editor") + shared.opts.add_option( + "dataset_editor_image_columns", + shared.OptionInfo(6, "Number of columns on image gallery", section=section), + ) + shared.opts.add_option( + "dataset_editor_max_res", + shared.OptionInfo(0, "Max resolution of temporary files", section=section), + ) + shared.opts.add_option( + "dataset_editor_use_temp_files", + shared.OptionInfo( + False, "Force image gallery to use temporary files", section=section + ), + ) + shared.opts.add_option( + "dataset_editor_use_raw_clip_token", + shared.OptionInfo( + True, + "Use raw CLIP token to calculate token count (without emphasis or embeddings)", + section=section, + ), + ) + shared.opts.add_option( + "dataset_editor_use_rating", + shared.OptionInfo( + False, + "Use rating tags", + section=section, + ), + ) + + shared.opts.add_option( + "dataset_editor_num_cpu_workers", + shared.OptionInfo( + -1, + "Number of CPU workers when preprocessing images (set -1 to auto)", + section=section, + ), + ) script_callbacks.on_ui_settings(on_ui_settings) diff --git a/scripts/model_loader.py b/scripts/model_loader.py new file mode 100644 index 0000000..f164ed0 --- /dev/null +++ b/scripts/model_loader.py @@ -0,0 +1,16 @@ +from pathlib import Path + +from torch.hub import download_url_to_file + +def load(model_path:Path, model_url:str, progress:bool=True, force_download:bool=False): + model_path = Path(model_path) + if model_path.exists(): + return model_path + + if model_url is not None and (force_download or not model_path.is_file()): + if not model_path.parent.is_dir(): + model_path.parent.mkdir(parents=True) + download_url_to_file(model_url, model_path, progress=progress) + return model_path + + return model_path diff --git a/scripts/paths.py b/scripts/paths.py new file mode 100644 index 0000000..7495da7 --- /dev/null +++ b/scripts/paths.py @@ -0,0 +1,17 @@ +from pathlib import Path + +from scripts.singleton import Singleton + +def base_dir_path(): + return Path(__file__).parents[1].absolute() + +def base_dir(): + return str(base_dir_path()) + +class Paths(Singleton): + def __init__(self): + self.base_path:Path = base_dir_path() + self.script_path: Path = self.base_path / "scripts" + self.userscript_path: Path = self.base_path / "userscripts" + +paths = Paths() \ No newline at end of file diff --git a/scripts/singleton.py b/scripts/singleton.py index 93a1872..ce5d4f1 100644 --- a/scripts/singleton.py +++ b/scripts/singleton.py @@ -1,6 +1,6 @@ class Singleton(object): - @classmethod - def get_instance(cls): - if not hasattr(cls, "_instance"): - cls._instance = cls() - return cls._instance \ No newline at end of file + _instance = None + def __new__(class_, *args, **kwargs): + if not isinstance(class_._instance, class_): + class_._instance = object.__new__(class_, *args, **kwargs) + return class_._instance diff --git a/scripts/tag_editor_ui/block_dataset_gallery.py b/scripts/tag_editor_ui/block_dataset_gallery.py index 6fc9ef2..2e330a7 100644 --- a/scripts/tag_editor_ui/block_dataset_gallery.py +++ b/scripts/tag_editor_ui/block_dataset_gallery.py @@ -22,7 +22,7 @@ class DatasetGalleryUI(UIBase): self.btn_hidden_set_index = gr.Button(elem_id="dataset_tag_editor_btn_hidden_set_index") self.nb_hidden_image_index = gr.Number(value=None, label='hidden_idx_next') self.nb_hidden_image_index_prev = gr.Number(value=None, label='hidden_idx_prev') - self.gl_dataset_images = gr.Gallery(label='Dataset Images', elem_id="dataset_tag_editor_dataset_gallery").style(grid=image_columns) + self.gl_dataset_images = gr.Gallery(label='Dataset Images', elem_id="dataset_tag_editor_dataset_gallery", columns=image_columns) def set_callbacks(self, load_dataset:LoadDatasetUI, gallery_state:GalleryStateUI, get_filters:Callable[[], dte_module.filters.Filter]): gallery_state.register_value('Selected Image', self.selected_path) diff --git a/scripts/tag_editor_ui/block_gallery_state.py b/scripts/tag_editor_ui/block_gallery_state.py index 8c3cf69..794666a 100644 --- a/scripts/tag_editor_ui/block_gallery_state.py +++ b/scripts/tag_editor_ui/block_gallery_state.py @@ -28,7 +28,7 @@ class GalleryStateUI(UIBase): self.txt_gallery = gr.HTML(value=self.get_current_gallery_txt()) def set_callbacks(self, dataset_gallery:DatasetGalleryUI): - dataset_gallery.nb_hidden_image_index.change( + dataset_gallery.nb_hidden_image_index.change(fn=lambda:None).then( fn=self.update_gallery_txt, inputs=None, outputs=self.txt_gallery diff --git a/scripts/tag_editor_ui/block_load_dataset.py b/scripts/tag_editor_ui/block_load_dataset.py index 9437a4d..0bdeb62 100644 --- a/scripts/tag_editor_ui/block_load_dataset.py +++ b/scripts/tag_editor_ui/block_load_dataset.py @@ -11,43 +11,113 @@ from .uibase import UIBase if TYPE_CHECKING: from .ui_classes import * -INTERROGATOR_NAMES = dte_module.INTERROGATOR_NAMES -InterrogateMethod = dte_instance.InterrogateMethod - class LoadDatasetUI(UIBase): def __init__(self): - self.caption_file_ext = '' + self.caption_file_ext = "" def create_ui(self, cfg_general): - with gr.Column(variant='panel'): + with gr.Column(variant="panel"): with gr.Row(): with gr.Column(scale=3): - self.tb_img_directory = gr.Textbox(label='Dataset directory', placeholder='C:\\directory\\of\\datasets', value=cfg_general.dataset_dir) + self.tb_img_directory = gr.Textbox( + label="Dataset directory", + placeholder="C:\\directory\\of\\datasets", + value=cfg_general.dataset_dir, + ) with gr.Column(scale=1, min_width=60): - self.tb_caption_file_ext = gr.Textbox(label='Caption File Ext', placeholder='.txt (on Load and Save)', value=cfg_general.caption_ext) + self.tb_caption_file_ext = gr.Textbox( + label="Caption File Ext", + placeholder=".txt (on Load and Save)", + value=cfg_general.caption_ext, + ) self.caption_file_ext = cfg_general.caption_ext with gr.Column(scale=1, min_width=80): - self.btn_load_datasets = gr.Button(value='Load') - self.btn_unload_datasets = gr.Button(value='Unload') - with gr.Accordion(label='Dataset Load Settings'): + self.btn_load_datasets = gr.Button(value="Load") + self.btn_unload_datasets = gr.Button(value="Unload") + with gr.Accordion(label="Dataset Load Settings"): with gr.Row(): with gr.Column(): - self.cb_load_recursive = gr.Checkbox(value=cfg_general.load_recursive, label='Load from subdirectories') - self.cb_load_caption_from_filename = gr.Checkbox(value=cfg_general.load_caption_from_filename, label='Load caption from filename if no text file exists') - self.cb_replace_new_line_with_comma = gr.Checkbox(value=cfg_general.replace_new_line, label='Replace new-line character with comma') + self.cb_load_recursive = gr.Checkbox( + value=cfg_general.load_recursive, + label="Load from subdirectories", + ) + self.cb_load_caption_from_filename = gr.Checkbox( + value=cfg_general.load_caption_from_filename, + label="Load caption from filename if no text file exists", + ) + self.cb_replace_new_line_with_comma = gr.Checkbox( + value=cfg_general.replace_new_line, + label="Replace new-line character with comma", + ) with gr.Column(): - self.rb_use_interrogator = gr.Radio(choices=['No', 'If Empty', 'Overwrite', 'Prepend', 'Append'], value=cfg_general.use_interrogator, label='Use Interrogator Caption') - self.dd_intterogator_names = gr.Dropdown(label = 'Interrogators', choices=INTERROGATOR_NAMES, value=cfg_general.use_interrogator_names, interactive=True, multiselect=True) - with gr.Accordion(label='Interrogator Settings', open=False): + self.rb_use_interrogator = gr.Radio( + choices=[ + "No", + "If Empty", + "Overwrite", + "Prepend", + "Append", + ], + value=cfg_general.use_interrogator, + label="Use Interrogator Caption", + ) + self.dd_intterogator_names = gr.Dropdown( + label="Interrogators", + choices=dte_instance.INTERROGATOR_NAMES, + value=cfg_general.use_interrogator_names, + interactive=True, + multiselect=True, + ) + with gr.Accordion(label="Interrogator Settings", open=False): with gr.Row(): - self.cb_use_custom_threshold_booru = gr.Checkbox(value=cfg_general.use_custom_threshold_booru, label='Use Custom Threshold (Booru)', interactive=True) - self.sl_custom_threshold_booru = gr.Slider(minimum=0, maximum=1, value=cfg_general.custom_threshold_booru, step=0.01, interactive=True, label='Booru Score Threshold') + self.cb_use_custom_threshold_booru = gr.Checkbox( + value=cfg_general.use_custom_threshold_booru, + label="Use Custom Threshold (Booru)", + interactive=True, + ) + self.sl_custom_threshold_booru = gr.Slider( + minimum=0, + maximum=1, + value=cfg_general.custom_threshold_booru, + step=0.01, + interactive=True, + label="Booru Score Threshold", + ) with gr.Row(): - self.cb_use_custom_threshold_waifu = gr.Checkbox(value=cfg_general.use_custom_threshold_waifu, label='Use Custom Threshold (WDv1.4 Tagger)', interactive=True) - self.sl_custom_threshold_waifu = gr.Slider(minimum=0, maximum=1, value=cfg_general.custom_threshold_waifu, step=0.01, interactive=True, label='WDv1.4 Tagger Score Threshold') - - def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], toprow:ToprowUI, dataset_gallery:DatasetGalleryUI, filter_by_tags:FilterByTagsUI, filter_by_selection:FilterBySelectionUI, batch_edit_captions:BatchEditCaptionsUI, update_filter_and_gallery:Callable[[], List]): + self.sl_custom_threshold_z3d = gr.Slider( + minimum=0, + maximum=1, + value=cfg_general.custom_threshold_z3d, + step=0.01, + interactive=True, + label="Z3D-E621 Score Threshold", + ) + with gr.Row(): + self.cb_use_custom_threshold_waifu = gr.Checkbox( + value=cfg_general.use_custom_threshold_waifu, + label="Use Custom Threshold (WDv1.4 Tagger)", + interactive=True, + ) + self.sl_custom_threshold_waifu = gr.Slider( + minimum=0, + maximum=1, + value=cfg_general.custom_threshold_waifu, + step=0.01, + interactive=True, + label="WDv1.4 Tagger Score Threshold", + ) + + def set_callbacks( + self, + o_update_filter_and_gallery: List[gr.components.Component], + toprow: ToprowUI, + dataset_gallery: DatasetGalleryUI, + filter_by_tags: FilterByTagsUI, + filter_by_selection: FilterBySelectionUI, + batch_edit_captions: BatchEditCaptionsUI, + update_filter_and_gallery: Callable[[], List], + ): def load_files_from_dir( dir: str, caption_file_ext: str, @@ -55,63 +125,112 @@ class LoadDatasetUI(UIBase): load_caption_from_filename: bool, replace_new_line: bool, use_interrogator: str, - use_interrogator_names, #: List[str], : to avoid error on gradio v3.23.0 + use_interrogator_names, #: List[str], : to avoid error on gradio v3.23.0 use_custom_threshold_booru: bool, custom_threshold_booru: float, use_custom_threshold_waifu: bool, custom_threshold_waifu: float, + custom_threshold_z3d: float, use_kohya_metadata: bool, - kohya_json_path: str - ): - - interrogate_method = InterrogateMethod.NONE - if use_interrogator == 'If Empty': - interrogate_method = InterrogateMethod.PREFILL - elif use_interrogator == 'Overwrite': - interrogate_method = InterrogateMethod.OVERWRITE - elif use_interrogator == 'Prepend': - interrogate_method = InterrogateMethod.PREPEND - elif use_interrogator == 'Append': - interrogate_method = InterrogateMethod.APPEND + kohya_json_path: str, + ): - threshold_booru = custom_threshold_booru if use_custom_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold - threshold_waifu = custom_threshold_waifu if use_custom_threshold_waifu else -1 + interrogate_method = dte_instance.InterrogateMethod.NONE + if use_interrogator == "If Empty": + interrogate_method = dte_instance.InterrogateMethod.PREFILL + elif use_interrogator == "Overwrite": + interrogate_method = dte_instance.InterrogateMethod.OVERWRITE + elif use_interrogator == "Prepend": + interrogate_method = dte_instance.InterrogateMethod.PREPEND + elif use_interrogator == "Append": + interrogate_method = dte_instance.InterrogateMethod.APPEND - dte_instance.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, replace_new_line, interrogate_method, use_interrogator_names, threshold_booru, threshold_waifu, opts.dataset_editor_use_temp_files, kohya_json_path if use_kohya_metadata else None, opts.dataset_editor_max_res) + threshold_booru = ( + custom_threshold_booru + if use_custom_threshold_booru + else opts.interrogate_deepbooru_score_threshold + ) + threshold_waifu = ( + custom_threshold_waifu if use_custom_threshold_waifu else -1 + ) + threshold_z3d = custom_threshold_z3d + + dte_instance.load_dataset( + dir, + caption_file_ext, + recursive, + load_caption_from_filename, + replace_new_line, + interrogate_method, + use_interrogator_names, + threshold_booru, + threshold_waifu, + threshold_z3d, + opts.dataset_editor_use_temp_files, + kohya_json_path if use_kohya_metadata else None, + opts.dataset_editor_max_res, + ) imgs = dte_instance.get_filtered_imgs(filters=[]) img_indices = dte_instance.get_filtered_imgindices(filters=[]) - return [ - imgs, - [] - ] +\ - [gr.CheckboxGroup.update(value=[str(i) for i in img_indices], choices=[str(i) for i in img_indices]), 1] +\ - filter_by_tags.clear_filters(update_filter_and_gallery) +\ - [batch_edit_captions.tag_select_ui_remove.cbg_tags_update()] - + return ( + [imgs, []] + + [ + gr.CheckboxGroup.update( + value=[str(i) for i in img_indices], + choices=[str(i) for i in img_indices], + ), + 1, + ] + + filter_by_tags.clear_filters(update_filter_and_gallery) + + [batch_edit_captions.tag_select_ui_remove.cbg_tags_update()] + ) + self.btn_load_datasets.click( fn=load_files_from_dir, - inputs=[self.tb_img_directory, self.tb_caption_file_ext, self.cb_load_recursive, self.cb_load_caption_from_filename, self.cb_replace_new_line_with_comma, self.rb_use_interrogator, self.dd_intterogator_names, self.cb_use_custom_threshold_booru, self.sl_custom_threshold_booru, self.cb_use_custom_threshold_waifu, self.sl_custom_threshold_waifu, toprow.cb_save_kohya_metadata, toprow.tb_metadata_output], - outputs= - [dataset_gallery.gl_dataset_images, filter_by_selection.gl_filter_images] + - [dataset_gallery.cbg_hidden_dataset_filter, dataset_gallery.nb_hidden_dataset_filter_apply] + - o_update_filter_and_gallery + inputs=[ + self.tb_img_directory, + self.tb_caption_file_ext, + self.cb_load_recursive, + self.cb_load_caption_from_filename, + self.cb_replace_new_line_with_comma, + self.rb_use_interrogator, + self.dd_intterogator_names, + self.cb_use_custom_threshold_booru, + self.sl_custom_threshold_booru, + self.cb_use_custom_threshold_waifu, + self.sl_custom_threshold_waifu, + toprow.cb_save_kohya_metadata, + toprow.tb_metadata_output, + ], + outputs=[ + dataset_gallery.gl_dataset_images, + filter_by_selection.gl_filter_images, + ] + + [ + dataset_gallery.cbg_hidden_dataset_filter, + dataset_gallery.nb_hidden_dataset_filter_apply, + ] + + o_update_filter_and_gallery, ) def unload_files(): dte_instance.clear() - return [ - [], - [] - ] +\ - [gr.CheckboxGroup.update(value=[], choices=[]), 1] +\ - filter_by_tags.clear_filters(update_filter_and_gallery) +\ - [batch_edit_captions.tag_select_ui_remove.cbg_tags_update()] + return ( + [[], []] + + [gr.CheckboxGroup.update(value=[], choices=[]), 1] + + filter_by_tags.clear_filters(update_filter_and_gallery) + + [batch_edit_captions.tag_select_ui_remove.cbg_tags_update()] + ) self.btn_unload_datasets.click( fn=unload_files, - outputs= - [dataset_gallery.gl_dataset_images, filter_by_selection.gl_filter_images] + - [dataset_gallery.cbg_hidden_dataset_filter, dataset_gallery.nb_hidden_dataset_filter_apply] + - o_update_filter_and_gallery + outputs=[ + dataset_gallery.gl_dataset_images, + filter_by_selection.gl_filter_images, + ] + + [ + dataset_gallery.cbg_hidden_dataset_filter, + dataset_gallery.nb_hidden_dataset_filter_apply, + ] + + o_update_filter_and_gallery, ) - diff --git a/scripts/tag_editor_ui/block_tag_filter.py b/scripts/tag_editor_ui/block_tag_filter.py index 1c696d4..d5f055e 100644 --- a/scripts/tag_editor_ui/block_tag_filter.py +++ b/scripts/tag_editor_ui/block_tag_filter.py @@ -66,10 +66,10 @@ class TagFilterUI(): self.rb_logic.change(fn=self.rd_logic_changed, inputs=[self.rb_logic], outputs=[self.cbg_tags]) for fn, inputs, outputs, _js in self.on_filter_update_callbacks: - self.rb_logic.change(fn=fn, inputs=inputs, outputs=outputs, _js=_js) + self.rb_logic.change(fn=lambda:None).then(fn=fn, inputs=inputs, outputs=outputs, _js=_js) self.cbg_tags.change(fn=self.cbg_tags_changed, inputs=[self.cbg_tags], outputs=[self.cbg_tags]) for fn, inputs, outputs, _js in self.on_filter_update_callbacks: - self.cbg_tags.change(fn=fn, inputs=inputs, outputs=outputs, _js=_js) + self.cbg_tags.change(fn=lambda:None).then(fn=fn, inputs=inputs, outputs=outputs, _js=_js) def tb_search_tags_changed(self, tb_search_tags: str): diff --git a/scripts/tag_editor_ui/tab_batch_edit_captions.py b/scripts/tag_editor_ui/tab_batch_edit_captions.py index 15bed9b..0dca5d7 100644 --- a/scripts/tag_editor_ui/tab_batch_edit_captions.py +++ b/scripts/tag_editor_ui/tab_batch_edit_captions.py @@ -95,8 +95,7 @@ class BatchEditCaptionsUI(UIBase): fn=apply_edit_tags, inputs=[self.tb_common_tags, self.tb_edit_tags, self.cb_prepend_tags], outputs=o_update_filter_and_gallery - ) - self.btn_apply_edit_tags.click( + ).then( fn=None, _js='() => dataset_tag_editor_gl_dataset_images_close()' ) @@ -124,8 +123,7 @@ class BatchEditCaptionsUI(UIBase): fn=search_and_replace, inputs=[self.tb_sr_search_tags, self.tb_sr_replace_tags, self.rb_sr_replace_target, self.cb_use_regex], outputs=o_update_filter_and_gallery - ) - self.btn_apply_sr_tags.click( + ).then( fn=None, _js='() => dataset_tag_editor_gl_dataset_images_close()' ) diff --git a/scripts/tag_editor_ui/tab_edit_caption_of_selected_image.py b/scripts/tag_editor_ui/tab_edit_caption_of_selected_image.py index 47d493d..b0560f2 100644 --- a/scripts/tag_editor_ui/tab_edit_caption_of_selected_image.py +++ b/scripts/tag_editor_ui/tab_edit_caption_of_selected_image.py @@ -35,7 +35,7 @@ class EditCaptionOfSelectedImageUI(UIBase): with gr.Tab(label='Interrogate Selected Image'): with gr.Row(): - self.dd_intterogator_names_si = gr.Dropdown(label = 'Interrogator', choices=dte_module.INTERROGATOR_NAMES, value=cfg_edit_selected.use_interrogator_name, interactive=True, multiselect=False) + self.dd_intterogator_names_si = gr.Dropdown(label = 'Interrogator', choices=dte_instance.INTERROGATOR_NAMES, value=cfg_edit_selected.use_interrogator_name, interactive=True, multiselect=False) self.btn_interrogate_si = gr.Button(value='Interrogate') with gr.Column(): self.tb_interrogate = gr.Textbox(label='Interrogate Result', interactive=True, lines=6, elem_id='dte_interrogate') @@ -89,7 +89,7 @@ class EditCaptionOfSelectedImageUI(UIBase): _js='(a) => dataset_tag_editor_ask_save_change_or_not(a)', inputs=self.nb_hidden_image_index_save_or_not ) - dataset_gallery.nb_hidden_image_index.change( + dataset_gallery.nb_hidden_image_index.change(lambda:None).then( fn=gallery_index_changed, inputs=[dataset_gallery.nb_hidden_image_index, dataset_gallery.nb_hidden_image_index_prev, self.tb_edit_caption, self.cb_copy_caption_automatically, self.cb_ask_save_when_caption_changed], outputs=[self.nb_hidden_image_index_save_or_not] + [self.tb_caption, self.tb_edit_caption] + [self.tb_hidden_edit_caption] @@ -138,7 +138,7 @@ class EditCaptionOfSelectedImageUI(UIBase): return '' threshold_booru = threshold_booru if use_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold threshold_waifu = threshold_waifu if use_threshold_waifu else -1 - return dte_module.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_waifu) + return dte_instance.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_waifu) self.btn_interrogate_si.click( fn=interrogate_selected_image, diff --git a/scripts/tag_editor_ui/tab_filter_by_selection.py b/scripts/tag_editor_ui/tab_filter_by_selection.py index bd96f9f..f3848bc 100644 --- a/scripts/tag_editor_ui/tab_filter_by_selection.py +++ b/scripts/tag_editor_ui/tab_filter_by_selection.py @@ -32,7 +32,7 @@ class FilterBySelectionUI(UIBase): self.btn_add_image_selection = gr.Button(value='Add selection [Enter]', elem_id='dataset_tag_editor_btn_add_image_selection') self.btn_add_all_displayed_image_selection = gr.Button(value='Add ALL Displayed') - self.gl_filter_images = gr.Gallery(label='Filter Images', elem_id="dataset_tag_editor_filter_gallery").style(grid=image_columns) + self.gl_filter_images = gr.Gallery(label='Filter Images', elem_id="dataset_tag_editor_filter_gallery", columns=image_columns) self.txt_selection = gr.HTML(value=self.get_current_txt_selection()) with gr.Row(): @@ -130,7 +130,7 @@ class FilterBySelectionUI(UIBase): self.path_filter = filters.PathFilter() return clear_image_selection() + update_filter_and_gallery() - filter_by_tags.btn_clear_all_filters.click( + filter_by_tags.btn_clear_all_filters.click(lambda:None).then( fn=clear_image_filter, outputs= [self.gl_filter_images, self.txt_selection, self.nb_hidden_selection_image_index] + @@ -147,8 +147,7 @@ class FilterBySelectionUI(UIBase): self.btn_apply_image_selection_filter.click( fn=apply_image_selection_filter, outputs=o_update_filter_and_gallery - ) - self.btn_apply_image_selection_filter.click( + ).then( fn=None, _js='() => dataset_tag_editor_gl_dataset_images_close()' ) diff --git a/scripts/tag_editor_ui/tab_move_or_delete_files.py b/scripts/tag_editor_ui/tab_move_or_delete_files.py index 9c99700..c6dccec 100644 --- a/scripts/tag_editor_ui/tab_move_or_delete_files.py +++ b/scripts/tag_editor_ui/tab_move_or_delete_files.py @@ -47,17 +47,17 @@ class MoveOrDeleteFilesUI(UIBase): 'outputs' : [self.ta_move_or_delete_target_dataset_num] } - batch_edit_captions.btn_apply_edit_tags.click(**update_args) + batch_edit_captions.btn_apply_edit_tags.click(lambda:None).then(**update_args) - batch_edit_captions.btn_apply_sr_tags.click(**update_args) + batch_edit_captions.btn_apply_sr_tags.click(lambda:None).then(**update_args) - filter_by_selection.btn_apply_image_selection_filter.click(**update_args) + filter_by_selection.btn_apply_image_selection_filter.click(lambda:None).then(**update_args) - filter_by_tags.btn_clear_tag_filters.click(**update_args) + filter_by_tags.btn_clear_tag_filters.click(lambda:None).then(**update_args) - filter_by_tags.btn_clear_all_filters.click(**update_args) + filter_by_tags.btn_clear_all_filters.click(lambda:None).then(**update_args) - edit_caption_of_selected_image.btn_apply_changes_selected_image.click(**update_args) + edit_caption_of_selected_image.btn_apply_changes_selected_image.click(lambda:None).then(**update_args) self.rb_move_or_delete_target_data.change(**update_args) @@ -84,9 +84,7 @@ class MoveOrDeleteFilesUI(UIBase): fn=move_files, inputs=[self.rb_move_or_delete_target_data, self.cbg_move_or_delete_target_file, self.tb_move_or_delete_caption_ext, self.tb_move_or_delete_destination_dir], outputs=o_update_filter_and_gallery - ) - self.btn_move_or_delete_move_files.click(**update_args) - self.btn_move_or_delete_move_files.click( + ).then(**update_args).then( fn=None, _js='() => dataset_tag_editor_gl_dataset_images_close()' ) @@ -114,8 +112,7 @@ class MoveOrDeleteFilesUI(UIBase): inputs=[self.rb_move_or_delete_target_data, self.cbg_move_or_delete_target_file, self.tb_move_or_delete_caption_ext], outputs=o_update_filter_and_gallery ) - self.btn_move_or_delete_delete_files.click(**update_args) - self.btn_move_or_delete_delete_files.click( + self.btn_move_or_delete_delete_files.click(**update_args).then( fn=None, _js='() => dataset_tag_editor_gl_dataset_images_close()' ) diff --git a/scripts/tag_editor_ui/ui_instance.py b/scripts/tag_editor_ui/ui_instance.py index 6660347..232618b 100644 --- a/scripts/tag_editor_ui/ui_instance.py +++ b/scripts/tag_editor_ui/ui_instance.py @@ -4,12 +4,12 @@ __all__ = [ 'toprow', 'load_dataset', 'dataset_gallery', 'gallery_state', 'filter_by_tags', 'filter_by_selection', 'batch_edit_captions', 'edit_caption_of_selected_image', 'move_or_delete_files' ] -toprow = ToprowUI.get_instance() -load_dataset = LoadDatasetUI.get_instance() -dataset_gallery = DatasetGalleryUI.get_instance() -gallery_state = GalleryStateUI.get_instance() -filter_by_tags = FilterByTagsUI.get_instance() -filter_by_selection = FilterBySelectionUI.get_instance() -batch_edit_captions = BatchEditCaptionsUI.get_instance() -edit_caption_of_selected_image = EditCaptionOfSelectedImageUI.get_instance() -move_or_delete_files = MoveOrDeleteFilesUI.get_instance() +toprow = ToprowUI() +load_dataset = LoadDatasetUI() +dataset_gallery = DatasetGalleryUI() +gallery_state = GalleryStateUI() +filter_by_tags = FilterByTagsUI() +filter_by_selection = FilterBySelectionUI() +batch_edit_captions = BatchEditCaptionsUI() +edit_caption_of_selected_image = EditCaptionOfSelectedImageUI() +move_or_delete_files = MoveOrDeleteFilesUI() diff --git a/scripts/tagger.py b/scripts/tagger.py new file mode 100644 index 0000000..341039f --- /dev/null +++ b/scripts/tagger.py @@ -0,0 +1,52 @@ +import re +from typing import Optional, Generator, Any + +from PIL import Image + +from modules import shared, lowvram, devices +from modules import deepbooru as db + +# Custom tagger classes have to inherit from this class +class Tagger: + def __enter__(self): + lowvram.send_everything_to_cpu() + devices.torch_gc() + self.start() + return self + + def __exit__(self, exception_type, exception_value, traceback): + self.stop() + pass + + def start(self): + pass + + def stop(self): + pass + + # predict tags of one image + def predict(self, image: Image.Image, threshold: Optional[float] = None) -> list[str]: + raise NotImplementedError() + + # Please implement if you want to use more efficient data loading system + # None input will come to check if this function is implemented + def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None) -> Generator[list[str], Any, None]: + raise NotImplementedError() + + # Visible name in UI + def name(self): + raise NotImplementedError() + + +def get_replaced_tag(tag: str): + use_spaces = shared.opts.deepbooru_use_spaces + use_escape = shared.opts.deepbooru_escape + if use_spaces: + tag = tag.replace('_', ' ') + if use_escape: + tag = re.sub(db.re_special, r'\\\1', tag) + return tag + + +def get_arranged_tags(probs: dict[str, float]): + return [tag for tag, _ in sorted(probs.items(), key=lambda x: -x[1])] diff --git a/userscripts/taggers/aesthetic_shadow.py b/userscripts/taggers/aesthetic_shadow.py new file mode 100644 index 0000000..51ef301 --- /dev/null +++ b/userscripts/taggers/aesthetic_shadow.py @@ -0,0 +1,54 @@ +import math + +from PIL import Image +from transformers import pipeline +import torch + +from modules import devices, shared +from scripts.tagger import Tagger + +# brought and modified from https://huggingface.co/spaces/cafeai/cafe_aesthetic_demo/blob/main/app.py + +# I'm not sure if this is really working +BATCH_SIZE = 3 + +class AestheticShadowV2(Tagger): + def load(self): + if devices.device.index is None: + dev = torch.device(devices.device.type, 0) + else: + dev = devices.device + self.pipe_aesthetic = pipeline("image-classification", "shadowlilac/aesthetic-shadow-v2", device=dev, batch_size=BATCH_SIZE) + + def unload(self): + if not shared.opts.interrogate_keep_models_in_memory: + self.pipe_aesthetic = None + devices.torch_gc() + + def start(self): + self.load() + return self + + def stop(self): + self.unload() + + def _get_score(self, data): + final = {} + for d in data: + final[d["label"]] = d["score"] + hq = final['hq'] + lq = final['lq'] + return [f"score_{math.floor((hq + (1 - lq))/2 * 10)}"] + + def predict(self, image: Image.Image, threshold=None): + data = self.pipe_aesthetic(image) + return self._get_score(data) + + def predict_pipe(self, data: list[Image.Image], threshold=None): + if data is None: + return + for out in self.pipe_aesthetic(data, batch_size=BATCH_SIZE): + yield self._get_score(out) + + def name(self): + return "aesthetic shadow" \ No newline at end of file diff --git a/userscripts/taggers/cafeai_aesthetic_classifier.py b/userscripts/taggers/cafeai_aesthetic_classifier.py new file mode 100644 index 0000000..8464af6 --- /dev/null +++ b/userscripts/taggers/cafeai_aesthetic_classifier.py @@ -0,0 +1,54 @@ +import math + +from PIL import Image +from transformers import pipeline +import torch + +from modules import devices, shared +from scripts.tagger import Tagger + +# brought and modified from https://huggingface.co/spaces/cafeai/cafe_aesthetic_demo/blob/main/app.py + +# I'm not sure if this is really working +BATCH_SIZE = 8 + +class CafeAIAesthetic(Tagger): + def load(self): + if devices.device.index is None: + dev = torch.device(devices.device.type, 0) + else: + dev = devices.device + self.pipe_aesthetic = pipeline("image-classification", "cafeai/cafe_aesthetic", device=dev, batch_size=BATCH_SIZE) + + def unload(self): + if not shared.opts.interrogate_keep_models_in_memory: + self.pipe_aesthetic = None + devices.torch_gc() + + def start(self): + self.load() + return self + + def stop(self): + self.unload() + + def _get_score(self, data): + final = {} + for d in data: + final[d["label"]] = d["score"] + nae = final['not_aesthetic'] + ae = final['aesthetic'] + return [f"score_{math.floor((ae + (1 - nae))/2 * 10)}"] + + def predict(self, image: Image.Image, threshold=None): + data = self.pipe_aesthetic(image, top_k=2) + return self._get_score(data) + + def predict_pipe(self, data: list[Image.Image], threshold=None): + if data is None: + return + for out in self.pipe_aesthetic(data, batch_size=BATCH_SIZE): + yield self._get_score(out) + + def name(self): + return "cafeai aesthetic classifier" \ No newline at end of file diff --git a/userscripts/taggers/improved_aesthetic_predictor.py b/userscripts/taggers/improved_aesthetic_predictor.py new file mode 100644 index 0000000..fe217c5 --- /dev/null +++ b/userscripts/taggers/improved_aesthetic_predictor.py @@ -0,0 +1,75 @@ +from PIL import Image +import torch +import torch.nn as nn +import numpy as np +import math + +from transformers import CLIPModel, CLIPProcessor + +from modules import devices, shared +from scripts import model_loader +from scripts.paths import paths +from scripts.tagger import Tagger + +# brought from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py and modified +class Classifier(nn.Module): + def __init__(self, input_size): + super().__init__() + self.input_size = input_size + self.layers = nn.Sequential( + nn.Linear(self.input_size, 1024), + nn.Dropout(0.2), + nn.Linear(1024, 128), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Dropout(0.1), + nn.Linear(64, 16), + nn.Linear(16, 1) + ) + + def forward(self, x): + return self.layers(x) + +# brought and modified from https://github.com/waifu-diffusion/aesthetic/blob/main/aesthetic.py +def image_embeddings(image:Image, model:CLIPModel, processor:CLIPProcessor): + inputs = processor(images=image, return_tensors='pt')['pixel_values'] + inputs = inputs.to(devices.device) + result:np.ndarray = model.get_image_features(pixel_values=inputs).cpu().detach().numpy() + return (result / np.linalg.norm(result)).squeeze(axis=0) + + +class ImprovedAestheticPredictor(Tagger): + def load(self): + MODEL_VERSION = "sac+logos+ava1-l14-linearMSE" + file = model_loader.load( + model_path=paths.models_path / "aesthetic" / f"{MODEL_VERSION}.pth", + model_url=f'https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/{MODEL_VERSION}.pth' + ) + CLIP_REPOS = 'openai/clip-vit-large-patch14' + self.model = Classifier(768) + self.model.load_state_dict(torch.load(file)) + self.model = self.model.to(devices.device) + self.clip_processor = CLIPProcessor.from_pretrained(CLIP_REPOS) + self.clip_model = CLIPModel.from_pretrained(CLIP_REPOS).to(devices.device).eval() + + def unload(self): + if not shared.opts.interrogate_keep_models_in_memory: + self.model = None + self.clip_processor = None + self.clip_model = None + devices.torch_gc() + + def start(self): + self.load() + return self + + def stop(self): + self.unload() + + def predict(self, image: Image.Image, threshold=None): + image_embeds = image_embeddings(image, self.clip_model, self.clip_processor) + prediction:torch.Tensor = self.model(torch.from_numpy(image_embeds).float().to(devices.device)) + return [f"score_{math.floor(prediction.item())}"] + + def name(self): + return "Improved Aesthetic Predictor" \ No newline at end of file diff --git a/userscripts/taggers/waifu_aesthetic_classifier.py b/userscripts/taggers/waifu_aesthetic_classifier.py new file mode 100644 index 0000000..1952f51 --- /dev/null +++ b/userscripts/taggers/waifu_aesthetic_classifier.py @@ -0,0 +1,73 @@ +from PIL import Image +import torch +import numpy as np +import math + +from transformers import CLIPModel, CLIPProcessor + +from modules import devices, shared +from scripts import model_loader +from scripts.paths import paths +from scripts.tagger import Tagger + +# brought from https://github.com/waifu-diffusion/aesthetic/blob/main/aesthetic.py +class Classifier(torch.nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(Classifier, self).__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.fc2 = torch.nn.Linear(hidden_size, hidden_size//2) + self.fc3 = torch.nn.Linear(hidden_size//2, output_size) + self.relu = torch.nn.ReLU() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x:torch.Tensor): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + x = self.sigmoid(x) + return x + +# brought and modified from https://github.com/waifu-diffusion/aesthetic/blob/main/aesthetic.py +def image_embeddings(image:Image, model:CLIPModel, processor:CLIPProcessor): + inputs = processor(images=image, return_tensors='pt')['pixel_values'] + inputs = inputs.to(devices.device) + result:np.ndarray = model.get_image_features(pixel_values=inputs).cpu().detach().numpy() + return (result / np.linalg.norm(result)).squeeze(axis=0) + + +class WaifuAesthetic(Tagger): + def load(self): + file = model_loader.load( + model_path=paths.models_path / "aesthetic" / "aes-B32-v0.pth", + model_url='https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/models/aes-B32-v0.pth' + ) + CLIP_REPOS = 'openai/clip-vit-base-patch32' + self.model = Classifier(512, 256, 1) + self.model.load_state_dict(torch.load(file)) + self.model = self.model.to(devices.device) + self.clip_processor = CLIPProcessor.from_pretrained(CLIP_REPOS) + self.clip_model = CLIPModel.from_pretrained(CLIP_REPOS).to(devices.device).eval() + + def unload(self): + if not shared.opts.interrogate_keep_models_in_memory: + self.model = None + self.clip_processor = None + self.clip_model = None + devices.torch_gc() + + def start(self): + self.load() + return self + + def stop(self): + self.unload() + + def predict(self, image: Image.Image, threshold=None): + image_embeds = image_embeddings(image, self.clip_model, self.clip_processor) + prediction:torch.Tensor = self.model(torch.from_numpy(image_embeds).float().to(devices.device)) + return [f"score_{math.floor(prediction.item()*10)}"] + + def name(self): + return "wd aesthetic classifier" \ No newline at end of file