Compare commits

...

17 Commits
v0.3.1 ... main

Author SHA1 Message Date
toshiaki1729 d6c7dcef02 fix wrong func name 2024-06-27 20:03:26 +09:00
toshiaki1729 b69c821b97 load large images 2024-06-25 17:44:05 +09:00
toshiaki1729 620ebe1333 Make image shape square before interrogating
to fix #101
2024-06-25 17:24:58 +09:00
toshiaki1729 74de5b7a07 fix: "move or delete files" shows wrong number of target images 2024-06-10 11:29:54 +09:00
toshiaki1729 652286e46d Merge branch 'main' of https://github.com/toshiaki1729/stable-diffusion-webui-dataset-tag-editor 2024-05-26 13:17:19 +09:00
toshiaki1729 da5053bc2e Update waifu_diffusion_tagger_timm.py
fix #100
2024-05-26 13:17:17 +09:00
toshiaki1729 c716c83db0
Update version.txt 2024-05-21 00:48:45 +09:00
toshiaki1729 94cc5a2a6b
Add batch inference feature to WD Tagger (#99)
as same as standalone version
2024-05-21 00:47:51 +09:00
toshiaki1729 fe4bc8b141
Update version.txt 2024-05-20 19:25:58 +09:00
toshiaki1729 8feb29de40
fix #97 (#98)
and a bug: caption file extension not work on save
2024-05-20 19:07:40 +09:00
Zak 2f99cc54d8
Add Z3D threshold to image interrogator function (#96)
* Update tab_edit_caption_of_selected_image.py

* Update tab_edit_caption_of_selected_image.py

fix interrogate_selected_image() missing threshold_z3d

* Update tab_edit_caption_of_selected_image.py

remove print
2024-05-18 09:53:59 +09:00
toshiaki1729 89df266d90
Update version.txt 2024-05-17 14:10:49 +09:00
Ririan 6d929765eb
Add missing Z3D threshold (#95)
Currently attempting to use the Z3D interrogator doesn't work because the threshold value isn't getting passed into the load_dataset function.  This fixes that.

Co-authored-by: The Divine Heir <the.divine.heir.kuro@gmail.com>
2024-05-17 14:06:48 +09:00
toshiaki1729 41f452f81a Add troubleshooting 2024-05-11 19:39:51 +09:00
toshiaki1729 42a42f0bee
Update issue templates 2024-05-11 16:23:54 +09:00
toshiaki1729 9863241452
Update issue templates 2024-05-11 16:22:15 +09:00
toshiaki1729 1c5dddc336 change aesthetic score tags 2024-05-10 03:32:45 +09:00
22 changed files with 419 additions and 86 deletions

33
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@ -0,0 +1,33 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: bug
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Environment (please complete the following information):**
- OS: [e.g. Windows, Linux]
- Browser: [e.g. chrome, safari]
- Version of SD WebUI: [e.g. v1.9.3, by AUTOMATIC1111]
- Version of this app: [e.g. v0.0.7]
**Additional context**
Add any other context about the problem here.

View File

@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: enhancement
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

View File

@ -113,6 +113,12 @@ git clone https://github.com/toshiaki1729/stable-diffusion-webui-dataset-tag-edi
"Settings" タブで、サムネイル画像を一時保存するフォルダを指定してください。
"Directory to save temporary files" にパスを指定して "Force using temporary file…" をチェックしてください。
### 大量の画像や巨大な画像を開いたときに動作が遅くなる
"Settings" タブで、"Force image gallery to use temporary files" にチェックを入れて、 "Maximum resolution of ..." に希望の解像度を入れてください。
数百万もの画像を含むなど、あまりにも巨大なデータセットでは効果がないかもしれません。
もしくは、[**スタンドアロン版**](https://github.com/toshiaki1729/dataset-tag-editor-standalone)を試してください。
![](pic/ss12.png)
## 表示内容

View File

@ -115,6 +115,12 @@ Please note that all batch editing will be applyed **only to displayed images (=
Set folder to store temporaly image in the "Settings" tab.
Input path in "Directory to save temporary files" and check "Force using temporary file…"
### So laggy when opening many images or extremely large image
Check "Force image gallery to use temporary files" and input number in "Maximum resolution of ..." in the "Settings" tab.
It may not work with dataset with millions of images.
If it doesn't work, please consider using [**stand alone version**](https://github.com/toshiaki1729/dataset-tag-editor-standalone).
![](pic/ss12.png)
## Description of Display

BIN
pic/ss12.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

View File

@ -3,6 +3,8 @@ import re, sys
from typing import List, Set, Optional
from enum import Enum
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from PIL import Image
from tqdm import tqdm
@ -11,7 +13,7 @@ from modules.textual_inversion.dataset import re_numbers_at_start
from scripts.singleton import Singleton
from scripts import logger
from scripts import logger, utilities
from scripts.paths import paths
from . import (
@ -21,14 +23,17 @@ from . import (
taggers_builtin
)
from .custom_scripts import CustomScripts
from .interrogator_names import BLIP2_CAPTIONING_NAMES, WD_TAGGERS, WD_TAGGERS_TIMM
from scripts.tokenizer import clip_tokenizer
from scripts.tagger import Tagger
re_tags = re.compile(r"^([\s\S]+?)( \[\d+\])?$")
re_newlines = re.compile(r"[\r\n]+")
def convert_rgb(data:Image.Image):
return data.convert("RGB")
def get_square_rgb(data:Image.Image):
data_rgb = utilities.get_rgb_image(data)
size = max(data.size)
return utilities.resize_and_fill(data_rgb, (size, size))
class DatasetTagEditor(Singleton):
class SortBy(Enum):
@ -68,39 +73,26 @@ class DatasetTagEditor(Singleton):
custom_taggers:list[Tagger] = custom_tagger_scripts.load_derived_classes(Tagger)
logger.write(f"Custom taggers loaded: {[tagger().name() for tagger in custom_taggers]}")
self.BLIP2_CAPTIONING_NAMES = [
"blip2-opt-2.7b",
"blip2-opt-2.7b-coco",
"blip2-opt-6.7b",
"blip2-opt-6.7b-coco",
"blip2-flan-t5-xl",
"blip2-flan-t5-xl-coco",
"blip2-flan-t5-xxl",
]
self.WD_TAGGERS = {
"wd-v1-4-vit-tagger" : 0.35,
"wd-v1-4-convnext-tagger" : 0.35,
"wd-v1-4-vit-tagger-v2" : 0.3537,
"wd-v1-4-convnext-tagger-v2" : 0.3685,
"wd-v1-4-convnextv2-tagger-v2" : 0.371,
"wd-v1-4-swinv2-tagger-v2" : 0.3771,
"wd-v1-4-moat-tagger-v2" : 0.3771,
"wd-vit-tagger-v3" : 0.2614,
"wd-convnext-tagger-v3" : 0.2682,
"wd-swinv2-tagger-v3" : 0.2653,
}
# {tagger name : default tagger threshold}
# v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf
def read_wd_batchsize(name:str):
if "vit" in name:
return shared.opts.dataset_editor_batch_size_vit
elif "convnext" in name:
return shared.opts.dataset_editor_batch_size_convnext
elif "swinv2" in name:
return shared.opts.dataset_editor_batch_size_swinv2
self.INTERROGATORS = (
[taggers_builtin.BLIP()]
+ [taggers_builtin.BLIP2(name) for name in self.BLIP2_CAPTIONING_NAMES]
+ [taggers_builtin.BLIP2(name) for name in BLIP2_CAPTIONING_NAMES]
+ [taggers_builtin.GITLarge()]
+ [taggers_builtin.DeepDanbooru()]
+ [
taggers_builtin.WaifuDiffusion(name, threshold)
for name, threshold in self.WD_TAGGERS.items()
for name, threshold in WD_TAGGERS.items()
]
+ [
taggers_builtin.WaifuDiffusionTimm(name, threshold, int(read_wd_batchsize(name)))
for name, threshold in WD_TAGGERS_TIMM.items()
]
+ [taggers_builtin.Z3D_E621()]
+ [cls_tagger() for cls_tagger in custom_taggers]
@ -109,7 +101,7 @@ class DatasetTagEditor(Singleton):
def interrogate_image(self, path: str, interrogator_name: str, threshold_booru, threshold_wd, threshold_z3d):
try:
img = Image.open(path).convert("RGB")
img = get_square_rgb(Image.open(path))
except:
return ""
else:
@ -118,7 +110,7 @@ class DatasetTagEditor(Singleton):
if isinstance(it, taggers_builtin.DeepDanbooru):
with it as tg:
res = tg.predict(img, threshold_booru)
elif isinstance(it, taggers_builtin.WaifuDiffusion):
elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm):
with it as tg:
res = tg.predict(img, threshold_wd)
elif isinstance(it, taggers_builtin.Z3D_E621):
@ -704,19 +696,27 @@ class DatasetTagEditor(Singleton):
continue
try:
img = Image.open(img_path)
if (max_res > 0):
img_res = int(max_res), int(max_res)
img.thumbnail(img_res)
except:
continue
else:
abs_path = str(img_path.absolute())
if not use_temp_dir and max_res <= 0:
img.already_saved_as = abs_path
images[abs_path] = img
imgpaths.append(abs_path)
imgpaths.append(abs_path)
return imgpaths, images
def load_thumbnails(images_raw: dict[str, Image.Image]):
images = {}
if max_res > 0:
for img_path, img in images_raw.items():
img_res = int(max_res), int(max_res)
images[img_path] = img.copy()
images[img_path].thumbnail(img_res)
else:
for img_path, img in images_raw.items():
if not use_temp_dir:
img.already_saved_as = img_path
images[img_path] = img
return images
def load_captions(imgpaths: list[str]):
taglists = []
@ -752,7 +752,7 @@ class DatasetTagEditor(Singleton):
if it.name() in interrogator_names:
if isinstance(it, taggers_builtin.DeepDanbooru):
tagger_thresholds.append((it, threshold_booru))
elif isinstance(it, taggers_builtin.WaifuDiffusion):
elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm):
tagger_thresholds.append((it, threshold_waifu))
elif isinstance(it, taggers_builtin.Z3D_E621):
tagger_thresholds.append((it, threshold_z3d))
@ -760,15 +760,23 @@ class DatasetTagEditor(Singleton):
tagger_thresholds.append((it, None))
if kohya_json_path:
imgpaths, self.images, taglists = kohya_metadata.read(
imgpaths, images_raw, taglists = kohya_metadata.read(
img_dir, kohya_json_path, use_temp_dir
)
else:
imgpaths, self.images = load_images(filepaths)
imgpaths, images_raw = load_images(filepaths)
taglists = load_captions(imgpaths)
self.images = load_thumbnails(images_raw)
interrogate_tags = {img_path : [] for img_path in imgpaths}
if interrogate_method != self.InterrogateMethod.NONE:
img_to_interrogate = [
img_path for i, img_path in enumerate(imgpaths)
if (not taglists[i] or interrogate_method != self.InterrogateMethod.PREFILL)
]
if interrogate_method != self.InterrogateMethod.NONE and img_to_interrogate:
logger.write("Preprocess images...")
max_workers = shared.opts.dataset_editor_num_cpu_workers
if max_workers < 0:
@ -777,11 +785,11 @@ class DatasetTagEditor(Singleton):
def gen_data(paths:list[str], images:dict[str, Image.Image]):
for img_path in paths:
yield images.get(img_path)
yield images[img_path]
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=max_workers) as executor:
result = list(executor.map(convert_rgb, gen_data(imgpaths, self.images)))
result = list(executor.map(get_square_rgb, gen_data(img_to_interrogate, images_raw)))
logger.write("Preprocess completed")
for tg, th in tqdm(tagger_thresholds):
@ -798,10 +806,10 @@ class DatasetTagEditor(Singleton):
continue
try:
if use_pipe:
for img_path, tags in tqdm(zip(imgpaths, tg.predict_pipe(result, th)), desc=tg.name(), total=len(imgpaths)):
for img_path, tags in tqdm(zip(img_to_interrogate, tg.predict_pipe(result, th)), desc=tg.name(), total=len(img_to_interrogate)):
interrogate_tags[img_path] += tags
else:
for img_path, data in tqdm(zip(imgpaths, result), desc=tg.name(), total=len(imgpaths)):
for img_path, data in tqdm(zip(img_to_interrogate, result), desc=tg.name(), total=len(img_to_interrogate)):
interrogate_tags[img_path] += tg.predict(data, th)
except Exception as e:
tb = sys.exc_info()[2]
@ -814,7 +822,7 @@ class DatasetTagEditor(Singleton):
tags = interrogate_tags[img_path]
elif interrogate_method == self.InterrogateMethod.PREPEND:
tags = interrogate_tags[img_path] + tags
else:
elif interrogate_method != self.InterrogateMethod.PREFILL:
tags = tags + interrogate_tags[img_path]
self.set_tags_by_image_path(img_path, tags)

View File

@ -0,0 +1,28 @@
BLIP2_CAPTIONING_NAMES = [
"blip2-opt-2.7b",
"blip2-opt-2.7b-coco",
"blip2-opt-6.7b",
"blip2-opt-6.7b-coco",
"blip2-flan-t5-xl",
"blip2-flan-t5-xl-coco",
"blip2-flan-t5-xxl",
]
# {tagger name : default tagger threshold}
# v1: idk if it's okay v2, v3: P=R thresholds on each repo https://huggingface.co/SmilingWolf
WD_TAGGERS = {
"wd-v1-4-vit-tagger" : 0.35,
"wd-v1-4-convnext-tagger" : 0.35,
"wd-v1-4-vit-tagger-v2" : 0.3537,
"wd-v1-4-convnext-tagger-v2" : 0.3685,
"wd-v1-4-convnextv2-tagger-v2" : 0.371,
"wd-v1-4-moat-tagger-v2" : 0.3771
}
WD_TAGGERS_TIMM = {
"wd-v1-4-swinv2-tagger-v2" : 0.3771,
"wd-vit-tagger-v3" : 0.2614,
"wd-convnext-tagger-v3" : 0.2682,
"wd-swinv2-tagger-v3" : 0.2653,
}

View File

@ -1,7 +1,8 @@
from .blip2_captioning import BLIP2Captioning
from .git_large_captioning import GITLargeCaptioning
from .waifu_diffusion_tagger import WaifuDiffusionTagger
from .waifu_diffusion_tagger_timm import WaifuDiffusionTaggerTimm
__all__ = [
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger'
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger', 'WaifuDiffusionTaggerTimm'
]

View File

@ -0,0 +1,104 @@
from PIL import Image
from typing import Tuple
import torch
from torch.nn import functional as F
import torchvision.transforms as tf
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from modules import shared, devices
import launch
class ImageDataset(Dataset):
def __init__(self, images:list[Image.Image], transforms:tf.Compose=None):
self.images = images
self.transforms = transforms
def __len__(self):
return len(self.images)
def __getitem__(self, i):
img = self.images[i]
if self.transforms is not None:
img = self.transforms(img)
return img
class WaifuDiffusionTaggerTimm:
# some codes are brought from https://github.com/neggles/wdv3-timm and modified
def __init__(self, model_repo, label_filename="selected_tags.csv"):
self.LABEL_FILENAME = label_filename
self.MODEL_REPO = model_repo
self.model = None
self.transform = None
self.labels = []
def load(self):
import huggingface_hub
if not launch.is_installed("timm"):
launch.run_pip(
"install -U timm",
"requirements for dataset-tag-editor [timm]",
)
import timm
from timm.data import create_transform, resolve_data_config
if not self.model:
self.model: torch.nn.Module = timm.create_model(
"hf-hub:" + self.MODEL_REPO
).eval()
state_dict = timm.models.load_state_dict_from_hf(self.MODEL_REPO)
self.model.load_state_dict(state_dict)
self.model.to(devices.device)
self.transform = create_transform(
**resolve_data_config(self.model.pretrained_cfg, model=self.model)
)
path_label = huggingface_hub.hf_hub_download(
self.MODEL_REPO, self.LABEL_FILENAME
)
import pandas as pd
self.labels = pd.read_csv(path_label)["name"].tolist()
def unload(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.model = None
devices.torch_gc()
def apply(self, image: Image.Image):
if not self.model:
return []
image_t: torch.Tensor = self.transform(image).unsqueeze(0)
image_t = image_t[:, [2, 1, 0]]
image_t = image_t.to(devices.device)
with torch.inference_mode():
features = self.model.forward(image_t)
probs = F.sigmoid(features).detach().cpu().numpy()
labels: list[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))
return labels
def apply_multi(self, images: list[Image.Image], batch_size: int):
if not self.model:
return []
dataset = ImageDataset(images, self.transform)
dataloader = DataLoader(dataset, batch_size=batch_size)
with torch.inference_mode():
for batch in tqdm(dataloader):
batch = batch[:, [2, 1, 0]].to(devices.device)
features = self.model.forward(batch)
probs = F.sigmoid(features).detach().cpu().numpy()
labels: list[Tuple[str, float]] = [list(zip(self.labels, probs[i].astype(float))) for i in range(probs.shape[0])]
yield labels

View File

@ -8,7 +8,7 @@ from modules import devices, shared
from modules import deepbooru as db
from scripts.tagger import Tagger, get_replaced_tag
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger, WaifuDiffusionTaggerTimm
class BLIP(Tagger):
@ -142,6 +142,28 @@ class WaifuDiffusion(Tagger):
return self.repo_name
class WaifuDiffusionTimm(WaifuDiffusion):
def __init__(self, repo_name, threshold, batch_size=4):
super().__init__(repo_name, threshold)
self.tagger_inst = WaifuDiffusionTaggerTimm("SmilingWolf/" + repo_name)
self.batch_size = batch_size
def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None):
for labels_list in self.tagger_inst.apply_multi(data, batch_size=self.batch_size):
for labels in labels_list:
if not shared.opts.dataset_editor_use_rating:
labels = labels[4:]
if threshold is not None:
if threshold < 0:
threshold = self.threshold
tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold]
else:
tags = [get_replaced_tag(tag) for tag, _ in labels]
yield tags
class Z3D_E621(Tagger):
def __init__(self):
self.tagger_inst = WaifuDiffusionTagger("toynya/Z3D-E621-Convnext", label_filename="tags-selected.csv")

View File

@ -413,10 +413,6 @@ def on_ui_tabs():
ui.move_or_delete_files.set_callbacks(
o_update_filter_and_gallery,
ui.dataset_gallery,
ui.filter_by_tags,
ui.batch_edit_captions,
ui.filter_by_selection,
ui.edit_caption_of_selected_image,
get_filters,
update_filter_and_gallery,
)
@ -472,6 +468,33 @@ def on_ui_settings():
),
)
shared.opts.add_option(
"dataset_editor_batch_size_vit",
shared.OptionInfo(
4,
"Inference batch size for ViT taggers",
section=section,
),
)
shared.opts.add_option(
"dataset_editor_batch_size_convnext",
shared.OptionInfo(
4,
"Inference batch size for ConvNeXt taggers",
section=section,
),
)
shared.opts.add_option(
"dataset_editor_batch_size_swinv2",
shared.OptionInfo(
4,
"Inference batch size for SwinTransformerV2 taggers",
section=section,
),
)
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_ui_tabs(on_ui_tabs)

View File

@ -199,6 +199,7 @@ class LoadDatasetUI(UIBase):
self.sl_custom_threshold_booru,
self.cb_use_custom_threshold_waifu,
self.sl_custom_threshold_waifu,
self.sl_custom_threshold_z3d,
toprow.cb_save_kohya_metadata,
toprow.tb_metadata_output,
],

View File

@ -33,14 +33,14 @@ class ToprowUI(UIBase):
def set_callbacks(self, load_dataset:LoadDatasetUI):
def save_all_changes(backup: bool, save_kohya_metadata:bool, metadata_output:str, metadata_input:str, metadata_overwrite:bool, metadata_as_caption:bool, metadata_use_fullpath:bool):
def save_all_changes(backup: bool, save_kohya_metadata:bool, metadata_output:str, metadata_input:str, metadata_overwrite:bool, metadata_as_caption:bool, metadata_use_fullpath:bool, caption_file_ext:str):
if not metadata_input:
metadata_input = None
dte_instance.save_dataset(backup, load_dataset.caption_file_ext, save_kohya_metadata, metadata_output, metadata_input, metadata_overwrite, metadata_as_caption, metadata_use_fullpath)
dte_instance.save_dataset(backup, caption_file_ext, save_kohya_metadata, metadata_output, metadata_input, metadata_overwrite, metadata_as_caption, metadata_use_fullpath)
self.btn_save_all_changes.click(
fn=save_all_changes,
inputs=[self.cb_backup, self.cb_save_kohya_metadata, self.tb_metadata_output, self.tb_metadata_input, self.cb_metadata_overwrite, self.cb_metadata_as_caption, self.cb_metadata_use_fullpath]
inputs=[self.cb_backup, self.cb_save_kohya_metadata, self.tb_metadata_output, self.tb_metadata_input, self.cb_metadata_overwrite, self.cb_metadata_as_caption, self.cb_metadata_use_fullpath, load_dataset.tb_caption_file_ext]
)
self.cb_save_kohya_metadata.change(

View File

@ -133,16 +133,17 @@ class EditCaptionOfSelectedImageUI(UIBase):
outputs=[self.tb_edit_caption]
)
def interrogate_selected_image(interrogator_name: str, use_threshold_booru: bool, threshold_booru: float, use_threshold_waifu: bool, threshold_waifu: float):
def interrogate_selected_image(interrogator_name: str, use_threshold_booru: bool, threshold_booru: float, use_threshold_waifu: bool, threshold_waifu: float, threshold_z3d: float):
if not interrogator_name:
return ''
threshold_booru = threshold_booru if use_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold
threshold_waifu = threshold_waifu if use_threshold_waifu else -1
return dte_instance.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_waifu)
return dte_instance.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_waifu, threshold_z3d)
self.btn_interrogate_si.click(
fn=interrogate_selected_image,
inputs=[self.dd_intterogator_names_si, load_dataset.cb_use_custom_threshold_booru, load_dataset.sl_custom_threshold_booru, load_dataset.cb_use_custom_threshold_waifu, load_dataset.sl_custom_threshold_waifu],
inputs=[self.dd_intterogator_names_si, load_dataset.cb_use_custom_threshold_booru, load_dataset.sl_custom_threshold_booru, load_dataset.cb_use_custom_threshold_waifu, load_dataset.sl_custom_threshold_waifu, load_dataset.sl_custom_threshold_z3d],
outputs=[self.tb_interrogate]
)

View File

@ -37,7 +37,7 @@ class FilterByTagsUI(UIBase):
common_callback = lambda : \
update_gallery() + \
batch_edit_captions.get_common_tags(get_filters, self) + \
[move_or_delete_files.get_current_move_or_delete_target_num()] + \
[move_or_delete_files.update_current_move_or_delete_target_num()] + \
[batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
common_callback_output = \

View File

@ -14,6 +14,7 @@ class MoveOrDeleteFilesUI(UIBase):
def __init__(self):
self.target_data = 'Selected One'
self.current_target_txt = ''
self.update_func = None
def create_ui(self, cfg_file_move_delete):
gr.HTML(value='<b>Note: </b>Moved or deleted images will be unloaded.')
@ -27,11 +28,15 @@ class MoveOrDeleteFilesUI(UIBase):
gr.HTML(value='<b>Note: </b>DELETE cannot be undone. The files will be deleted completely.')
self.btn_move_or_delete_delete_files = gr.Button(value='DELETE File(s)', variant='primary')
def get_current_move_or_delete_target_num(self):
return self.current_target_txt
def update_current_move_or_delete_target_num(self):
if self.update_func:
return self.update_func(self.target_data)
else:
return self.current_target_txt
def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, filter_by_tags:FilterByTagsUI, batch_edit_captions:BatchEditCaptionsUI, filter_by_selection:FilterBySelectionUI, edit_caption_of_selected_image:EditCaptionOfSelectedImageUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
def _get_current_move_or_delete_target_num():
def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
def _get_current_move_or_delete_target_num(text: str):
self.target_data = text
if self.target_data == 'Selected One':
self.current_target_txt = f'Target dataset num: {1 if dataset_gallery.selected_index != -1 else 0}'
elif self.target_data == 'All Displayed Ones':
@ -41,25 +46,17 @@ class MoveOrDeleteFilesUI(UIBase):
self.current_target_txt = f'Target dataset num: 0'
return self.current_target_txt
self.update_func = _get_current_move_or_delete_target_num
update_args = {
'fn' : _get_current_move_or_delete_target_num,
'inputs' : None,
'fn': self.update_func,
'inputs': [self.rb_move_or_delete_target_data],
'outputs' : [self.ta_move_or_delete_target_dataset_num]
}
batch_edit_captions.btn_apply_edit_tags.click(lambda:None).then(**update_args)
batch_edit_captions.btn_apply_sr_tags.click(lambda:None).then(**update_args)
filter_by_selection.btn_apply_image_selection_filter.click(lambda:None).then(**update_args)
filter_by_tags.btn_clear_tag_filters.click(lambda:None).then(**update_args)
filter_by_tags.btn_clear_all_filters.click(lambda:None).then(**update_args)
edit_caption_of_selected_image.btn_apply_changes_selected_image.click(lambda:None).then(**update_args)
self.rb_move_or_delete_target_data.change(**update_args)
dataset_gallery.cbg_hidden_dataset_filter.change(lambda:None).then(**update_args)
dataset_gallery.nb_hidden_image_index.change(lambda:None).then(**update_args)
def move_files(
target_data: str,

68
scripts/utilities.py Normal file
View File

@ -0,0 +1,68 @@
from typing import Tuple
import math
from PIL import Image
if not hasattr(Image, 'Resampling'): # Pillow<9.0
Image.Resampling = Image
def resize(image: Image.Image, size: Tuple[int, int]):
return image.resize(size, resample=Image.Resampling.LANCZOS)
def get_rgb_image(image:Image.Image):
if image.mode not in ["RGB", "RGBA"]:
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
if image.mode == "RGBA":
white = Image.new("RGBA", image.size, (255, 255, 255, 255))
white.alpha_composite(image)
image = white.convert("RGB")
return image
def resize_and_fill(image: Image.Image, size: Tuple[int, int], repeat_edge = True, fill_rgb:tuple[int,int,int] = (255, 255, 255)):
width, height = size
scale_w, scale_h = width / image.width, height / image.height
resized_w, resized_h = width, height
if scale_w < scale_h:
resized_h = image.height * resized_w // image.width
elif scale_h < scale_w:
resized_w = image.width * resized_h // image.height
resized = resize(image, (resized_w, resized_h))
if resized_w == width and resized_h == height:
return resized
if repeat_edge:
fill_l = math.floor((width - resized_w) / 2)
fill_r = width - resized_w - fill_l
fill_t = math.floor((height - resized_h) / 2)
fill_b = height - resized_h - fill_t
result = Image.new("RGB", (width, height))
result.paste(resized, (fill_l, fill_t))
if fill_t > 0:
result.paste(resized.resize((width, fill_t), box=(0, 0, width, 0)), (0, 0))
if fill_b > 0:
result.paste(
resized.resize(
(width, fill_b), box=(0, resized.height, width, resized.height)
),
(0, resized.height + fill_t),
)
if fill_l > 0:
result.paste(resized.resize((fill_l, height), box=(0, 0, 0, height)), (0, 0))
if fill_r > 0:
result.paste(
resized.resize(
(fill_r, height), box=(resized.width, 0, resized.width, height)
),
(resized.width + fill_l, 0),
)
return result
else:
result = Image.new("RGB", size, fill_rgb)
result.paste(resized, box=((width - resized_w) // 2, (height - resized_h) // 2))
return result.convert("RGB")

View File

@ -15,6 +15,19 @@ from scripts.tagger import Tagger
# I'm not sure if this is really working
BATCH_SIZE = 3
# tags used in Animagine-XL
SCORE_N = {
'very aesthetic':0.71,
'aesthetic':0.45,
'displeasing':0.27,
'very displeasing':-float('inf'),
}
def get_aesthetic_tag(score:float):
for k, v in SCORE_N.items():
if score > v:
return k
class AestheticShadowV2(Tagger):
def load(self):
if devices.device.index is None:
@ -40,8 +53,7 @@ class AestheticShadowV2(Tagger):
for d in data:
final[d["label"]] = d["score"]
hq = final['hq']
lq = final['lq']
return [f"score_{math.floor((hq + (1 - lq))/2 * 10)}"]
return [get_aesthetic_tag(hq)]
def predict(self, image: Image.Image, threshold=None):
data = self.pipe_aesthetic(image)

View File

@ -36,9 +36,10 @@ class CafeAIAesthetic(Tagger):
final = {}
for d in data:
final[d["label"]] = d["score"]
nae = final['not_aesthetic']
ae = final['aesthetic']
return [f"score_{math.floor((ae + (1 - nae))/2 * 10)}"]
# edit here to change tag
return [f"[CAFE]score_{math.floor(ae*10)}"]
def predict(self, image: Image.Image, threshold=None):
data = self.pipe_aesthetic(image, top_k=2)

View File

@ -69,7 +69,8 @@ class ImprovedAestheticPredictor(Tagger):
def predict(self, image: Image.Image, threshold=None):
image_embeds = image_embeddings(image, self.clip_model, self.clip_processor)
prediction:torch.Tensor = self.model(torch.from_numpy(image_embeds).float().to(devices.device))
return [f"score_{math.floor(prediction.item())}"]
# edit here to change tag
return [f"[IAP]score_{math.floor(prediction.item())}"]
def name(self):
return "Improved Aesthetic Predictor"

View File

@ -67,7 +67,8 @@ class WaifuAesthetic(Tagger):
def predict(self, image: Image.Image, threshold=None):
image_embeds = image_embeddings(image, self.clip_model, self.clip_processor)
prediction:torch.Tensor = self.model(torch.from_numpy(image_embeds).float().to(devices.device))
return [f"score_{math.floor(prediction.item()*10)}"]
# edit here to change tag
return [f"[WD]score_{math.floor(prediction.item()*10)}"]
def name(self):
return "wd aesthetic classifier"

View File

@ -1 +1 @@
0.3.1
0.3.4