Add batch inference feature to WD Tagger
as same as standalone versionfeatures/batch-inference
parent
8feb29de40
commit
570b0f5d94
|
|
@ -21,6 +21,7 @@ from . import (
|
|||
taggers_builtin
|
||||
)
|
||||
from .custom_scripts import CustomScripts
|
||||
from .interrogator_names import BLIP2_CAPTIONING_NAMES, WD_TAGGERS, WD_TAGGERS_TIMM
|
||||
from scripts.tokenizer import clip_tokenizer
|
||||
from scripts.tagger import Tagger
|
||||
|
||||
|
|
@ -68,39 +69,26 @@ class DatasetTagEditor(Singleton):
|
|||
custom_taggers:list[Tagger] = custom_tagger_scripts.load_derived_classes(Tagger)
|
||||
logger.write(f"Custom taggers loaded: {[tagger().name() for tagger in custom_taggers]}")
|
||||
|
||||
self.BLIP2_CAPTIONING_NAMES = [
|
||||
"blip2-opt-2.7b",
|
||||
"blip2-opt-2.7b-coco",
|
||||
"blip2-opt-6.7b",
|
||||
"blip2-opt-6.7b-coco",
|
||||
"blip2-flan-t5-xl",
|
||||
"blip2-flan-t5-xl-coco",
|
||||
"blip2-flan-t5-xxl",
|
||||
]
|
||||
|
||||
self.WD_TAGGERS = {
|
||||
"wd-v1-4-vit-tagger" : 0.35,
|
||||
"wd-v1-4-convnext-tagger" : 0.35,
|
||||
"wd-v1-4-vit-tagger-v2" : 0.3537,
|
||||
"wd-v1-4-convnext-tagger-v2" : 0.3685,
|
||||
"wd-v1-4-convnextv2-tagger-v2" : 0.371,
|
||||
"wd-v1-4-swinv2-tagger-v2" : 0.3771,
|
||||
"wd-v1-4-moat-tagger-v2" : 0.3771,
|
||||
"wd-vit-tagger-v3" : 0.2614,
|
||||
"wd-convnext-tagger-v3" : 0.2682,
|
||||
"wd-swinv2-tagger-v3" : 0.2653,
|
||||
}
|
||||
# {tagger name : default tagger threshold}
|
||||
# v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf
|
||||
|
||||
def read_wd_batchsize(name:str):
|
||||
if "vit" in name:
|
||||
return shared.opts.dataset_editor_batch_size_vit
|
||||
elif "convnext" in name:
|
||||
return shared.opts.dataset_editor_batch_size_convnext
|
||||
elif "swinv2" in name:
|
||||
return shared.opts.dataset_editor_batch_size_swinv2
|
||||
|
||||
self.INTERROGATORS = (
|
||||
[taggers_builtin.BLIP()]
|
||||
+ [taggers_builtin.BLIP2(name) for name in self.BLIP2_CAPTIONING_NAMES]
|
||||
+ [taggers_builtin.BLIP2(name) for name in BLIP2_CAPTIONING_NAMES]
|
||||
+ [taggers_builtin.GITLarge()]
|
||||
+ [taggers_builtin.DeepDanbooru()]
|
||||
+ [
|
||||
taggers_builtin.WaifuDiffusion(name, threshold)
|
||||
for name, threshold in self.WD_TAGGERS.items()
|
||||
for name, threshold in WD_TAGGERS.items()
|
||||
]
|
||||
+ [
|
||||
taggers_builtin.WaifuDiffusionTimm(name, threshold, int(read_wd_batchsize(name)))
|
||||
for name, threshold in WD_TAGGERS_TIMM.items()
|
||||
]
|
||||
+ [taggers_builtin.Z3D_E621()]
|
||||
+ [cls_tagger() for cls_tagger in custom_taggers]
|
||||
|
|
@ -118,7 +106,7 @@ class DatasetTagEditor(Singleton):
|
|||
if isinstance(it, taggers_builtin.DeepDanbooru):
|
||||
with it as tg:
|
||||
res = tg.predict(img, threshold_booru)
|
||||
elif isinstance(it, taggers_builtin.WaifuDiffusion):
|
||||
elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm):
|
||||
with it as tg:
|
||||
res = tg.predict(img, threshold_wd)
|
||||
elif isinstance(it, taggers_builtin.Z3D_E621):
|
||||
|
|
@ -752,7 +740,7 @@ class DatasetTagEditor(Singleton):
|
|||
if it.name() in interrogator_names:
|
||||
if isinstance(it, taggers_builtin.DeepDanbooru):
|
||||
tagger_thresholds.append((it, threshold_booru))
|
||||
elif isinstance(it, taggers_builtin.WaifuDiffusion):
|
||||
elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm):
|
||||
tagger_thresholds.append((it, threshold_waifu))
|
||||
elif isinstance(it, taggers_builtin.Z3D_E621):
|
||||
tagger_thresholds.append((it, threshold_z3d))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
|
||||
BLIP2_CAPTIONING_NAMES = [
|
||||
"blip2-opt-2.7b",
|
||||
"blip2-opt-2.7b-coco",
|
||||
"blip2-opt-6.7b",
|
||||
"blip2-opt-6.7b-coco",
|
||||
"blip2-flan-t5-xl",
|
||||
"blip2-flan-t5-xl-coco",
|
||||
"blip2-flan-t5-xxl",
|
||||
]
|
||||
|
||||
|
||||
# {tagger name : default tagger threshold}
|
||||
# v1: idk if it's okay v2, v3: P=R thresholds on each repo https://huggingface.co/SmilingWolf
|
||||
WD_TAGGERS = {
|
||||
"wd-v1-4-vit-tagger" : 0.35,
|
||||
"wd-v1-4-convnext-tagger" : 0.35,
|
||||
"wd-v1-4-vit-tagger-v2" : 0.3537,
|
||||
"wd-v1-4-convnext-tagger-v2" : 0.3685,
|
||||
"wd-v1-4-convnextv2-tagger-v2" : 0.371,
|
||||
"wd-v1-4-moat-tagger-v2" : 0.3771
|
||||
}
|
||||
WD_TAGGERS_TIMM = {
|
||||
"wd-v1-4-swinv2-tagger-v2" : 0.3771,
|
||||
"wd-vit-tagger-v3" : 0.2614,
|
||||
"wd-convnext-tagger-v3" : 0.2682,
|
||||
"wd-swinv2-tagger-v3" : 0.2653,
|
||||
}
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
from .blip2_captioning import BLIP2Captioning
|
||||
from .git_large_captioning import GITLargeCaptioning
|
||||
from .waifu_diffusion_tagger import WaifuDiffusionTagger
|
||||
from .waifu_diffusion_tagger_timm import WaifuDiffusionTaggerTimm
|
||||
|
||||
__all__ = [
|
||||
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger'
|
||||
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger', 'WaifuDiffusionTaggerTimm'
|
||||
]
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
from PIL import Image
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import torchvision.transforms as tf
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import shared, devices
|
||||
import launch
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, images:list[Image.Image], transforms:tf.Compose=None):
|
||||
self.images = images
|
||||
self.transforms = transforms
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, i):
|
||||
img = self.images[i]
|
||||
if self.transforms is not None:
|
||||
img = self.transforms(img)
|
||||
return img
|
||||
|
||||
|
||||
|
||||
class WaifuDiffusionTaggerTimm:
|
||||
# some codes are brought from https://github.com/neggles/wdv3-timm and modified
|
||||
|
||||
def __init__(self, model_repo, label_filename="selected_tags.csv"):
|
||||
self.LABEL_FILENAME = label_filename
|
||||
self.MODEL_REPO = model_repo
|
||||
self.model = None
|
||||
self.transform = None
|
||||
self.labels = []
|
||||
|
||||
def load(self):
|
||||
import huggingface_hub
|
||||
|
||||
if not launch.is_installed("timm"):
|
||||
launch.run_pip(
|
||||
"install -U timm",
|
||||
"requirements for dataset-tag-editor [timm]",
|
||||
)
|
||||
import timm
|
||||
from timm.data import create_transform, resolve_data_config
|
||||
|
||||
if not self.model:
|
||||
self.model: torch.nn.Module = timm.create_model(
|
||||
"hf-hub:" + self.MODEL_REPO
|
||||
).eval()
|
||||
state_dict = timm.models.load_state_dict_from_hf(self.MODEL_REPO)
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.model.to(devices.device)
|
||||
self.transform = create_transform(
|
||||
**resolve_data_config(self.model.pretrained_cfg, model=self.model)
|
||||
)
|
||||
|
||||
path_label = huggingface_hub.hf_hub_download(
|
||||
self.MODEL_REPO, self.LABEL_FILENAME
|
||||
)
|
||||
import pandas as pd
|
||||
|
||||
self.labels = pd.read_csv(path_label)["name"].tolist()
|
||||
|
||||
def unload(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.model = None
|
||||
devices.torch_gc()
|
||||
|
||||
def apply(self, image: Image.Image):
|
||||
if not self.model:
|
||||
return []
|
||||
|
||||
image_t: torch.Tensor = self.transform(image).unsqueeze(0)
|
||||
image_t = image_t[:, [2, 1, 0]]
|
||||
image_t = image_t.to(devices.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
features = self.model.forward(image_t)
|
||||
probs = F.sigmoid(features).detach().cpu()
|
||||
|
||||
labels: list[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def apply_multi(self, images: list[Image.Image], batch_size: int):
|
||||
if not self.model:
|
||||
return []
|
||||
|
||||
dataset = ImageDataset(images, self.transform)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
|
||||
with torch.inference_mode():
|
||||
for batch in tqdm(dataloader):
|
||||
batch = batch[:, [2, 1, 0]].to(devices.device)
|
||||
features = self.model.forward(batch)
|
||||
probs = F.sigmoid(features).detach().cpu().numpy()
|
||||
labels: list[Tuple[str, float]] = [list(zip(self.labels, probs[i].astype(float))) for i in range(probs.shape[0])]
|
||||
yield labels
|
||||
|
|
@ -8,7 +8,7 @@ from modules import devices, shared
|
|||
from modules import deepbooru as db
|
||||
|
||||
from scripts.tagger import Tagger, get_replaced_tag
|
||||
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger
|
||||
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger, WaifuDiffusionTaggerTimm
|
||||
|
||||
|
||||
class BLIP(Tagger):
|
||||
|
|
@ -142,6 +142,28 @@ class WaifuDiffusion(Tagger):
|
|||
return self.repo_name
|
||||
|
||||
|
||||
class WaifuDiffusionTimm(WaifuDiffusion):
|
||||
def __init__(self, repo_name, threshold, batch_size=4):
|
||||
super().__init__(repo_name, threshold)
|
||||
self.tagger_inst = WaifuDiffusionTaggerTimm("SmilingWolf/" + repo_name)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None):
|
||||
for labels_list in self.tagger_inst.apply_multi(data, batch_size=self.batch_size):
|
||||
for labels in labels_list:
|
||||
if not shared.opts.dataset_editor_use_rating:
|
||||
labels = labels[4:]
|
||||
|
||||
if threshold is not None:
|
||||
if threshold < 0:
|
||||
threshold = self.threshold
|
||||
tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold]
|
||||
else:
|
||||
tags = [get_replaced_tag(tag) for tag, _ in labels]
|
||||
|
||||
yield tags
|
||||
|
||||
|
||||
class Z3D_E621(Tagger):
|
||||
def __init__(self):
|
||||
self.tagger_inst = WaifuDiffusionTagger("toynya/Z3D-E621-Convnext", label_filename="tags-selected.csv")
|
||||
|
|
|
|||
|
|
@ -472,6 +472,33 @@ def on_ui_settings():
|
|||
),
|
||||
)
|
||||
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_batch_size_vit",
|
||||
shared.OptionInfo(
|
||||
4,
|
||||
"Inference batch size for ViT taggers",
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_batch_size_convnext",
|
||||
shared.OptionInfo(
|
||||
4,
|
||||
"Inference batch size for ConvNeXt taggers",
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_batch_size_swinv2",
|
||||
shared.OptionInfo(
|
||||
4,
|
||||
"Inference batch size for SwinTransformerV2 taggers",
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
script_callbacks.on_ui_tabs(on_ui_tabs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue