diff --git a/scripts/dataset_tag_editor/dte_logic.py b/scripts/dataset_tag_editor/dte_logic.py index 10b2a81..563d73f 100644 --- a/scripts/dataset_tag_editor/dte_logic.py +++ b/scripts/dataset_tag_editor/dte_logic.py @@ -21,6 +21,7 @@ from . import ( taggers_builtin ) from .custom_scripts import CustomScripts +from .interrogator_names import BLIP2_CAPTIONING_NAMES, WD_TAGGERS, WD_TAGGERS_TIMM from scripts.tokenizer import clip_tokenizer from scripts.tagger import Tagger @@ -68,39 +69,26 @@ class DatasetTagEditor(Singleton): custom_taggers:list[Tagger] = custom_tagger_scripts.load_derived_classes(Tagger) logger.write(f"Custom taggers loaded: {[tagger().name() for tagger in custom_taggers]}") - self.BLIP2_CAPTIONING_NAMES = [ - "blip2-opt-2.7b", - "blip2-opt-2.7b-coco", - "blip2-opt-6.7b", - "blip2-opt-6.7b-coco", - "blip2-flan-t5-xl", - "blip2-flan-t5-xl-coco", - "blip2-flan-t5-xxl", - ] - - self.WD_TAGGERS = { - "wd-v1-4-vit-tagger" : 0.35, - "wd-v1-4-convnext-tagger" : 0.35, - "wd-v1-4-vit-tagger-v2" : 0.3537, - "wd-v1-4-convnext-tagger-v2" : 0.3685, - "wd-v1-4-convnextv2-tagger-v2" : 0.371, - "wd-v1-4-swinv2-tagger-v2" : 0.3771, - "wd-v1-4-moat-tagger-v2" : 0.3771, - "wd-vit-tagger-v3" : 0.2614, - "wd-convnext-tagger-v3" : 0.2682, - "wd-swinv2-tagger-v3" : 0.2653, - } - # {tagger name : default tagger threshold} - # v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf - + def read_wd_batchsize(name:str): + if "vit" in name: + return shared.opts.dataset_editor_batch_size_vit + elif "convnext" in name: + return shared.opts.dataset_editor_batch_size_convnext + elif "swinv2" in name: + return shared.opts.dataset_editor_batch_size_swinv2 + self.INTERROGATORS = ( [taggers_builtin.BLIP()] - + [taggers_builtin.BLIP2(name) for name in self.BLIP2_CAPTIONING_NAMES] + + [taggers_builtin.BLIP2(name) for name in BLIP2_CAPTIONING_NAMES] + [taggers_builtin.GITLarge()] + [taggers_builtin.DeepDanbooru()] + [ taggers_builtin.WaifuDiffusion(name, threshold) - for name, threshold in self.WD_TAGGERS.items() + for name, threshold in WD_TAGGERS.items() + ] + + [ + taggers_builtin.WaifuDiffusionTimm(name, threshold, int(read_wd_batchsize(name))) + for name, threshold in WD_TAGGERS_TIMM.items() ] + [taggers_builtin.Z3D_E621()] + [cls_tagger() for cls_tagger in custom_taggers] @@ -118,7 +106,7 @@ class DatasetTagEditor(Singleton): if isinstance(it, taggers_builtin.DeepDanbooru): with it as tg: res = tg.predict(img, threshold_booru) - elif isinstance(it, taggers_builtin.WaifuDiffusion): + elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm): with it as tg: res = tg.predict(img, threshold_wd) elif isinstance(it, taggers_builtin.Z3D_E621): @@ -752,7 +740,7 @@ class DatasetTagEditor(Singleton): if it.name() in interrogator_names: if isinstance(it, taggers_builtin.DeepDanbooru): tagger_thresholds.append((it, threshold_booru)) - elif isinstance(it, taggers_builtin.WaifuDiffusion): + elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm): tagger_thresholds.append((it, threshold_waifu)) elif isinstance(it, taggers_builtin.Z3D_E621): tagger_thresholds.append((it, threshold_z3d)) diff --git a/scripts/dataset_tag_editor/interrogator_names.py b/scripts/dataset_tag_editor/interrogator_names.py new file mode 100644 index 0000000..0ebaedd --- /dev/null +++ b/scripts/dataset_tag_editor/interrogator_names.py @@ -0,0 +1,28 @@ + +BLIP2_CAPTIONING_NAMES = [ + "blip2-opt-2.7b", + "blip2-opt-2.7b-coco", + "blip2-opt-6.7b", + "blip2-opt-6.7b-coco", + "blip2-flan-t5-xl", + "blip2-flan-t5-xl-coco", + "blip2-flan-t5-xxl", +] + + +# {tagger name : default tagger threshold} +# v1: idk if it's okay v2, v3: P=R thresholds on each repo https://huggingface.co/SmilingWolf +WD_TAGGERS = { + "wd-v1-4-vit-tagger" : 0.35, + "wd-v1-4-convnext-tagger" : 0.35, + "wd-v1-4-vit-tagger-v2" : 0.3537, + "wd-v1-4-convnext-tagger-v2" : 0.3685, + "wd-v1-4-convnextv2-tagger-v2" : 0.371, + "wd-v1-4-moat-tagger-v2" : 0.3771 +} +WD_TAGGERS_TIMM = { + "wd-v1-4-swinv2-tagger-v2" : 0.3771, + "wd-vit-tagger-v3" : 0.2614, + "wd-convnext-tagger-v3" : 0.2682, + "wd-swinv2-tagger-v3" : 0.2653, +} \ No newline at end of file diff --git a/scripts/dataset_tag_editor/interrogators/__init__.py b/scripts/dataset_tag_editor/interrogators/__init__.py index 3ab3dbf..1fa1552 100644 --- a/scripts/dataset_tag_editor/interrogators/__init__.py +++ b/scripts/dataset_tag_editor/interrogators/__init__.py @@ -1,7 +1,8 @@ from .blip2_captioning import BLIP2Captioning from .git_large_captioning import GITLargeCaptioning from .waifu_diffusion_tagger import WaifuDiffusionTagger +from .waifu_diffusion_tagger_timm import WaifuDiffusionTaggerTimm __all__ = [ - "BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger' + "BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger', 'WaifuDiffusionTaggerTimm' ] \ No newline at end of file diff --git a/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger_timm.py b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger_timm.py new file mode 100644 index 0000000..0e7599b --- /dev/null +++ b/scripts/dataset_tag_editor/interrogators/waifu_diffusion_tagger_timm.py @@ -0,0 +1,104 @@ +from PIL import Image +from typing import Tuple + +import torch +from torch.nn import functional as F +import torchvision.transforms as tf +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm + +from modules import shared, devices +import launch + + +class ImageDataset(Dataset): + def __init__(self, images:list[Image.Image], transforms:tf.Compose=None): + self.images = images + self.transforms = transforms + + def __len__(self): + return len(self.images) + + def __getitem__(self, i): + img = self.images[i] + if self.transforms is not None: + img = self.transforms(img) + return img + + + +class WaifuDiffusionTaggerTimm: + # some codes are brought from https://github.com/neggles/wdv3-timm and modified + + def __init__(self, model_repo, label_filename="selected_tags.csv"): + self.LABEL_FILENAME = label_filename + self.MODEL_REPO = model_repo + self.model = None + self.transform = None + self.labels = [] + + def load(self): + import huggingface_hub + + if not launch.is_installed("timm"): + launch.run_pip( + "install -U timm", + "requirements for dataset-tag-editor [timm]", + ) + import timm + from timm.data import create_transform, resolve_data_config + + if not self.model: + self.model: torch.nn.Module = timm.create_model( + "hf-hub:" + self.MODEL_REPO + ).eval() + state_dict = timm.models.load_state_dict_from_hf(self.MODEL_REPO) + self.model.load_state_dict(state_dict) + self.model.to(devices.device) + self.transform = create_transform( + **resolve_data_config(self.model.pretrained_cfg, model=self.model) + ) + + path_label = huggingface_hub.hf_hub_download( + self.MODEL_REPO, self.LABEL_FILENAME + ) + import pandas as pd + + self.labels = pd.read_csv(path_label)["name"].tolist() + + def unload(self): + if not shared.opts.interrogate_keep_models_in_memory: + self.model = None + devices.torch_gc() + + def apply(self, image: Image.Image): + if not self.model: + return [] + + image_t: torch.Tensor = self.transform(image).unsqueeze(0) + image_t = image_t[:, [2, 1, 0]] + image_t = image_t.to(devices.device) + + with torch.inference_mode(): + features = self.model.forward(image_t) + probs = F.sigmoid(features).detach().cpu() + + labels: list[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float))) + + return labels + + + def apply_multi(self, images: list[Image.Image], batch_size: int): + if not self.model: + return [] + + dataset = ImageDataset(images, self.transform) + dataloader = DataLoader(dataset, batch_size=batch_size) + + with torch.inference_mode(): + for batch in tqdm(dataloader): + batch = batch[:, [2, 1, 0]].to(devices.device) + features = self.model.forward(batch) + probs = F.sigmoid(features).detach().cpu().numpy() + labels: list[Tuple[str, float]] = [list(zip(self.labels, probs[i].astype(float))) for i in range(probs.shape[0])] + yield labels diff --git a/scripts/dataset_tag_editor/taggers_builtin.py b/scripts/dataset_tag_editor/taggers_builtin.py index bd3a1e3..ab99e4b 100644 --- a/scripts/dataset_tag_editor/taggers_builtin.py +++ b/scripts/dataset_tag_editor/taggers_builtin.py @@ -8,7 +8,7 @@ from modules import devices, shared from modules import deepbooru as db from scripts.tagger import Tagger, get_replaced_tag -from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger +from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger, WaifuDiffusionTaggerTimm class BLIP(Tagger): @@ -142,6 +142,28 @@ class WaifuDiffusion(Tagger): return self.repo_name +class WaifuDiffusionTimm(WaifuDiffusion): + def __init__(self, repo_name, threshold, batch_size=4): + super().__init__(repo_name, threshold) + self.tagger_inst = WaifuDiffusionTaggerTimm("SmilingWolf/" + repo_name) + self.batch_size = batch_size + + def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None): + for labels_list in self.tagger_inst.apply_multi(data, batch_size=self.batch_size): + for labels in labels_list: + if not shared.opts.dataset_editor_use_rating: + labels = labels[4:] + + if threshold is not None: + if threshold < 0: + threshold = self.threshold + tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold] + else: + tags = [get_replaced_tag(tag) for tag, _ in labels] + + yield tags + + class Z3D_E621(Tagger): def __init__(self): self.tagger_inst = WaifuDiffusionTagger("toynya/Z3D-E621-Convnext", label_filename="tags-selected.csv") diff --git a/scripts/main.py b/scripts/main.py index 297018f..6964b0e 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -472,6 +472,33 @@ def on_ui_settings(): ), ) + shared.opts.add_option( + "dataset_editor_batch_size_vit", + shared.OptionInfo( + 4, + "Inference batch size for ViT taggers", + section=section, + ), + ) + + shared.opts.add_option( + "dataset_editor_batch_size_convnext", + shared.OptionInfo( + 4, + "Inference batch size for ConvNeXt taggers", + section=section, + ), + ) + + shared.opts.add_option( + "dataset_editor_batch_size_swinv2", + shared.OptionInfo( + 4, + "Inference batch size for SwinTransformerV2 taggers", + section=section, + ), + ) + script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_ui_tabs(on_ui_tabs)