Compare commits
17 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
d6c7dcef02 | |
|
|
b69c821b97 | |
|
|
620ebe1333 | |
|
|
74de5b7a07 | |
|
|
652286e46d | |
|
|
da5053bc2e | |
|
|
c716c83db0 | |
|
|
94cc5a2a6b | |
|
|
fe4bc8b141 | |
|
|
8feb29de40 | |
|
|
2f99cc54d8 | |
|
|
89df266d90 | |
|
|
6d929765eb | |
|
|
41f452f81a | |
|
|
42a42f0bee | |
|
|
9863241452 | |
|
|
1c5dddc336 |
|
|
@ -0,0 +1,33 @@
|
|||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Environment (please complete the following information):**
|
||||
- OS: [e.g. Windows, Linux]
|
||||
- Browser: [e.g. chrome, safari]
|
||||
- Version of SD WebUI: [e.g. v1.9.3, by AUTOMATIC1111]
|
||||
- Version of this app: [e.g. v0.0.7]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
|
|
@ -113,6 +113,12 @@ git clone https://github.com/toshiaki1729/stable-diffusion-webui-dataset-tag-edi
|
|||
"Settings" タブで、サムネイル画像を一時保存するフォルダを指定してください。
|
||||
"Directory to save temporary files" にパスを指定して "Force using temporary file…" をチェックしてください。
|
||||
|
||||
### 大量の画像や巨大な画像を開いたときに動作が遅くなる
|
||||
"Settings" タブで、"Force image gallery to use temporary files" にチェックを入れて、 "Maximum resolution of ..." に希望の解像度を入れてください。
|
||||
数百万もの画像を含むなど、あまりにも巨大なデータセットでは効果がないかもしれません。
|
||||
もしくは、[**スタンドアロン版**](https://github.com/toshiaki1729/dataset-tag-editor-standalone)を試してください。
|
||||

|
||||
|
||||
|
||||
## 表示内容
|
||||
|
||||
|
|
|
|||
|
|
@ -115,6 +115,12 @@ Please note that all batch editing will be applyed **only to displayed images (=
|
|||
Set folder to store temporaly image in the "Settings" tab.
|
||||
Input path in "Directory to save temporary files" and check "Force using temporary file…"
|
||||
|
||||
### So laggy when opening many images or extremely large image
|
||||
Check "Force image gallery to use temporary files" and input number in "Maximum resolution of ..." in the "Settings" tab.
|
||||
It may not work with dataset with millions of images.
|
||||
If it doesn't work, please consider using [**stand alone version**](https://github.com/toshiaki1729/dataset-tag-editor-standalone).
|
||||

|
||||
|
||||
|
||||
## Description of Display
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 80 KiB |
|
|
@ -3,6 +3,8 @@ import re, sys
|
|||
from typing import List, Set, Optional
|
||||
from enum import Enum
|
||||
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
|
@ -11,7 +13,7 @@ from modules.textual_inversion.dataset import re_numbers_at_start
|
|||
|
||||
from scripts.singleton import Singleton
|
||||
|
||||
from scripts import logger
|
||||
from scripts import logger, utilities
|
||||
from scripts.paths import paths
|
||||
|
||||
from . import (
|
||||
|
|
@ -21,14 +23,17 @@ from . import (
|
|||
taggers_builtin
|
||||
)
|
||||
from .custom_scripts import CustomScripts
|
||||
from .interrogator_names import BLIP2_CAPTIONING_NAMES, WD_TAGGERS, WD_TAGGERS_TIMM
|
||||
from scripts.tokenizer import clip_tokenizer
|
||||
from scripts.tagger import Tagger
|
||||
|
||||
re_tags = re.compile(r"^([\s\S]+?)( \[\d+\])?$")
|
||||
re_newlines = re.compile(r"[\r\n]+")
|
||||
|
||||
def convert_rgb(data:Image.Image):
|
||||
return data.convert("RGB")
|
||||
def get_square_rgb(data:Image.Image):
|
||||
data_rgb = utilities.get_rgb_image(data)
|
||||
size = max(data.size)
|
||||
return utilities.resize_and_fill(data_rgb, (size, size))
|
||||
|
||||
class DatasetTagEditor(Singleton):
|
||||
class SortBy(Enum):
|
||||
|
|
@ -68,39 +73,26 @@ class DatasetTagEditor(Singleton):
|
|||
custom_taggers:list[Tagger] = custom_tagger_scripts.load_derived_classes(Tagger)
|
||||
logger.write(f"Custom taggers loaded: {[tagger().name() for tagger in custom_taggers]}")
|
||||
|
||||
self.BLIP2_CAPTIONING_NAMES = [
|
||||
"blip2-opt-2.7b",
|
||||
"blip2-opt-2.7b-coco",
|
||||
"blip2-opt-6.7b",
|
||||
"blip2-opt-6.7b-coco",
|
||||
"blip2-flan-t5-xl",
|
||||
"blip2-flan-t5-xl-coco",
|
||||
"blip2-flan-t5-xxl",
|
||||
]
|
||||
|
||||
self.WD_TAGGERS = {
|
||||
"wd-v1-4-vit-tagger" : 0.35,
|
||||
"wd-v1-4-convnext-tagger" : 0.35,
|
||||
"wd-v1-4-vit-tagger-v2" : 0.3537,
|
||||
"wd-v1-4-convnext-tagger-v2" : 0.3685,
|
||||
"wd-v1-4-convnextv2-tagger-v2" : 0.371,
|
||||
"wd-v1-4-swinv2-tagger-v2" : 0.3771,
|
||||
"wd-v1-4-moat-tagger-v2" : 0.3771,
|
||||
"wd-vit-tagger-v3" : 0.2614,
|
||||
"wd-convnext-tagger-v3" : 0.2682,
|
||||
"wd-swinv2-tagger-v3" : 0.2653,
|
||||
}
|
||||
# {tagger name : default tagger threshold}
|
||||
# v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf
|
||||
|
||||
def read_wd_batchsize(name:str):
|
||||
if "vit" in name:
|
||||
return shared.opts.dataset_editor_batch_size_vit
|
||||
elif "convnext" in name:
|
||||
return shared.opts.dataset_editor_batch_size_convnext
|
||||
elif "swinv2" in name:
|
||||
return shared.opts.dataset_editor_batch_size_swinv2
|
||||
|
||||
self.INTERROGATORS = (
|
||||
[taggers_builtin.BLIP()]
|
||||
+ [taggers_builtin.BLIP2(name) for name in self.BLIP2_CAPTIONING_NAMES]
|
||||
+ [taggers_builtin.BLIP2(name) for name in BLIP2_CAPTIONING_NAMES]
|
||||
+ [taggers_builtin.GITLarge()]
|
||||
+ [taggers_builtin.DeepDanbooru()]
|
||||
+ [
|
||||
taggers_builtin.WaifuDiffusion(name, threshold)
|
||||
for name, threshold in self.WD_TAGGERS.items()
|
||||
for name, threshold in WD_TAGGERS.items()
|
||||
]
|
||||
+ [
|
||||
taggers_builtin.WaifuDiffusionTimm(name, threshold, int(read_wd_batchsize(name)))
|
||||
for name, threshold in WD_TAGGERS_TIMM.items()
|
||||
]
|
||||
+ [taggers_builtin.Z3D_E621()]
|
||||
+ [cls_tagger() for cls_tagger in custom_taggers]
|
||||
|
|
@ -109,7 +101,7 @@ class DatasetTagEditor(Singleton):
|
|||
|
||||
def interrogate_image(self, path: str, interrogator_name: str, threshold_booru, threshold_wd, threshold_z3d):
|
||||
try:
|
||||
img = Image.open(path).convert("RGB")
|
||||
img = get_square_rgb(Image.open(path))
|
||||
except:
|
||||
return ""
|
||||
else:
|
||||
|
|
@ -118,7 +110,7 @@ class DatasetTagEditor(Singleton):
|
|||
if isinstance(it, taggers_builtin.DeepDanbooru):
|
||||
with it as tg:
|
||||
res = tg.predict(img, threshold_booru)
|
||||
elif isinstance(it, taggers_builtin.WaifuDiffusion):
|
||||
elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm):
|
||||
with it as tg:
|
||||
res = tg.predict(img, threshold_wd)
|
||||
elif isinstance(it, taggers_builtin.Z3D_E621):
|
||||
|
|
@ -704,19 +696,27 @@ class DatasetTagEditor(Singleton):
|
|||
continue
|
||||
try:
|
||||
img = Image.open(img_path)
|
||||
if (max_res > 0):
|
||||
img_res = int(max_res), int(max_res)
|
||||
img.thumbnail(img_res)
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
abs_path = str(img_path.absolute())
|
||||
if not use_temp_dir and max_res <= 0:
|
||||
img.already_saved_as = abs_path
|
||||
images[abs_path] = img
|
||||
|
||||
imgpaths.append(abs_path)
|
||||
imgpaths.append(abs_path)
|
||||
return imgpaths, images
|
||||
|
||||
def load_thumbnails(images_raw: dict[str, Image.Image]):
|
||||
images = {}
|
||||
if max_res > 0:
|
||||
for img_path, img in images_raw.items():
|
||||
img_res = int(max_res), int(max_res)
|
||||
images[img_path] = img.copy()
|
||||
images[img_path].thumbnail(img_res)
|
||||
else:
|
||||
for img_path, img in images_raw.items():
|
||||
if not use_temp_dir:
|
||||
img.already_saved_as = img_path
|
||||
images[img_path] = img
|
||||
return images
|
||||
|
||||
def load_captions(imgpaths: list[str]):
|
||||
taglists = []
|
||||
|
|
@ -752,7 +752,7 @@ class DatasetTagEditor(Singleton):
|
|||
if it.name() in interrogator_names:
|
||||
if isinstance(it, taggers_builtin.DeepDanbooru):
|
||||
tagger_thresholds.append((it, threshold_booru))
|
||||
elif isinstance(it, taggers_builtin.WaifuDiffusion):
|
||||
elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm):
|
||||
tagger_thresholds.append((it, threshold_waifu))
|
||||
elif isinstance(it, taggers_builtin.Z3D_E621):
|
||||
tagger_thresholds.append((it, threshold_z3d))
|
||||
|
|
@ -760,15 +760,23 @@ class DatasetTagEditor(Singleton):
|
|||
tagger_thresholds.append((it, None))
|
||||
|
||||
if kohya_json_path:
|
||||
imgpaths, self.images, taglists = kohya_metadata.read(
|
||||
imgpaths, images_raw, taglists = kohya_metadata.read(
|
||||
img_dir, kohya_json_path, use_temp_dir
|
||||
)
|
||||
else:
|
||||
imgpaths, self.images = load_images(filepaths)
|
||||
imgpaths, images_raw = load_images(filepaths)
|
||||
taglists = load_captions(imgpaths)
|
||||
|
||||
self.images = load_thumbnails(images_raw)
|
||||
|
||||
interrogate_tags = {img_path : [] for img_path in imgpaths}
|
||||
if interrogate_method != self.InterrogateMethod.NONE:
|
||||
|
||||
img_to_interrogate = [
|
||||
img_path for i, img_path in enumerate(imgpaths)
|
||||
if (not taglists[i] or interrogate_method != self.InterrogateMethod.PREFILL)
|
||||
]
|
||||
|
||||
if interrogate_method != self.InterrogateMethod.NONE and img_to_interrogate:
|
||||
logger.write("Preprocess images...")
|
||||
max_workers = shared.opts.dataset_editor_num_cpu_workers
|
||||
if max_workers < 0:
|
||||
|
|
@ -777,11 +785,11 @@ class DatasetTagEditor(Singleton):
|
|||
|
||||
def gen_data(paths:list[str], images:dict[str, Image.Image]):
|
||||
for img_path in paths:
|
||||
yield images.get(img_path)
|
||||
yield images[img_path]
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
result = list(executor.map(convert_rgb, gen_data(imgpaths, self.images)))
|
||||
result = list(executor.map(get_square_rgb, gen_data(img_to_interrogate, images_raw)))
|
||||
logger.write("Preprocess completed")
|
||||
|
||||
for tg, th in tqdm(tagger_thresholds):
|
||||
|
|
@ -798,10 +806,10 @@ class DatasetTagEditor(Singleton):
|
|||
continue
|
||||
try:
|
||||
if use_pipe:
|
||||
for img_path, tags in tqdm(zip(imgpaths, tg.predict_pipe(result, th)), desc=tg.name(), total=len(imgpaths)):
|
||||
for img_path, tags in tqdm(zip(img_to_interrogate, tg.predict_pipe(result, th)), desc=tg.name(), total=len(img_to_interrogate)):
|
||||
interrogate_tags[img_path] += tags
|
||||
else:
|
||||
for img_path, data in tqdm(zip(imgpaths, result), desc=tg.name(), total=len(imgpaths)):
|
||||
for img_path, data in tqdm(zip(img_to_interrogate, result), desc=tg.name(), total=len(img_to_interrogate)):
|
||||
interrogate_tags[img_path] += tg.predict(data, th)
|
||||
except Exception as e:
|
||||
tb = sys.exc_info()[2]
|
||||
|
|
@ -814,7 +822,7 @@ class DatasetTagEditor(Singleton):
|
|||
tags = interrogate_tags[img_path]
|
||||
elif interrogate_method == self.InterrogateMethod.PREPEND:
|
||||
tags = interrogate_tags[img_path] + tags
|
||||
else:
|
||||
elif interrogate_method != self.InterrogateMethod.PREFILL:
|
||||
tags = tags + interrogate_tags[img_path]
|
||||
|
||||
self.set_tags_by_image_path(img_path, tags)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
|
||||
BLIP2_CAPTIONING_NAMES = [
|
||||
"blip2-opt-2.7b",
|
||||
"blip2-opt-2.7b-coco",
|
||||
"blip2-opt-6.7b",
|
||||
"blip2-opt-6.7b-coco",
|
||||
"blip2-flan-t5-xl",
|
||||
"blip2-flan-t5-xl-coco",
|
||||
"blip2-flan-t5-xxl",
|
||||
]
|
||||
|
||||
|
||||
# {tagger name : default tagger threshold}
|
||||
# v1: idk if it's okay v2, v3: P=R thresholds on each repo https://huggingface.co/SmilingWolf
|
||||
WD_TAGGERS = {
|
||||
"wd-v1-4-vit-tagger" : 0.35,
|
||||
"wd-v1-4-convnext-tagger" : 0.35,
|
||||
"wd-v1-4-vit-tagger-v2" : 0.3537,
|
||||
"wd-v1-4-convnext-tagger-v2" : 0.3685,
|
||||
"wd-v1-4-convnextv2-tagger-v2" : 0.371,
|
||||
"wd-v1-4-moat-tagger-v2" : 0.3771
|
||||
}
|
||||
WD_TAGGERS_TIMM = {
|
||||
"wd-v1-4-swinv2-tagger-v2" : 0.3771,
|
||||
"wd-vit-tagger-v3" : 0.2614,
|
||||
"wd-convnext-tagger-v3" : 0.2682,
|
||||
"wd-swinv2-tagger-v3" : 0.2653,
|
||||
}
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
from .blip2_captioning import BLIP2Captioning
|
||||
from .git_large_captioning import GITLargeCaptioning
|
||||
from .waifu_diffusion_tagger import WaifuDiffusionTagger
|
||||
from .waifu_diffusion_tagger_timm import WaifuDiffusionTaggerTimm
|
||||
|
||||
__all__ = [
|
||||
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger'
|
||||
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger', 'WaifuDiffusionTaggerTimm'
|
||||
]
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
from PIL import Image
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import torchvision.transforms as tf
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import shared, devices
|
||||
import launch
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, images:list[Image.Image], transforms:tf.Compose=None):
|
||||
self.images = images
|
||||
self.transforms = transforms
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, i):
|
||||
img = self.images[i]
|
||||
if self.transforms is not None:
|
||||
img = self.transforms(img)
|
||||
return img
|
||||
|
||||
|
||||
|
||||
class WaifuDiffusionTaggerTimm:
|
||||
# some codes are brought from https://github.com/neggles/wdv3-timm and modified
|
||||
|
||||
def __init__(self, model_repo, label_filename="selected_tags.csv"):
|
||||
self.LABEL_FILENAME = label_filename
|
||||
self.MODEL_REPO = model_repo
|
||||
self.model = None
|
||||
self.transform = None
|
||||
self.labels = []
|
||||
|
||||
def load(self):
|
||||
import huggingface_hub
|
||||
|
||||
if not launch.is_installed("timm"):
|
||||
launch.run_pip(
|
||||
"install -U timm",
|
||||
"requirements for dataset-tag-editor [timm]",
|
||||
)
|
||||
import timm
|
||||
from timm.data import create_transform, resolve_data_config
|
||||
|
||||
if not self.model:
|
||||
self.model: torch.nn.Module = timm.create_model(
|
||||
"hf-hub:" + self.MODEL_REPO
|
||||
).eval()
|
||||
state_dict = timm.models.load_state_dict_from_hf(self.MODEL_REPO)
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.model.to(devices.device)
|
||||
self.transform = create_transform(
|
||||
**resolve_data_config(self.model.pretrained_cfg, model=self.model)
|
||||
)
|
||||
|
||||
path_label = huggingface_hub.hf_hub_download(
|
||||
self.MODEL_REPO, self.LABEL_FILENAME
|
||||
)
|
||||
import pandas as pd
|
||||
|
||||
self.labels = pd.read_csv(path_label)["name"].tolist()
|
||||
|
||||
def unload(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.model = None
|
||||
devices.torch_gc()
|
||||
|
||||
def apply(self, image: Image.Image):
|
||||
if not self.model:
|
||||
return []
|
||||
|
||||
image_t: torch.Tensor = self.transform(image).unsqueeze(0)
|
||||
image_t = image_t[:, [2, 1, 0]]
|
||||
image_t = image_t.to(devices.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
features = self.model.forward(image_t)
|
||||
probs = F.sigmoid(features).detach().cpu().numpy()
|
||||
|
||||
labels: list[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def apply_multi(self, images: list[Image.Image], batch_size: int):
|
||||
if not self.model:
|
||||
return []
|
||||
|
||||
dataset = ImageDataset(images, self.transform)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
|
||||
with torch.inference_mode():
|
||||
for batch in tqdm(dataloader):
|
||||
batch = batch[:, [2, 1, 0]].to(devices.device)
|
||||
features = self.model.forward(batch)
|
||||
probs = F.sigmoid(features).detach().cpu().numpy()
|
||||
labels: list[Tuple[str, float]] = [list(zip(self.labels, probs[i].astype(float))) for i in range(probs.shape[0])]
|
||||
yield labels
|
||||
|
|
@ -8,7 +8,7 @@ from modules import devices, shared
|
|||
from modules import deepbooru as db
|
||||
|
||||
from scripts.tagger import Tagger, get_replaced_tag
|
||||
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger
|
||||
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger, WaifuDiffusionTaggerTimm
|
||||
|
||||
|
||||
class BLIP(Tagger):
|
||||
|
|
@ -142,6 +142,28 @@ class WaifuDiffusion(Tagger):
|
|||
return self.repo_name
|
||||
|
||||
|
||||
class WaifuDiffusionTimm(WaifuDiffusion):
|
||||
def __init__(self, repo_name, threshold, batch_size=4):
|
||||
super().__init__(repo_name, threshold)
|
||||
self.tagger_inst = WaifuDiffusionTaggerTimm("SmilingWolf/" + repo_name)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None):
|
||||
for labels_list in self.tagger_inst.apply_multi(data, batch_size=self.batch_size):
|
||||
for labels in labels_list:
|
||||
if not shared.opts.dataset_editor_use_rating:
|
||||
labels = labels[4:]
|
||||
|
||||
if threshold is not None:
|
||||
if threshold < 0:
|
||||
threshold = self.threshold
|
||||
tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold]
|
||||
else:
|
||||
tags = [get_replaced_tag(tag) for tag, _ in labels]
|
||||
|
||||
yield tags
|
||||
|
||||
|
||||
class Z3D_E621(Tagger):
|
||||
def __init__(self):
|
||||
self.tagger_inst = WaifuDiffusionTagger("toynya/Z3D-E621-Convnext", label_filename="tags-selected.csv")
|
||||
|
|
|
|||
|
|
@ -413,10 +413,6 @@ def on_ui_tabs():
|
|||
ui.move_or_delete_files.set_callbacks(
|
||||
o_update_filter_and_gallery,
|
||||
ui.dataset_gallery,
|
||||
ui.filter_by_tags,
|
||||
ui.batch_edit_captions,
|
||||
ui.filter_by_selection,
|
||||
ui.edit_caption_of_selected_image,
|
||||
get_filters,
|
||||
update_filter_and_gallery,
|
||||
)
|
||||
|
|
@ -472,6 +468,33 @@ def on_ui_settings():
|
|||
),
|
||||
)
|
||||
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_batch_size_vit",
|
||||
shared.OptionInfo(
|
||||
4,
|
||||
"Inference batch size for ViT taggers",
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_batch_size_convnext",
|
||||
shared.OptionInfo(
|
||||
4,
|
||||
"Inference batch size for ConvNeXt taggers",
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_batch_size_swinv2",
|
||||
shared.OptionInfo(
|
||||
4,
|
||||
"Inference batch size for SwinTransformerV2 taggers",
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
script_callbacks.on_ui_tabs(on_ui_tabs)
|
||||
|
|
|
|||
|
|
@ -199,6 +199,7 @@ class LoadDatasetUI(UIBase):
|
|||
self.sl_custom_threshold_booru,
|
||||
self.cb_use_custom_threshold_waifu,
|
||||
self.sl_custom_threshold_waifu,
|
||||
self.sl_custom_threshold_z3d,
|
||||
toprow.cb_save_kohya_metadata,
|
||||
toprow.tb_metadata_output,
|
||||
],
|
||||
|
|
|
|||
|
|
@ -33,14 +33,14 @@ class ToprowUI(UIBase):
|
|||
|
||||
def set_callbacks(self, load_dataset:LoadDatasetUI):
|
||||
|
||||
def save_all_changes(backup: bool, save_kohya_metadata:bool, metadata_output:str, metadata_input:str, metadata_overwrite:bool, metadata_as_caption:bool, metadata_use_fullpath:bool):
|
||||
def save_all_changes(backup: bool, save_kohya_metadata:bool, metadata_output:str, metadata_input:str, metadata_overwrite:bool, metadata_as_caption:bool, metadata_use_fullpath:bool, caption_file_ext:str):
|
||||
if not metadata_input:
|
||||
metadata_input = None
|
||||
dte_instance.save_dataset(backup, load_dataset.caption_file_ext, save_kohya_metadata, metadata_output, metadata_input, metadata_overwrite, metadata_as_caption, metadata_use_fullpath)
|
||||
dte_instance.save_dataset(backup, caption_file_ext, save_kohya_metadata, metadata_output, metadata_input, metadata_overwrite, metadata_as_caption, metadata_use_fullpath)
|
||||
|
||||
self.btn_save_all_changes.click(
|
||||
fn=save_all_changes,
|
||||
inputs=[self.cb_backup, self.cb_save_kohya_metadata, self.tb_metadata_output, self.tb_metadata_input, self.cb_metadata_overwrite, self.cb_metadata_as_caption, self.cb_metadata_use_fullpath]
|
||||
inputs=[self.cb_backup, self.cb_save_kohya_metadata, self.tb_metadata_output, self.tb_metadata_input, self.cb_metadata_overwrite, self.cb_metadata_as_caption, self.cb_metadata_use_fullpath, load_dataset.tb_caption_file_ext]
|
||||
)
|
||||
|
||||
self.cb_save_kohya_metadata.change(
|
||||
|
|
|
|||
|
|
@ -133,16 +133,17 @@ class EditCaptionOfSelectedImageUI(UIBase):
|
|||
outputs=[self.tb_edit_caption]
|
||||
)
|
||||
|
||||
def interrogate_selected_image(interrogator_name: str, use_threshold_booru: bool, threshold_booru: float, use_threshold_waifu: bool, threshold_waifu: float):
|
||||
def interrogate_selected_image(interrogator_name: str, use_threshold_booru: bool, threshold_booru: float, use_threshold_waifu: bool, threshold_waifu: float, threshold_z3d: float):
|
||||
|
||||
if not interrogator_name:
|
||||
return ''
|
||||
threshold_booru = threshold_booru if use_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold
|
||||
threshold_waifu = threshold_waifu if use_threshold_waifu else -1
|
||||
return dte_instance.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_waifu)
|
||||
return dte_instance.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_waifu, threshold_z3d)
|
||||
|
||||
self.btn_interrogate_si.click(
|
||||
fn=interrogate_selected_image,
|
||||
inputs=[self.dd_intterogator_names_si, load_dataset.cb_use_custom_threshold_booru, load_dataset.sl_custom_threshold_booru, load_dataset.cb_use_custom_threshold_waifu, load_dataset.sl_custom_threshold_waifu],
|
||||
inputs=[self.dd_intterogator_names_si, load_dataset.cb_use_custom_threshold_booru, load_dataset.sl_custom_threshold_booru, load_dataset.cb_use_custom_threshold_waifu, load_dataset.sl_custom_threshold_waifu, load_dataset.sl_custom_threshold_z3d],
|
||||
outputs=[self.tb_interrogate]
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class FilterByTagsUI(UIBase):
|
|||
common_callback = lambda : \
|
||||
update_gallery() + \
|
||||
batch_edit_captions.get_common_tags(get_filters, self) + \
|
||||
[move_or_delete_files.get_current_move_or_delete_target_num()] + \
|
||||
[move_or_delete_files.update_current_move_or_delete_target_num()] + \
|
||||
[batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
|
||||
|
||||
common_callback_output = \
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ class MoveOrDeleteFilesUI(UIBase):
|
|||
def __init__(self):
|
||||
self.target_data = 'Selected One'
|
||||
self.current_target_txt = ''
|
||||
self.update_func = None
|
||||
|
||||
def create_ui(self, cfg_file_move_delete):
|
||||
gr.HTML(value='<b>Note: </b>Moved or deleted images will be unloaded.')
|
||||
|
|
@ -27,11 +28,15 @@ class MoveOrDeleteFilesUI(UIBase):
|
|||
gr.HTML(value='<b>Note: </b>DELETE cannot be undone. The files will be deleted completely.')
|
||||
self.btn_move_or_delete_delete_files = gr.Button(value='DELETE File(s)', variant='primary')
|
||||
|
||||
def get_current_move_or_delete_target_num(self):
|
||||
return self.current_target_txt
|
||||
def update_current_move_or_delete_target_num(self):
|
||||
if self.update_func:
|
||||
return self.update_func(self.target_data)
|
||||
else:
|
||||
return self.current_target_txt
|
||||
|
||||
def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, filter_by_tags:FilterByTagsUI, batch_edit_captions:BatchEditCaptionsUI, filter_by_selection:FilterBySelectionUI, edit_caption_of_selected_image:EditCaptionOfSelectedImageUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
|
||||
def _get_current_move_or_delete_target_num():
|
||||
def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
|
||||
def _get_current_move_or_delete_target_num(text: str):
|
||||
self.target_data = text
|
||||
if self.target_data == 'Selected One':
|
||||
self.current_target_txt = f'Target dataset num: {1 if dataset_gallery.selected_index != -1 else 0}'
|
||||
elif self.target_data == 'All Displayed Ones':
|
||||
|
|
@ -41,25 +46,17 @@ class MoveOrDeleteFilesUI(UIBase):
|
|||
self.current_target_txt = f'Target dataset num: 0'
|
||||
return self.current_target_txt
|
||||
|
||||
self.update_func = _get_current_move_or_delete_target_num
|
||||
|
||||
update_args = {
|
||||
'fn' : _get_current_move_or_delete_target_num,
|
||||
'inputs' : None,
|
||||
'fn': self.update_func,
|
||||
'inputs': [self.rb_move_or_delete_target_data],
|
||||
'outputs' : [self.ta_move_or_delete_target_dataset_num]
|
||||
}
|
||||
|
||||
batch_edit_captions.btn_apply_edit_tags.click(lambda:None).then(**update_args)
|
||||
|
||||
batch_edit_captions.btn_apply_sr_tags.click(lambda:None).then(**update_args)
|
||||
|
||||
filter_by_selection.btn_apply_image_selection_filter.click(lambda:None).then(**update_args)
|
||||
|
||||
filter_by_tags.btn_clear_tag_filters.click(lambda:None).then(**update_args)
|
||||
|
||||
filter_by_tags.btn_clear_all_filters.click(lambda:None).then(**update_args)
|
||||
|
||||
edit_caption_of_selected_image.btn_apply_changes_selected_image.click(lambda:None).then(**update_args)
|
||||
|
||||
self.rb_move_or_delete_target_data.change(**update_args)
|
||||
dataset_gallery.cbg_hidden_dataset_filter.change(lambda:None).then(**update_args)
|
||||
dataset_gallery.nb_hidden_image_index.change(lambda:None).then(**update_args)
|
||||
|
||||
def move_files(
|
||||
target_data: str,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
from typing import Tuple
|
||||
import math
|
||||
|
||||
from PIL import Image
|
||||
|
||||
if not hasattr(Image, 'Resampling'): # Pillow<9.0
|
||||
Image.Resampling = Image
|
||||
|
||||
|
||||
def resize(image: Image.Image, size: Tuple[int, int]):
|
||||
return image.resize(size, resample=Image.Resampling.LANCZOS)
|
||||
|
||||
|
||||
def get_rgb_image(image:Image.Image):
|
||||
if image.mode not in ["RGB", "RGBA"]:
|
||||
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
|
||||
if image.mode == "RGBA":
|
||||
white = Image.new("RGBA", image.size, (255, 255, 255, 255))
|
||||
white.alpha_composite(image)
|
||||
image = white.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def resize_and_fill(image: Image.Image, size: Tuple[int, int], repeat_edge = True, fill_rgb:tuple[int,int,int] = (255, 255, 255)):
|
||||
width, height = size
|
||||
scale_w, scale_h = width / image.width, height / image.height
|
||||
resized_w, resized_h = width, height
|
||||
if scale_w < scale_h:
|
||||
resized_h = image.height * resized_w // image.width
|
||||
elif scale_h < scale_w:
|
||||
resized_w = image.width * resized_h // image.height
|
||||
|
||||
resized = resize(image, (resized_w, resized_h))
|
||||
if resized_w == width and resized_h == height:
|
||||
return resized
|
||||
|
||||
if repeat_edge:
|
||||
fill_l = math.floor((width - resized_w) / 2)
|
||||
fill_r = width - resized_w - fill_l
|
||||
fill_t = math.floor((height - resized_h) / 2)
|
||||
fill_b = height - resized_h - fill_t
|
||||
result = Image.new("RGB", (width, height))
|
||||
result.paste(resized, (fill_l, fill_t))
|
||||
if fill_t > 0:
|
||||
result.paste(resized.resize((width, fill_t), box=(0, 0, width, 0)), (0, 0))
|
||||
if fill_b > 0:
|
||||
result.paste(
|
||||
resized.resize(
|
||||
(width, fill_b), box=(0, resized.height, width, resized.height)
|
||||
),
|
||||
(0, resized.height + fill_t),
|
||||
)
|
||||
if fill_l > 0:
|
||||
result.paste(resized.resize((fill_l, height), box=(0, 0, 0, height)), (0, 0))
|
||||
if fill_r > 0:
|
||||
result.paste(
|
||||
resized.resize(
|
||||
(fill_r, height), box=(resized.width, 0, resized.width, height)
|
||||
),
|
||||
(resized.width + fill_l, 0),
|
||||
)
|
||||
return result
|
||||
else:
|
||||
result = Image.new("RGB", size, fill_rgb)
|
||||
result.paste(resized, box=((width - resized_w) // 2, (height - resized_h) // 2))
|
||||
return result.convert("RGB")
|
||||
|
||||
|
||||
|
|
@ -15,6 +15,19 @@ from scripts.tagger import Tagger
|
|||
# I'm not sure if this is really working
|
||||
BATCH_SIZE = 3
|
||||
|
||||
# tags used in Animagine-XL
|
||||
SCORE_N = {
|
||||
'very aesthetic':0.71,
|
||||
'aesthetic':0.45,
|
||||
'displeasing':0.27,
|
||||
'very displeasing':-float('inf'),
|
||||
}
|
||||
|
||||
def get_aesthetic_tag(score:float):
|
||||
for k, v in SCORE_N.items():
|
||||
if score > v:
|
||||
return k
|
||||
|
||||
class AestheticShadowV2(Tagger):
|
||||
def load(self):
|
||||
if devices.device.index is None:
|
||||
|
|
@ -40,8 +53,7 @@ class AestheticShadowV2(Tagger):
|
|||
for d in data:
|
||||
final[d["label"]] = d["score"]
|
||||
hq = final['hq']
|
||||
lq = final['lq']
|
||||
return [f"score_{math.floor((hq + (1 - lq))/2 * 10)}"]
|
||||
return [get_aesthetic_tag(hq)]
|
||||
|
||||
def predict(self, image: Image.Image, threshold=None):
|
||||
data = self.pipe_aesthetic(image)
|
||||
|
|
|
|||
|
|
@ -36,9 +36,10 @@ class CafeAIAesthetic(Tagger):
|
|||
final = {}
|
||||
for d in data:
|
||||
final[d["label"]] = d["score"]
|
||||
nae = final['not_aesthetic']
|
||||
ae = final['aesthetic']
|
||||
return [f"score_{math.floor((ae + (1 - nae))/2 * 10)}"]
|
||||
|
||||
# edit here to change tag
|
||||
return [f"[CAFE]score_{math.floor(ae*10)}"]
|
||||
|
||||
def predict(self, image: Image.Image, threshold=None):
|
||||
data = self.pipe_aesthetic(image, top_k=2)
|
||||
|
|
|
|||
|
|
@ -69,7 +69,8 @@ class ImprovedAestheticPredictor(Tagger):
|
|||
def predict(self, image: Image.Image, threshold=None):
|
||||
image_embeds = image_embeddings(image, self.clip_model, self.clip_processor)
|
||||
prediction:torch.Tensor = self.model(torch.from_numpy(image_embeds).float().to(devices.device))
|
||||
return [f"score_{math.floor(prediction.item())}"]
|
||||
# edit here to change tag
|
||||
return [f"[IAP]score_{math.floor(prediction.item())}"]
|
||||
|
||||
def name(self):
|
||||
return "Improved Aesthetic Predictor"
|
||||
|
|
@ -67,7 +67,8 @@ class WaifuAesthetic(Tagger):
|
|||
def predict(self, image: Image.Image, threshold=None):
|
||||
image_embeds = image_embeddings(image, self.clip_model, self.clip_processor)
|
||||
prediction:torch.Tensor = self.model(torch.from_numpy(image_embeds).float().to(devices.device))
|
||||
return [f"score_{math.floor(prediction.item()*10)}"]
|
||||
# edit here to change tag
|
||||
return [f"[WD]score_{math.floor(prediction.item()*10)}"]
|
||||
|
||||
def name(self):
|
||||
return "wd aesthetic classifier"
|
||||
|
|
@ -1 +1 @@
|
|||
0.3.1
|
||||
0.3.4
|
||||
|
|
|
|||
Loading…
Reference in New Issue