Add batch inference feature to WD Tagger (#99)

as same as standalone version
main
toshiaki1729 2024-05-21 00:47:51 +09:00 committed by GitHub
parent fe4bc8b141
commit 94cc5a2a6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 201 additions and 31 deletions

View File

@ -21,6 +21,7 @@ from . import (
taggers_builtin
)
from .custom_scripts import CustomScripts
from .interrogator_names import BLIP2_CAPTIONING_NAMES, WD_TAGGERS, WD_TAGGERS_TIMM
from scripts.tokenizer import clip_tokenizer
from scripts.tagger import Tagger
@ -68,39 +69,26 @@ class DatasetTagEditor(Singleton):
custom_taggers:list[Tagger] = custom_tagger_scripts.load_derived_classes(Tagger)
logger.write(f"Custom taggers loaded: {[tagger().name() for tagger in custom_taggers]}")
self.BLIP2_CAPTIONING_NAMES = [
"blip2-opt-2.7b",
"blip2-opt-2.7b-coco",
"blip2-opt-6.7b",
"blip2-opt-6.7b-coco",
"blip2-flan-t5-xl",
"blip2-flan-t5-xl-coco",
"blip2-flan-t5-xxl",
]
self.WD_TAGGERS = {
"wd-v1-4-vit-tagger" : 0.35,
"wd-v1-4-convnext-tagger" : 0.35,
"wd-v1-4-vit-tagger-v2" : 0.3537,
"wd-v1-4-convnext-tagger-v2" : 0.3685,
"wd-v1-4-convnextv2-tagger-v2" : 0.371,
"wd-v1-4-swinv2-tagger-v2" : 0.3771,
"wd-v1-4-moat-tagger-v2" : 0.3771,
"wd-vit-tagger-v3" : 0.2614,
"wd-convnext-tagger-v3" : 0.2682,
"wd-swinv2-tagger-v3" : 0.2653,
}
# {tagger name : default tagger threshold}
# v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf
def read_wd_batchsize(name:str):
if "vit" in name:
return shared.opts.dataset_editor_batch_size_vit
elif "convnext" in name:
return shared.opts.dataset_editor_batch_size_convnext
elif "swinv2" in name:
return shared.opts.dataset_editor_batch_size_swinv2
self.INTERROGATORS = (
[taggers_builtin.BLIP()]
+ [taggers_builtin.BLIP2(name) for name in self.BLIP2_CAPTIONING_NAMES]
+ [taggers_builtin.BLIP2(name) for name in BLIP2_CAPTIONING_NAMES]
+ [taggers_builtin.GITLarge()]
+ [taggers_builtin.DeepDanbooru()]
+ [
taggers_builtin.WaifuDiffusion(name, threshold)
for name, threshold in self.WD_TAGGERS.items()
for name, threshold in WD_TAGGERS.items()
]
+ [
taggers_builtin.WaifuDiffusionTimm(name, threshold, int(read_wd_batchsize(name)))
for name, threshold in WD_TAGGERS_TIMM.items()
]
+ [taggers_builtin.Z3D_E621()]
+ [cls_tagger() for cls_tagger in custom_taggers]
@ -118,7 +106,7 @@ class DatasetTagEditor(Singleton):
if isinstance(it, taggers_builtin.DeepDanbooru):
with it as tg:
res = tg.predict(img, threshold_booru)
elif isinstance(it, taggers_builtin.WaifuDiffusion):
elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm):
with it as tg:
res = tg.predict(img, threshold_wd)
elif isinstance(it, taggers_builtin.Z3D_E621):
@ -752,7 +740,7 @@ class DatasetTagEditor(Singleton):
if it.name() in interrogator_names:
if isinstance(it, taggers_builtin.DeepDanbooru):
tagger_thresholds.append((it, threshold_booru))
elif isinstance(it, taggers_builtin.WaifuDiffusion):
elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm):
tagger_thresholds.append((it, threshold_waifu))
elif isinstance(it, taggers_builtin.Z3D_E621):
tagger_thresholds.append((it, threshold_z3d))

View File

@ -0,0 +1,28 @@
BLIP2_CAPTIONING_NAMES = [
"blip2-opt-2.7b",
"blip2-opt-2.7b-coco",
"blip2-opt-6.7b",
"blip2-opt-6.7b-coco",
"blip2-flan-t5-xl",
"blip2-flan-t5-xl-coco",
"blip2-flan-t5-xxl",
]
# {tagger name : default tagger threshold}
# v1: idk if it's okay v2, v3: P=R thresholds on each repo https://huggingface.co/SmilingWolf
WD_TAGGERS = {
"wd-v1-4-vit-tagger" : 0.35,
"wd-v1-4-convnext-tagger" : 0.35,
"wd-v1-4-vit-tagger-v2" : 0.3537,
"wd-v1-4-convnext-tagger-v2" : 0.3685,
"wd-v1-4-convnextv2-tagger-v2" : 0.371,
"wd-v1-4-moat-tagger-v2" : 0.3771
}
WD_TAGGERS_TIMM = {
"wd-v1-4-swinv2-tagger-v2" : 0.3771,
"wd-vit-tagger-v3" : 0.2614,
"wd-convnext-tagger-v3" : 0.2682,
"wd-swinv2-tagger-v3" : 0.2653,
}

View File

@ -1,7 +1,8 @@
from .blip2_captioning import BLIP2Captioning
from .git_large_captioning import GITLargeCaptioning
from .waifu_diffusion_tagger import WaifuDiffusionTagger
from .waifu_diffusion_tagger_timm import WaifuDiffusionTaggerTimm
__all__ = [
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger'
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger', 'WaifuDiffusionTaggerTimm'
]

View File

@ -0,0 +1,104 @@
from PIL import Image
from typing import Tuple
import torch
from torch.nn import functional as F
import torchvision.transforms as tf
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from modules import shared, devices
import launch
class ImageDataset(Dataset):
def __init__(self, images:list[Image.Image], transforms:tf.Compose=None):
self.images = images
self.transforms = transforms
def __len__(self):
return len(self.images)
def __getitem__(self, i):
img = self.images[i]
if self.transforms is not None:
img = self.transforms(img)
return img
class WaifuDiffusionTaggerTimm:
# some codes are brought from https://github.com/neggles/wdv3-timm and modified
def __init__(self, model_repo, label_filename="selected_tags.csv"):
self.LABEL_FILENAME = label_filename
self.MODEL_REPO = model_repo
self.model = None
self.transform = None
self.labels = []
def load(self):
import huggingface_hub
if not launch.is_installed("timm"):
launch.run_pip(
"install -U timm",
"requirements for dataset-tag-editor [timm]",
)
import timm
from timm.data import create_transform, resolve_data_config
if not self.model:
self.model: torch.nn.Module = timm.create_model(
"hf-hub:" + self.MODEL_REPO
).eval()
state_dict = timm.models.load_state_dict_from_hf(self.MODEL_REPO)
self.model.load_state_dict(state_dict)
self.model.to(devices.device)
self.transform = create_transform(
**resolve_data_config(self.model.pretrained_cfg, model=self.model)
)
path_label = huggingface_hub.hf_hub_download(
self.MODEL_REPO, self.LABEL_FILENAME
)
import pandas as pd
self.labels = pd.read_csv(path_label)["name"].tolist()
def unload(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.model = None
devices.torch_gc()
def apply(self, image: Image.Image):
if not self.model:
return []
image_t: torch.Tensor = self.transform(image).unsqueeze(0)
image_t = image_t[:, [2, 1, 0]]
image_t = image_t.to(devices.device)
with torch.inference_mode():
features = self.model.forward(image_t)
probs = F.sigmoid(features).detach().cpu()
labels: list[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))
return labels
def apply_multi(self, images: list[Image.Image], batch_size: int):
if not self.model:
return []
dataset = ImageDataset(images, self.transform)
dataloader = DataLoader(dataset, batch_size=batch_size)
with torch.inference_mode():
for batch in tqdm(dataloader):
batch = batch[:, [2, 1, 0]].to(devices.device)
features = self.model.forward(batch)
probs = F.sigmoid(features).detach().cpu().numpy()
labels: list[Tuple[str, float]] = [list(zip(self.labels, probs[i].astype(float))) for i in range(probs.shape[0])]
yield labels

View File

@ -8,7 +8,7 @@ from modules import devices, shared
from modules import deepbooru as db
from scripts.tagger import Tagger, get_replaced_tag
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger, WaifuDiffusionTaggerTimm
class BLIP(Tagger):
@ -142,6 +142,28 @@ class WaifuDiffusion(Tagger):
return self.repo_name
class WaifuDiffusionTimm(WaifuDiffusion):
def __init__(self, repo_name, threshold, batch_size=4):
super().__init__(repo_name, threshold)
self.tagger_inst = WaifuDiffusionTaggerTimm("SmilingWolf/" + repo_name)
self.batch_size = batch_size
def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None):
for labels_list in self.tagger_inst.apply_multi(data, batch_size=self.batch_size):
for labels in labels_list:
if not shared.opts.dataset_editor_use_rating:
labels = labels[4:]
if threshold is not None:
if threshold < 0:
threshold = self.threshold
tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold]
else:
tags = [get_replaced_tag(tag) for tag, _ in labels]
yield tags
class Z3D_E621(Tagger):
def __init__(self):
self.tagger_inst = WaifuDiffusionTagger("toynya/Z3D-E621-Convnext", label_filename="tags-selected.csv")

View File

@ -472,6 +472,33 @@ def on_ui_settings():
),
)
shared.opts.add_option(
"dataset_editor_batch_size_vit",
shared.OptionInfo(
4,
"Inference batch size for ViT taggers",
section=section,
),
)
shared.opts.add_option(
"dataset_editor_batch_size_convnext",
shared.OptionInfo(
4,
"Inference batch size for ConvNeXt taggers",
section=section,
),
)
shared.opts.add_option(
"dataset_editor_batch_size_swinv2",
shared.OptionInfo(
4,
"Inference batch size for SwinTransformerV2 taggers",
section=section,
),
)
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_ui_tabs(on_ui_tabs)