diff --git a/install.py b/install.py new file mode 100644 index 0000000..d24b145 --- /dev/null +++ b/install.py @@ -0,0 +1,9 @@ +import launch +from modules.shared import cmd_opts + +if 'all' in cmd_opts.use_cpu or 'interrogate' in cmd_opts.use_cpu: + if not launch.is_installed("onnxruntime"): + launch.run_pip("install onnxruntime", "requirements for using SmilingWolf/wd-v1-4-vit-tagger on CPU device") +else: + if not launch.is_installed("onnxruntime-gpu"): + launch.run_pip("install onnxruntime-gpu", "requirements for using SmilingWolf/wd-v1-4-vit-tagger on GPU device") \ No newline at end of file diff --git a/scripts/dataset_tag_editor/dataset_tag_editor.py b/scripts/dataset_tag_editor/dataset_tag_editor.py index 8e99093..a55f851 100644 --- a/scripts/dataset_tag_editor/dataset_tag_editor.py +++ b/scripts/dataset_tag_editor/dataset_tag_editor.py @@ -8,6 +8,7 @@ from PIL import Image from enum import Enum import modules.deepbooru as deepbooru from scripts.dataset_tag_editor.dataset import Dataset, Data +from . import tag_scorer re_tags = re.compile(r'^(.+) \[\d+\]$') @@ -38,6 +39,17 @@ def interrogate_image_booru(path): return deepbooru.model.tag(img) +def interrogate_image_waifu(path): + try: + img = Image.open(path).convert('RGB') + except: + return '' + else: + with tag_scorer.WaifuDiffusion() as scorer: + res = scorer.predict(img, threshold=shared.opts.interrogate_deepbooru_score_threshold) + return ', '.join(tag_scorer.get_arranged_tags(res)) + + def get_filepath_set(dir: str, recursive: bool): if recursive: dirs_to_see = [dir] @@ -65,6 +77,8 @@ class DatasetTagEditor: self.dataset = Dataset() self.tag_counts = {} self.dataset_dir = '' + self.booru_tag_scores = None + self.waifu_tag_scores = None def get_tag_list(self): if len(self.tag_counts) == 0: @@ -342,7 +356,25 @@ class DatasetTagEditor: print(e) - def load_dataset(self, img_dir: str, recursive: bool = False, load_caption_from_filename: bool = True, interrogate_method: InterrogateMethod = InterrogateMethod.NONE, use_booru: bool = True, use_clip: bool = False): + def score_dataset_booru(self): + with tag_scorer.DeepDanbooru() as scorer: + self.booru_tag_scores = dict() + for img_path in self.dataset.datas.keys(): + img = Image.open(img_path) + probs = scorer.predict(img) + self.booru_tag_scores[img_path] = probs + + + def score_dataset_waifu(self): + with tag_scorer.DeepDanbooru() as scorer: + self.waifu_tag_scores = dict() + for img_path in self.dataset.datas.keys(): + img = Image.open(img_path) + probs = scorer.predict(img) + self.waifu_tag_scores[img_path] = probs + + + def load_dataset(self, img_dir: str, recursive: bool = False, load_caption_from_filename: bool = True, interrogate_method: InterrogateMethod = InterrogateMethod.NONE, use_booru: bool = True, use_clip: bool = False, use_waifu: bool = False): self.clear() print(f'Loading dataset from {img_dir}') if recursive: @@ -359,7 +391,9 @@ class DatasetTagEditor: print(f'Total {len(filepath_set)} files under the directory including not image files.') - def load_images(filepath_set: Set[str]): + + from . import tag_scorer + def load_images(filepath_set: Set[str], scorers: List[tag_scorer.TagScorer]): for img_path in filepath_set: img_dir = os.path.dirname(img_path) img_filename, img_ext = os.path.splitext(os.path.basename(img_path)) @@ -387,46 +421,63 @@ class DatasetTagEditor: tokens = self.re_word.findall(caption_text) caption_text = (shared.opts.dataset_filename_join_string or "").join(tokens) - if interrogate_method != InterrogateMethod.NONE and ((interrogate_method != InterrogateMethod.PREFILL) or (interrogate_method == InterrogateMethod.PREFILL and not caption_text)): + interrogate_tags = [] + caption_tags = [t.strip() for t in caption_text.split(',')] + if interrogate_method != InterrogateMethod.NONE and ((interrogate_method != InterrogateMethod.PREFILL) or (interrogate_method == InterrogateMethod.PREFILL and not caption_tags)): try: img = Image.open(img_path).convert('RGB') except Exception as e: print(e) print(f'Cannot interrogate file: {img_path}') else: - interrogate_text = '' if use_clip: - interrogate_text += shared.interrogator.generate_caption(img) + tmp = [t.strip() for t in shared.interrogator.generate_caption(img).split(',')] + interrogate_tags += [t for t in tmp if t] - if use_booru: - tmp = deepbooru.model.tag_multi(img) - interrogate_text += (', ' if interrogate_text and tmp else '') + tmp + for scorer in scorers: + probs = scorer.predict(img) + interrogate_tags += [t for t, p in probs.items() if p > shared.opts.interrogate_deepbooru_score_threshold] + if isinstance(scorer, tag_scorer.DeepDanbooru): + if not self.booru_tag_scores: + self.booru_tag_scores = dict() + self.booru_tag_scores[img_path] = probs + elif isinstance(scorer, tag_scorer.WaifuDiffusion): + if not self.waifu_tag_scores: + self.waifu_tag_scores = dict() + self.waifu_tag_scores[img_path] = probs - if interrogate_method == InterrogateMethod.OVERWRITE: - caption_text = interrogate_text - elif interrogate_method == InterrogateMethod.PREPEND: - caption_text = interrogate_text + (', ' if interrogate_text and caption_text else '') + caption_text - else: - caption_text += (', ' if interrogate_text and caption_text else '') + interrogate_text img.close() - self.set_tags_by_image_path(img_path, [t.strip() for t in caption_text.split(',')]) - + if interrogate_method == InterrogateMethod.OVERWRITE: + tags = interrogate_tags + elif interrogate_method == InterrogateMethod.PREPEND: + tags = interrogate_tags + caption_tags + else: + tags = caption_tags + interrogate_tags + self.set_tags_by_image_path(img_path, tags) + try: + scorers = [] if interrogate_method != InterrogateMethod.NONE: if use_clip: shared.interrogator.load() if use_booru: - deepbooru.model.start() + scorer = tag_scorer.DeepDanbooru() + scorer.start() + scorers.append(scorer) + if use_waifu: + scorer = tag_scorer.WaifuDiffusion() + scorer.start() + scorers.append(scorer) - load_images(filepath_set = filepath_set) + load_images(filepath_set = filepath_set, scorers=scorers) finally: if interrogate_method != InterrogateMethod.NONE: if use_clip: shared.interrogator.send_blip_to_ram() - if use_booru: - deepbooru.model.stop() + for scorer in scorers: + scorer.stop() self.construct_tag_counts() print(f'Loading Completed: {len(self.dataset)} images found') @@ -481,6 +532,8 @@ class DatasetTagEditor: self.dataset.clear() self.tag_counts.clear() self.dataset_dir = '' + self.booru_tag_scores = None + self.waifu_tag_scores = None def construct_tag_counts(self): diff --git a/scripts/dataset_tag_editor/filters.py b/scripts/dataset_tag_editor/filters.py index 6816db6..1bf4a76 100644 --- a/scripts/dataset_tag_editor/filters.py +++ b/scripts/dataset_tag_editor/filters.py @@ -1,5 +1,5 @@ from scripts.dataset_tag_editor.dataset import Dataset -from typing import Set, List +from typing import Set, Dict from enum import Enum class TagFilter(Dataset.Filter): @@ -92,4 +92,28 @@ class PathFilter(Dataset.Filter): return dataset - \ No newline at end of file +class TagScoreFilter(Dataset.Filter): + class Mode(Enum): + NONE = 0 + LESS_THAN = 1 + GREATER_THAN = 2 + + def __init__(self, scores: Dict[str, Dict[str, float]], tag: str, threshold: float, mode: Mode = Mode.NONE): + self.scores = scores + self.mode = mode + self.tag = tag + self.threshold = threshold + + def apply(self, dataset: Dataset): + if self.mode == TagScoreFilter.Mode.NONE: + return dataset + + paths_remove = {path for path, scores in self.scores.items() if (scores.get(self.tag) or 0) > self.threshold} + + if self.mode == TagScoreFilter.Mode.GREATER_THAN: + paths_remove = {path for path in dataset.datas.keys()} - paths_remove + + for path in paths_remove: + dataset.remove_by_path(path) + + return dataset \ No newline at end of file diff --git a/scripts/dataset_tag_editor/tag_scorer.py b/scripts/dataset_tag_editor/tag_scorer.py new file mode 100644 index 0000000..c2dd3f4 --- /dev/null +++ b/scripts/dataset_tag_editor/tag_scorer.py @@ -0,0 +1,101 @@ +from PIL import Image +import re +import torch +import numpy as np +from typing import Optional, Dict +from modules import devices, shared +from modules import deepbooru as db +from . import waifu_diffusion_tagger + +class TagScorer: + def __enter__(self): + self.start() + return self + def __exit__(self, exception_type, exception_value, traceback): + self.stop() + pass + def start(self): + pass + def stop(self): + pass + def predict(self,image: Image.Image, threshold: Optional[float]): + raise NotImplementedError + def name(self): + raise NotImplementedError + + +def get_replaced_tag(tag: str): + use_spaces = shared.opts.deepbooru_use_spaces + use_escape = shared.opts.deepbooru_escape + if use_spaces: + tag = tag.replace('_', ' ') + if use_escape: + tag = re.sub(db.re_special, r'\\\1', tag) + return tag + + +def get_arranged_tags(probs: Dict[str, float]): + alpha_sort = shared.opts.deepbooru_sort_alpha + if alpha_sort: + return sorted(probs) + else: + return [tag for tag, _ in sorted(probs.items(), key=lambda x: -x[1])] + + +class DeepDanbooru(TagScorer): + def start(self): + db.model.start() + + def stop(self): + db.model.stop() + + # brought from webUI modules/deepbooru.py and modified + def predict(self, image: Image.Image, threshold: Optional[float] = None): + from modules import images + + pic = images.resize_image(2, image.convert("RGB"), 512, 512) + a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 + + with torch.no_grad(), devices.autocast(): + x = torch.from_numpy(a).to(devices.device) + y = db.model.model(x)[0].detach().cpu().numpy() + + probability_dict = dict() + + for tag, probability in zip(db.model.model.tags, y): + if threshold and probability < threshold: + continue + if tag.startswith("rating:"): + continue + probability_dict[get_replaced_tag(tag)] = probability + + return probability_dict + + def name(self): + return 'DeepDanbooru' + + +class WaifuDiffusion(TagScorer): + def start(self): + waifu_diffusion_tagger.instance.load() + return self + + def stop(self): + waifu_diffusion_tagger.instance.unload() + + # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified + def predict(self, image: Image.Image, threshold: Optional[float] = None): + # may not use ratings + # rating = dict(labels[:4]) + + labels = waifu_diffusion_tagger.instance.apply(image) + + if threshold: + probability_dict = dict([(get_replaced_tag(x[0]), x[1]) for x in labels[4:] if x[1] > threshold]) + else: + probability_dict = dict([(get_replaced_tag(x[0]), x[1]) for x in labels[4:]]) + + return probability_dict + + def name(self): + return 'wd-v1-4-tags' \ No newline at end of file diff --git a/scripts/dataset_tag_editor/waifu_diffusion_tagger.py b/scripts/dataset_tag_editor/waifu_diffusion_tagger.py new file mode 100644 index 0000000..55d01a4 --- /dev/null +++ b/scripts/dataset_tag_editor/waifu_diffusion_tagger.py @@ -0,0 +1,65 @@ +from PIL import Image +import numpy as np +from typing import Optional, List, Tuple +import onnxruntime as ort +from modules import shared + + +class WaifuDiffusionTagger(): + # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified + MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger" + MODEL_FILENAME = "model.onnx" + LABEL_FILENAME = "selected_tags.csv" + def __init__(self): + self.model: ort.InferenceSession = None + self.labels = [] + + def load(self): + import huggingface_hub + if not self.model: + path_model = huggingface_hub.hf_hub_download( + self.MODEL_REPO, self.MODEL_FILENAME + ) + if 'all' in shared.cmd_opts.use_cpu or 'interrogate' in shared.cmd_opts.use_cpu: + providers = ['CPUExecutionProvider'] + else: + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + self.model = ort.InferenceSession(path_model, providers=providers) + + path_label = huggingface_hub.hf_hub_download( + self.MODEL_REPO, self.LABEL_FILENAME + ) + import pandas as pd + self.labels = pd.read_csv(path_label)["name"].tolist() + + def unload(self): + if not shared.opts.interrogate_keep_models_in_memory: + self.model = None + + # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified + def apply(self, image: Image.Image): + if not self.model: + return dict() + + from modules import images + + _, height, width, _ = self.model.get_inputs()[0].shape + + # the way to fill empty pixels is quite different from original one; + # original: fill by white pixels + # this: repeat the pixels on the edge + image = images.resize_image(2, image.convert("RGB"), width, height) + image_np = np.array(image, dtype=np.float32) + # PIL RGB to OpenCV BGR + image_np = image_np[:, :, ::-1] + image_np = np.expand_dims(image_np, 0) + + input_name = self.model.get_inputs()[0].name + label_name = self.model.get_outputs()[0].name + probs = self.model.run([label_name], {input_name: image_np})[0] + labels: List[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float))) + + return labels + + +instance = WaifuDiffusionTagger() diff --git a/scripts/main.py b/scripts/main.py index 6438973..cab0fde 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -2,7 +2,7 @@ from typing import List, Set from modules import shared, script_callbacks from modules.shared import opts, cmd_opts import gradio as gr -from scripts.dataset_tag_editor.dataset_tag_editor import DatasetTagEditor, interrogate_image_clip, interrogate_image_booru, InterrogateMethod +from scripts.dataset_tag_editor.dataset_tag_editor import DatasetTagEditor, InterrogateMethod from scripts.dataset_tag_editor.filters import TagFilter, PathFilter from scripts.dataset_tag_editor.ui import TagFilterUI @@ -52,7 +52,7 @@ def get_current_move_or_delete_target_num(target_data: str, idx: int): return f'Target dataset num: 0' -def load_files_from_dir(dir: str, recursive: bool, load_caption_from_filename: bool, use_interrogator: str, use_clip: bool, use_booru: bool): +def load_files_from_dir(dir: str, recursive: bool, load_caption_from_filename: bool, use_interrogator: str, use_clip: bool, use_booru: bool, use_waifu: bool): global total_image_num, displayed_image_num, tmp_selection_img_path_set, gallery_selected_image_path, selection_selected_image_path, path_filter interrogate_method = InterrogateMethod.NONE @@ -65,7 +65,7 @@ def load_files_from_dir(dir: str, recursive: bool, load_caption_from_filename: b elif use_interrogator == 'Append': interrogate_method = InterrogateMethod.APPEND - dataset_tag_editor.load_dataset(img_dir=dir, recursive=recursive, load_caption_from_filename=load_caption_from_filename, interrogate_method=interrogate_method, use_clip=use_clip, use_booru=use_booru) + dataset_tag_editor.load_dataset(dir, recursive, load_caption_from_filename, interrogate_method, use_booru, use_clip, use_waifu) img_paths = dataset_tag_editor.get_filtered_imgpaths(filters=get_filters()) path_filter = PathFilter() total_image_num = displayed_image_num = len(dataset_tag_editor.get_img_path_set()) @@ -237,14 +237,22 @@ def change_selected_image_caption(tags_text: str, idx: int): def interrogate_selected_image_clip(): global gallery_selected_image_path + from scripts.dataset_tag_editor.dataset_tag_editor import interrogate_image_clip return interrogate_image_clip(gallery_selected_image_path) def interrogate_selected_image_booru(): global gallery_selected_image_path + from scripts.dataset_tag_editor.dataset_tag_editor import interrogate_image_booru return interrogate_image_booru(gallery_selected_image_path) +def interrogate_selected_image_waifu(): + global gallery_selected_image_path + from scripts.dataset_tag_editor.dataset_tag_editor import interrogate_image_waifu + return interrogate_image_waifu(gallery_selected_image_path) + + # ================================================================ # Callbacks for "Batch Edit Captions" tab # ================================================================ @@ -365,6 +373,7 @@ def on_ui_tabs(): with gr.Row(): cb_use_clip_to_prefill = gr.Checkbox(value=False, label='Use BLIP') cb_use_booru_to_prefill = gr.Checkbox(value=False, label='Use DeepDanbooru') + cb_use_waifu_to_prefill = gr.Checkbox(value=False, label='Use WDv1.4 Tagger') gl_dataset_images = gr.Gallery(label='Dataset Images', elem_id="dataset_tag_editor_dataset_gallery").style(grid=opts.dataset_editor_image_columns) txt_gallery = gr.HTML(value=get_current_gallery_txt()) @@ -452,6 +461,7 @@ def on_ui_tabs(): with gr.Row(): btn_interrogate_clip = gr.Button(value='Interrogate with BLIP') btn_interrogate_booru = gr.Button(value='Interrogate with DeepDanbooru') + btn_interrogate_waifu = gr.Button(value='Interrogate with WDv1.4 tagger') tb_interrogate_selected_image = gr.Textbox(label='Interrogate Result', interactive=True, lines=6) with gr.Row(): btn_copy_interrogate = gr.Button(value='Copy and Overwrite') @@ -500,7 +510,7 @@ def on_ui_tabs(): btn_load_datasets.click( fn=load_files_from_dir, - inputs=[tb_img_directory, cb_load_recursive, cb_load_caption_from_filename, rb_use_interrogator, cb_use_clip_to_prefill, cb_use_booru_to_prefill], + inputs=[tb_img_directory, cb_load_recursive, cb_load_caption_from_filename, rb_use_interrogator, cb_use_clip_to_prefill, cb_use_booru_to_prefill, cb_use_waifu_to_prefill], outputs=[gl_dataset_images, gl_selected_images, txt_gallery, txt_selection] + [tag_filter_ui.cbg_tags, tag_filter_ui_neg.cbg_tags, gl_dataset_images, nb_hidden_image_index, txt_gallery] + [tb_common_tags, tb_edit_tags] ) btn_load_datasets.click( @@ -629,6 +639,11 @@ def on_ui_tabs(): outputs=[tb_interrogate_selected_image] ) + btn_interrogate_waifu.click( + fn=interrogate_selected_image_waifu, + outputs=[tb_interrogate_selected_image] + ) + btn_copy_interrogate.click( fn=lambda a:a, inputs=[tb_interrogate_selected_image],