Merge changes in standalone version (#93)

* Merge changes in standalone version
 - New Taggers and Custom Tagger
 - a little bit stable UI
feature/visualize-tokens
toshiaki1729 2024-05-08 03:58:58 +09:00 committed by GitHub
parent 7a2f4c53fb
commit c3252d8325
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 1839 additions and 735 deletions

143
scripts/config.py Normal file
View File

@ -0,0 +1,143 @@
from collections import namedtuple
import json
from scripts import logger
from scripts.paths import paths
from scripts.dte_instance import dte_instance
SortBy = dte_instance.SortBy
SortOrder = dte_instance.SortOrder
CONFIG_PATH = paths.base_path / "config.json"
GeneralConfig = namedtuple(
"GeneralConfig",
[
"backup",
"dataset_dir",
"caption_ext",
"load_recursive",
"load_caption_from_filename",
"replace_new_line",
"use_interrogator",
"use_interrogator_names",
"use_custom_threshold_booru",
"custom_threshold_booru",
"use_custom_threshold_waifu",
"custom_threshold_waifu",
"custom_threshold_z3d",
"save_kohya_metadata",
"meta_output_path",
"meta_input_path",
"meta_overwrite",
"meta_save_as_caption",
"meta_use_full_path",
],
)
FilterConfig = namedtuple(
"FilterConfig",
["sw_prefix", "sw_suffix", "sw_regex", "sort_by", "sort_order", "logic"],
)
BatchEditConfig = namedtuple(
"BatchEditConfig",
[
"show_only_selected",
"prepend",
"use_regex",
"target",
"sw_prefix",
"sw_suffix",
"sw_regex",
"sory_by",
"sort_order",
"batch_sort_by",
"batch_sort_order",
"token_count",
],
)
EditSelectedConfig = namedtuple(
"EditSelectedConfig",
[
"auto_copy",
"sort_on_save",
"warn_change_not_saved",
"use_interrogator_name",
"sort_by",
"sort_order",
],
)
MoveDeleteConfig = namedtuple(
"MoveDeleteConfig", ["range", "target", "caption_ext", "destination"]
)
CFG_GENERAL_DEFAULT = GeneralConfig(
True,
"",
".txt",
False,
True,
False,
"No",
[],
False,
0.7,
False,
0.35,
0.35,
False,
"",
"",
True,
False,
False,
)
CFG_FILTER_P_DEFAULT = FilterConfig(
False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, "AND"
)
CFG_FILTER_N_DEFAULT = FilterConfig(
False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, "OR"
)
CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(
True,
False,
False,
"Only Selected Tags",
False,
False,
False,
SortBy.ALPHA.value,
SortOrder.ASC.value,
SortBy.ALPHA.value,
SortOrder.ASC.value,
75,
)
CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(
False, False, False, "", SortBy.ALPHA.value, SortOrder.ASC.value
)
CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig("Selected One", [], ".txt", "")
class Config:
def __init__(self):
self.config = dict()
def load(self):
if not CONFIG_PATH.is_file():
self.config = dict()
return
try:
self.config = json.loads(CONFIG_PATH.read_text("utf8"))
except:
logger.warn("Error on loading config.json. Default settings will be loaded.")
self.config = dict()
else:
logger.write("Settings has been read from config.json")
def save(self):
CONFIG_PATH.write_text(json.dumps(self.config, indent=4), "utf8")
def read(self, name: str):
return self.config.get(name)
def write(self, cfg: dict, name: str):
self.config[name] = cfg

View File

@ -1,8 +1,8 @@
from . import tagger
from . import captioning
from . import taggers_builtin
from . import filters
from . import dataset as ds
from . import kohya_finetune_metadata
from .dte_logic import DatasetTagEditor, INTERROGATOR_NAMES, interrogate_image
from .dte_logic import DatasetTagEditor
__all__ = ["ds", "tagger", "captioning", "filters", "kohya_metadata", "INTERROGATOR_NAMES", "interrogate_image", "DatasetTagEditor"]
__all__ = ["ds", "taggers_builtin", "filters", "kohya_finetune_metadata", "DatasetTagEditor"]

View File

@ -1,47 +0,0 @@
import modules.shared as shared
from .interrogator import Interrogator
from .interrogators import GITLargeCaptioning
class Captioning(Interrogator):
def start(self):
pass
def stop(self):
pass
def predict(self, image):
raise NotImplementedError()
def name(self):
raise NotImplementedError()
class BLIP(Captioning):
def start(self):
shared.interrogator.load()
def stop(self):
shared.interrogator.unload()
def predict(self, image):
tags = shared.interrogator.generate_caption(image).split(',')
return [t for t in tags if t]
def name(self):
return 'BLIP'
class GITLarge(Captioning):
def __init__(self):
self.interrogator = GITLargeCaptioning()
def start(self):
self.interrogator.load()
def stop(self):
self.interrogator.unload()
def predict(self, image):
tags = self.interrogator.apply(image).split(',')
return [t for t in tags if t]
def name(self):
return 'GIT-large-COCO'

View File

@ -0,0 +1,50 @@
import sys
from pathlib import Path
import importlib.util
from types import ModuleType
from scripts import logger
from scripts.paths import paths
class CustomScripts:
def _load_module_from(self, path:Path):
module_spec = importlib.util.spec_from_file_location(path.stem, path)
module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)
return module
def _load_derived_classes(self, module:ModuleType, base_class:type):
derived_classes = []
for name in dir(module):
obj = getattr(module, name)
if isinstance(obj, type) and issubclass(obj, base_class) and obj is not base_class:
derived_classes.append(obj)
return derived_classes
def __init__(self, scripts_dir:Path) -> None:
self.scripts = dict()
self.scripts_dir = scripts_dir.absolute()
def load_derived_classes(self, baseclass:type):
back_syspath = sys.path
if not self.scripts_dir.is_dir():
logger.warn(f"NOT A DIRECTORY: {self.scripts_dir}")
return []
classes = []
try:
sys.path = [str(paths.base_path)] + sys.path
for path in self.scripts_dir.glob("*.py"):
self.scripts[path.stem] = self._load_module_from(path)
for module in self.scripts.values():
classes.extend(self._load_derived_classes(module, baseclass))
except Exception as e:
tb = sys.exc_info()[2]
logger.error(f"Error on loading {path}")
logger.error(e.with_traceback(tb))
finally:
sys.path = back_syspath
return classes

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +0,0 @@
class Interrogator:
def __enter__(self):
self.start()
return self
def __exit__(self, exception_type, exception_value, traceback):
self.stop()
pass
def start(self):
pass
def stop(self):
pass
def predict(self, image, **kwargs):
raise NotImplementedError()
def name(self):
raise NotImplementedError()

View File

@ -1,6 +1,7 @@
from .blip2_captioning import BLIP2Captioning
from .git_large_captioning import GITLargeCaptioning
from .waifu_diffusion_tagger import WaifuDiffusionTagger
__all__ = [
'GITLargeCaptioning', 'WaifuDiffusionTagger'
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger'
]

View File

@ -0,0 +1,33 @@
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from modules import devices, shared
from scripts.paths import paths
class BLIP2Captioning:
def __init__(self, model_repo: str):
self.MODEL_REPO = model_repo
self.processor: Blip2Processor = None
self.model: Blip2ForConditionalGeneration = None
def load(self):
if self.model is None or self.processor is None:
self.processor = Blip2Processor.from_pretrained(
self.MODEL_REPO, cache_dir=paths.setting_model_path
)
self.model = Blip2ForConditionalGeneration.from_pretrained(
self.MODEL_REPO, cache_dir=paths.setting_model_path
).to(devices.device)
def unload(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.model = None
self.processor = None
devices.torch_gc()
def apply(self, image):
if self.model is None or self.processor is None:
return ""
inputs = self.processor(images=image, return_tensors="pt").to(devices.device)
ids = self.model.generate(**inputs)
return self.processor.batch_decode(ids, skip_special_tokens=True)

View File

@ -1,26 +1,35 @@
from transformers import AutoProcessor, AutoModelForCausalLM
from modules import shared
from modules import shared, devices, lowvram
# brought from https://huggingface.co/docs/transformers/main/en/model_doc/git and modified
class GITLargeCaptioning():
class GITLargeCaptioning:
MODEL_REPO = "microsoft/git-large-coco"
def __init__(self):
self.processor:AutoProcessor = None
self.model:AutoModelForCausalLM = None
self.processor: AutoProcessor = None
self.model: AutoModelForCausalLM = None
def load(self):
if self.model is None or self.processor is None:
self.processor = AutoProcessor.from_pretrained(self.MODEL_REPO)
self.model = AutoModelForCausalLM.from_pretrained(self.MODEL_REPO).to(shared.device)
self.model = AutoModelForCausalLM.from_pretrained(self.MODEL_REPO).to(
shared.device
)
lowvram.send_everything_to_cpu()
def unload(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.model = None
self.processor = None
devices.torch_gc()
def apply(self, image):
if self.model is None or self.processor is None:
return ''
inputs = self.processor(images=image, return_tensors='pt').to(shared.device)
ids = self.model.generate(pixel_values=inputs.pixel_values, max_length=shared.opts.interrogate_clip_max_length)
return self.processor.batch_decode(ids, skip_special_tokens=True)[0]
return ""
inputs = self.processor(images=image, return_tensors="pt").to(shared.device)
ids = self.model.generate(
pixel_values=inputs.pixel_values,
max_length=shared.opts.interrogate_clip_max_length,
)
return self.processor.batch_decode(ids, skip_special_tokens=True)[0]

View File

@ -1,74 +1,103 @@
from PIL import Image
import numpy as np
from typing import List, Tuple
from modules import shared
from modules import shared, devices
import launch
class WaifuDiffusionTagger():
class WaifuDiffusionTagger:
# brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
def __init__(self, model_name):
def __init__(
self,
model_name,
model_filename="model.onnx",
label_filename="selected_tags.csv",
):
self.MODEL_FILENAME = model_filename
self.LABEL_FILENAME = label_filename
self.MODEL_REPO = model_name
self.model = None
self.labels = []
def load(self):
import huggingface_hub
if not self.model:
path_model = huggingface_hub.hf_hub_download(
self.MODEL_REPO, self.MODEL_FILENAME
)
if 'all' in shared.cmd_opts.use_cpu or 'interrogate' in shared.cmd_opts.use_cpu:
providers = ['CPUExecutionProvider']
if (
"all" in shared.cmd_opts.use_cpu
or "interrogate" in shared.cmd_opts.use_cpu
):
providers = ["CPUExecutionProvider"]
else:
providers = ['CUDAExecutionProvider', 'DmlExecutionProvider', 'CPUExecutionProvider']
providers = [
"CUDAExecutionProvider",
"DmlExecutionProvider",
"CPUExecutionProvider",
]
def check_available_device():
import torch
if torch.cuda.is_available():
return 'cuda'
return "cuda"
elif launch.is_installed("torch-directml"):
# This code cannot detect DirectML available device without pytorch-directml
try:
import torch_directml
torch_directml.device()
except:
pass
else:
return 'directml'
return 'cpu'
return "directml"
return "cpu"
if not launch.is_installed("onnxruntime"):
dev = check_available_device()
if dev == 'cuda':
launch.run_pip("install -U onnxruntime-gpu", "requirements for dataset-tag-editor [onnxruntime-gpu]")
elif dev == 'directml':
launch.run_pip("install -U onnxruntime-directml", "requirements for dataset-tag-editor [onnxruntime-directml]")
if dev == "cuda":
launch.run_pip(
"install -U onnxruntime-gpu",
"requirements for dataset-tag-editor [onnxruntime-gpu]",
)
elif dev == "directml":
launch.run_pip(
"install -U onnxruntime-directml",
"requirements for dataset-tag-editor [onnxruntime-directml]",
)
else:
print('Your device is not compatible with onnx hardware acceleration. CPU only version will be installed and it may be very slow.')
launch.run_pip("install -U onnxruntime", "requirements for dataset-tag-editor [onnxruntime for CPU]")
print(
"Your device is not compatible with onnx hardware acceleration. CPU only version will be installed and it may be very slow."
)
launch.run_pip(
"install -U onnxruntime",
"requirements for dataset-tag-editor [onnxruntime for CPU]",
)
import onnxruntime as ort
self.model = ort.InferenceSession(path_model, providers=providers)
path_label = huggingface_hub.hf_hub_download(
self.MODEL_REPO, self.LABEL_FILENAME
)
import pandas as pd
self.labels = pd.read_csv(path_label)["name"].tolist()
def unload(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.model = None
devices.torch_gc()
# brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
def apply(self, image: Image.Image):
if not self.model:
return dict()
from modules import images
_, height, width, _ = self.model.get_inputs()[0].shape
# the way to fill empty pixels is quite different from original one;
@ -85,4 +114,4 @@ class WaifuDiffusionTagger():
probs = self.model.run([label_name], {input_name: image_np})[0]
labels: List[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))
return labels
return labels

View File

@ -1,106 +0,0 @@
from PIL import Image
import re
import torch
import numpy as np
from typing import Optional, Dict
from modules import devices, shared
from modules import deepbooru as db
from .interrogator import Interrogator
from .interrogators import WaifuDiffusionTagger
class Tagger(Interrogator):
def start(self):
pass
def stop(self):
pass
def predict(self, image: Image.Image, threshold: Optional[float]):
raise NotImplementedError()
def name(self):
raise NotImplementedError()
def get_replaced_tag(tag: str):
use_spaces = shared.opts.deepbooru_use_spaces
use_escape = shared.opts.deepbooru_escape
if use_spaces:
tag = tag.replace('_', ' ')
if use_escape:
tag = re.sub(db.re_special, r'\\\1', tag)
return tag
def get_arranged_tags(probs: Dict[str, float]):
alpha_sort = shared.opts.deepbooru_sort_alpha
if alpha_sort:
return sorted(probs)
else:
return [tag for tag, _ in sorted(probs.items(), key=lambda x: -x[1])]
class DeepDanbooru(Tagger):
def start(self):
db.model.start()
def stop(self):
db.model.stop()
# brought from webUI modules/deepbooru.py and modified
def predict(self, image: Image.Image, threshold: Optional[float] = None):
from modules import images
pic = images.resize_image(2, image.convert("RGB"), 512, 512)
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with torch.no_grad(), devices.autocast():
x = torch.from_numpy(a).to(devices.device)
y = db.model.model(x)[0].detach().cpu().numpy()
probability_dict = dict()
for tag, probability in zip(db.model.model.tags, y):
if threshold and probability < threshold:
continue
if tag.startswith("rating:"):
continue
probability_dict[get_replaced_tag(tag)] = probability
return probability_dict
def name(self):
return 'DeepDanbooru'
class WaifuDiffusion(Tagger):
def __init__(self, repo_name, threshold):
self.repo_name = repo_name
self.tagger_inst = WaifuDiffusionTagger("SmilingWolf/" + repo_name)
self.threshold = threshold
def start(self):
self.tagger_inst.load()
return self
def stop(self):
self.tagger_inst.unload()
# brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
# set threshold<0 to use default value for now...
def predict(self, image: Image.Image, threshold: Optional[float] = None):
# may not use ratings
# rating = dict(labels[:4])
labels = self.tagger_inst.apply(image)
if threshold is not None:
if threshold < 0:
threshold = self.threshold
probability_dict = dict([(get_replaced_tag(x[0]), x[1]) for x in labels[4:] if x[1] > threshold])
else:
probability_dict = dict([(get_replaced_tag(x[0]), x[1]) for x in labels[4:]])
return probability_dict
def name(self):
return self.repo_name

View File

@ -0,0 +1,171 @@
from typing import Optional
from PIL import Image
import numpy as np
import torch
from modules import devices, shared
from modules import deepbooru as db
from scripts.tagger import Tagger, get_replaced_tag
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger
class BLIP(Tagger):
def start(self):
shared.interrogator.load()
def stop(self):
shared.interrogator.unload()
def predict(self, image:Image.Image, threshold=None):
tags = shared.interrogator.generate_caption(image).split(',')
return [t for t in tags if t]
def name(self):
return 'BLIP'
class BLIP2(Tagger):
def __init__(self, repo_name):
self.interrogator = BLIP2Captioning("Salesforce/" + repo_name)
self.repo_name = repo_name
def start(self):
self.interrogator.load()
def stop(self):
self.interrogator.unload()
def predict(self, image:Image, threshold=None):
tags = self.interrogator.apply(image)[0].split(",")
return [t for t in tags if t]
# def predict_multi(self, images:list):
# captions = self.interrogator.apply(images)
# return [[t for t in caption.split(',') if t] for caption in captions]
def name(self):
return self.repo_name
class GITLarge(Tagger):
def __init__(self):
self.interrogator = GITLargeCaptioning()
def start(self):
self.interrogator.load()
def stop(self):
self.interrogator.unload()
def predict(self, image:Image, threshold=None):
tags = self.interrogator.apply(image)[0].split(",")
return [t for t in tags if t]
# def predict_multi(self, images:list):
# captions = self.interrogator.apply(images)
# return [[t for t in caption.split(',') if t] for caption in captions]
def name(self):
return "GIT-large-COCO"
class DeepDanbooru(Tagger):
def start(self):
db.model.start()
def stop(self):
db.model.stop()
# brought from webUI modules/deepbooru.py and modified
def predict(self, image: Image.Image, threshold: Optional[float] = None):
from modules import images
pic = images.resize_image(2, image.convert("RGB"), 512, 512)
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with torch.no_grad(), devices.autocast():
x = torch.from_numpy(a).to(devices.device)
y = db.model.model(x)[0].detach().cpu().numpy()
tags = []
for tag, probability in zip(db.model.model.tags, y):
if threshold and probability < threshold:
continue
if not shared.opts.dataset_editor_use_rating and tag.startswith("rating:"):
continue
tags.append(get_replaced_tag(tag))
return tags
def name(self):
return 'DeepDanbooru'
class WaifuDiffusion(Tagger):
def __init__(self, repo_name, threshold):
self.repo_name = repo_name
self.tagger_inst = WaifuDiffusionTagger("SmilingWolf/" + repo_name)
self.threshold = threshold
def start(self):
self.tagger_inst.load()
return self
def stop(self):
self.tagger_inst.unload()
# brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
# set threshold<0 to use default value for now...
def predict(self, image: Image.Image, threshold: Optional[float] = None):
# may not use ratings
# rating = dict(labels[:4])
labels = self.tagger_inst.apply(image)
if not shared.opts.dataset_editor_use_rating:
labels = labels[4:]
if threshold is not None:
if threshold < 0:
threshold = self.threshold
tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold]
else:
tags = [get_replaced_tag(tag) for tag, _ in labels]
return tags
def name(self):
return self.repo_name
class Z3D_E621(Tagger):
def __init__(self):
self.tagger_inst = WaifuDiffusionTagger("toynya/Z3D-E621-Convnext", label_filename="tags-selected.csv")
def start(self):
self.tagger_inst.load()
return self
def stop(self):
self.tagger_inst.unload()
# brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
# set threshold<0 to use default value for now...
def predict(self, image: Image.Image, threshold: Optional[float] = None):
# may not use ratings
# rating = dict(labels[:4])
labels = self.tagger_inst.apply(image)
if threshold is not None:
tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold]
else:
tags = [get_replaced_tag(tag) for tag, _ in labels]
return tags
def name(self):
return "Z3D-E621-Convnext"

View File

@ -1,2 +1,2 @@
import scripts.dataset_tag_editor as dte_module
dte_instance = dte_module.DatasetTagEditor.get_instance()
dte_instance = dte_module.DatasetTagEditor()

8
scripts/logger.py Normal file
View File

@ -0,0 +1,8 @@
def write(content):
print("[tag-editor] " + content)
def warn(content):
write(f"[tag-editor:WARNING] {content}")
def error(content):
write(f"[tag-editor:ERROR] {content}")

View File

@ -1,197 +1,165 @@
from typing import NamedTuple, Type, Dict, Any
from modules import shared, script_callbacks, scripts
from modules import shared, script_callbacks
from modules.shared import opts
import gradio as gr
import json
from pathlib import Path
from collections import namedtuple
from scripts.config import *
import scripts.tag_editor_ui as ui
from scripts.dte_instance import dte_instance
# ================================================================
# General Callbacks
# ================================================================
CONFIG_PATH = Path(scripts.basedir(), 'config.json')
SortBy = dte_instance.SortBy
SortOrder = dte_instance.SortOrder
GeneralConfig = namedtuple('GeneralConfig', [
'backup',
'dataset_dir',
'caption_ext',
'load_recursive',
'load_caption_from_filename',
'replace_new_line',
'use_interrogator',
'use_interrogator_names',
'use_custom_threshold_booru',
'custom_threshold_booru',
'use_custom_threshold_waifu',
'custom_threshold_waifu',
'save_kohya_metadata',
'meta_output_path',
'meta_input_path',
'meta_overwrite',
'meta_save_as_caption',
'meta_use_full_path'
])
FilterConfig = namedtuple('FilterConfig', ['sw_prefix', 'sw_suffix', 'sw_regex','sort_by', 'sort_order', 'logic'])
BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sw_prefix', 'sw_suffix', 'sw_regex', 'sory_by', 'sort_order', 'batch_sort_by', 'batch_sort_order', 'token_count'])
EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination'])
CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, False, 'No', [], False, 0.7, False, 0.35, False, '', '', True, False, False)
CFG_FILTER_P_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'AND')
CFG_FILTER_N_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'OR')
CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value, 75)
CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value)
CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig('Selected One', [], '.txt', '')
class Config:
def __init__(self):
self.config = dict()
def load(self):
if not CONFIG_PATH.is_file():
self.config = dict()
return
try:
self.config = json.loads(CONFIG_PATH.read_text('utf8'))
except:
print('[tag-editor] Error on loading config.json. Default settings will be loaded.')
self.config = dict()
else:
print('[tag-editor] Settings has been read from config.json')
def save(self):
CONFIG_PATH.write_text(json.dumps(self.config, indent=4), 'utf8')
def read(self, name: str):
return self.config.get(name)
def write(self, cfg: dict, name: str):
self.config[name] = cfg
config = Config()
def write_general_config(*args):
cfg = GeneralConfig(*args)
config.write(cfg._asdict(), 'general')
config.write(cfg._asdict(), "general")
def write_filter_config(*args):
hlen = len(args) // 2
cfg_p = FilterConfig(*args[:hlen])
cfg_n = FilterConfig(*args[hlen:])
config.write({'positive':cfg_p._asdict(), 'negative':cfg_n._asdict()}, 'filter')
config.write({"positive": cfg_p._asdict(), "negative": cfg_n._asdict()}, "filter")
def write_batch_edit_config(*args):
cfg = BatchEditConfig(*args)
config.write(cfg._asdict(), 'batch_edit')
config.write(cfg._asdict(), "batch_edit")
def write_edit_selected_config(*args):
cfg = EditSelectedConfig(*args)
config.write(cfg._asdict(), 'edit_selected')
config.write(cfg._asdict(), "edit_selected")
def write_move_delete_config(*args):
cfg = MoveDeleteConfig(*args)
config.write(cfg._asdict(), 'file_move_delete')
config.write(cfg._asdict(), "file_move_delete")
def read_config(name: str, config_type: Type, default: NamedTuple, compat_func = None):
def read_config(name: str, config_type: Type, default: NamedTuple, compat_func=None):
d = config.read(name)
cfg = default
if d:
if compat_func: d = compat_func(d)
if compat_func:
d = compat_func(d)
d = cfg._asdict() | d
d = {k:v for k,v in d.items() if k in cfg._asdict().keys()}
d = {k: v for k, v in d.items() if k in cfg._asdict().keys()}
cfg = config_type(**d)
return cfg
def read_general_config():
# for compatibility
generalcfg_intterogator_names = [
('use_blip_to_prefill', 'BLIP'),
('use_git_to_prefill', 'GIT-large-COCO'),
('use_booru_to_prefill', 'DeepDanbooru'),
('use_waifu_to_prefill', 'wd-v1-4-vit-tagger')
("use_blip_to_prefill", "BLIP"),
("use_git_to_prefill", "GIT-large-COCO"),
("use_booru_to_prefill", "DeepDanbooru"),
("use_waifu_to_prefill", "wd-v1-4-vit-tagger"),
]
use_interrogator_names = []
def compat_func(d: Dict[str, Any]):
if 'use_interrogator_names' in d.keys():
if "use_interrogator_names" in d.keys():
return d
for cfg in generalcfg_intterogator_names:
if d.get(cfg[0]):
use_interrogator_names.append(cfg[1])
d['use_interrogator_names'] = use_interrogator_names
d["use_interrogator_names"] = use_interrogator_names
return d
return read_config('general', GeneralConfig, CFG_GENERAL_DEFAULT, compat_func)
return read_config("general", GeneralConfig, CFG_GENERAL_DEFAULT, compat_func)
def read_filter_config():
d = config.read('filter')
d_p = d.get('positive') if d else None
d_n = d.get('negative') if d else None
d = config.read("filter")
d_p = d.get("positive") if d else None
d_n = d.get("negative") if d else None
cfg_p = CFG_FILTER_P_DEFAULT
cfg_n = CFG_FILTER_N_DEFAULT
if d_p:
d_p = cfg_p._asdict() | d_p
d_p = {k:v for k,v in d_p.items() if k in cfg_p._asdict().keys()}
d_p = {k: v for k, v in d_p.items() if k in cfg_p._asdict().keys()}
cfg_p = FilterConfig(**d_p)
if d_n:
d_n = cfg_n._asdict() | d_n
d_n = {k:v for k,v in d_n.items() if k in cfg_n._asdict().keys()}
d_n = {k: v for k, v in d_n.items() if k in cfg_n._asdict().keys()}
cfg_n = FilterConfig(**d_n)
return cfg_p, cfg_n
def read_batch_edit_config():
return read_config('batch_edit', BatchEditConfig, CFG_BATCH_EDIT_DEFAULT)
return read_config("batch_edit", BatchEditConfig, CFG_BATCH_EDIT_DEFAULT)
def read_edit_selected_config():
return read_config('edit_selected', EditSelectedConfig, CFG_EDIT_SELECTED_DEFAULT)
return read_config("edit_selected", EditSelectedConfig, CFG_EDIT_SELECTED_DEFAULT)
def read_move_delete_config():
return read_config('file_move_delete', MoveDeleteConfig, CFG_MOVE_DELETE_DEFAULT)
return read_config("file_move_delete", MoveDeleteConfig, CFG_MOVE_DELETE_DEFAULT)
# ================================================================
# General Callbacks for Updating UIs
# ================================================================
def get_filters():
filters = [ui.filter_by_tags.tag_filter_ui.get_filter(), ui.filter_by_tags.tag_filter_ui_neg.get_filter()] + [ui.filter_by_selection.path_filter]
filters = [
ui.filter_by_tags.tag_filter_ui.get_filter(),
ui.filter_by_tags.tag_filter_ui_neg.get_filter(),
] + [ui.filter_by_selection.path_filter]
return filters
def update_gallery():
img_indices = ui.dte_instance.get_filtered_imgindices(filters=get_filters())
total_image_num = len(ui.dte_instance.dataset)
displayed_image_num = len(img_indices)
ui.gallery_state.register_value('Displayed Images', f'{displayed_image_num} / {total_image_num} total')
ui.gallery_state.register_value('Current Tag Filter', f"{ui.filter_by_tags.tag_filter_ui.get_filter()} {' AND ' if ui.filter_by_tags.tag_filter_ui.get_filter().tags and ui.filter_by_tags.tag_filter_ui_neg.get_filter().tags else ''} {ui.filter_by_tags.tag_filter_ui_neg.get_filter()}")
ui.gallery_state.register_value('Current Selection Filter', f'{len(ui.filter_by_selection.path_filter.paths)} images')
ui.gallery_state.register_value(
"Displayed Images", f"{displayed_image_num} / {total_image_num} total"
)
ui.gallery_state.register_value(
"Current Tag Filter",
f"{ui.filter_by_tags.tag_filter_ui.get_filter()} {' AND ' if ui.filter_by_tags.tag_filter_ui.get_filter().tags and ui.filter_by_tags.tag_filter_ui_neg.get_filter().tags else ''} {ui.filter_by_tags.tag_filter_ui_neg.get_filter()}",
)
ui.gallery_state.register_value(
"Current Selection Filter",
f"{len(ui.filter_by_selection.path_filter.paths)} images",
)
return [
[str(i) for i in img_indices],
1,
-1,
-1,
-1,
ui.gallery_state.get_current_gallery_txt()
]
ui.gallery_state.get_current_gallery_txt(),
]
def update_filter_and_gallery():
return \
[ui.filter_by_tags.tag_filter_ui.cbg_tags_update(), ui.filter_by_tags.tag_filter_ui_neg.cbg_tags_update()] +\
update_gallery() +\
ui.batch_edit_captions.get_common_tags(get_filters, ui.filter_by_tags) +\
[', '.join(ui.filter_by_tags.tag_filter_ui.filter.tags)] +\
[ui.batch_edit_captions.tag_select_ui_remove.cbg_tags_update()] +\
['', '']
return (
[
ui.filter_by_tags.tag_filter_ui.cbg_tags_update(),
ui.filter_by_tags.tag_filter_ui_neg.cbg_tags_update(),
]
+ update_gallery()
+ ui.batch_edit_captions.get_common_tags(get_filters, ui.filter_by_tags)
+ [", ".join(ui.filter_by_tags.tag_filter_ui.filter.tags)]
+ [ui.batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
+ ["", ""]
)
# ================================================================
# Script Callbacks
# ================================================================
def on_ui_tabs():
config.load()
@ -201,101 +169,156 @@ def on_ui_tabs():
cfg_edit_selected = read_edit_selected_config()
cfg_file_move_delete = read_move_delete_config()
ui.dte_instance.load_interrogators()
with gr.Blocks(analytics_enabled=False) as dataset_tag_editor_interface:
gr.HTML(value="""
gr.HTML(
value="""
This extension works well with text captions in comma-separated style (such as the tags generated by DeepBooru interrogator).
""")
"""
)
ui.toprow.create_ui(cfg_general)
with gr.Accordion(label='Reload/Save Settings (config.json)', open=False):
with gr.Accordion(label="Reload/Save Settings (config.json)", open=False):
with gr.Row():
btn_reload_config_file = gr.Button(value='Reload settings')
btn_save_setting_as_default = gr.Button(value='Save current settings')
btn_restore_default = gr.Button(value='Restore settings to default')
btn_reload_config_file = gr.Button(value="Reload settings")
btn_save_setting_as_default = gr.Button(value="Save current settings")
btn_restore_default = gr.Button(value="Restore settings to default")
with gr.Row().style(equal_height=False):
with gr.Row(equal_height=False):
with gr.Column():
ui.load_dataset.create_ui(cfg_general)
ui.dataset_gallery.create_ui(opts.dataset_editor_image_columns)
ui.gallery_state.create_ui()
with gr.Tab(label='Filter by Tags'):
with gr.Tab(label="Filter by Tags"):
ui.filter_by_tags.create_ui(cfg_filter_p, cfg_filter_n, get_filters)
with gr.Tab(label='Filter by Selection'):
with gr.Tab(label="Filter by Selection"):
ui.filter_by_selection.create_ui(opts.dataset_editor_image_columns)
with gr.Tab(label='Batch Edit Captions'):
with gr.Tab(label="Batch Edit Captions"):
ui.batch_edit_captions.create_ui(cfg_batch_edit, get_filters)
with gr.Tab(label='Edit Caption of Selected Image'):
with gr.Tab(label="Edit Caption of Selected Image"):
ui.edit_caption_of_selected_image.create_ui(cfg_edit_selected)
with gr.Tab(label='Move or Delete Files'):
with gr.Tab(label="Move or Delete Files"):
ui.move_or_delete_files.create_ui(cfg_file_move_delete)
#----------------------------------------------------------------
# ----------------------------------------------------------------
# General
components_general = [
ui.toprow.cb_backup, ui.load_dataset.tb_img_directory, ui.load_dataset.tb_caption_file_ext, ui.load_dataset.cb_load_recursive,
ui.load_dataset.cb_load_caption_from_filename, ui.load_dataset.cb_replace_new_line_with_comma, ui.load_dataset.rb_use_interrogator, ui.load_dataset.dd_intterogator_names,
ui.load_dataset.cb_use_custom_threshold_booru, ui.load_dataset.sl_custom_threshold_booru, ui.load_dataset.cb_use_custom_threshold_waifu, ui.load_dataset.sl_custom_threshold_waifu,
ui.toprow.cb_save_kohya_metadata, ui.toprow.tb_metadata_output, ui.toprow.tb_metadata_input, ui.toprow.cb_metadata_overwrite, ui.toprow.cb_metadata_as_caption, ui.toprow.cb_metadata_use_fullpath
ui.toprow.cb_backup,
ui.load_dataset.tb_img_directory,
ui.load_dataset.tb_caption_file_ext,
ui.load_dataset.cb_load_recursive,
ui.load_dataset.cb_load_caption_from_filename,
ui.load_dataset.cb_replace_new_line_with_comma,
ui.load_dataset.rb_use_interrogator,
ui.load_dataset.dd_intterogator_names,
ui.load_dataset.cb_use_custom_threshold_booru,
ui.load_dataset.sl_custom_threshold_booru,
ui.load_dataset.cb_use_custom_threshold_waifu,
ui.load_dataset.sl_custom_threshold_waifu,
ui.load_dataset.sl_custom_threshold_z3d,
ui.toprow.cb_save_kohya_metadata,
ui.toprow.tb_metadata_output,
ui.toprow.tb_metadata_input,
ui.toprow.cb_metadata_overwrite,
ui.toprow.cb_metadata_as_caption,
ui.toprow.cb_metadata_use_fullpath,
]
components_filter = [
ui.filter_by_tags.tag_filter_ui.cb_prefix,
ui.filter_by_tags.tag_filter_ui.cb_suffix,
ui.filter_by_tags.tag_filter_ui.cb_regex,
ui.filter_by_tags.tag_filter_ui.rb_sort_by,
ui.filter_by_tags.tag_filter_ui.rb_sort_order,
ui.filter_by_tags.tag_filter_ui.rb_logic,
] + [
ui.filter_by_tags.tag_filter_ui_neg.cb_prefix,
ui.filter_by_tags.tag_filter_ui_neg.cb_suffix,
ui.filter_by_tags.tag_filter_ui_neg.cb_regex,
ui.filter_by_tags.tag_filter_ui_neg.rb_sort_by,
ui.filter_by_tags.tag_filter_ui_neg.rb_sort_order,
ui.filter_by_tags.tag_filter_ui_neg.rb_logic,
]
components_filter = \
[ui.filter_by_tags.tag_filter_ui.cb_prefix, ui.filter_by_tags.tag_filter_ui.cb_suffix, ui.filter_by_tags.tag_filter_ui.cb_regex, ui.filter_by_tags.tag_filter_ui.rb_sort_by, ui.filter_by_tags.tag_filter_ui.rb_sort_order, ui.filter_by_tags.tag_filter_ui.rb_logic] +\
[ui.filter_by_tags.tag_filter_ui_neg.cb_prefix, ui.filter_by_tags.tag_filter_ui_neg.cb_suffix, ui.filter_by_tags.tag_filter_ui_neg.cb_regex, ui.filter_by_tags.tag_filter_ui_neg.rb_sort_by, ui.filter_by_tags.tag_filter_ui_neg.rb_sort_order, ui.filter_by_tags.tag_filter_ui_neg.rb_logic]
components_batch_edit = [
ui.batch_edit_captions.cb_show_only_tags_selected, ui.batch_edit_captions.cb_prepend_tags, ui.batch_edit_captions.cb_use_regex,
ui.batch_edit_captions.cb_show_only_tags_selected,
ui.batch_edit_captions.cb_prepend_tags,
ui.batch_edit_captions.cb_use_regex,
ui.batch_edit_captions.rb_sr_replace_target,
ui.batch_edit_captions.tag_select_ui_remove.cb_prefix, ui.batch_edit_captions.tag_select_ui_remove.cb_suffix, ui.batch_edit_captions.tag_select_ui_remove.cb_regex,
ui.batch_edit_captions.tag_select_ui_remove.rb_sort_by, ui.batch_edit_captions.tag_select_ui_remove.rb_sort_order,
ui.batch_edit_captions.rb_sort_by, ui.batch_edit_captions.rb_sort_order,
ui.batch_edit_captions.nb_token_count
ui.batch_edit_captions.tag_select_ui_remove.cb_prefix,
ui.batch_edit_captions.tag_select_ui_remove.cb_suffix,
ui.batch_edit_captions.tag_select_ui_remove.cb_regex,
ui.batch_edit_captions.tag_select_ui_remove.rb_sort_by,
ui.batch_edit_captions.tag_select_ui_remove.rb_sort_order,
ui.batch_edit_captions.rb_sort_by,
ui.batch_edit_captions.rb_sort_order,
ui.batch_edit_captions.nb_token_count,
]
components_edit_selected = [
ui.edit_caption_of_selected_image.cb_copy_caption_automatically, ui.edit_caption_of_selected_image.cb_sort_caption_on_save,
ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed, ui.edit_caption_of_selected_image.dd_intterogator_names_si,
ui.edit_caption_of_selected_image.rb_sort_by, ui.edit_caption_of_selected_image.rb_sort_order
ui.edit_caption_of_selected_image.cb_copy_caption_automatically,
ui.edit_caption_of_selected_image.cb_sort_caption_on_save,
ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed,
ui.edit_caption_of_selected_image.dd_intterogator_names_si,
ui.edit_caption_of_selected_image.rb_sort_by,
ui.edit_caption_of_selected_image.rb_sort_order,
]
components_move_delete = [
ui.move_or_delete_files.rb_move_or_delete_target_data, ui.move_or_delete_files.cbg_move_or_delete_target_file,
ui.move_or_delete_files.tb_move_or_delete_caption_ext, ui.move_or_delete_files.tb_move_or_delete_destination_dir
ui.move_or_delete_files.rb_move_or_delete_target_data,
ui.move_or_delete_files.cbg_move_or_delete_target_file,
ui.move_or_delete_files.tb_move_or_delete_caption_ext,
ui.move_or_delete_files.tb_move_or_delete_destination_dir,
]
configurable_components = components_general + components_filter + components_batch_edit + components_edit_selected + components_move_delete
configurable_components = (
components_general
+ components_filter
+ components_batch_edit
+ components_edit_selected
+ components_move_delete
)
def reload_config_file():
config.load()
p, n = read_filter_config()
print('[tag-editor] Reload config.json')
return read_general_config() + p + n + read_batch_edit_config() + read_edit_selected_config() + read_move_delete_config()
logger.write("Reload config.json")
return (
read_general_config()
+ p
+ n
+ read_batch_edit_config()
+ read_edit_selected_config()
+ read_move_delete_config()
)
btn_reload_config_file.click(
fn=reload_config_file,
outputs=configurable_components
fn=reload_config_file, outputs=configurable_components
)
def save_settings_callback(*a):
p = 0
def inc(v):
nonlocal p
p += v
return p
write_general_config(*a[p:inc(len(components_general))])
write_filter_config(*a[p:inc(len(components_filter))])
write_batch_edit_config(*a[p:inc(len(components_batch_edit))])
write_edit_selected_config(*a[p:inc(len(components_edit_selected))])
write_general_config(*a[p : inc(len(components_general))])
write_filter_config(*a[p : inc(len(components_filter))])
write_batch_edit_config(*a[p : inc(len(components_batch_edit))])
write_edit_selected_config(*a[p : inc(len(components_edit_selected))])
write_move_delete_config(*a[p:])
config.save()
print('[tag-editor] Current settings have been saved into config.json')
logger.write("Current settings have been saved into config.json")
btn_save_setting_as_default.click(
fn=save_settings_callback,
inputs=configurable_components
fn=save_settings_callback, inputs=configurable_components
)
def restore_default_settings():
@ -304,44 +327,150 @@ def on_ui_tabs():
write_batch_edit_config(*CFG_BATCH_EDIT_DEFAULT)
write_edit_selected_config(*CFG_EDIT_SELECTED_DEFAULT)
write_move_delete_config(*CFG_MOVE_DELETE_DEFAULT)
print('[tag-editor] Restore default settings')
return CFG_GENERAL_DEFAULT + CFG_FILTER_P_DEFAULT + CFG_FILTER_N_DEFAULT + CFG_BATCH_EDIT_DEFAULT + CFG_EDIT_SELECTED_DEFAULT + CFG_MOVE_DELETE_DEFAULT
logger.write("Restore default settings")
return (
CFG_GENERAL_DEFAULT
+ CFG_FILTER_P_DEFAULT
+ CFG_FILTER_N_DEFAULT
+ CFG_BATCH_EDIT_DEFAULT
+ CFG_EDIT_SELECTED_DEFAULT
+ CFG_MOVE_DELETE_DEFAULT
)
btn_restore_default.click(
fn=restore_default_settings,
outputs=configurable_components
fn=restore_default_settings, outputs=configurable_components
)
o_update_gallery = [ui.dataset_gallery.cbg_hidden_dataset_filter, ui.dataset_gallery.nb_hidden_dataset_filter_apply, ui.dataset_gallery.nb_hidden_image_index, ui.dataset_gallery.nb_hidden_image_index_prev, ui.edit_caption_of_selected_image.nb_hidden_image_index_save_or_not, ui.gallery_state.txt_gallery]
o_update_gallery = [
ui.dataset_gallery.cbg_hidden_dataset_filter,
ui.dataset_gallery.nb_hidden_dataset_filter_apply,
ui.dataset_gallery.nb_hidden_image_index,
ui.dataset_gallery.nb_hidden_image_index_prev,
ui.edit_caption_of_selected_image.nb_hidden_image_index_save_or_not,
ui.gallery_state.txt_gallery,
]
o_update_filter_and_gallery = (
[
ui.filter_by_tags.tag_filter_ui.cbg_tags,
ui.filter_by_tags.tag_filter_ui_neg.cbg_tags,
]
+ o_update_gallery
+ [
ui.batch_edit_captions.tb_common_tags,
ui.batch_edit_captions.tb_edit_tags,
]
+ [ui.batch_edit_captions.tb_sr_selected_tags]
+ [ui.batch_edit_captions.tag_select_ui_remove.cbg_tags]
+ [
ui.edit_caption_of_selected_image.tb_caption,
ui.edit_caption_of_selected_image.tb_edit_caption,
]
)
o_update_filter_and_gallery = \
[ui.filter_by_tags.tag_filter_ui.cbg_tags, ui.filter_by_tags.tag_filter_ui_neg.cbg_tags] + \
o_update_gallery + \
[ui.batch_edit_captions.tb_common_tags, ui.batch_edit_captions.tb_edit_tags] + \
[ui.batch_edit_captions.tb_sr_selected_tags] +\
[ui.batch_edit_captions.tag_select_ui_remove.cbg_tags] +\
[ui.edit_caption_of_selected_image.tb_caption, ui.edit_caption_of_selected_image.tb_edit_caption]
ui.toprow.set_callbacks(ui.load_dataset)
ui.load_dataset.set_callbacks(o_update_filter_and_gallery,ui.toprow, ui.dataset_gallery, ui.filter_by_tags, ui.filter_by_selection, ui.batch_edit_captions, update_filter_and_gallery)
ui.load_dataset.set_callbacks(
o_update_filter_and_gallery,
ui.toprow,
ui.dataset_gallery,
ui.filter_by_tags,
ui.filter_by_selection,
ui.batch_edit_captions,
update_filter_and_gallery,
)
ui.dataset_gallery.set_callbacks(ui.load_dataset, ui.gallery_state, get_filters)
ui.gallery_state.set_callbacks(ui.dataset_gallery)
ui.filter_by_tags.set_callbacks(o_update_gallery, o_update_filter_and_gallery, ui.batch_edit_captions, ui.move_or_delete_files, update_gallery, update_filter_and_gallery, get_filters)
ui.filter_by_selection.set_callbacks(o_update_filter_and_gallery, ui.dataset_gallery, ui.filter_by_tags, get_filters, update_filter_and_gallery)
ui.batch_edit_captions.set_callbacks(o_update_filter_and_gallery, ui.load_dataset, ui.filter_by_tags, get_filters, update_filter_and_gallery)
ui.edit_caption_of_selected_image.set_callbacks(o_update_filter_and_gallery, ui.dataset_gallery, ui.load_dataset, get_filters, update_filter_and_gallery)
ui.move_or_delete_files.set_callbacks(o_update_filter_and_gallery, ui.dataset_gallery, ui.filter_by_tags, ui.batch_edit_captions, ui.filter_by_selection, ui.edit_caption_of_selected_image, get_filters, update_filter_and_gallery)
return [(dataset_tag_editor_interface, "Dataset Tag Editor", "dataset_tag_editor_interface")]
ui.filter_by_tags.set_callbacks(
o_update_gallery,
o_update_filter_and_gallery,
ui.batch_edit_captions,
ui.move_or_delete_files,
update_gallery,
update_filter_and_gallery,
get_filters,
)
ui.filter_by_selection.set_callbacks(
o_update_filter_and_gallery,
ui.dataset_gallery,
ui.filter_by_tags,
get_filters,
update_filter_and_gallery,
)
ui.batch_edit_captions.set_callbacks(
o_update_filter_and_gallery,
ui.load_dataset,
ui.filter_by_tags,
get_filters,
update_filter_and_gallery,
)
ui.edit_caption_of_selected_image.set_callbacks(
o_update_filter_and_gallery,
ui.dataset_gallery,
ui.load_dataset,
get_filters,
update_filter_and_gallery,
)
ui.move_or_delete_files.set_callbacks(
o_update_filter_and_gallery,
ui.dataset_gallery,
ui.filter_by_tags,
ui.batch_edit_captions,
ui.filter_by_selection,
ui.edit_caption_of_selected_image,
get_filters,
update_filter_and_gallery,
)
return [
(
dataset_tag_editor_interface,
"Dataset Tag Editor",
"dataset_tag_editor_interface",
)
]
def on_ui_settings():
section = ('dataset-tag-editor', "Dataset Tag Editor")
shared.opts.add_option("dataset_editor_image_columns", shared.OptionInfo(6, "Number of columns on image gallery", section=section))
shared.opts.add_option("dataset_editor_max_res", shared.OptionInfo(0, "Max resolution of temporary files", section=section))
shared.opts.add_option("dataset_editor_use_temp_files", shared.OptionInfo(False, "Force image gallery to use temporary files", section=section))
shared.opts.add_option("dataset_editor_use_raw_clip_token", shared.OptionInfo(True, "Use raw CLIP token to calculate token count (without emphasis or embeddings)", section=section))
section = ("dataset-tag-editor", "Dataset Tag Editor")
shared.opts.add_option(
"dataset_editor_image_columns",
shared.OptionInfo(6, "Number of columns on image gallery", section=section),
)
shared.opts.add_option(
"dataset_editor_max_res",
shared.OptionInfo(0, "Max resolution of temporary files", section=section),
)
shared.opts.add_option(
"dataset_editor_use_temp_files",
shared.OptionInfo(
False, "Force image gallery to use temporary files", section=section
),
)
shared.opts.add_option(
"dataset_editor_use_raw_clip_token",
shared.OptionInfo(
True,
"Use raw CLIP token to calculate token count (without emphasis or embeddings)",
section=section,
),
)
shared.opts.add_option(
"dataset_editor_use_rating",
shared.OptionInfo(
False,
"Use rating tags",
section=section,
),
)
shared.opts.add_option(
"dataset_editor_num_cpu_workers",
shared.OptionInfo(
-1,
"Number of CPU workers when preprocessing images (set -1 to auto)",
section=section,
),
)
script_callbacks.on_ui_settings(on_ui_settings)

16
scripts/model_loader.py Normal file
View File

@ -0,0 +1,16 @@
from pathlib import Path
from torch.hub import download_url_to_file
def load(model_path:Path, model_url:str, progress:bool=True, force_download:bool=False):
model_path = Path(model_path)
if model_path.exists():
return model_path
if model_url is not None and (force_download or not model_path.is_file()):
if not model_path.parent.is_dir():
model_path.parent.mkdir(parents=True)
download_url_to_file(model_url, model_path, progress=progress)
return model_path
return model_path

17
scripts/paths.py Normal file
View File

@ -0,0 +1,17 @@
from pathlib import Path
from scripts.singleton import Singleton
def base_dir_path():
return Path(__file__).parents[1].absolute()
def base_dir():
return str(base_dir_path())
class Paths(Singleton):
def __init__(self):
self.base_path:Path = base_dir_path()
self.script_path: Path = self.base_path / "scripts"
self.userscript_path: Path = self.base_path / "userscripts"
paths = Paths()

View File

@ -1,6 +1,6 @@
class Singleton(object):
@classmethod
def get_instance(cls):
if not hasattr(cls, "_instance"):
cls._instance = cls()
return cls._instance
_instance = None
def __new__(class_, *args, **kwargs):
if not isinstance(class_._instance, class_):
class_._instance = object.__new__(class_, *args, **kwargs)
return class_._instance

View File

@ -22,7 +22,7 @@ class DatasetGalleryUI(UIBase):
self.btn_hidden_set_index = gr.Button(elem_id="dataset_tag_editor_btn_hidden_set_index")
self.nb_hidden_image_index = gr.Number(value=None, label='hidden_idx_next')
self.nb_hidden_image_index_prev = gr.Number(value=None, label='hidden_idx_prev')
self.gl_dataset_images = gr.Gallery(label='Dataset Images', elem_id="dataset_tag_editor_dataset_gallery").style(grid=image_columns)
self.gl_dataset_images = gr.Gallery(label='Dataset Images', elem_id="dataset_tag_editor_dataset_gallery", columns=image_columns)
def set_callbacks(self, load_dataset:LoadDatasetUI, gallery_state:GalleryStateUI, get_filters:Callable[[], dte_module.filters.Filter]):
gallery_state.register_value('Selected Image', self.selected_path)

View File

@ -28,7 +28,7 @@ class GalleryStateUI(UIBase):
self.txt_gallery = gr.HTML(value=self.get_current_gallery_txt())
def set_callbacks(self, dataset_gallery:DatasetGalleryUI):
dataset_gallery.nb_hidden_image_index.change(
dataset_gallery.nb_hidden_image_index.change(fn=lambda:None).then(
fn=self.update_gallery_txt,
inputs=None,
outputs=self.txt_gallery

View File

@ -11,43 +11,113 @@ from .uibase import UIBase
if TYPE_CHECKING:
from .ui_classes import *
INTERROGATOR_NAMES = dte_module.INTERROGATOR_NAMES
InterrogateMethod = dte_instance.InterrogateMethod
class LoadDatasetUI(UIBase):
def __init__(self):
self.caption_file_ext = ''
self.caption_file_ext = ""
def create_ui(self, cfg_general):
with gr.Column(variant='panel'):
with gr.Column(variant="panel"):
with gr.Row():
with gr.Column(scale=3):
self.tb_img_directory = gr.Textbox(label='Dataset directory', placeholder='C:\\directory\\of\\datasets', value=cfg_general.dataset_dir)
self.tb_img_directory = gr.Textbox(
label="Dataset directory",
placeholder="C:\\directory\\of\\datasets",
value=cfg_general.dataset_dir,
)
with gr.Column(scale=1, min_width=60):
self.tb_caption_file_ext = gr.Textbox(label='Caption File Ext', placeholder='.txt (on Load and Save)', value=cfg_general.caption_ext)
self.tb_caption_file_ext = gr.Textbox(
label="Caption File Ext",
placeholder=".txt (on Load and Save)",
value=cfg_general.caption_ext,
)
self.caption_file_ext = cfg_general.caption_ext
with gr.Column(scale=1, min_width=80):
self.btn_load_datasets = gr.Button(value='Load')
self.btn_unload_datasets = gr.Button(value='Unload')
with gr.Accordion(label='Dataset Load Settings'):
self.btn_load_datasets = gr.Button(value="Load")
self.btn_unload_datasets = gr.Button(value="Unload")
with gr.Accordion(label="Dataset Load Settings"):
with gr.Row():
with gr.Column():
self.cb_load_recursive = gr.Checkbox(value=cfg_general.load_recursive, label='Load from subdirectories')
self.cb_load_caption_from_filename = gr.Checkbox(value=cfg_general.load_caption_from_filename, label='Load caption from filename if no text file exists')
self.cb_replace_new_line_with_comma = gr.Checkbox(value=cfg_general.replace_new_line, label='Replace new-line character with comma')
self.cb_load_recursive = gr.Checkbox(
value=cfg_general.load_recursive,
label="Load from subdirectories",
)
self.cb_load_caption_from_filename = gr.Checkbox(
value=cfg_general.load_caption_from_filename,
label="Load caption from filename if no text file exists",
)
self.cb_replace_new_line_with_comma = gr.Checkbox(
value=cfg_general.replace_new_line,
label="Replace new-line character with comma",
)
with gr.Column():
self.rb_use_interrogator = gr.Radio(choices=['No', 'If Empty', 'Overwrite', 'Prepend', 'Append'], value=cfg_general.use_interrogator, label='Use Interrogator Caption')
self.dd_intterogator_names = gr.Dropdown(label = 'Interrogators', choices=INTERROGATOR_NAMES, value=cfg_general.use_interrogator_names, interactive=True, multiselect=True)
with gr.Accordion(label='Interrogator Settings', open=False):
self.rb_use_interrogator = gr.Radio(
choices=[
"No",
"If Empty",
"Overwrite",
"Prepend",
"Append",
],
value=cfg_general.use_interrogator,
label="Use Interrogator Caption",
)
self.dd_intterogator_names = gr.Dropdown(
label="Interrogators",
choices=dte_instance.INTERROGATOR_NAMES,
value=cfg_general.use_interrogator_names,
interactive=True,
multiselect=True,
)
with gr.Accordion(label="Interrogator Settings", open=False):
with gr.Row():
self.cb_use_custom_threshold_booru = gr.Checkbox(value=cfg_general.use_custom_threshold_booru, label='Use Custom Threshold (Booru)', interactive=True)
self.sl_custom_threshold_booru = gr.Slider(minimum=0, maximum=1, value=cfg_general.custom_threshold_booru, step=0.01, interactive=True, label='Booru Score Threshold')
self.cb_use_custom_threshold_booru = gr.Checkbox(
value=cfg_general.use_custom_threshold_booru,
label="Use Custom Threshold (Booru)",
interactive=True,
)
self.sl_custom_threshold_booru = gr.Slider(
minimum=0,
maximum=1,
value=cfg_general.custom_threshold_booru,
step=0.01,
interactive=True,
label="Booru Score Threshold",
)
with gr.Row():
self.cb_use_custom_threshold_waifu = gr.Checkbox(value=cfg_general.use_custom_threshold_waifu, label='Use Custom Threshold (WDv1.4 Tagger)', interactive=True)
self.sl_custom_threshold_waifu = gr.Slider(minimum=0, maximum=1, value=cfg_general.custom_threshold_waifu, step=0.01, interactive=True, label='WDv1.4 Tagger Score Threshold')
def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], toprow:ToprowUI, dataset_gallery:DatasetGalleryUI, filter_by_tags:FilterByTagsUI, filter_by_selection:FilterBySelectionUI, batch_edit_captions:BatchEditCaptionsUI, update_filter_and_gallery:Callable[[], List]):
self.sl_custom_threshold_z3d = gr.Slider(
minimum=0,
maximum=1,
value=cfg_general.custom_threshold_z3d,
step=0.01,
interactive=True,
label="Z3D-E621 Score Threshold",
)
with gr.Row():
self.cb_use_custom_threshold_waifu = gr.Checkbox(
value=cfg_general.use_custom_threshold_waifu,
label="Use Custom Threshold (WDv1.4 Tagger)",
interactive=True,
)
self.sl_custom_threshold_waifu = gr.Slider(
minimum=0,
maximum=1,
value=cfg_general.custom_threshold_waifu,
step=0.01,
interactive=True,
label="WDv1.4 Tagger Score Threshold",
)
def set_callbacks(
self,
o_update_filter_and_gallery: List[gr.components.Component],
toprow: ToprowUI,
dataset_gallery: DatasetGalleryUI,
filter_by_tags: FilterByTagsUI,
filter_by_selection: FilterBySelectionUI,
batch_edit_captions: BatchEditCaptionsUI,
update_filter_and_gallery: Callable[[], List],
):
def load_files_from_dir(
dir: str,
caption_file_ext: str,
@ -55,63 +125,112 @@ class LoadDatasetUI(UIBase):
load_caption_from_filename: bool,
replace_new_line: bool,
use_interrogator: str,
use_interrogator_names, #: List[str], : to avoid error on gradio v3.23.0
use_interrogator_names, #: List[str], : to avoid error on gradio v3.23.0
use_custom_threshold_booru: bool,
custom_threshold_booru: float,
use_custom_threshold_waifu: bool,
custom_threshold_waifu: float,
custom_threshold_z3d: float,
use_kohya_metadata: bool,
kohya_json_path: str
):
interrogate_method = InterrogateMethod.NONE
if use_interrogator == 'If Empty':
interrogate_method = InterrogateMethod.PREFILL
elif use_interrogator == 'Overwrite':
interrogate_method = InterrogateMethod.OVERWRITE
elif use_interrogator == 'Prepend':
interrogate_method = InterrogateMethod.PREPEND
elif use_interrogator == 'Append':
interrogate_method = InterrogateMethod.APPEND
kohya_json_path: str,
):
threshold_booru = custom_threshold_booru if use_custom_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold
threshold_waifu = custom_threshold_waifu if use_custom_threshold_waifu else -1
interrogate_method = dte_instance.InterrogateMethod.NONE
if use_interrogator == "If Empty":
interrogate_method = dte_instance.InterrogateMethod.PREFILL
elif use_interrogator == "Overwrite":
interrogate_method = dte_instance.InterrogateMethod.OVERWRITE
elif use_interrogator == "Prepend":
interrogate_method = dte_instance.InterrogateMethod.PREPEND
elif use_interrogator == "Append":
interrogate_method = dte_instance.InterrogateMethod.APPEND
dte_instance.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, replace_new_line, interrogate_method, use_interrogator_names, threshold_booru, threshold_waifu, opts.dataset_editor_use_temp_files, kohya_json_path if use_kohya_metadata else None, opts.dataset_editor_max_res)
threshold_booru = (
custom_threshold_booru
if use_custom_threshold_booru
else opts.interrogate_deepbooru_score_threshold
)
threshold_waifu = (
custom_threshold_waifu if use_custom_threshold_waifu else -1
)
threshold_z3d = custom_threshold_z3d
dte_instance.load_dataset(
dir,
caption_file_ext,
recursive,
load_caption_from_filename,
replace_new_line,
interrogate_method,
use_interrogator_names,
threshold_booru,
threshold_waifu,
threshold_z3d,
opts.dataset_editor_use_temp_files,
kohya_json_path if use_kohya_metadata else None,
opts.dataset_editor_max_res,
)
imgs = dte_instance.get_filtered_imgs(filters=[])
img_indices = dte_instance.get_filtered_imgindices(filters=[])
return [
imgs,
[]
] +\
[gr.CheckboxGroup.update(value=[str(i) for i in img_indices], choices=[str(i) for i in img_indices]), 1] +\
filter_by_tags.clear_filters(update_filter_and_gallery) +\
[batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
return (
[imgs, []]
+ [
gr.CheckboxGroup.update(
value=[str(i) for i in img_indices],
choices=[str(i) for i in img_indices],
),
1,
]
+ filter_by_tags.clear_filters(update_filter_and_gallery)
+ [batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
)
self.btn_load_datasets.click(
fn=load_files_from_dir,
inputs=[self.tb_img_directory, self.tb_caption_file_ext, self.cb_load_recursive, self.cb_load_caption_from_filename, self.cb_replace_new_line_with_comma, self.rb_use_interrogator, self.dd_intterogator_names, self.cb_use_custom_threshold_booru, self.sl_custom_threshold_booru, self.cb_use_custom_threshold_waifu, self.sl_custom_threshold_waifu, toprow.cb_save_kohya_metadata, toprow.tb_metadata_output],
outputs=
[dataset_gallery.gl_dataset_images, filter_by_selection.gl_filter_images] +
[dataset_gallery.cbg_hidden_dataset_filter, dataset_gallery.nb_hidden_dataset_filter_apply] +
o_update_filter_and_gallery
inputs=[
self.tb_img_directory,
self.tb_caption_file_ext,
self.cb_load_recursive,
self.cb_load_caption_from_filename,
self.cb_replace_new_line_with_comma,
self.rb_use_interrogator,
self.dd_intterogator_names,
self.cb_use_custom_threshold_booru,
self.sl_custom_threshold_booru,
self.cb_use_custom_threshold_waifu,
self.sl_custom_threshold_waifu,
toprow.cb_save_kohya_metadata,
toprow.tb_metadata_output,
],
outputs=[
dataset_gallery.gl_dataset_images,
filter_by_selection.gl_filter_images,
]
+ [
dataset_gallery.cbg_hidden_dataset_filter,
dataset_gallery.nb_hidden_dataset_filter_apply,
]
+ o_update_filter_and_gallery,
)
def unload_files():
dte_instance.clear()
return [
[],
[]
] +\
[gr.CheckboxGroup.update(value=[], choices=[]), 1] +\
filter_by_tags.clear_filters(update_filter_and_gallery) +\
[batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
return (
[[], []]
+ [gr.CheckboxGroup.update(value=[], choices=[]), 1]
+ filter_by_tags.clear_filters(update_filter_and_gallery)
+ [batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
)
self.btn_unload_datasets.click(
fn=unload_files,
outputs=
[dataset_gallery.gl_dataset_images, filter_by_selection.gl_filter_images] +
[dataset_gallery.cbg_hidden_dataset_filter, dataset_gallery.nb_hidden_dataset_filter_apply] +
o_update_filter_and_gallery
outputs=[
dataset_gallery.gl_dataset_images,
filter_by_selection.gl_filter_images,
]
+ [
dataset_gallery.cbg_hidden_dataset_filter,
dataset_gallery.nb_hidden_dataset_filter_apply,
]
+ o_update_filter_and_gallery,
)

View File

@ -66,10 +66,10 @@ class TagFilterUI():
self.rb_logic.change(fn=self.rd_logic_changed, inputs=[self.rb_logic], outputs=[self.cbg_tags])
for fn, inputs, outputs, _js in self.on_filter_update_callbacks:
self.rb_logic.change(fn=fn, inputs=inputs, outputs=outputs, _js=_js)
self.rb_logic.change(fn=lambda:None).then(fn=fn, inputs=inputs, outputs=outputs, _js=_js)
self.cbg_tags.change(fn=self.cbg_tags_changed, inputs=[self.cbg_tags], outputs=[self.cbg_tags])
for fn, inputs, outputs, _js in self.on_filter_update_callbacks:
self.cbg_tags.change(fn=fn, inputs=inputs, outputs=outputs, _js=_js)
self.cbg_tags.change(fn=lambda:None).then(fn=fn, inputs=inputs, outputs=outputs, _js=_js)
def tb_search_tags_changed(self, tb_search_tags: str):

View File

@ -95,8 +95,7 @@ class BatchEditCaptionsUI(UIBase):
fn=apply_edit_tags,
inputs=[self.tb_common_tags, self.tb_edit_tags, self.cb_prepend_tags],
outputs=o_update_filter_and_gallery
)
self.btn_apply_edit_tags.click(
).then(
fn=None,
_js='() => dataset_tag_editor_gl_dataset_images_close()'
)
@ -124,8 +123,7 @@ class BatchEditCaptionsUI(UIBase):
fn=search_and_replace,
inputs=[self.tb_sr_search_tags, self.tb_sr_replace_tags, self.rb_sr_replace_target, self.cb_use_regex],
outputs=o_update_filter_and_gallery
)
self.btn_apply_sr_tags.click(
).then(
fn=None,
_js='() => dataset_tag_editor_gl_dataset_images_close()'
)

View File

@ -35,7 +35,7 @@ class EditCaptionOfSelectedImageUI(UIBase):
with gr.Tab(label='Interrogate Selected Image'):
with gr.Row():
self.dd_intterogator_names_si = gr.Dropdown(label = 'Interrogator', choices=dte_module.INTERROGATOR_NAMES, value=cfg_edit_selected.use_interrogator_name, interactive=True, multiselect=False)
self.dd_intterogator_names_si = gr.Dropdown(label = 'Interrogator', choices=dte_instance.INTERROGATOR_NAMES, value=cfg_edit_selected.use_interrogator_name, interactive=True, multiselect=False)
self.btn_interrogate_si = gr.Button(value='Interrogate')
with gr.Column():
self.tb_interrogate = gr.Textbox(label='Interrogate Result', interactive=True, lines=6, elem_id='dte_interrogate')
@ -89,7 +89,7 @@ class EditCaptionOfSelectedImageUI(UIBase):
_js='(a) => dataset_tag_editor_ask_save_change_or_not(a)',
inputs=self.nb_hidden_image_index_save_or_not
)
dataset_gallery.nb_hidden_image_index.change(
dataset_gallery.nb_hidden_image_index.change(lambda:None).then(
fn=gallery_index_changed,
inputs=[dataset_gallery.nb_hidden_image_index, dataset_gallery.nb_hidden_image_index_prev, self.tb_edit_caption, self.cb_copy_caption_automatically, self.cb_ask_save_when_caption_changed],
outputs=[self.nb_hidden_image_index_save_or_not] + [self.tb_caption, self.tb_edit_caption] + [self.tb_hidden_edit_caption]
@ -138,7 +138,7 @@ class EditCaptionOfSelectedImageUI(UIBase):
return ''
threshold_booru = threshold_booru if use_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold
threshold_waifu = threshold_waifu if use_threshold_waifu else -1
return dte_module.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_waifu)
return dte_instance.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_waifu)
self.btn_interrogate_si.click(
fn=interrogate_selected_image,

View File

@ -32,7 +32,7 @@ class FilterBySelectionUI(UIBase):
self.btn_add_image_selection = gr.Button(value='Add selection [Enter]', elem_id='dataset_tag_editor_btn_add_image_selection')
self.btn_add_all_displayed_image_selection = gr.Button(value='Add ALL Displayed')
self.gl_filter_images = gr.Gallery(label='Filter Images', elem_id="dataset_tag_editor_filter_gallery").style(grid=image_columns)
self.gl_filter_images = gr.Gallery(label='Filter Images', elem_id="dataset_tag_editor_filter_gallery", columns=image_columns)
self.txt_selection = gr.HTML(value=self.get_current_txt_selection())
with gr.Row():
@ -130,7 +130,7 @@ class FilterBySelectionUI(UIBase):
self.path_filter = filters.PathFilter()
return clear_image_selection() + update_filter_and_gallery()
filter_by_tags.btn_clear_all_filters.click(
filter_by_tags.btn_clear_all_filters.click(lambda:None).then(
fn=clear_image_filter,
outputs=
[self.gl_filter_images, self.txt_selection, self.nb_hidden_selection_image_index] +
@ -147,8 +147,7 @@ class FilterBySelectionUI(UIBase):
self.btn_apply_image_selection_filter.click(
fn=apply_image_selection_filter,
outputs=o_update_filter_and_gallery
)
self.btn_apply_image_selection_filter.click(
).then(
fn=None,
_js='() => dataset_tag_editor_gl_dataset_images_close()'
)

View File

@ -47,17 +47,17 @@ class MoveOrDeleteFilesUI(UIBase):
'outputs' : [self.ta_move_or_delete_target_dataset_num]
}
batch_edit_captions.btn_apply_edit_tags.click(**update_args)
batch_edit_captions.btn_apply_edit_tags.click(lambda:None).then(**update_args)
batch_edit_captions.btn_apply_sr_tags.click(**update_args)
batch_edit_captions.btn_apply_sr_tags.click(lambda:None).then(**update_args)
filter_by_selection.btn_apply_image_selection_filter.click(**update_args)
filter_by_selection.btn_apply_image_selection_filter.click(lambda:None).then(**update_args)
filter_by_tags.btn_clear_tag_filters.click(**update_args)
filter_by_tags.btn_clear_tag_filters.click(lambda:None).then(**update_args)
filter_by_tags.btn_clear_all_filters.click(**update_args)
filter_by_tags.btn_clear_all_filters.click(lambda:None).then(**update_args)
edit_caption_of_selected_image.btn_apply_changes_selected_image.click(**update_args)
edit_caption_of_selected_image.btn_apply_changes_selected_image.click(lambda:None).then(**update_args)
self.rb_move_or_delete_target_data.change(**update_args)
@ -84,9 +84,7 @@ class MoveOrDeleteFilesUI(UIBase):
fn=move_files,
inputs=[self.rb_move_or_delete_target_data, self.cbg_move_or_delete_target_file, self.tb_move_or_delete_caption_ext, self.tb_move_or_delete_destination_dir],
outputs=o_update_filter_and_gallery
)
self.btn_move_or_delete_move_files.click(**update_args)
self.btn_move_or_delete_move_files.click(
).then(**update_args).then(
fn=None,
_js='() => dataset_tag_editor_gl_dataset_images_close()'
)
@ -114,8 +112,7 @@ class MoveOrDeleteFilesUI(UIBase):
inputs=[self.rb_move_or_delete_target_data, self.cbg_move_or_delete_target_file, self.tb_move_or_delete_caption_ext],
outputs=o_update_filter_and_gallery
)
self.btn_move_or_delete_delete_files.click(**update_args)
self.btn_move_or_delete_delete_files.click(
self.btn_move_or_delete_delete_files.click(**update_args).then(
fn=None,
_js='() => dataset_tag_editor_gl_dataset_images_close()'
)

View File

@ -4,12 +4,12 @@ __all__ = [
'toprow', 'load_dataset', 'dataset_gallery', 'gallery_state', 'filter_by_tags', 'filter_by_selection', 'batch_edit_captions', 'edit_caption_of_selected_image', 'move_or_delete_files'
]
toprow = ToprowUI.get_instance()
load_dataset = LoadDatasetUI.get_instance()
dataset_gallery = DatasetGalleryUI.get_instance()
gallery_state = GalleryStateUI.get_instance()
filter_by_tags = FilterByTagsUI.get_instance()
filter_by_selection = FilterBySelectionUI.get_instance()
batch_edit_captions = BatchEditCaptionsUI.get_instance()
edit_caption_of_selected_image = EditCaptionOfSelectedImageUI.get_instance()
move_or_delete_files = MoveOrDeleteFilesUI.get_instance()
toprow = ToprowUI()
load_dataset = LoadDatasetUI()
dataset_gallery = DatasetGalleryUI()
gallery_state = GalleryStateUI()
filter_by_tags = FilterByTagsUI()
filter_by_selection = FilterBySelectionUI()
batch_edit_captions = BatchEditCaptionsUI()
edit_caption_of_selected_image = EditCaptionOfSelectedImageUI()
move_or_delete_files = MoveOrDeleteFilesUI()

52
scripts/tagger.py Normal file
View File

@ -0,0 +1,52 @@
import re
from typing import Optional, Generator, Any
from PIL import Image
from modules import shared, lowvram, devices
from modules import deepbooru as db
# Custom tagger classes have to inherit from this class
class Tagger:
def __enter__(self):
lowvram.send_everything_to_cpu()
devices.torch_gc()
self.start()
return self
def __exit__(self, exception_type, exception_value, traceback):
self.stop()
pass
def start(self):
pass
def stop(self):
pass
# predict tags of one image
def predict(self, image: Image.Image, threshold: Optional[float] = None) -> list[str]:
raise NotImplementedError()
# Please implement if you want to use more efficient data loading system
# None input will come to check if this function is implemented
def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None) -> Generator[list[str], Any, None]:
raise NotImplementedError()
# Visible name in UI
def name(self):
raise NotImplementedError()
def get_replaced_tag(tag: str):
use_spaces = shared.opts.deepbooru_use_spaces
use_escape = shared.opts.deepbooru_escape
if use_spaces:
tag = tag.replace('_', ' ')
if use_escape:
tag = re.sub(db.re_special, r'\\\1', tag)
return tag
def get_arranged_tags(probs: dict[str, float]):
return [tag for tag, _ in sorted(probs.items(), key=lambda x: -x[1])]

View File

@ -0,0 +1,54 @@
import math
from PIL import Image
from transformers import pipeline
import torch
from modules import devices, shared
from scripts.tagger import Tagger
# brought and modified from https://huggingface.co/spaces/cafeai/cafe_aesthetic_demo/blob/main/app.py
# I'm not sure if this is really working
BATCH_SIZE = 3
class AestheticShadowV2(Tagger):
def load(self):
if devices.device.index is None:
dev = torch.device(devices.device.type, 0)
else:
dev = devices.device
self.pipe_aesthetic = pipeline("image-classification", "shadowlilac/aesthetic-shadow-v2", device=dev, batch_size=BATCH_SIZE)
def unload(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.pipe_aesthetic = None
devices.torch_gc()
def start(self):
self.load()
return self
def stop(self):
self.unload()
def _get_score(self, data):
final = {}
for d in data:
final[d["label"]] = d["score"]
hq = final['hq']
lq = final['lq']
return [f"score_{math.floor((hq + (1 - lq))/2 * 10)}"]
def predict(self, image: Image.Image, threshold=None):
data = self.pipe_aesthetic(image)
return self._get_score(data)
def predict_pipe(self, data: list[Image.Image], threshold=None):
if data is None:
return
for out in self.pipe_aesthetic(data, batch_size=BATCH_SIZE):
yield self._get_score(out)
def name(self):
return "aesthetic shadow"

View File

@ -0,0 +1,54 @@
import math
from PIL import Image
from transformers import pipeline
import torch
from modules import devices, shared
from scripts.tagger import Tagger
# brought and modified from https://huggingface.co/spaces/cafeai/cafe_aesthetic_demo/blob/main/app.py
# I'm not sure if this is really working
BATCH_SIZE = 8
class CafeAIAesthetic(Tagger):
def load(self):
if devices.device.index is None:
dev = torch.device(devices.device.type, 0)
else:
dev = devices.device
self.pipe_aesthetic = pipeline("image-classification", "cafeai/cafe_aesthetic", device=dev, batch_size=BATCH_SIZE)
def unload(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.pipe_aesthetic = None
devices.torch_gc()
def start(self):
self.load()
return self
def stop(self):
self.unload()
def _get_score(self, data):
final = {}
for d in data:
final[d["label"]] = d["score"]
nae = final['not_aesthetic']
ae = final['aesthetic']
return [f"score_{math.floor((ae + (1 - nae))/2 * 10)}"]
def predict(self, image: Image.Image, threshold=None):
data = self.pipe_aesthetic(image, top_k=2)
return self._get_score(data)
def predict_pipe(self, data: list[Image.Image], threshold=None):
if data is None:
return
for out in self.pipe_aesthetic(data, batch_size=BATCH_SIZE):
yield self._get_score(out)
def name(self):
return "cafeai aesthetic classifier"

View File

@ -0,0 +1,75 @@
from PIL import Image
import torch
import torch.nn as nn
import numpy as np
import math
from transformers import CLIPModel, CLIPProcessor
from modules import devices, shared
from scripts import model_loader
from scripts.paths import paths
from scripts.tagger import Tagger
# brought from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py and modified
class Classifier(nn.Module):
def __init__(self, input_size):
super().__init__()
self.input_size = input_size
self.layers = nn.Sequential(
nn.Linear(self.input_size, 1024),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1)
)
def forward(self, x):
return self.layers(x)
# brought and modified from https://github.com/waifu-diffusion/aesthetic/blob/main/aesthetic.py
def image_embeddings(image:Image, model:CLIPModel, processor:CLIPProcessor):
inputs = processor(images=image, return_tensors='pt')['pixel_values']
inputs = inputs.to(devices.device)
result:np.ndarray = model.get_image_features(pixel_values=inputs).cpu().detach().numpy()
return (result / np.linalg.norm(result)).squeeze(axis=0)
class ImprovedAestheticPredictor(Tagger):
def load(self):
MODEL_VERSION = "sac+logos+ava1-l14-linearMSE"
file = model_loader.load(
model_path=paths.models_path / "aesthetic" / f"{MODEL_VERSION}.pth",
model_url=f'https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/{MODEL_VERSION}.pth'
)
CLIP_REPOS = 'openai/clip-vit-large-patch14'
self.model = Classifier(768)
self.model.load_state_dict(torch.load(file))
self.model = self.model.to(devices.device)
self.clip_processor = CLIPProcessor.from_pretrained(CLIP_REPOS)
self.clip_model = CLIPModel.from_pretrained(CLIP_REPOS).to(devices.device).eval()
def unload(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.model = None
self.clip_processor = None
self.clip_model = None
devices.torch_gc()
def start(self):
self.load()
return self
def stop(self):
self.unload()
def predict(self, image: Image.Image, threshold=None):
image_embeds = image_embeddings(image, self.clip_model, self.clip_processor)
prediction:torch.Tensor = self.model(torch.from_numpy(image_embeds).float().to(devices.device))
return [f"score_{math.floor(prediction.item())}"]
def name(self):
return "Improved Aesthetic Predictor"

View File

@ -0,0 +1,73 @@
from PIL import Image
import torch
import numpy as np
import math
from transformers import CLIPModel, CLIPProcessor
from modules import devices, shared
from scripts import model_loader
from scripts.paths import paths
from scripts.tagger import Tagger
# brought from https://github.com/waifu-diffusion/aesthetic/blob/main/aesthetic.py
class Classifier(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Classifier, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.fc2 = torch.nn.Linear(hidden_size, hidden_size//2)
self.fc3 = torch.nn.Linear(hidden_size//2, output_size)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x:torch.Tensor):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.sigmoid(x)
return x
# brought and modified from https://github.com/waifu-diffusion/aesthetic/blob/main/aesthetic.py
def image_embeddings(image:Image, model:CLIPModel, processor:CLIPProcessor):
inputs = processor(images=image, return_tensors='pt')['pixel_values']
inputs = inputs.to(devices.device)
result:np.ndarray = model.get_image_features(pixel_values=inputs).cpu().detach().numpy()
return (result / np.linalg.norm(result)).squeeze(axis=0)
class WaifuAesthetic(Tagger):
def load(self):
file = model_loader.load(
model_path=paths.models_path / "aesthetic" / "aes-B32-v0.pth",
model_url='https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/models/aes-B32-v0.pth'
)
CLIP_REPOS = 'openai/clip-vit-base-patch32'
self.model = Classifier(512, 256, 1)
self.model.load_state_dict(torch.load(file))
self.model = self.model.to(devices.device)
self.clip_processor = CLIPProcessor.from_pretrained(CLIP_REPOS)
self.clip_model = CLIPModel.from_pretrained(CLIP_REPOS).to(devices.device).eval()
def unload(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.model = None
self.clip_processor = None
self.clip_model = None
devices.torch_gc()
def start(self):
self.load()
return self
def stop(self):
self.unload()
def predict(self, image: Image.Image, threshold=None):
image_embeds = image_embeddings(image, self.clip_model, self.clip_processor)
prediction:torch.Tensor = self.model(torch.from_numpy(image_embeds).float().to(devices.device))
return [f"score_{math.floor(prediction.item()*10)}"]
def name(self):
return "wd aesthetic classifier"