diff --git a/scripts/dataset_tag_editor/captioning.py b/scripts/dataset_tag_editor/captioning.py index 113c084..f4bd405 100644 --- a/scripts/dataset_tag_editor/captioning.py +++ b/scripts/dataset_tag_editor/captioning.py @@ -1,15 +1,10 @@ import modules.shared as shared +from scripts.dataset_tag_editor.interrogator import Interrogator from scripts.dynamic_import import dynamic_import git_large_captioning = dynamic_import('scripts/dataset_tag_editor/interrogators/git_large_captioning.py') -class Captioning: - def __enter__(self): - self.start() - return self - def __exit__(self, exception_type, exception_value, traceback): - self.stop() - pass +class Captioning(Interrogator): def start(self): pass def stop(self): diff --git a/scripts/dataset_tag_editor/dataset_tag_editor.py b/scripts/dataset_tag_editor/dataset_tag_editor.py index c432287..59c3858 100644 --- a/scripts/dataset_tag_editor/dataset_tag_editor.py +++ b/scripts/dataset_tag_editor/dataset_tag_editor.py @@ -15,6 +15,8 @@ filters = dynamic_import('scripts/dataset_tag_editor/filters.py') re_tags = re.compile(r'^(.+) \[\d+\]$') +INTERROGATORS = [captioning.BLIP(), tagger.DeepDanbooru()] + [tagger.WaifuDiffusion(name) for name in tagger.WD_TAGGER_NAMES] +INTERROGATOR_NAMES = [it.name() for it in INTERROGATORS] class InterrogateMethod(Enum): NONE = 0 @@ -24,14 +26,23 @@ class InterrogateMethod(Enum): APPEND = 4 -def interrogate_image_blip(path): +def interrogate_image(path: str, interrogator_name: str, threshold_booru, threshold_wd): try: img = Image.open(path).convert('RGB') except: return '' else: - with captioning.BLIP() as cap: - res = cap.predict(img) + for it in INTERROGATORS: + if it.name() == interrogator_name: + if isinstance(it, tagger.DeepDanbooru): + with it as tg: + res = tg.predict(img, threshold_booru) + elif isinstance(it, tagger.WaifuDiffusion): + with it as tg: + res = tg.predict(img, threshold_wd) + else: + with it as cap: + res = cap.predict(img) return ', '.join(res) @@ -66,6 +77,17 @@ def interrogate_image_waifu(path, threshold): with tagger.WaifuDiffusion() as tg: res = tg.predict(img, threshold=threshold) return ', '.join(tagger.get_arranged_tags(res)) + + +def interrogate_image_waifu_v2(path, threshold): + try: + img = Image.open(path).convert('RGB') + except: + return '' + else: + with tagger.WaifuDiffusionV2() as tg: + res = tg.predict(img, threshold=threshold) + return ', '.join(tagger.get_arranged_tags(res)) def get_filepath_set(dir: str, recursive: bool): @@ -401,7 +423,7 @@ class DatasetTagEditor: print(e) - def load_dataset(self, img_dir: str, caption_ext:str, recursive: bool, load_caption_from_filename: bool, interrogate_method: InterrogateMethod, use_booru: bool, use_blip: bool, use_git:bool, use_waifu: bool, threshold_booru: float, threshold_waifu: float): + def load_dataset(self, img_dir: str, caption_ext:str, recursive: bool, load_caption_from_filename: bool, interrogate_method: InterrogateMethod, interrogator_names: List[str], threshold_booru: float, threshold_waifu: float): self.clear() print(f'[tag-editor] Loading dataset from {img_dir}') if recursive: @@ -477,22 +499,17 @@ class DatasetTagEditor: captionings = [] taggers = [] if interrogate_method != InterrogateMethod.NONE: - if use_blip: - cap = captioning.BLIP() - cap.start() - captionings.append(cap) - if use_git: - cap = captioning.GITLarge() - cap.start() - captionings.append(cap) - if use_booru: - tg = tagger.DeepDanbooru() - tg.start() - taggers.append((tg, threshold_booru)) - if use_waifu: - tg = tagger.WaifuDiffusion() - tg.start() - taggers.append((tg, threshold_waifu)) + for it in INTERROGATORS: + if it.name() in interrogator_names: + it.start() + if isinstance(it, tagger.Tagger): + if isinstance(it, tagger.DeepDanbooru): + taggers.append((it, threshold_booru)) + if isinstance(it, tagger.WaifuDiffusion): + taggers.append((it, threshold_waifu)) + elif isinstance(it, captioning.Captioning): + captionings.append(it) + load_images(filepath_set=filepath_set, captionings=captionings, taggers=taggers) diff --git a/scripts/dataset_tag_editor/interrogator.py b/scripts/dataset_tag_editor/interrogator.py new file mode 100644 index 0000000..d5a0a8e --- /dev/null +++ b/scripts/dataset_tag_editor/interrogator.py @@ -0,0 +1,15 @@ +class Interrogator: + def __enter__(self): + self.start() + return self + def __exit__(self, exception_type, exception_value, traceback): + self.stop() + pass + def start(self): + pass + def stop(self): + pass + def predict(self, image, **kwargs): + raise NotImplementedError + def name(self): + raise NotImplementedError diff --git a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py index d1b45d2..4a88bcd 100644 --- a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py +++ b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py @@ -7,10 +7,10 @@ import launch class WaifuDiffusionTagger(): # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified - MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger" MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" - def __init__(self): + def __init__(self, model_name): + self.MODEL_REPO = model_name self.model = None self.labels = [] @@ -63,7 +63,4 @@ class WaifuDiffusionTagger(): probs = self.model.run([label_name], {input_name: image_np})[0] labels: List[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float))) - return labels - - -instance = WaifuDiffusionTagger() + return labels \ No newline at end of file diff --git a/scripts/dataset_tag_editor/tagger.py b/scripts/dataset_tag_editor/tagger.py index 9659058..d2351fd 100644 --- a/scripts/dataset_tag_editor/tagger.py +++ b/scripts/dataset_tag_editor/tagger.py @@ -6,22 +6,17 @@ from typing import Optional, Dict from modules import devices, shared from modules import deepbooru as db +from scripts.dataset_tag_editor.interrogator import Interrogator from scripts.dynamic_import import dynamic_import waifu_diffusion_tagger = dynamic_import('scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger.py') -class Tagger: - def __enter__(self): - self.start() - return self - def __exit__(self, exception_type, exception_value, traceback): - self.stop() - pass +class Tagger(Interrogator): def start(self): pass def stop(self): pass - def predict(self,image: Image.Image, threshold: Optional[float]): + def predict(self, image: Image.Image, threshold: Optional[float]): raise NotImplementedError def name(self): raise NotImplementedError @@ -78,20 +73,26 @@ class DeepDanbooru(Tagger): return 'DeepDanbooru' +WD_TAGGER_NAMES = ["wd-v1-4-vit-tagger", "wd-v1-4-convnext-tagger", "wd-v1-4-vit-tagger-v2", "wd-v1-4-convnext-tagger-v2", "wd-v1-4-swinv2-tagger-v2"] + class WaifuDiffusion(Tagger): + def __init__(self, repo_name): + self.repo_name = repo_name + self.tagger_inst = waifu_diffusion_tagger.WaifuDiffusionTagger("SmilingWolf/" + repo_name) + def start(self): - waifu_diffusion_tagger.instance.load() + self.tagger_inst.load() return self def stop(self): - waifu_diffusion_tagger.instance.unload() + self.tagger_inst.unload() # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified def predict(self, image: Image.Image, threshold: Optional[float] = None): # may not use ratings # rating = dict(labels[:4]) - labels = waifu_diffusion_tagger.instance.apply(image) + labels = self.tagger_inst.apply(image) if threshold: probability_dict = dict([(get_replaced_tag(x[0]), x[1]) for x in labels[4:] if x[1] > threshold]) @@ -101,4 +102,4 @@ class WaifuDiffusion(Tagger): return probability_dict def name(self): - return 'wd-v1-4-tagger' \ No newline at end of file + return self.repo_name \ No newline at end of file diff --git a/scripts/main.py b/scripts/main.py index 8bb6e4a..36c1414 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -1,4 +1,4 @@ -from typing import List, NamedTuple, Type +from typing import List, NamedTuple, Type, Dict, Any from modules import shared, script_callbacks, scripts from modules.shared import opts import gradio as gr @@ -44,10 +44,7 @@ GeneralConfig = namedtuple('GeneralConfig', [ 'load_recursive', 'load_caption_from_filename', 'use_interrogator', - 'use_blip_to_prefill', - 'use_git_to_prefill', - 'use_booru_to_prefill', - 'use_waifu_to_prefill', + 'use_interrogator_names', 'use_custom_threshold_booru', 'custom_threshold_booru', 'use_custom_threshold_waifu', @@ -55,14 +52,14 @@ GeneralConfig = namedtuple('GeneralConfig', [ ]) FilterConfig = namedtuple('FilterConfig', ['sort_by', 'sort_order', 'logic']) BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sory_by', 'sort_order']) -EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'warn_change_not_saved']) +EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'warn_change_not_saved', 'use_interrogator_name']) MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination']) -CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, 'No', False, False, False, False, False, 0.7, False, 0.5) +CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, 'No', [], False, 0.7, False, 0.5) CFG_FILTER_P_DEFAULT = FilterConfig('Alphabetical Order', 'Ascending', 'AND') CFG_FILTER_N_DEFAULT = FilterConfig('Alphabetical Order', 'Ascending', 'OR') CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', 'Alphabetical Order', 'Ascending') -CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(False, False) +CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(False, False, '') CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig('Selected One', [], '.txt', '') class Config: @@ -114,17 +111,34 @@ def write_move_delete_config(*args): cfg = MoveDeleteConfig(*args) config.write(cfg._asdict(), 'file_move_delete') -def read_config(name: str, config_type: Type, default: NamedTuple): +def read_config(name: str, config_type: Type, default: NamedTuple, compat_func = None): d = config.read(name) cfg = default if d: + if compat_func: d = compat_func(d) d = cfg._asdict() | d d = {k:v for k,v in d.items() if k in cfg._asdict().keys()} cfg = config_type(**d) return cfg def read_general_config(): - return read_config('general', GeneralConfig, CFG_GENERAL_DEFAULT) + # for compatibility + generalcfg_intterogator_names = [ + ('use_blip_to_prefill', 'BLIP'), + ('use_git_to_prefill', 'GIT-large-COCO'), + ('use_booru_to_prefill', 'DeepDanbooru'), + ('use_waifu_to_prefill', 'wd-v1-4-vit-tagger') + ] + use_interrogator_names = [] + def compat_func(d: Dict[str, Any]): + if 'use_interrogator_names' in d.keys(): + return d + for cfg in generalcfg_intterogator_names: + if d.get(cfg[0]): + use_interrogator_names.append(cfg[1]) + d['use_interrogator_names'] = use_interrogator_names + return d + return read_config('general', GeneralConfig, CFG_GENERAL_DEFAULT, compat_func) def read_filter_config(): d = config.read('filter') @@ -187,10 +201,7 @@ def load_files_from_dir( recursive: bool, load_caption_from_filename: bool, use_interrogator: str, - use_blip: bool, - use_git: bool, - use_booru: bool, - use_waifu: bool, + use_interrogator_names: List[str], use_custom_threshold_booru: bool, custom_threshold_booru: float, use_custom_threshold_waifu: bool, @@ -211,7 +222,7 @@ def load_files_from_dir( threshold_booru = custom_threshold_booru if use_custom_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold threshold_waifu = custom_threshold_waifu if use_custom_threshold_waifu else shared.opts.interrogate_deepbooru_score_threshold - dataset_tag_editor.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, interrogate_method, use_booru, use_blip, use_git, use_waifu, threshold_booru, threshold_waifu) + dataset_tag_editor.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, interrogate_method, use_interrogator_names, threshold_booru, threshold_waifu) img_paths = dataset_tag_editor.get_filtered_imgpaths(filters=[]) img_indices = dataset_tag_editor.get_filtered_imgindices(filters=[]) path_filter = filters.PathFilter() @@ -402,26 +413,11 @@ def change_selected_image_caption(tags_text: str, idx: int): return update_filter_and_gallery() -def interrogate_selected_image_blip(): +def interrogate_selected_image(interrogator_name: str, use_threshold_booru: bool, threshold_booru: float, use_threshold_waifu: bool, threshold_waifu: float): global gallery_selected_image_path - return dte.interrogate_image_blip(gallery_selected_image_path) - - -def interrogate_selected_image_git(): - global gallery_selected_image_path - return dte.interrogate_image_git(gallery_selected_image_path) - - -def interrogate_selected_image_booru(use_threshold: bool, threshold: float): - global gallery_selected_image_path - threshold = threshold if use_threshold else shared.opts.interrogate_deepbooru_score_threshold - return dte.interrogate_image_booru(gallery_selected_image_path, threshold) - - -def interrogate_selected_image_waifu(use_threshold: bool, threshold: float): - global gallery_selected_image_path - threshold = threshold if use_threshold else shared.opts.interrogate_deepbooru_score_threshold - return dte.interrogate_image_waifu(gallery_selected_image_path, threshold) + threshold_booru = threshold_booru if use_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold + threshold_waifu = threshold_waifu if use_threshold_waifu else shared.opts.interrogate_deepbooru_score_threshold + return dte.interrogate_image(gallery_selected_image_path, interrogator_name, threshold_booru, threshold_waifu) # ================================================================ @@ -581,11 +577,7 @@ def on_ui_tabs(): cb_load_caption_from_filename = gr.Checkbox(value=cfg_general.load_caption_from_filename, label='Load caption from filename if no text file exists') with gr.Column(): rb_use_interrogator = gr.Radio(choices=['No', 'If Empty', 'Overwrite', 'Prepend', 'Append'], value=cfg_general.use_interrogator, label='Use Interrogator Caption') - with gr.Row(): - cb_use_blip_to_prefill = gr.Checkbox(value=cfg_general.use_blip_to_prefill, label='Use BLIP') - cb_use_git_to_prefill = gr.Checkbox(value=cfg_general.use_git_to_prefill, label='Use GIT', visible=False) - cb_use_booru_to_prefill = gr.Checkbox(value=cfg_general.use_booru_to_prefill, label='Use DeepDanbooru') - cb_use_waifu_to_prefill = gr.Checkbox(value=cfg_general.use_waifu_to_prefill, label='Use WDv1.4 Tagger') + dd_intterogator_names = gr.Dropdown(label = 'Interrogators', choices=dte.INTERROGATOR_NAMES, value=cfg_general.use_interrogator_names, interactive=True, multiselect=True) with gr.Accordion(label='Interrogator Settings', open=False): with gr.Row(): cb_use_custom_threshold_booru = gr.Checkbox(value=cfg_general.use_custom_threshold_booru, label='Use Custom Threshold (Booru)', interactive=True) @@ -689,10 +681,8 @@ def on_ui_tabs(): with gr.Tab(label='Interrogate Selected Image'): with gr.Row(): - btn_interrogate_blip = gr.Button(value='Interrogate with BLIP') - btn_interrogate_git = gr.Button(value='Interrogate with GIT Large', visible=False) - btn_interrogate_booru = gr.Button(value='Interrogate with DeepDanbooru') - btn_interrogate_waifu = gr.Button(value='Interrogate with WDv1.4 tagger') + dd_intterogator_names_si = gr.Dropdown(label = 'Interrogator', choices=dte.INTERROGATOR_NAMES, value=cfg_edit_selected.use_interrogator_name, interactive=True, multiselect=False) + btn_interrogate_si = gr.Button(value='Interrogate') tb_interrogate_selected_image = gr.Textbox(label='Interrogate Result', interactive=True, lines=6) with gr.Row(): btn_copy_interrogate = gr.Button(value='Copy and Overwrite') @@ -725,10 +715,10 @@ def on_ui_tabs(): #---------------------------------------------------------------- # General - components_general = [cb_backup, tb_img_directory, tb_caption_file_ext, cb_load_recursive, cb_load_caption_from_filename, rb_use_interrogator, cb_use_blip_to_prefill, cb_use_git_to_prefill, cb_use_booru_to_prefill, cb_use_waifu_to_prefill, cb_use_custom_threshold_booru, sl_custom_threshold_booru, cb_use_custom_threshold_waifu, sl_custom_threshold_waifu] + components_general = [cb_backup, tb_img_directory, tb_caption_file_ext, cb_load_recursive, cb_load_caption_from_filename, rb_use_interrogator, dd_intterogator_names, cb_use_custom_threshold_booru, sl_custom_threshold_booru, cb_use_custom_threshold_waifu, sl_custom_threshold_waifu] components_filter = [tag_filter_ui.rb_sort_by, tag_filter_ui.rb_sort_order, tag_filter_ui.rb_logic, tag_filter_ui_neg.rb_sort_by, tag_filter_ui_neg.rb_sort_order, tag_filter_ui_neg.rb_logic] components_batch_edit = [cb_show_only_tags_selected, cb_prepend_tags, cb_use_regex, rb_sr_replace_target, tag_select_ui_remove.rb_sort_by, tag_select_ui_remove.rb_sort_order] - components_edit_selected = [cb_copy_caption_automatically, cb_ask_save_when_caption_changed] + components_edit_selected = [cb_copy_caption_automatically, cb_ask_save_when_caption_changed, dd_intterogator_names_si] components_move_delete = [rb_move_or_delete_target_data, cbg_move_or_delete_target_file, tb_move_or_delete_caption_ext, tb_move_or_delete_destination_dir] configurable_components = components_general + components_filter + components_batch_edit + components_edit_selected + components_move_delete @@ -820,7 +810,7 @@ def on_ui_tabs(): btn_load_datasets.click( fn=load_files_from_dir, - inputs=[tb_img_directory, tb_caption_file_ext, cb_load_recursive, cb_load_caption_from_filename, rb_use_interrogator, cb_use_blip_to_prefill, cb_use_git_to_prefill, cb_use_booru_to_prefill, cb_use_waifu_to_prefill, cb_use_custom_threshold_booru, sl_custom_threshold_booru, cb_use_custom_threshold_waifu, sl_custom_threshold_waifu], + inputs=[tb_img_directory, tb_caption_file_ext, cb_load_recursive, cb_load_caption_from_filename, rb_use_interrogator, dd_intterogator_names, cb_use_custom_threshold_booru, sl_custom_threshold_booru, cb_use_custom_threshold_waifu, sl_custom_threshold_waifu], outputs= [gl_dataset_images, gl_filter_images, txt_gallery, txt_selection] + [cbg_hidden_dataset_filter, nb_hidden_dataset_filter_apply] + @@ -998,25 +988,9 @@ def on_ui_tabs(): outputs=[tb_edit_caption_selected_image] ) - btn_interrogate_blip.click( - fn=interrogate_selected_image_blip, - outputs=[tb_interrogate_selected_image] - ) - - btn_interrogate_git.click( - fn=interrogate_selected_image_git, - outputs=[tb_interrogate_selected_image] - ) - - btn_interrogate_booru.click( - fn=interrogate_selected_image_booru, - inputs=[cb_use_custom_threshold_booru, sl_custom_threshold_booru], - outputs=[tb_interrogate_selected_image] - ) - - btn_interrogate_waifu.click( - fn=interrogate_selected_image_waifu, - inputs=[cb_use_custom_threshold_waifu, sl_custom_threshold_waifu], + btn_interrogate_si.click( + fn=interrogate_selected_image, + inputs=[dd_intterogator_names_si, cb_use_custom_threshold_booru, sl_custom_threshold_booru, cb_use_custom_threshold_waifu, sl_custom_threshold_waifu], outputs=[tb_interrogate_selected_image] )