Compare commits

...

8 Commits
v0.3.3 ... main

Author SHA1 Message Date
toshiaki1729 d6c7dcef02 fix wrong func name 2024-06-27 20:03:26 +09:00
toshiaki1729 b69c821b97 load large images 2024-06-25 17:44:05 +09:00
toshiaki1729 620ebe1333 Make image shape square before interrogating
to fix #101
2024-06-25 17:24:58 +09:00
toshiaki1729 74de5b7a07 fix: "move or delete files" shows wrong number of target images 2024-06-10 11:29:54 +09:00
toshiaki1729 652286e46d Merge branch 'main' of https://github.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor 2024-05-26 13:17:19 +09:00
toshiaki1729 da5053bc2e Update waifu_diffusion_tagger_timm.py
fix #100
2024-05-26 13:17:17 +09:00
toshiaki1729 c716c83db0
Update version.txt 2024-05-21 00:48:45 +09:00
toshiaki1729 94cc5a2a6b
Add batch inference feature to WD Tagger (#99)
as same as standalone version
2024-05-21 00:47:51 +09:00
10 changed files with 315 additions and 70 deletions

View File

@ -3,6 +3,8 @@ import re, sys
from typing import List, Set, Optional
from enum import Enum
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from PIL import Image
from tqdm import tqdm
@ -11,7 +13,7 @@ from modules.textual_inversion.dataset import re_numbers_at_start
from scripts.singleton import Singleton
from scripts import logger
from scripts import logger, utilities
from scripts.paths import paths
from . import (
@ -21,14 +23,17 @@ from . import (
taggers_builtin
)
from .custom_scripts import CustomScripts
from .interrogator_names import BLIP2_CAPTIONING_NAMES, WD_TAGGERS, WD_TAGGERS_TIMM
from scripts.tokenizer import clip_tokenizer
from scripts.tagger import Tagger
re_tags = re.compile(r"^([\s\S]+?)( \[\d+\])?$")
re_newlines = re.compile(r"[\r\n]+")
def convert_rgb(data:Image.Image):
return data.convert("RGB")
def get_square_rgb(data:Image.Image):
data_rgb = utilities.get_rgb_image(data)
size = max(data.size)
return utilities.resize_and_fill(data_rgb, (size, size))
class DatasetTagEditor(Singleton):
class SortBy(Enum):
@ -68,39 +73,26 @@ class DatasetTagEditor(Singleton):
custom_taggers:list[Tagger] = custom_tagger_scripts.load_derived_classes(Tagger)
logger.write(f"Custom taggers loaded: {[tagger().name() for tagger in custom_taggers]}")
self.BLIP2_CAPTIONING_NAMES = [
"blip2-opt-2.7b",
"blip2-opt-2.7b-coco",
"blip2-opt-6.7b",
"blip2-opt-6.7b-coco",
"blip2-flan-t5-xl",
"blip2-flan-t5-xl-coco",
"blip2-flan-t5-xxl",
]
self.WD_TAGGERS = {
"wd-v1-4-vit-tagger" : 0.35,
"wd-v1-4-convnext-tagger" : 0.35,
"wd-v1-4-vit-tagger-v2" : 0.3537,
"wd-v1-4-convnext-tagger-v2" : 0.3685,
"wd-v1-4-convnextv2-tagger-v2" : 0.371,
"wd-v1-4-swinv2-tagger-v2" : 0.3771,
"wd-v1-4-moat-tagger-v2" : 0.3771,
"wd-vit-tagger-v3" : 0.2614,
"wd-convnext-tagger-v3" : 0.2682,
"wd-swinv2-tagger-v3" : 0.2653,
}
# {tagger name : default tagger threshold}
# v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf
def read_wd_batchsize(name:str):
if "vit" in name:
return shared.opts.dataset_editor_batch_size_vit
elif "convnext" in name:
return shared.opts.dataset_editor_batch_size_convnext
elif "swinv2" in name:
return shared.opts.dataset_editor_batch_size_swinv2
self.INTERROGATORS = (
[taggers_builtin.BLIP()]
+ [taggers_builtin.BLIP2(name) for name in self.BLIP2_CAPTIONING_NAMES]
+ [taggers_builtin.BLIP2(name) for name in BLIP2_CAPTIONING_NAMES]
+ [taggers_builtin.GITLarge()]
+ [taggers_builtin.DeepDanbooru()]
+ [
taggers_builtin.WaifuDiffusion(name, threshold)
for name, threshold in self.WD_TAGGERS.items()
for name, threshold in WD_TAGGERS.items()
]
+ [
taggers_builtin.WaifuDiffusionTimm(name, threshold, int(read_wd_batchsize(name)))
for name, threshold in WD_TAGGERS_TIMM.items()
]
+ [taggers_builtin.Z3D_E621()]
+ [cls_tagger() for cls_tagger in custom_taggers]
@ -109,7 +101,7 @@ class DatasetTagEditor(Singleton):
def interrogate_image(self, path: str, interrogator_name: str, threshold_booru, threshold_wd, threshold_z3d):
try:
img = Image.open(path).convert("RGB")
img = get_square_rgb(Image.open(path))
except:
return ""
else:
@ -118,7 +110,7 @@ class DatasetTagEditor(Singleton):
if isinstance(it, taggers_builtin.DeepDanbooru):
with it as tg:
res = tg.predict(img, threshold_booru)
elif isinstance(it, taggers_builtin.WaifuDiffusion):
elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm):
with it as tg:
res = tg.predict(img, threshold_wd)
elif isinstance(it, taggers_builtin.Z3D_E621):
@ -704,19 +696,27 @@ class DatasetTagEditor(Singleton):
continue
try:
img = Image.open(img_path)
if (max_res > 0):
img_res = int(max_res), int(max_res)
img.thumbnail(img_res)
except:
continue
else:
abs_path = str(img_path.absolute())
if not use_temp_dir and max_res <= 0:
img.already_saved_as = abs_path
images[abs_path] = img
imgpaths.append(abs_path)
imgpaths.append(abs_path)
return imgpaths, images
def load_thumbnails(images_raw: dict[str, Image.Image]):
images = {}
if max_res > 0:
for img_path, img in images_raw.items():
img_res = int(max_res), int(max_res)
images[img_path] = img.copy()
images[img_path].thumbnail(img_res)
else:
for img_path, img in images_raw.items():
if not use_temp_dir:
img.already_saved_as = img_path
images[img_path] = img
return images
def load_captions(imgpaths: list[str]):
taglists = []
@ -752,7 +752,7 @@ class DatasetTagEditor(Singleton):
if it.name() in interrogator_names:
if isinstance(it, taggers_builtin.DeepDanbooru):
tagger_thresholds.append((it, threshold_booru))
elif isinstance(it, taggers_builtin.WaifuDiffusion):
elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm):
tagger_thresholds.append((it, threshold_waifu))
elif isinstance(it, taggers_builtin.Z3D_E621):
tagger_thresholds.append((it, threshold_z3d))
@ -760,13 +760,15 @@ class DatasetTagEditor(Singleton):
tagger_thresholds.append((it, None))
if kohya_json_path:
imgpaths, self.images, taglists = kohya_metadata.read(
imgpaths, images_raw, taglists = kohya_metadata.read(
img_dir, kohya_json_path, use_temp_dir
)
else:
imgpaths, self.images = load_images(filepaths)
imgpaths, images_raw = load_images(filepaths)
taglists = load_captions(imgpaths)
self.images = load_thumbnails(images_raw)
interrogate_tags = {img_path : [] for img_path in imgpaths}
img_to_interrogate = [
@ -783,11 +785,11 @@ class DatasetTagEditor(Singleton):
def gen_data(paths:list[str], images:dict[str, Image.Image]):
for img_path in paths:
yield images.get(img_path)
yield images[img_path]
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=max_workers) as executor:
result = list(executor.map(convert_rgb, gen_data(img_to_interrogate, self.images)))
result = list(executor.map(get_square_rgb, gen_data(img_to_interrogate, images_raw)))
logger.write("Preprocess completed")
for tg, th in tqdm(tagger_thresholds):

View File

@ -0,0 +1,28 @@
BLIP2_CAPTIONING_NAMES = [
"blip2-opt-2.7b",
"blip2-opt-2.7b-coco",
"blip2-opt-6.7b",
"blip2-opt-6.7b-coco",
"blip2-flan-t5-xl",
"blip2-flan-t5-xl-coco",
"blip2-flan-t5-xxl",
]
# {tagger name : default tagger threshold}
# v1: idk if it's okay v2, v3: P=R thresholds on each repo https://huggingface.co/SmilingWolf
WD_TAGGERS = {
"wd-v1-4-vit-tagger" : 0.35,
"wd-v1-4-convnext-tagger" : 0.35,
"wd-v1-4-vit-tagger-v2" : 0.3537,
"wd-v1-4-convnext-tagger-v2" : 0.3685,
"wd-v1-4-convnextv2-tagger-v2" : 0.371,
"wd-v1-4-moat-tagger-v2" : 0.3771
}
WD_TAGGERS_TIMM = {
"wd-v1-4-swinv2-tagger-v2" : 0.3771,
"wd-vit-tagger-v3" : 0.2614,
"wd-convnext-tagger-v3" : 0.2682,
"wd-swinv2-tagger-v3" : 0.2653,
}

View File

@ -1,7 +1,8 @@
from .blip2_captioning import BLIP2Captioning
from .git_large_captioning import GITLargeCaptioning
from .waifu_diffusion_tagger import WaifuDiffusionTagger
from .waifu_diffusion_tagger_timm import WaifuDiffusionTaggerTimm
__all__ = [
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger'
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger', 'WaifuDiffusionTaggerTimm'
]

View File

@ -0,0 +1,104 @@
from PIL import Image
from typing import Tuple
import torch
from torch.nn import functional as F
import torchvision.transforms as tf
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from modules import shared, devices
import launch
class ImageDataset(Dataset):
def __init__(self, images:list[Image.Image], transforms:tf.Compose=None):
self.images = images
self.transforms = transforms
def __len__(self):
return len(self.images)
def __getitem__(self, i):
img = self.images[i]
if self.transforms is not None:
img = self.transforms(img)
return img
class WaifuDiffusionTaggerTimm:
# some codes are brought from https://github.com/neggles/wdv3-timm and modified
def __init__(self, model_repo, label_filename="selected_tags.csv"):
self.LABEL_FILENAME = label_filename
self.MODEL_REPO = model_repo
self.model = None
self.transform = None
self.labels = []
def load(self):
import huggingface_hub
if not launch.is_installed("timm"):
launch.run_pip(
"install -U timm",
"requirements for dataset-tag-editor [timm]",
)
import timm
from timm.data import create_transform, resolve_data_config
if not self.model:
self.model: torch.nn.Module = timm.create_model(
"hf-hub:" + self.MODEL_REPO
).eval()
state_dict = timm.models.load_state_dict_from_hf(self.MODEL_REPO)
self.model.load_state_dict(state_dict)
self.model.to(devices.device)
self.transform = create_transform(
**resolve_data_config(self.model.pretrained_cfg, model=self.model)
)
path_label = huggingface_hub.hf_hub_download(
self.MODEL_REPO, self.LABEL_FILENAME
)
import pandas as pd
self.labels = pd.read_csv(path_label)["name"].tolist()
def unload(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.model = None
devices.torch_gc()
def apply(self, image: Image.Image):
if not self.model:
return []
image_t: torch.Tensor = self.transform(image).unsqueeze(0)
image_t = image_t[:, [2, 1, 0]]
image_t = image_t.to(devices.device)
with torch.inference_mode():
features = self.model.forward(image_t)
probs = F.sigmoid(features).detach().cpu().numpy()
labels: list[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))
return labels
def apply_multi(self, images: list[Image.Image], batch_size: int):
if not self.model:
return []
dataset = ImageDataset(images, self.transform)
dataloader = DataLoader(dataset, batch_size=batch_size)
with torch.inference_mode():
for batch in tqdm(dataloader):
batch = batch[:, [2, 1, 0]].to(devices.device)
features = self.model.forward(batch)
probs = F.sigmoid(features).detach().cpu().numpy()
labels: list[Tuple[str, float]] = [list(zip(self.labels, probs[i].astype(float))) for i in range(probs.shape[0])]
yield labels

View File

@ -8,7 +8,7 @@ from modules import devices, shared
from modules import deepbooru as db
from scripts.tagger import Tagger, get_replaced_tag
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger, WaifuDiffusionTaggerTimm
class BLIP(Tagger):
@ -142,6 +142,28 @@ class WaifuDiffusion(Tagger):
return self.repo_name
class WaifuDiffusionTimm(WaifuDiffusion):
def __init__(self, repo_name, threshold, batch_size=4):
super().__init__(repo_name, threshold)
self.tagger_inst = WaifuDiffusionTaggerTimm("SmilingWolf/" + repo_name)
self.batch_size = batch_size
def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None):
for labels_list in self.tagger_inst.apply_multi(data, batch_size=self.batch_size):
for labels in labels_list:
if not shared.opts.dataset_editor_use_rating:
labels = labels[4:]
if threshold is not None:
if threshold < 0:
threshold = self.threshold
tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold]
else:
tags = [get_replaced_tag(tag) for tag, _ in labels]
yield tags
class Z3D_E621(Tagger):
def __init__(self):
self.tagger_inst = WaifuDiffusionTagger("toynya/Z3D-E621-Convnext", label_filename="tags-selected.csv")

View File

@ -413,10 +413,6 @@ def on_ui_tabs():
ui.move_or_delete_files.set_callbacks(
o_update_filter_and_gallery,
ui.dataset_gallery,
ui.filter_by_tags,
ui.batch_edit_captions,
ui.filter_by_selection,
ui.edit_caption_of_selected_image,
get_filters,
update_filter_and_gallery,
)
@ -472,6 +468,33 @@ def on_ui_settings():
),
)
shared.opts.add_option(
"dataset_editor_batch_size_vit",
shared.OptionInfo(
4,
"Inference batch size for ViT taggers",
section=section,
),
)
shared.opts.add_option(
"dataset_editor_batch_size_convnext",
shared.OptionInfo(
4,
"Inference batch size for ConvNeXt taggers",
section=section,
),
)
shared.opts.add_option(
"dataset_editor_batch_size_swinv2",
shared.OptionInfo(
4,
"Inference batch size for SwinTransformerV2 taggers",
section=section,
),
)
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_ui_tabs(on_ui_tabs)

View File

@ -37,7 +37,7 @@ class FilterByTagsUI(UIBase):
common_callback = lambda : \
update_gallery() + \
batch_edit_captions.get_common_tags(get_filters, self) + \
[move_or_delete_files.get_current_move_or_delete_target_num()] + \
[move_or_delete_files.update_current_move_or_delete_target_num()] + \
[batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
common_callback_output = \

View File

@ -14,6 +14,7 @@ class MoveOrDeleteFilesUI(UIBase):
def __init__(self):
self.target_data = 'Selected One'
self.current_target_txt = ''
self.update_func = None
def create_ui(self, cfg_file_move_delete):
gr.HTML(value='<b>Note: </b>Moved or deleted images will be unloaded.')
@ -27,11 +28,15 @@ class MoveOrDeleteFilesUI(UIBase):
gr.HTML(value='<b>Note: </b>DELETE cannot be undone. The files will be deleted completely.')
self.btn_move_or_delete_delete_files = gr.Button(value='DELETE File(s)', variant='primary')
def get_current_move_or_delete_target_num(self):
return self.current_target_txt
def update_current_move_or_delete_target_num(self):
if self.update_func:
return self.update_func(self.target_data)
else:
return self.current_target_txt
def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, filter_by_tags:FilterByTagsUI, batch_edit_captions:BatchEditCaptionsUI, filter_by_selection:FilterBySelectionUI, edit_caption_of_selected_image:EditCaptionOfSelectedImageUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
def _get_current_move_or_delete_target_num():
def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
def _get_current_move_or_delete_target_num(text: str):
self.target_data = text
if self.target_data == 'Selected One':
self.current_target_txt = f'Target dataset num: {1 if dataset_gallery.selected_index != -1 else 0}'
elif self.target_data == 'All Displayed Ones':
@ -41,25 +46,17 @@ class MoveOrDeleteFilesUI(UIBase):
self.current_target_txt = f'Target dataset num: 0'
return self.current_target_txt
self.update_func = _get_current_move_or_delete_target_num
update_args = {
'fn' : _get_current_move_or_delete_target_num,
'inputs' : None,
'fn': self.update_func,
'inputs': [self.rb_move_or_delete_target_data],
'outputs' : [self.ta_move_or_delete_target_dataset_num]
}
batch_edit_captions.btn_apply_edit_tags.click(lambda:None).then(**update_args)
batch_edit_captions.btn_apply_sr_tags.click(lambda:None).then(**update_args)
filter_by_selection.btn_apply_image_selection_filter.click(lambda:None).then(**update_args)
filter_by_tags.btn_clear_tag_filters.click(lambda:None).then(**update_args)
filter_by_tags.btn_clear_all_filters.click(lambda:None).then(**update_args)
edit_caption_of_selected_image.btn_apply_changes_selected_image.click(lambda:None).then(**update_args)
self.rb_move_or_delete_target_data.change(**update_args)
dataset_gallery.cbg_hidden_dataset_filter.change(lambda:None).then(**update_args)
dataset_gallery.nb_hidden_image_index.change(lambda:None).then(**update_args)
def move_files(
target_data: str,

68
scripts/utilities.py Normal file
View File

@ -0,0 +1,68 @@
from typing import Tuple
import math
from PIL import Image
if not hasattr(Image, 'Resampling'): # Pillow<9.0
Image.Resampling = Image
def resize(image: Image.Image, size: Tuple[int, int]):
return image.resize(size, resample=Image.Resampling.LANCZOS)
def get_rgb_image(image:Image.Image):
if image.mode not in ["RGB", "RGBA"]:
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
if image.mode == "RGBA":
white = Image.new("RGBA", image.size, (255, 255, 255, 255))
white.alpha_composite(image)
image = white.convert("RGB")
return image
def resize_and_fill(image: Image.Image, size: Tuple[int, int], repeat_edge = True, fill_rgb:tuple[int,int,int] = (255, 255, 255)):
width, height = size
scale_w, scale_h = width / image.width, height / image.height
resized_w, resized_h = width, height
if scale_w < scale_h:
resized_h = image.height * resized_w // image.width
elif scale_h < scale_w:
resized_w = image.width * resized_h // image.height
resized = resize(image, (resized_w, resized_h))
if resized_w == width and resized_h == height:
return resized
if repeat_edge:
fill_l = math.floor((width - resized_w) / 2)
fill_r = width - resized_w - fill_l
fill_t = math.floor((height - resized_h) / 2)
fill_b = height - resized_h - fill_t
result = Image.new("RGB", (width, height))
result.paste(resized, (fill_l, fill_t))
if fill_t > 0:
result.paste(resized.resize((width, fill_t), box=(0, 0, width, 0)), (0, 0))
if fill_b > 0:
result.paste(
resized.resize(
(width, fill_b), box=(0, resized.height, width, resized.height)
),
(0, resized.height + fill_t),
)
if fill_l > 0:
result.paste(resized.resize((fill_l, height), box=(0, 0, 0, height)), (0, 0))
if fill_r > 0:
result.paste(
resized.resize(
(fill_r, height), box=(resized.width, 0, resized.width, height)
),
(resized.width + fill_l, 0),
)
return result
else:
result = Image.new("RGB", size, fill_rgb)
result.paste(resized, box=((width - resized_w) // 2, (height - resized_h) // 2))
return result.convert("RGB")

View File

@ -1 +1 @@
0.3.3
0.3.4