Merge changes in standalone version (#93)
* Merge changes in standalone version - New Taggers and Custom Tagger - a little bit stable UIfeature/visualize-tokens
parent
7a2f4c53fb
commit
c3252d8325
|
|
@ -0,0 +1,143 @@
|
|||
from collections import namedtuple
|
||||
import json
|
||||
|
||||
from scripts import logger
|
||||
from scripts.paths import paths
|
||||
from scripts.dte_instance import dte_instance
|
||||
|
||||
SortBy = dte_instance.SortBy
|
||||
SortOrder = dte_instance.SortOrder
|
||||
|
||||
CONFIG_PATH = paths.base_path / "config.json"
|
||||
|
||||
GeneralConfig = namedtuple(
|
||||
"GeneralConfig",
|
||||
[
|
||||
"backup",
|
||||
"dataset_dir",
|
||||
"caption_ext",
|
||||
"load_recursive",
|
||||
"load_caption_from_filename",
|
||||
"replace_new_line",
|
||||
"use_interrogator",
|
||||
"use_interrogator_names",
|
||||
"use_custom_threshold_booru",
|
||||
"custom_threshold_booru",
|
||||
"use_custom_threshold_waifu",
|
||||
"custom_threshold_waifu",
|
||||
"custom_threshold_z3d",
|
||||
"save_kohya_metadata",
|
||||
"meta_output_path",
|
||||
"meta_input_path",
|
||||
"meta_overwrite",
|
||||
"meta_save_as_caption",
|
||||
"meta_use_full_path",
|
||||
],
|
||||
)
|
||||
FilterConfig = namedtuple(
|
||||
"FilterConfig",
|
||||
["sw_prefix", "sw_suffix", "sw_regex", "sort_by", "sort_order", "logic"],
|
||||
)
|
||||
BatchEditConfig = namedtuple(
|
||||
"BatchEditConfig",
|
||||
[
|
||||
"show_only_selected",
|
||||
"prepend",
|
||||
"use_regex",
|
||||
"target",
|
||||
"sw_prefix",
|
||||
"sw_suffix",
|
||||
"sw_regex",
|
||||
"sory_by",
|
||||
"sort_order",
|
||||
"batch_sort_by",
|
||||
"batch_sort_order",
|
||||
"token_count",
|
||||
],
|
||||
)
|
||||
EditSelectedConfig = namedtuple(
|
||||
"EditSelectedConfig",
|
||||
[
|
||||
"auto_copy",
|
||||
"sort_on_save",
|
||||
"warn_change_not_saved",
|
||||
"use_interrogator_name",
|
||||
"sort_by",
|
||||
"sort_order",
|
||||
],
|
||||
)
|
||||
MoveDeleteConfig = namedtuple(
|
||||
"MoveDeleteConfig", ["range", "target", "caption_ext", "destination"]
|
||||
)
|
||||
|
||||
CFG_GENERAL_DEFAULT = GeneralConfig(
|
||||
True,
|
||||
"",
|
||||
".txt",
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
"No",
|
||||
[],
|
||||
False,
|
||||
0.7,
|
||||
False,
|
||||
0.35,
|
||||
0.35,
|
||||
False,
|
||||
"",
|
||||
"",
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
CFG_FILTER_P_DEFAULT = FilterConfig(
|
||||
False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, "AND"
|
||||
)
|
||||
CFG_FILTER_N_DEFAULT = FilterConfig(
|
||||
False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, "OR"
|
||||
)
|
||||
CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
"Only Selected Tags",
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
SortBy.ALPHA.value,
|
||||
SortOrder.ASC.value,
|
||||
SortBy.ALPHA.value,
|
||||
SortOrder.ASC.value,
|
||||
75,
|
||||
)
|
||||
CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(
|
||||
False, False, False, "", SortBy.ALPHA.value, SortOrder.ASC.value
|
||||
)
|
||||
CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig("Selected One", [], ".txt", "")
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
self.config = dict()
|
||||
|
||||
def load(self):
|
||||
if not CONFIG_PATH.is_file():
|
||||
self.config = dict()
|
||||
return
|
||||
try:
|
||||
self.config = json.loads(CONFIG_PATH.read_text("utf8"))
|
||||
except:
|
||||
logger.warn("Error on loading config.json. Default settings will be loaded.")
|
||||
self.config = dict()
|
||||
else:
|
||||
logger.write("Settings has been read from config.json")
|
||||
|
||||
def save(self):
|
||||
CONFIG_PATH.write_text(json.dumps(self.config, indent=4), "utf8")
|
||||
|
||||
def read(self, name: str):
|
||||
return self.config.get(name)
|
||||
|
||||
def write(self, cfg: dict, name: str):
|
||||
self.config[name] = cfg
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
from . import tagger
|
||||
from . import captioning
|
||||
from . import taggers_builtin
|
||||
from . import filters
|
||||
from . import dataset as ds
|
||||
from . import kohya_finetune_metadata
|
||||
|
||||
from .dte_logic import DatasetTagEditor, INTERROGATOR_NAMES, interrogate_image
|
||||
from .dte_logic import DatasetTagEditor
|
||||
|
||||
__all__ = ["ds", "tagger", "captioning", "filters", "kohya_metadata", "INTERROGATOR_NAMES", "interrogate_image", "DatasetTagEditor"]
|
||||
__all__ = ["ds", "taggers_builtin", "filters", "kohya_finetune_metadata", "DatasetTagEditor"]
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
import modules.shared as shared
|
||||
|
||||
from .interrogator import Interrogator
|
||||
from .interrogators import GITLargeCaptioning
|
||||
|
||||
|
||||
class Captioning(Interrogator):
|
||||
def start(self):
|
||||
pass
|
||||
def stop(self):
|
||||
pass
|
||||
def predict(self, image):
|
||||
raise NotImplementedError()
|
||||
def name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BLIP(Captioning):
|
||||
def start(self):
|
||||
shared.interrogator.load()
|
||||
|
||||
def stop(self):
|
||||
shared.interrogator.unload()
|
||||
|
||||
def predict(self, image):
|
||||
tags = shared.interrogator.generate_caption(image).split(',')
|
||||
return [t for t in tags if t]
|
||||
|
||||
def name(self):
|
||||
return 'BLIP'
|
||||
|
||||
|
||||
class GITLarge(Captioning):
|
||||
def __init__(self):
|
||||
self.interrogator = GITLargeCaptioning()
|
||||
def start(self):
|
||||
self.interrogator.load()
|
||||
|
||||
def stop(self):
|
||||
self.interrogator.unload()
|
||||
|
||||
def predict(self, image):
|
||||
tags = self.interrogator.apply(image).split(',')
|
||||
return [t for t in tags if t]
|
||||
|
||||
def name(self):
|
||||
return 'GIT-large-COCO'
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
import importlib.util
|
||||
from types import ModuleType
|
||||
|
||||
from scripts import logger
|
||||
from scripts.paths import paths
|
||||
|
||||
|
||||
class CustomScripts:
|
||||
def _load_module_from(self, path:Path):
|
||||
module_spec = importlib.util.spec_from_file_location(path.stem, path)
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
def _load_derived_classes(self, module:ModuleType, base_class:type):
|
||||
derived_classes = []
|
||||
for name in dir(module):
|
||||
obj = getattr(module, name)
|
||||
if isinstance(obj, type) and issubclass(obj, base_class) and obj is not base_class:
|
||||
derived_classes.append(obj)
|
||||
|
||||
return derived_classes
|
||||
|
||||
def __init__(self, scripts_dir:Path) -> None:
|
||||
self.scripts = dict()
|
||||
self.scripts_dir = scripts_dir.absolute()
|
||||
|
||||
def load_derived_classes(self, baseclass:type):
|
||||
back_syspath = sys.path
|
||||
if not self.scripts_dir.is_dir():
|
||||
logger.warn(f"NOT A DIRECTORY: {self.scripts_dir}")
|
||||
return []
|
||||
|
||||
classes = []
|
||||
try:
|
||||
sys.path = [str(paths.base_path)] + sys.path
|
||||
for path in self.scripts_dir.glob("*.py"):
|
||||
self.scripts[path.stem] = self._load_module_from(path)
|
||||
for module in self.scripts.values():
|
||||
classes.extend(self._load_derived_classes(module, baseclass))
|
||||
except Exception as e:
|
||||
tb = sys.exc_info()[2]
|
||||
logger.error(f"Error on loading {path}")
|
||||
logger.error(e.with_traceback(tb))
|
||||
finally:
|
||||
sys.path = back_syspath
|
||||
|
||||
return classes
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,15 +0,0 @@
|
|||
class Interrogator:
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
self.stop()
|
||||
pass
|
||||
def start(self):
|
||||
pass
|
||||
def stop(self):
|
||||
pass
|
||||
def predict(self, image, **kwargs):
|
||||
raise NotImplementedError()
|
||||
def name(self):
|
||||
raise NotImplementedError()
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
from .blip2_captioning import BLIP2Captioning
|
||||
from .git_large_captioning import GITLargeCaptioning
|
||||
from .waifu_diffusion_tagger import WaifuDiffusionTagger
|
||||
|
||||
__all__ = [
|
||||
'GITLargeCaptioning', 'WaifuDiffusionTagger'
|
||||
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger'
|
||||
]
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||
|
||||
from modules import devices, shared
|
||||
from scripts.paths import paths
|
||||
|
||||
|
||||
class BLIP2Captioning:
|
||||
def __init__(self, model_repo: str):
|
||||
self.MODEL_REPO = model_repo
|
||||
self.processor: Blip2Processor = None
|
||||
self.model: Blip2ForConditionalGeneration = None
|
||||
|
||||
def load(self):
|
||||
if self.model is None or self.processor is None:
|
||||
self.processor = Blip2Processor.from_pretrained(
|
||||
self.MODEL_REPO, cache_dir=paths.setting_model_path
|
||||
)
|
||||
self.model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
self.MODEL_REPO, cache_dir=paths.setting_model_path
|
||||
).to(devices.device)
|
||||
|
||||
def unload(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.model = None
|
||||
self.processor = None
|
||||
devices.torch_gc()
|
||||
|
||||
def apply(self, image):
|
||||
if self.model is None or self.processor is None:
|
||||
return ""
|
||||
inputs = self.processor(images=image, return_tensors="pt").to(devices.device)
|
||||
ids = self.model.generate(**inputs)
|
||||
return self.processor.batch_decode(ids, skip_special_tokens=True)
|
||||
|
|
@ -1,26 +1,35 @@
|
|||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
from modules import shared
|
||||
from modules import shared, devices, lowvram
|
||||
|
||||
|
||||
# brought from https://huggingface.co/docs/transformers/main/en/model_doc/git and modified
|
||||
class GITLargeCaptioning():
|
||||
class GITLargeCaptioning:
|
||||
MODEL_REPO = "microsoft/git-large-coco"
|
||||
|
||||
def __init__(self):
|
||||
self.processor:AutoProcessor = None
|
||||
self.model:AutoModelForCausalLM = None
|
||||
self.processor: AutoProcessor = None
|
||||
self.model: AutoModelForCausalLM = None
|
||||
|
||||
def load(self):
|
||||
if self.model is None or self.processor is None:
|
||||
self.processor = AutoProcessor.from_pretrained(self.MODEL_REPO)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(self.MODEL_REPO).to(shared.device)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(self.MODEL_REPO).to(
|
||||
shared.device
|
||||
)
|
||||
lowvram.send_everything_to_cpu()
|
||||
|
||||
def unload(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.model = None
|
||||
self.processor = None
|
||||
devices.torch_gc()
|
||||
|
||||
def apply(self, image):
|
||||
if self.model is None or self.processor is None:
|
||||
return ''
|
||||
inputs = self.processor(images=image, return_tensors='pt').to(shared.device)
|
||||
ids = self.model.generate(pixel_values=inputs.pixel_values, max_length=shared.opts.interrogate_clip_max_length)
|
||||
return self.processor.batch_decode(ids, skip_special_tokens=True)[0]
|
||||
return ""
|
||||
inputs = self.processor(images=image, return_tensors="pt").to(shared.device)
|
||||
ids = self.model.generate(
|
||||
pixel_values=inputs.pixel_values,
|
||||
max_length=shared.opts.interrogate_clip_max_length,
|
||||
)
|
||||
return self.processor.batch_decode(ids, skip_special_tokens=True)[0]
|
||||
|
|
|
|||
|
|
@ -1,74 +1,103 @@
|
|||
from PIL import Image
|
||||
import numpy as np
|
||||
from typing import List, Tuple
|
||||
from modules import shared
|
||||
from modules import shared, devices
|
||||
import launch
|
||||
|
||||
|
||||
class WaifuDiffusionTagger():
|
||||
class WaifuDiffusionTagger:
|
||||
# brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
|
||||
MODEL_FILENAME = "model.onnx"
|
||||
LABEL_FILENAME = "selected_tags.csv"
|
||||
def __init__(self, model_name):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
model_filename="model.onnx",
|
||||
label_filename="selected_tags.csv",
|
||||
):
|
||||
self.MODEL_FILENAME = model_filename
|
||||
self.LABEL_FILENAME = label_filename
|
||||
self.MODEL_REPO = model_name
|
||||
self.model = None
|
||||
self.labels = []
|
||||
|
||||
def load(self):
|
||||
import huggingface_hub
|
||||
|
||||
if not self.model:
|
||||
path_model = huggingface_hub.hf_hub_download(
|
||||
self.MODEL_REPO, self.MODEL_FILENAME
|
||||
)
|
||||
if 'all' in shared.cmd_opts.use_cpu or 'interrogate' in shared.cmd_opts.use_cpu:
|
||||
providers = ['CPUExecutionProvider']
|
||||
if (
|
||||
"all" in shared.cmd_opts.use_cpu
|
||||
or "interrogate" in shared.cmd_opts.use_cpu
|
||||
):
|
||||
providers = ["CPUExecutionProvider"]
|
||||
else:
|
||||
providers = ['CUDAExecutionProvider', 'DmlExecutionProvider', 'CPUExecutionProvider']
|
||||
|
||||
providers = [
|
||||
"CUDAExecutionProvider",
|
||||
"DmlExecutionProvider",
|
||||
"CPUExecutionProvider",
|
||||
]
|
||||
|
||||
def check_available_device():
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
return "cuda"
|
||||
elif launch.is_installed("torch-directml"):
|
||||
# This code cannot detect DirectML available device without pytorch-directml
|
||||
try:
|
||||
import torch_directml
|
||||
|
||||
torch_directml.device()
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
return 'directml'
|
||||
return 'cpu'
|
||||
return "directml"
|
||||
return "cpu"
|
||||
|
||||
if not launch.is_installed("onnxruntime"):
|
||||
dev = check_available_device()
|
||||
if dev == 'cuda':
|
||||
launch.run_pip("install -U onnxruntime-gpu", "requirements for dataset-tag-editor [onnxruntime-gpu]")
|
||||
elif dev == 'directml':
|
||||
launch.run_pip("install -U onnxruntime-directml", "requirements for dataset-tag-editor [onnxruntime-directml]")
|
||||
if dev == "cuda":
|
||||
launch.run_pip(
|
||||
"install -U onnxruntime-gpu",
|
||||
"requirements for dataset-tag-editor [onnxruntime-gpu]",
|
||||
)
|
||||
elif dev == "directml":
|
||||
launch.run_pip(
|
||||
"install -U onnxruntime-directml",
|
||||
"requirements for dataset-tag-editor [onnxruntime-directml]",
|
||||
)
|
||||
else:
|
||||
print('Your device is not compatible with onnx hardware acceleration. CPU only version will be installed and it may be very slow.')
|
||||
launch.run_pip("install -U onnxruntime", "requirements for dataset-tag-editor [onnxruntime for CPU]")
|
||||
print(
|
||||
"Your device is not compatible with onnx hardware acceleration. CPU only version will be installed and it may be very slow."
|
||||
)
|
||||
launch.run_pip(
|
||||
"install -U onnxruntime",
|
||||
"requirements for dataset-tag-editor [onnxruntime for CPU]",
|
||||
)
|
||||
import onnxruntime as ort
|
||||
|
||||
self.model = ort.InferenceSession(path_model, providers=providers)
|
||||
|
||||
|
||||
path_label = huggingface_hub.hf_hub_download(
|
||||
self.MODEL_REPO, self.LABEL_FILENAME
|
||||
)
|
||||
import pandas as pd
|
||||
|
||||
self.labels = pd.read_csv(path_label)["name"].tolist()
|
||||
|
||||
def unload(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.model = None
|
||||
devices.torch_gc()
|
||||
|
||||
# brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
|
||||
def apply(self, image: Image.Image):
|
||||
if not self.model:
|
||||
return dict()
|
||||
|
||||
|
||||
from modules import images
|
||||
|
||||
|
||||
_, height, width, _ = self.model.get_inputs()[0].shape
|
||||
|
||||
# the way to fill empty pixels is quite different from original one;
|
||||
|
|
@ -85,4 +114,4 @@ class WaifuDiffusionTagger():
|
|||
probs = self.model.run([label_name], {input_name: image_np})[0]
|
||||
labels: List[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))
|
||||
|
||||
return labels
|
||||
return labels
|
||||
|
|
|
|||
|
|
@ -1,106 +0,0 @@
|
|||
from PIL import Image
|
||||
import re
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Optional, Dict
|
||||
from modules import devices, shared
|
||||
from modules import deepbooru as db
|
||||
|
||||
from .interrogator import Interrogator
|
||||
from .interrogators import WaifuDiffusionTagger
|
||||
|
||||
|
||||
class Tagger(Interrogator):
|
||||
def start(self):
|
||||
pass
|
||||
def stop(self):
|
||||
pass
|
||||
def predict(self, image: Image.Image, threshold: Optional[float]):
|
||||
raise NotImplementedError()
|
||||
def name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_replaced_tag(tag: str):
|
||||
use_spaces = shared.opts.deepbooru_use_spaces
|
||||
use_escape = shared.opts.deepbooru_escape
|
||||
if use_spaces:
|
||||
tag = tag.replace('_', ' ')
|
||||
if use_escape:
|
||||
tag = re.sub(db.re_special, r'\\\1', tag)
|
||||
return tag
|
||||
|
||||
|
||||
def get_arranged_tags(probs: Dict[str, float]):
|
||||
alpha_sort = shared.opts.deepbooru_sort_alpha
|
||||
if alpha_sort:
|
||||
return sorted(probs)
|
||||
else:
|
||||
return [tag for tag, _ in sorted(probs.items(), key=lambda x: -x[1])]
|
||||
|
||||
|
||||
class DeepDanbooru(Tagger):
|
||||
def start(self):
|
||||
db.model.start()
|
||||
|
||||
def stop(self):
|
||||
db.model.stop()
|
||||
|
||||
# brought from webUI modules/deepbooru.py and modified
|
||||
def predict(self, image: Image.Image, threshold: Optional[float] = None):
|
||||
from modules import images
|
||||
|
||||
pic = images.resize_image(2, image.convert("RGB"), 512, 512)
|
||||
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
||||
|
||||
with torch.no_grad(), devices.autocast():
|
||||
x = torch.from_numpy(a).to(devices.device)
|
||||
y = db.model.model(x)[0].detach().cpu().numpy()
|
||||
|
||||
probability_dict = dict()
|
||||
|
||||
for tag, probability in zip(db.model.model.tags, y):
|
||||
if threshold and probability < threshold:
|
||||
continue
|
||||
if tag.startswith("rating:"):
|
||||
continue
|
||||
probability_dict[get_replaced_tag(tag)] = probability
|
||||
|
||||
return probability_dict
|
||||
|
||||
def name(self):
|
||||
return 'DeepDanbooru'
|
||||
|
||||
|
||||
class WaifuDiffusion(Tagger):
|
||||
def __init__(self, repo_name, threshold):
|
||||
self.repo_name = repo_name
|
||||
self.tagger_inst = WaifuDiffusionTagger("SmilingWolf/" + repo_name)
|
||||
self.threshold = threshold
|
||||
|
||||
def start(self):
|
||||
self.tagger_inst.load()
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
self.tagger_inst.unload()
|
||||
|
||||
# brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
|
||||
# set threshold<0 to use default value for now...
|
||||
def predict(self, image: Image.Image, threshold: Optional[float] = None):
|
||||
# may not use ratings
|
||||
# rating = dict(labels[:4])
|
||||
|
||||
labels = self.tagger_inst.apply(image)
|
||||
|
||||
if threshold is not None:
|
||||
if threshold < 0:
|
||||
threshold = self.threshold
|
||||
probability_dict = dict([(get_replaced_tag(x[0]), x[1]) for x in labels[4:] if x[1] > threshold])
|
||||
else:
|
||||
probability_dict = dict([(get_replaced_tag(x[0]), x[1]) for x in labels[4:]])
|
||||
|
||||
return probability_dict
|
||||
|
||||
def name(self):
|
||||
return self.repo_name
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modules import devices, shared
|
||||
from modules import deepbooru as db
|
||||
|
||||
from scripts.tagger import Tagger, get_replaced_tag
|
||||
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger
|
||||
|
||||
|
||||
class BLIP(Tagger):
|
||||
def start(self):
|
||||
shared.interrogator.load()
|
||||
|
||||
def stop(self):
|
||||
shared.interrogator.unload()
|
||||
|
||||
def predict(self, image:Image.Image, threshold=None):
|
||||
tags = shared.interrogator.generate_caption(image).split(',')
|
||||
return [t for t in tags if t]
|
||||
|
||||
def name(self):
|
||||
return 'BLIP'
|
||||
|
||||
|
||||
|
||||
class BLIP2(Tagger):
|
||||
def __init__(self, repo_name):
|
||||
self.interrogator = BLIP2Captioning("Salesforce/" + repo_name)
|
||||
self.repo_name = repo_name
|
||||
|
||||
def start(self):
|
||||
self.interrogator.load()
|
||||
|
||||
def stop(self):
|
||||
self.interrogator.unload()
|
||||
|
||||
def predict(self, image:Image, threshold=None):
|
||||
tags = self.interrogator.apply(image)[0].split(",")
|
||||
return [t for t in tags if t]
|
||||
|
||||
# def predict_multi(self, images:list):
|
||||
# captions = self.interrogator.apply(images)
|
||||
# return [[t for t in caption.split(',') if t] for caption in captions]
|
||||
|
||||
def name(self):
|
||||
return self.repo_name
|
||||
|
||||
|
||||
class GITLarge(Tagger):
|
||||
def __init__(self):
|
||||
self.interrogator = GITLargeCaptioning()
|
||||
|
||||
def start(self):
|
||||
self.interrogator.load()
|
||||
|
||||
def stop(self):
|
||||
self.interrogator.unload()
|
||||
|
||||
def predict(self, image:Image, threshold=None):
|
||||
tags = self.interrogator.apply(image)[0].split(",")
|
||||
return [t for t in tags if t]
|
||||
|
||||
# def predict_multi(self, images:list):
|
||||
# captions = self.interrogator.apply(images)
|
||||
# return [[t for t in caption.split(',') if t] for caption in captions]
|
||||
|
||||
def name(self):
|
||||
return "GIT-large-COCO"
|
||||
|
||||
|
||||
class DeepDanbooru(Tagger):
|
||||
def start(self):
|
||||
db.model.start()
|
||||
|
||||
def stop(self):
|
||||
db.model.stop()
|
||||
|
||||
# brought from webUI modules/deepbooru.py and modified
|
||||
def predict(self, image: Image.Image, threshold: Optional[float] = None):
|
||||
from modules import images
|
||||
|
||||
pic = images.resize_image(2, image.convert("RGB"), 512, 512)
|
||||
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
||||
|
||||
with torch.no_grad(), devices.autocast():
|
||||
x = torch.from_numpy(a).to(devices.device)
|
||||
y = db.model.model(x)[0].detach().cpu().numpy()
|
||||
|
||||
tags = []
|
||||
|
||||
for tag, probability in zip(db.model.model.tags, y):
|
||||
if threshold and probability < threshold:
|
||||
continue
|
||||
if not shared.opts.dataset_editor_use_rating and tag.startswith("rating:"):
|
||||
continue
|
||||
tags.append(get_replaced_tag(tag))
|
||||
|
||||
return tags
|
||||
|
||||
def name(self):
|
||||
return 'DeepDanbooru'
|
||||
|
||||
|
||||
class WaifuDiffusion(Tagger):
|
||||
def __init__(self, repo_name, threshold):
|
||||
self.repo_name = repo_name
|
||||
self.tagger_inst = WaifuDiffusionTagger("SmilingWolf/" + repo_name)
|
||||
self.threshold = threshold
|
||||
|
||||
def start(self):
|
||||
self.tagger_inst.load()
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
self.tagger_inst.unload()
|
||||
|
||||
# brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
|
||||
# set threshold<0 to use default value for now...
|
||||
def predict(self, image: Image.Image, threshold: Optional[float] = None):
|
||||
# may not use ratings
|
||||
# rating = dict(labels[:4])
|
||||
|
||||
labels = self.tagger_inst.apply(image)
|
||||
|
||||
if not shared.opts.dataset_editor_use_rating:
|
||||
labels = labels[4:]
|
||||
|
||||
if threshold is not None:
|
||||
if threshold < 0:
|
||||
threshold = self.threshold
|
||||
tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold]
|
||||
else:
|
||||
tags = [get_replaced_tag(tag) for tag, _ in labels]
|
||||
|
||||
return tags
|
||||
|
||||
def name(self):
|
||||
return self.repo_name
|
||||
|
||||
|
||||
class Z3D_E621(Tagger):
|
||||
def __init__(self):
|
||||
self.tagger_inst = WaifuDiffusionTagger("toynya/Z3D-E621-Convnext", label_filename="tags-selected.csv")
|
||||
|
||||
def start(self):
|
||||
self.tagger_inst.load()
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
self.tagger_inst.unload()
|
||||
|
||||
# brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
|
||||
# set threshold<0 to use default value for now...
|
||||
def predict(self, image: Image.Image, threshold: Optional[float] = None):
|
||||
# may not use ratings
|
||||
# rating = dict(labels[:4])
|
||||
|
||||
labels = self.tagger_inst.apply(image)
|
||||
if threshold is not None:
|
||||
tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold]
|
||||
else:
|
||||
tags = [get_replaced_tag(tag) for tag, _ in labels]
|
||||
|
||||
return tags
|
||||
|
||||
def name(self):
|
||||
return "Z3D-E621-Convnext"
|
||||
|
|
@ -1,2 +1,2 @@
|
|||
import scripts.dataset_tag_editor as dte_module
|
||||
dte_instance = dte_module.DatasetTagEditor.get_instance()
|
||||
dte_instance = dte_module.DatasetTagEditor()
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
def write(content):
|
||||
print("[tag-editor] " + content)
|
||||
|
||||
def warn(content):
|
||||
write(f"[tag-editor:WARNING] {content}")
|
||||
|
||||
def error(content):
|
||||
write(f"[tag-editor:ERROR] {content}")
|
||||
479
scripts/main.py
479
scripts/main.py
|
|
@ -1,197 +1,165 @@
|
|||
from typing import NamedTuple, Type, Dict, Any
|
||||
from modules import shared, script_callbacks, scripts
|
||||
from modules import shared, script_callbacks
|
||||
from modules.shared import opts
|
||||
import gradio as gr
|
||||
import json
|
||||
from pathlib import Path
|
||||
from collections import namedtuple
|
||||
from scripts.config import *
|
||||
|
||||
import scripts.tag_editor_ui as ui
|
||||
from scripts.dte_instance import dte_instance
|
||||
|
||||
# ================================================================
|
||||
# General Callbacks
|
||||
# ================================================================
|
||||
|
||||
CONFIG_PATH = Path(scripts.basedir(), 'config.json')
|
||||
|
||||
|
||||
SortBy = dte_instance.SortBy
|
||||
SortOrder = dte_instance.SortOrder
|
||||
|
||||
GeneralConfig = namedtuple('GeneralConfig', [
|
||||
'backup',
|
||||
'dataset_dir',
|
||||
'caption_ext',
|
||||
'load_recursive',
|
||||
'load_caption_from_filename',
|
||||
'replace_new_line',
|
||||
'use_interrogator',
|
||||
'use_interrogator_names',
|
||||
'use_custom_threshold_booru',
|
||||
'custom_threshold_booru',
|
||||
'use_custom_threshold_waifu',
|
||||
'custom_threshold_waifu',
|
||||
'save_kohya_metadata',
|
||||
'meta_output_path',
|
||||
'meta_input_path',
|
||||
'meta_overwrite',
|
||||
'meta_save_as_caption',
|
||||
'meta_use_full_path'
|
||||
])
|
||||
FilterConfig = namedtuple('FilterConfig', ['sw_prefix', 'sw_suffix', 'sw_regex','sort_by', 'sort_order', 'logic'])
|
||||
BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sw_prefix', 'sw_suffix', 'sw_regex', 'sory_by', 'sort_order', 'batch_sort_by', 'batch_sort_order', 'token_count'])
|
||||
EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
|
||||
MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination'])
|
||||
|
||||
CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, False, 'No', [], False, 0.7, False, 0.35, False, '', '', True, False, False)
|
||||
CFG_FILTER_P_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'AND')
|
||||
CFG_FILTER_N_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'OR')
|
||||
CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value, 75)
|
||||
CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value)
|
||||
CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig('Selected One', [], '.txt', '')
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
self.config = dict()
|
||||
|
||||
def load(self):
|
||||
if not CONFIG_PATH.is_file():
|
||||
self.config = dict()
|
||||
return
|
||||
try:
|
||||
self.config = json.loads(CONFIG_PATH.read_text('utf8'))
|
||||
except:
|
||||
print('[tag-editor] Error on loading config.json. Default settings will be loaded.')
|
||||
self.config = dict()
|
||||
else:
|
||||
print('[tag-editor] Settings has been read from config.json')
|
||||
|
||||
def save(self):
|
||||
CONFIG_PATH.write_text(json.dumps(self.config, indent=4), 'utf8')
|
||||
|
||||
def read(self, name: str):
|
||||
return self.config.get(name)
|
||||
|
||||
def write(self, cfg: dict, name: str):
|
||||
self.config[name] = cfg
|
||||
|
||||
config = Config()
|
||||
|
||||
|
||||
def write_general_config(*args):
|
||||
cfg = GeneralConfig(*args)
|
||||
config.write(cfg._asdict(), 'general')
|
||||
config.write(cfg._asdict(), "general")
|
||||
|
||||
|
||||
def write_filter_config(*args):
|
||||
hlen = len(args) // 2
|
||||
cfg_p = FilterConfig(*args[:hlen])
|
||||
cfg_n = FilterConfig(*args[hlen:])
|
||||
config.write({'positive':cfg_p._asdict(), 'negative':cfg_n._asdict()}, 'filter')
|
||||
config.write({"positive": cfg_p._asdict(), "negative": cfg_n._asdict()}, "filter")
|
||||
|
||||
|
||||
def write_batch_edit_config(*args):
|
||||
cfg = BatchEditConfig(*args)
|
||||
config.write(cfg._asdict(), 'batch_edit')
|
||||
config.write(cfg._asdict(), "batch_edit")
|
||||
|
||||
|
||||
def write_edit_selected_config(*args):
|
||||
cfg = EditSelectedConfig(*args)
|
||||
config.write(cfg._asdict(), 'edit_selected')
|
||||
config.write(cfg._asdict(), "edit_selected")
|
||||
|
||||
|
||||
def write_move_delete_config(*args):
|
||||
cfg = MoveDeleteConfig(*args)
|
||||
config.write(cfg._asdict(), 'file_move_delete')
|
||||
config.write(cfg._asdict(), "file_move_delete")
|
||||
|
||||
def read_config(name: str, config_type: Type, default: NamedTuple, compat_func = None):
|
||||
|
||||
def read_config(name: str, config_type: Type, default: NamedTuple, compat_func=None):
|
||||
d = config.read(name)
|
||||
cfg = default
|
||||
if d:
|
||||
if compat_func: d = compat_func(d)
|
||||
if compat_func:
|
||||
d = compat_func(d)
|
||||
d = cfg._asdict() | d
|
||||
d = {k:v for k,v in d.items() if k in cfg._asdict().keys()}
|
||||
d = {k: v for k, v in d.items() if k in cfg._asdict().keys()}
|
||||
cfg = config_type(**d)
|
||||
return cfg
|
||||
|
||||
|
||||
def read_general_config():
|
||||
# for compatibility
|
||||
generalcfg_intterogator_names = [
|
||||
('use_blip_to_prefill', 'BLIP'),
|
||||
('use_git_to_prefill', 'GIT-large-COCO'),
|
||||
('use_booru_to_prefill', 'DeepDanbooru'),
|
||||
('use_waifu_to_prefill', 'wd-v1-4-vit-tagger')
|
||||
("use_blip_to_prefill", "BLIP"),
|
||||
("use_git_to_prefill", "GIT-large-COCO"),
|
||||
("use_booru_to_prefill", "DeepDanbooru"),
|
||||
("use_waifu_to_prefill", "wd-v1-4-vit-tagger"),
|
||||
]
|
||||
use_interrogator_names = []
|
||||
|
||||
def compat_func(d: Dict[str, Any]):
|
||||
if 'use_interrogator_names' in d.keys():
|
||||
if "use_interrogator_names" in d.keys():
|
||||
return d
|
||||
for cfg in generalcfg_intterogator_names:
|
||||
if d.get(cfg[0]):
|
||||
use_interrogator_names.append(cfg[1])
|
||||
d['use_interrogator_names'] = use_interrogator_names
|
||||
d["use_interrogator_names"] = use_interrogator_names
|
||||
return d
|
||||
return read_config('general', GeneralConfig, CFG_GENERAL_DEFAULT, compat_func)
|
||||
|
||||
return read_config("general", GeneralConfig, CFG_GENERAL_DEFAULT, compat_func)
|
||||
|
||||
|
||||
def read_filter_config():
|
||||
d = config.read('filter')
|
||||
d_p = d.get('positive') if d else None
|
||||
d_n = d.get('negative') if d else None
|
||||
d = config.read("filter")
|
||||
d_p = d.get("positive") if d else None
|
||||
d_n = d.get("negative") if d else None
|
||||
cfg_p = CFG_FILTER_P_DEFAULT
|
||||
cfg_n = CFG_FILTER_N_DEFAULT
|
||||
if d_p:
|
||||
d_p = cfg_p._asdict() | d_p
|
||||
d_p = {k:v for k,v in d_p.items() if k in cfg_p._asdict().keys()}
|
||||
d_p = {k: v for k, v in d_p.items() if k in cfg_p._asdict().keys()}
|
||||
cfg_p = FilterConfig(**d_p)
|
||||
if d_n:
|
||||
d_n = cfg_n._asdict() | d_n
|
||||
d_n = {k:v for k,v in d_n.items() if k in cfg_n._asdict().keys()}
|
||||
d_n = {k: v for k, v in d_n.items() if k in cfg_n._asdict().keys()}
|
||||
cfg_n = FilterConfig(**d_n)
|
||||
return cfg_p, cfg_n
|
||||
|
||||
|
||||
def read_batch_edit_config():
|
||||
return read_config('batch_edit', BatchEditConfig, CFG_BATCH_EDIT_DEFAULT)
|
||||
return read_config("batch_edit", BatchEditConfig, CFG_BATCH_EDIT_DEFAULT)
|
||||
|
||||
|
||||
def read_edit_selected_config():
|
||||
return read_config('edit_selected', EditSelectedConfig, CFG_EDIT_SELECTED_DEFAULT)
|
||||
return read_config("edit_selected", EditSelectedConfig, CFG_EDIT_SELECTED_DEFAULT)
|
||||
|
||||
|
||||
def read_move_delete_config():
|
||||
return read_config('file_move_delete', MoveDeleteConfig, CFG_MOVE_DELETE_DEFAULT)
|
||||
return read_config("file_move_delete", MoveDeleteConfig, CFG_MOVE_DELETE_DEFAULT)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# General Callbacks for Updating UIs
|
||||
# ================================================================
|
||||
|
||||
|
||||
def get_filters():
|
||||
filters = [ui.filter_by_tags.tag_filter_ui.get_filter(), ui.filter_by_tags.tag_filter_ui_neg.get_filter()] + [ui.filter_by_selection.path_filter]
|
||||
filters = [
|
||||
ui.filter_by_tags.tag_filter_ui.get_filter(),
|
||||
ui.filter_by_tags.tag_filter_ui_neg.get_filter(),
|
||||
] + [ui.filter_by_selection.path_filter]
|
||||
return filters
|
||||
|
||||
|
||||
def update_gallery():
|
||||
img_indices = ui.dte_instance.get_filtered_imgindices(filters=get_filters())
|
||||
total_image_num = len(ui.dte_instance.dataset)
|
||||
displayed_image_num = len(img_indices)
|
||||
ui.gallery_state.register_value('Displayed Images', f'{displayed_image_num} / {total_image_num} total')
|
||||
ui.gallery_state.register_value('Current Tag Filter', f"{ui.filter_by_tags.tag_filter_ui.get_filter()} {' AND ' if ui.filter_by_tags.tag_filter_ui.get_filter().tags and ui.filter_by_tags.tag_filter_ui_neg.get_filter().tags else ''} {ui.filter_by_tags.tag_filter_ui_neg.get_filter()}")
|
||||
ui.gallery_state.register_value('Current Selection Filter', f'{len(ui.filter_by_selection.path_filter.paths)} images')
|
||||
ui.gallery_state.register_value(
|
||||
"Displayed Images", f"{displayed_image_num} / {total_image_num} total"
|
||||
)
|
||||
ui.gallery_state.register_value(
|
||||
"Current Tag Filter",
|
||||
f"{ui.filter_by_tags.tag_filter_ui.get_filter()} {' AND ' if ui.filter_by_tags.tag_filter_ui.get_filter().tags and ui.filter_by_tags.tag_filter_ui_neg.get_filter().tags else ''} {ui.filter_by_tags.tag_filter_ui_neg.get_filter()}",
|
||||
)
|
||||
ui.gallery_state.register_value(
|
||||
"Current Selection Filter",
|
||||
f"{len(ui.filter_by_selection.path_filter.paths)} images",
|
||||
)
|
||||
return [
|
||||
[str(i) for i in img_indices],
|
||||
1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
ui.gallery_state.get_current_gallery_txt()
|
||||
]
|
||||
ui.gallery_state.get_current_gallery_txt(),
|
||||
]
|
||||
|
||||
|
||||
def update_filter_and_gallery():
|
||||
return \
|
||||
[ui.filter_by_tags.tag_filter_ui.cbg_tags_update(), ui.filter_by_tags.tag_filter_ui_neg.cbg_tags_update()] +\
|
||||
update_gallery() +\
|
||||
ui.batch_edit_captions.get_common_tags(get_filters, ui.filter_by_tags) +\
|
||||
[', '.join(ui.filter_by_tags.tag_filter_ui.filter.tags)] +\
|
||||
[ui.batch_edit_captions.tag_select_ui_remove.cbg_tags_update()] +\
|
||||
['', '']
|
||||
return (
|
||||
[
|
||||
ui.filter_by_tags.tag_filter_ui.cbg_tags_update(),
|
||||
ui.filter_by_tags.tag_filter_ui_neg.cbg_tags_update(),
|
||||
]
|
||||
+ update_gallery()
|
||||
+ ui.batch_edit_captions.get_common_tags(get_filters, ui.filter_by_tags)
|
||||
+ [", ".join(ui.filter_by_tags.tag_filter_ui.filter.tags)]
|
||||
+ [ui.batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
|
||||
+ ["", ""]
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Script Callbacks
|
||||
# ================================================================
|
||||
|
||||
|
||||
def on_ui_tabs():
|
||||
config.load()
|
||||
|
||||
|
|
@ -201,101 +169,156 @@ def on_ui_tabs():
|
|||
cfg_edit_selected = read_edit_selected_config()
|
||||
cfg_file_move_delete = read_move_delete_config()
|
||||
|
||||
ui.dte_instance.load_interrogators()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as dataset_tag_editor_interface:
|
||||
|
||||
gr.HTML(value="""
|
||||
gr.HTML(
|
||||
value="""
|
||||
This extension works well with text captions in comma-separated style (such as the tags generated by DeepBooru interrogator).
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
ui.toprow.create_ui(cfg_general)
|
||||
|
||||
with gr.Accordion(label='Reload/Save Settings (config.json)', open=False):
|
||||
with gr.Accordion(label="Reload/Save Settings (config.json)", open=False):
|
||||
with gr.Row():
|
||||
btn_reload_config_file = gr.Button(value='Reload settings')
|
||||
btn_save_setting_as_default = gr.Button(value='Save current settings')
|
||||
btn_restore_default = gr.Button(value='Restore settings to default')
|
||||
btn_reload_config_file = gr.Button(value="Reload settings")
|
||||
btn_save_setting_as_default = gr.Button(value="Save current settings")
|
||||
btn_restore_default = gr.Button(value="Restore settings to default")
|
||||
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Row(equal_height=False):
|
||||
with gr.Column():
|
||||
ui.load_dataset.create_ui(cfg_general)
|
||||
ui.dataset_gallery.create_ui(opts.dataset_editor_image_columns)
|
||||
ui.gallery_state.create_ui()
|
||||
|
||||
with gr.Tab(label='Filter by Tags'):
|
||||
with gr.Tab(label="Filter by Tags"):
|
||||
ui.filter_by_tags.create_ui(cfg_filter_p, cfg_filter_n, get_filters)
|
||||
|
||||
with gr.Tab(label='Filter by Selection'):
|
||||
|
||||
with gr.Tab(label="Filter by Selection"):
|
||||
ui.filter_by_selection.create_ui(opts.dataset_editor_image_columns)
|
||||
|
||||
with gr.Tab(label='Batch Edit Captions'):
|
||||
with gr.Tab(label="Batch Edit Captions"):
|
||||
ui.batch_edit_captions.create_ui(cfg_batch_edit, get_filters)
|
||||
|
||||
with gr.Tab(label='Edit Caption of Selected Image'):
|
||||
with gr.Tab(label="Edit Caption of Selected Image"):
|
||||
ui.edit_caption_of_selected_image.create_ui(cfg_edit_selected)
|
||||
|
||||
with gr.Tab(label='Move or Delete Files'):
|
||||
with gr.Tab(label="Move or Delete Files"):
|
||||
ui.move_or_delete_files.create_ui(cfg_file_move_delete)
|
||||
|
||||
#----------------------------------------------------------------
|
||||
# ----------------------------------------------------------------
|
||||
# General
|
||||
|
||||
components_general = [
|
||||
ui.toprow.cb_backup, ui.load_dataset.tb_img_directory, ui.load_dataset.tb_caption_file_ext, ui.load_dataset.cb_load_recursive,
|
||||
ui.load_dataset.cb_load_caption_from_filename, ui.load_dataset.cb_replace_new_line_with_comma, ui.load_dataset.rb_use_interrogator, ui.load_dataset.dd_intterogator_names,
|
||||
ui.load_dataset.cb_use_custom_threshold_booru, ui.load_dataset.sl_custom_threshold_booru, ui.load_dataset.cb_use_custom_threshold_waifu, ui.load_dataset.sl_custom_threshold_waifu,
|
||||
ui.toprow.cb_save_kohya_metadata, ui.toprow.tb_metadata_output, ui.toprow.tb_metadata_input, ui.toprow.cb_metadata_overwrite, ui.toprow.cb_metadata_as_caption, ui.toprow.cb_metadata_use_fullpath
|
||||
ui.toprow.cb_backup,
|
||||
ui.load_dataset.tb_img_directory,
|
||||
ui.load_dataset.tb_caption_file_ext,
|
||||
ui.load_dataset.cb_load_recursive,
|
||||
ui.load_dataset.cb_load_caption_from_filename,
|
||||
ui.load_dataset.cb_replace_new_line_with_comma,
|
||||
ui.load_dataset.rb_use_interrogator,
|
||||
ui.load_dataset.dd_intterogator_names,
|
||||
ui.load_dataset.cb_use_custom_threshold_booru,
|
||||
ui.load_dataset.sl_custom_threshold_booru,
|
||||
ui.load_dataset.cb_use_custom_threshold_waifu,
|
||||
ui.load_dataset.sl_custom_threshold_waifu,
|
||||
ui.load_dataset.sl_custom_threshold_z3d,
|
||||
ui.toprow.cb_save_kohya_metadata,
|
||||
ui.toprow.tb_metadata_output,
|
||||
ui.toprow.tb_metadata_input,
|
||||
ui.toprow.cb_metadata_overwrite,
|
||||
ui.toprow.cb_metadata_as_caption,
|
||||
ui.toprow.cb_metadata_use_fullpath,
|
||||
]
|
||||
components_filter = [
|
||||
ui.filter_by_tags.tag_filter_ui.cb_prefix,
|
||||
ui.filter_by_tags.tag_filter_ui.cb_suffix,
|
||||
ui.filter_by_tags.tag_filter_ui.cb_regex,
|
||||
ui.filter_by_tags.tag_filter_ui.rb_sort_by,
|
||||
ui.filter_by_tags.tag_filter_ui.rb_sort_order,
|
||||
ui.filter_by_tags.tag_filter_ui.rb_logic,
|
||||
] + [
|
||||
ui.filter_by_tags.tag_filter_ui_neg.cb_prefix,
|
||||
ui.filter_by_tags.tag_filter_ui_neg.cb_suffix,
|
||||
ui.filter_by_tags.tag_filter_ui_neg.cb_regex,
|
||||
ui.filter_by_tags.tag_filter_ui_neg.rb_sort_by,
|
||||
ui.filter_by_tags.tag_filter_ui_neg.rb_sort_order,
|
||||
ui.filter_by_tags.tag_filter_ui_neg.rb_logic,
|
||||
]
|
||||
components_filter = \
|
||||
[ui.filter_by_tags.tag_filter_ui.cb_prefix, ui.filter_by_tags.tag_filter_ui.cb_suffix, ui.filter_by_tags.tag_filter_ui.cb_regex, ui.filter_by_tags.tag_filter_ui.rb_sort_by, ui.filter_by_tags.tag_filter_ui.rb_sort_order, ui.filter_by_tags.tag_filter_ui.rb_logic] +\
|
||||
[ui.filter_by_tags.tag_filter_ui_neg.cb_prefix, ui.filter_by_tags.tag_filter_ui_neg.cb_suffix, ui.filter_by_tags.tag_filter_ui_neg.cb_regex, ui.filter_by_tags.tag_filter_ui_neg.rb_sort_by, ui.filter_by_tags.tag_filter_ui_neg.rb_sort_order, ui.filter_by_tags.tag_filter_ui_neg.rb_logic]
|
||||
components_batch_edit = [
|
||||
ui.batch_edit_captions.cb_show_only_tags_selected, ui.batch_edit_captions.cb_prepend_tags, ui.batch_edit_captions.cb_use_regex,
|
||||
ui.batch_edit_captions.cb_show_only_tags_selected,
|
||||
ui.batch_edit_captions.cb_prepend_tags,
|
||||
ui.batch_edit_captions.cb_use_regex,
|
||||
ui.batch_edit_captions.rb_sr_replace_target,
|
||||
ui.batch_edit_captions.tag_select_ui_remove.cb_prefix, ui.batch_edit_captions.tag_select_ui_remove.cb_suffix, ui.batch_edit_captions.tag_select_ui_remove.cb_regex,
|
||||
ui.batch_edit_captions.tag_select_ui_remove.rb_sort_by, ui.batch_edit_captions.tag_select_ui_remove.rb_sort_order,
|
||||
ui.batch_edit_captions.rb_sort_by, ui.batch_edit_captions.rb_sort_order,
|
||||
ui.batch_edit_captions.nb_token_count
|
||||
ui.batch_edit_captions.tag_select_ui_remove.cb_prefix,
|
||||
ui.batch_edit_captions.tag_select_ui_remove.cb_suffix,
|
||||
ui.batch_edit_captions.tag_select_ui_remove.cb_regex,
|
||||
ui.batch_edit_captions.tag_select_ui_remove.rb_sort_by,
|
||||
ui.batch_edit_captions.tag_select_ui_remove.rb_sort_order,
|
||||
ui.batch_edit_captions.rb_sort_by,
|
||||
ui.batch_edit_captions.rb_sort_order,
|
||||
ui.batch_edit_captions.nb_token_count,
|
||||
]
|
||||
components_edit_selected = [
|
||||
ui.edit_caption_of_selected_image.cb_copy_caption_automatically, ui.edit_caption_of_selected_image.cb_sort_caption_on_save,
|
||||
ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed, ui.edit_caption_of_selected_image.dd_intterogator_names_si,
|
||||
ui.edit_caption_of_selected_image.rb_sort_by, ui.edit_caption_of_selected_image.rb_sort_order
|
||||
ui.edit_caption_of_selected_image.cb_copy_caption_automatically,
|
||||
ui.edit_caption_of_selected_image.cb_sort_caption_on_save,
|
||||
ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed,
|
||||
ui.edit_caption_of_selected_image.dd_intterogator_names_si,
|
||||
ui.edit_caption_of_selected_image.rb_sort_by,
|
||||
ui.edit_caption_of_selected_image.rb_sort_order,
|
||||
]
|
||||
components_move_delete = [
|
||||
ui.move_or_delete_files.rb_move_or_delete_target_data, ui.move_or_delete_files.cbg_move_or_delete_target_file,
|
||||
ui.move_or_delete_files.tb_move_or_delete_caption_ext, ui.move_or_delete_files.tb_move_or_delete_destination_dir
|
||||
ui.move_or_delete_files.rb_move_or_delete_target_data,
|
||||
ui.move_or_delete_files.cbg_move_or_delete_target_file,
|
||||
ui.move_or_delete_files.tb_move_or_delete_caption_ext,
|
||||
ui.move_or_delete_files.tb_move_or_delete_destination_dir,
|
||||
]
|
||||
|
||||
configurable_components = components_general + components_filter + components_batch_edit + components_edit_selected + components_move_delete
|
||||
|
||||
configurable_components = (
|
||||
components_general
|
||||
+ components_filter
|
||||
+ components_batch_edit
|
||||
+ components_edit_selected
|
||||
+ components_move_delete
|
||||
)
|
||||
|
||||
def reload_config_file():
|
||||
config.load()
|
||||
p, n = read_filter_config()
|
||||
print('[tag-editor] Reload config.json')
|
||||
return read_general_config() + p + n + read_batch_edit_config() + read_edit_selected_config() + read_move_delete_config()
|
||||
logger.write("Reload config.json")
|
||||
return (
|
||||
read_general_config()
|
||||
+ p
|
||||
+ n
|
||||
+ read_batch_edit_config()
|
||||
+ read_edit_selected_config()
|
||||
+ read_move_delete_config()
|
||||
)
|
||||
|
||||
btn_reload_config_file.click(
|
||||
fn=reload_config_file,
|
||||
outputs=configurable_components
|
||||
fn=reload_config_file, outputs=configurable_components
|
||||
)
|
||||
|
||||
def save_settings_callback(*a):
|
||||
p = 0
|
||||
|
||||
def inc(v):
|
||||
nonlocal p
|
||||
p += v
|
||||
return p
|
||||
write_general_config(*a[p:inc(len(components_general))])
|
||||
write_filter_config(*a[p:inc(len(components_filter))])
|
||||
write_batch_edit_config(*a[p:inc(len(components_batch_edit))])
|
||||
write_edit_selected_config(*a[p:inc(len(components_edit_selected))])
|
||||
|
||||
write_general_config(*a[p : inc(len(components_general))])
|
||||
write_filter_config(*a[p : inc(len(components_filter))])
|
||||
write_batch_edit_config(*a[p : inc(len(components_batch_edit))])
|
||||
write_edit_selected_config(*a[p : inc(len(components_edit_selected))])
|
||||
write_move_delete_config(*a[p:])
|
||||
config.save()
|
||||
print('[tag-editor] Current settings have been saved into config.json')
|
||||
logger.write("Current settings have been saved into config.json")
|
||||
|
||||
btn_save_setting_as_default.click(
|
||||
fn=save_settings_callback,
|
||||
inputs=configurable_components
|
||||
fn=save_settings_callback, inputs=configurable_components
|
||||
)
|
||||
|
||||
def restore_default_settings():
|
||||
|
|
@ -304,44 +327,150 @@ def on_ui_tabs():
|
|||
write_batch_edit_config(*CFG_BATCH_EDIT_DEFAULT)
|
||||
write_edit_selected_config(*CFG_EDIT_SELECTED_DEFAULT)
|
||||
write_move_delete_config(*CFG_MOVE_DELETE_DEFAULT)
|
||||
print('[tag-editor] Restore default settings')
|
||||
return CFG_GENERAL_DEFAULT + CFG_FILTER_P_DEFAULT + CFG_FILTER_N_DEFAULT + CFG_BATCH_EDIT_DEFAULT + CFG_EDIT_SELECTED_DEFAULT + CFG_MOVE_DELETE_DEFAULT
|
||||
|
||||
logger.write("Restore default settings")
|
||||
return (
|
||||
CFG_GENERAL_DEFAULT
|
||||
+ CFG_FILTER_P_DEFAULT
|
||||
+ CFG_FILTER_N_DEFAULT
|
||||
+ CFG_BATCH_EDIT_DEFAULT
|
||||
+ CFG_EDIT_SELECTED_DEFAULT
|
||||
+ CFG_MOVE_DELETE_DEFAULT
|
||||
)
|
||||
|
||||
btn_restore_default.click(
|
||||
fn=restore_default_settings,
|
||||
outputs=configurable_components
|
||||
fn=restore_default_settings, outputs=configurable_components
|
||||
)
|
||||
|
||||
o_update_gallery = [ui.dataset_gallery.cbg_hidden_dataset_filter, ui.dataset_gallery.nb_hidden_dataset_filter_apply, ui.dataset_gallery.nb_hidden_image_index, ui.dataset_gallery.nb_hidden_image_index_prev, ui.edit_caption_of_selected_image.nb_hidden_image_index_save_or_not, ui.gallery_state.txt_gallery]
|
||||
o_update_gallery = [
|
||||
ui.dataset_gallery.cbg_hidden_dataset_filter,
|
||||
ui.dataset_gallery.nb_hidden_dataset_filter_apply,
|
||||
ui.dataset_gallery.nb_hidden_image_index,
|
||||
ui.dataset_gallery.nb_hidden_image_index_prev,
|
||||
ui.edit_caption_of_selected_image.nb_hidden_image_index_save_or_not,
|
||||
ui.gallery_state.txt_gallery,
|
||||
]
|
||||
|
||||
o_update_filter_and_gallery = (
|
||||
[
|
||||
ui.filter_by_tags.tag_filter_ui.cbg_tags,
|
||||
ui.filter_by_tags.tag_filter_ui_neg.cbg_tags,
|
||||
]
|
||||
+ o_update_gallery
|
||||
+ [
|
||||
ui.batch_edit_captions.tb_common_tags,
|
||||
ui.batch_edit_captions.tb_edit_tags,
|
||||
]
|
||||
+ [ui.batch_edit_captions.tb_sr_selected_tags]
|
||||
+ [ui.batch_edit_captions.tag_select_ui_remove.cbg_tags]
|
||||
+ [
|
||||
ui.edit_caption_of_selected_image.tb_caption,
|
||||
ui.edit_caption_of_selected_image.tb_edit_caption,
|
||||
]
|
||||
)
|
||||
|
||||
o_update_filter_and_gallery = \
|
||||
[ui.filter_by_tags.tag_filter_ui.cbg_tags, ui.filter_by_tags.tag_filter_ui_neg.cbg_tags] + \
|
||||
o_update_gallery + \
|
||||
[ui.batch_edit_captions.tb_common_tags, ui.batch_edit_captions.tb_edit_tags] + \
|
||||
[ui.batch_edit_captions.tb_sr_selected_tags] +\
|
||||
[ui.batch_edit_captions.tag_select_ui_remove.cbg_tags] +\
|
||||
[ui.edit_caption_of_selected_image.tb_caption, ui.edit_caption_of_selected_image.tb_edit_caption]
|
||||
|
||||
ui.toprow.set_callbacks(ui.load_dataset)
|
||||
ui.load_dataset.set_callbacks(o_update_filter_and_gallery,ui.toprow, ui.dataset_gallery, ui.filter_by_tags, ui.filter_by_selection, ui.batch_edit_captions, update_filter_and_gallery)
|
||||
ui.load_dataset.set_callbacks(
|
||||
o_update_filter_and_gallery,
|
||||
ui.toprow,
|
||||
ui.dataset_gallery,
|
||||
ui.filter_by_tags,
|
||||
ui.filter_by_selection,
|
||||
ui.batch_edit_captions,
|
||||
update_filter_and_gallery,
|
||||
)
|
||||
ui.dataset_gallery.set_callbacks(ui.load_dataset, ui.gallery_state, get_filters)
|
||||
ui.gallery_state.set_callbacks(ui.dataset_gallery)
|
||||
ui.filter_by_tags.set_callbacks(o_update_gallery, o_update_filter_and_gallery, ui.batch_edit_captions, ui.move_or_delete_files, update_gallery, update_filter_and_gallery, get_filters)
|
||||
ui.filter_by_selection.set_callbacks(o_update_filter_and_gallery, ui.dataset_gallery, ui.filter_by_tags, get_filters, update_filter_and_gallery)
|
||||
ui.batch_edit_captions.set_callbacks(o_update_filter_and_gallery, ui.load_dataset, ui.filter_by_tags, get_filters, update_filter_and_gallery)
|
||||
ui.edit_caption_of_selected_image.set_callbacks(o_update_filter_and_gallery, ui.dataset_gallery, ui.load_dataset, get_filters, update_filter_and_gallery)
|
||||
ui.move_or_delete_files.set_callbacks(o_update_filter_and_gallery, ui.dataset_gallery, ui.filter_by_tags, ui.batch_edit_captions, ui.filter_by_selection, ui.edit_caption_of_selected_image, get_filters, update_filter_and_gallery)
|
||||
|
||||
return [(dataset_tag_editor_interface, "Dataset Tag Editor", "dataset_tag_editor_interface")]
|
||||
ui.filter_by_tags.set_callbacks(
|
||||
o_update_gallery,
|
||||
o_update_filter_and_gallery,
|
||||
ui.batch_edit_captions,
|
||||
ui.move_or_delete_files,
|
||||
update_gallery,
|
||||
update_filter_and_gallery,
|
||||
get_filters,
|
||||
)
|
||||
ui.filter_by_selection.set_callbacks(
|
||||
o_update_filter_and_gallery,
|
||||
ui.dataset_gallery,
|
||||
ui.filter_by_tags,
|
||||
get_filters,
|
||||
update_filter_and_gallery,
|
||||
)
|
||||
ui.batch_edit_captions.set_callbacks(
|
||||
o_update_filter_and_gallery,
|
||||
ui.load_dataset,
|
||||
ui.filter_by_tags,
|
||||
get_filters,
|
||||
update_filter_and_gallery,
|
||||
)
|
||||
ui.edit_caption_of_selected_image.set_callbacks(
|
||||
o_update_filter_and_gallery,
|
||||
ui.dataset_gallery,
|
||||
ui.load_dataset,
|
||||
get_filters,
|
||||
update_filter_and_gallery,
|
||||
)
|
||||
ui.move_or_delete_files.set_callbacks(
|
||||
o_update_filter_and_gallery,
|
||||
ui.dataset_gallery,
|
||||
ui.filter_by_tags,
|
||||
ui.batch_edit_captions,
|
||||
ui.filter_by_selection,
|
||||
ui.edit_caption_of_selected_image,
|
||||
get_filters,
|
||||
update_filter_and_gallery,
|
||||
)
|
||||
|
||||
return [
|
||||
(
|
||||
dataset_tag_editor_interface,
|
||||
"Dataset Tag Editor",
|
||||
"dataset_tag_editor_interface",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
section = ('dataset-tag-editor', "Dataset Tag Editor")
|
||||
shared.opts.add_option("dataset_editor_image_columns", shared.OptionInfo(6, "Number of columns on image gallery", section=section))
|
||||
shared.opts.add_option("dataset_editor_max_res", shared.OptionInfo(0, "Max resolution of temporary files", section=section))
|
||||
shared.opts.add_option("dataset_editor_use_temp_files", shared.OptionInfo(False, "Force image gallery to use temporary files", section=section))
|
||||
shared.opts.add_option("dataset_editor_use_raw_clip_token", shared.OptionInfo(True, "Use raw CLIP token to calculate token count (without emphasis or embeddings)", section=section))
|
||||
section = ("dataset-tag-editor", "Dataset Tag Editor")
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_image_columns",
|
||||
shared.OptionInfo(6, "Number of columns on image gallery", section=section),
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_max_res",
|
||||
shared.OptionInfo(0, "Max resolution of temporary files", section=section),
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_use_temp_files",
|
||||
shared.OptionInfo(
|
||||
False, "Force image gallery to use temporary files", section=section
|
||||
),
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_use_raw_clip_token",
|
||||
shared.OptionInfo(
|
||||
True,
|
||||
"Use raw CLIP token to calculate token count (without emphasis or embeddings)",
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_use_rating",
|
||||
shared.OptionInfo(
|
||||
False,
|
||||
"Use rating tags",
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_num_cpu_workers",
|
||||
shared.OptionInfo(
|
||||
-1,
|
||||
"Number of CPU workers when preprocessing images (set -1 to auto)",
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
from pathlib import Path
|
||||
|
||||
from torch.hub import download_url_to_file
|
||||
|
||||
def load(model_path:Path, model_url:str, progress:bool=True, force_download:bool=False):
|
||||
model_path = Path(model_path)
|
||||
if model_path.exists():
|
||||
return model_path
|
||||
|
||||
if model_url is not None and (force_download or not model_path.is_file()):
|
||||
if not model_path.parent.is_dir():
|
||||
model_path.parent.mkdir(parents=True)
|
||||
download_url_to_file(model_url, model_path, progress=progress)
|
||||
return model_path
|
||||
|
||||
return model_path
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from pathlib import Path
|
||||
|
||||
from scripts.singleton import Singleton
|
||||
|
||||
def base_dir_path():
|
||||
return Path(__file__).parents[1].absolute()
|
||||
|
||||
def base_dir():
|
||||
return str(base_dir_path())
|
||||
|
||||
class Paths(Singleton):
|
||||
def __init__(self):
|
||||
self.base_path:Path = base_dir_path()
|
||||
self.script_path: Path = self.base_path / "scripts"
|
||||
self.userscript_path: Path = self.base_path / "userscripts"
|
||||
|
||||
paths = Paths()
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
class Singleton(object):
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if not hasattr(cls, "_instance"):
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
_instance = None
|
||||
def __new__(class_, *args, **kwargs):
|
||||
if not isinstance(class_._instance, class_):
|
||||
class_._instance = object.__new__(class_, *args, **kwargs)
|
||||
return class_._instance
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class DatasetGalleryUI(UIBase):
|
|||
self.btn_hidden_set_index = gr.Button(elem_id="dataset_tag_editor_btn_hidden_set_index")
|
||||
self.nb_hidden_image_index = gr.Number(value=None, label='hidden_idx_next')
|
||||
self.nb_hidden_image_index_prev = gr.Number(value=None, label='hidden_idx_prev')
|
||||
self.gl_dataset_images = gr.Gallery(label='Dataset Images', elem_id="dataset_tag_editor_dataset_gallery").style(grid=image_columns)
|
||||
self.gl_dataset_images = gr.Gallery(label='Dataset Images', elem_id="dataset_tag_editor_dataset_gallery", columns=image_columns)
|
||||
|
||||
def set_callbacks(self, load_dataset:LoadDatasetUI, gallery_state:GalleryStateUI, get_filters:Callable[[], dte_module.filters.Filter]):
|
||||
gallery_state.register_value('Selected Image', self.selected_path)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class GalleryStateUI(UIBase):
|
|||
self.txt_gallery = gr.HTML(value=self.get_current_gallery_txt())
|
||||
|
||||
def set_callbacks(self, dataset_gallery:DatasetGalleryUI):
|
||||
dataset_gallery.nb_hidden_image_index.change(
|
||||
dataset_gallery.nb_hidden_image_index.change(fn=lambda:None).then(
|
||||
fn=self.update_gallery_txt,
|
||||
inputs=None,
|
||||
outputs=self.txt_gallery
|
||||
|
|
|
|||
|
|
@ -11,43 +11,113 @@ from .uibase import UIBase
|
|||
if TYPE_CHECKING:
|
||||
from .ui_classes import *
|
||||
|
||||
INTERROGATOR_NAMES = dte_module.INTERROGATOR_NAMES
|
||||
InterrogateMethod = dte_instance.InterrogateMethod
|
||||
|
||||
|
||||
class LoadDatasetUI(UIBase):
|
||||
def __init__(self):
|
||||
self.caption_file_ext = ''
|
||||
self.caption_file_ext = ""
|
||||
|
||||
def create_ui(self, cfg_general):
|
||||
with gr.Column(variant='panel'):
|
||||
with gr.Column(variant="panel"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
self.tb_img_directory = gr.Textbox(label='Dataset directory', placeholder='C:\\directory\\of\\datasets', value=cfg_general.dataset_dir)
|
||||
self.tb_img_directory = gr.Textbox(
|
||||
label="Dataset directory",
|
||||
placeholder="C:\\directory\\of\\datasets",
|
||||
value=cfg_general.dataset_dir,
|
||||
)
|
||||
with gr.Column(scale=1, min_width=60):
|
||||
self.tb_caption_file_ext = gr.Textbox(label='Caption File Ext', placeholder='.txt (on Load and Save)', value=cfg_general.caption_ext)
|
||||
self.tb_caption_file_ext = gr.Textbox(
|
||||
label="Caption File Ext",
|
||||
placeholder=".txt (on Load and Save)",
|
||||
value=cfg_general.caption_ext,
|
||||
)
|
||||
self.caption_file_ext = cfg_general.caption_ext
|
||||
with gr.Column(scale=1, min_width=80):
|
||||
self.btn_load_datasets = gr.Button(value='Load')
|
||||
self.btn_unload_datasets = gr.Button(value='Unload')
|
||||
with gr.Accordion(label='Dataset Load Settings'):
|
||||
self.btn_load_datasets = gr.Button(value="Load")
|
||||
self.btn_unload_datasets = gr.Button(value="Unload")
|
||||
with gr.Accordion(label="Dataset Load Settings"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
self.cb_load_recursive = gr.Checkbox(value=cfg_general.load_recursive, label='Load from subdirectories')
|
||||
self.cb_load_caption_from_filename = gr.Checkbox(value=cfg_general.load_caption_from_filename, label='Load caption from filename if no text file exists')
|
||||
self.cb_replace_new_line_with_comma = gr.Checkbox(value=cfg_general.replace_new_line, label='Replace new-line character with comma')
|
||||
self.cb_load_recursive = gr.Checkbox(
|
||||
value=cfg_general.load_recursive,
|
||||
label="Load from subdirectories",
|
||||
)
|
||||
self.cb_load_caption_from_filename = gr.Checkbox(
|
||||
value=cfg_general.load_caption_from_filename,
|
||||
label="Load caption from filename if no text file exists",
|
||||
)
|
||||
self.cb_replace_new_line_with_comma = gr.Checkbox(
|
||||
value=cfg_general.replace_new_line,
|
||||
label="Replace new-line character with comma",
|
||||
)
|
||||
with gr.Column():
|
||||
self.rb_use_interrogator = gr.Radio(choices=['No', 'If Empty', 'Overwrite', 'Prepend', 'Append'], value=cfg_general.use_interrogator, label='Use Interrogator Caption')
|
||||
self.dd_intterogator_names = gr.Dropdown(label = 'Interrogators', choices=INTERROGATOR_NAMES, value=cfg_general.use_interrogator_names, interactive=True, multiselect=True)
|
||||
with gr.Accordion(label='Interrogator Settings', open=False):
|
||||
self.rb_use_interrogator = gr.Radio(
|
||||
choices=[
|
||||
"No",
|
||||
"If Empty",
|
||||
"Overwrite",
|
||||
"Prepend",
|
||||
"Append",
|
||||
],
|
||||
value=cfg_general.use_interrogator,
|
||||
label="Use Interrogator Caption",
|
||||
)
|
||||
self.dd_intterogator_names = gr.Dropdown(
|
||||
label="Interrogators",
|
||||
choices=dte_instance.INTERROGATOR_NAMES,
|
||||
value=cfg_general.use_interrogator_names,
|
||||
interactive=True,
|
||||
multiselect=True,
|
||||
)
|
||||
with gr.Accordion(label="Interrogator Settings", open=False):
|
||||
with gr.Row():
|
||||
self.cb_use_custom_threshold_booru = gr.Checkbox(value=cfg_general.use_custom_threshold_booru, label='Use Custom Threshold (Booru)', interactive=True)
|
||||
self.sl_custom_threshold_booru = gr.Slider(minimum=0, maximum=1, value=cfg_general.custom_threshold_booru, step=0.01, interactive=True, label='Booru Score Threshold')
|
||||
self.cb_use_custom_threshold_booru = gr.Checkbox(
|
||||
value=cfg_general.use_custom_threshold_booru,
|
||||
label="Use Custom Threshold (Booru)",
|
||||
interactive=True,
|
||||
)
|
||||
self.sl_custom_threshold_booru = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
value=cfg_general.custom_threshold_booru,
|
||||
step=0.01,
|
||||
interactive=True,
|
||||
label="Booru Score Threshold",
|
||||
)
|
||||
with gr.Row():
|
||||
self.cb_use_custom_threshold_waifu = gr.Checkbox(value=cfg_general.use_custom_threshold_waifu, label='Use Custom Threshold (WDv1.4 Tagger)', interactive=True)
|
||||
self.sl_custom_threshold_waifu = gr.Slider(minimum=0, maximum=1, value=cfg_general.custom_threshold_waifu, step=0.01, interactive=True, label='WDv1.4 Tagger Score Threshold')
|
||||
|
||||
def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], toprow:ToprowUI, dataset_gallery:DatasetGalleryUI, filter_by_tags:FilterByTagsUI, filter_by_selection:FilterBySelectionUI, batch_edit_captions:BatchEditCaptionsUI, update_filter_and_gallery:Callable[[], List]):
|
||||
self.sl_custom_threshold_z3d = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
value=cfg_general.custom_threshold_z3d,
|
||||
step=0.01,
|
||||
interactive=True,
|
||||
label="Z3D-E621 Score Threshold",
|
||||
)
|
||||
with gr.Row():
|
||||
self.cb_use_custom_threshold_waifu = gr.Checkbox(
|
||||
value=cfg_general.use_custom_threshold_waifu,
|
||||
label="Use Custom Threshold (WDv1.4 Tagger)",
|
||||
interactive=True,
|
||||
)
|
||||
self.sl_custom_threshold_waifu = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
value=cfg_general.custom_threshold_waifu,
|
||||
step=0.01,
|
||||
interactive=True,
|
||||
label="WDv1.4 Tagger Score Threshold",
|
||||
)
|
||||
|
||||
def set_callbacks(
|
||||
self,
|
||||
o_update_filter_and_gallery: List[gr.components.Component],
|
||||
toprow: ToprowUI,
|
||||
dataset_gallery: DatasetGalleryUI,
|
||||
filter_by_tags: FilterByTagsUI,
|
||||
filter_by_selection: FilterBySelectionUI,
|
||||
batch_edit_captions: BatchEditCaptionsUI,
|
||||
update_filter_and_gallery: Callable[[], List],
|
||||
):
|
||||
def load_files_from_dir(
|
||||
dir: str,
|
||||
caption_file_ext: str,
|
||||
|
|
@ -55,63 +125,112 @@ class LoadDatasetUI(UIBase):
|
|||
load_caption_from_filename: bool,
|
||||
replace_new_line: bool,
|
||||
use_interrogator: str,
|
||||
use_interrogator_names, #: List[str], : to avoid error on gradio v3.23.0
|
||||
use_interrogator_names, #: List[str], : to avoid error on gradio v3.23.0
|
||||
use_custom_threshold_booru: bool,
|
||||
custom_threshold_booru: float,
|
||||
use_custom_threshold_waifu: bool,
|
||||
custom_threshold_waifu: float,
|
||||
custom_threshold_z3d: float,
|
||||
use_kohya_metadata: bool,
|
||||
kohya_json_path: str
|
||||
):
|
||||
|
||||
interrogate_method = InterrogateMethod.NONE
|
||||
if use_interrogator == 'If Empty':
|
||||
interrogate_method = InterrogateMethod.PREFILL
|
||||
elif use_interrogator == 'Overwrite':
|
||||
interrogate_method = InterrogateMethod.OVERWRITE
|
||||
elif use_interrogator == 'Prepend':
|
||||
interrogate_method = InterrogateMethod.PREPEND
|
||||
elif use_interrogator == 'Append':
|
||||
interrogate_method = InterrogateMethod.APPEND
|
||||
kohya_json_path: str,
|
||||
):
|
||||
|
||||
threshold_booru = custom_threshold_booru if use_custom_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold
|
||||
threshold_waifu = custom_threshold_waifu if use_custom_threshold_waifu else -1
|
||||
interrogate_method = dte_instance.InterrogateMethod.NONE
|
||||
if use_interrogator == "If Empty":
|
||||
interrogate_method = dte_instance.InterrogateMethod.PREFILL
|
||||
elif use_interrogator == "Overwrite":
|
||||
interrogate_method = dte_instance.InterrogateMethod.OVERWRITE
|
||||
elif use_interrogator == "Prepend":
|
||||
interrogate_method = dte_instance.InterrogateMethod.PREPEND
|
||||
elif use_interrogator == "Append":
|
||||
interrogate_method = dte_instance.InterrogateMethod.APPEND
|
||||
|
||||
dte_instance.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, replace_new_line, interrogate_method, use_interrogator_names, threshold_booru, threshold_waifu, opts.dataset_editor_use_temp_files, kohya_json_path if use_kohya_metadata else None, opts.dataset_editor_max_res)
|
||||
threshold_booru = (
|
||||
custom_threshold_booru
|
||||
if use_custom_threshold_booru
|
||||
else opts.interrogate_deepbooru_score_threshold
|
||||
)
|
||||
threshold_waifu = (
|
||||
custom_threshold_waifu if use_custom_threshold_waifu else -1
|
||||
)
|
||||
threshold_z3d = custom_threshold_z3d
|
||||
|
||||
dte_instance.load_dataset(
|
||||
dir,
|
||||
caption_file_ext,
|
||||
recursive,
|
||||
load_caption_from_filename,
|
||||
replace_new_line,
|
||||
interrogate_method,
|
||||
use_interrogator_names,
|
||||
threshold_booru,
|
||||
threshold_waifu,
|
||||
threshold_z3d,
|
||||
opts.dataset_editor_use_temp_files,
|
||||
kohya_json_path if use_kohya_metadata else None,
|
||||
opts.dataset_editor_max_res,
|
||||
)
|
||||
imgs = dte_instance.get_filtered_imgs(filters=[])
|
||||
img_indices = dte_instance.get_filtered_imgindices(filters=[])
|
||||
return [
|
||||
imgs,
|
||||
[]
|
||||
] +\
|
||||
[gr.CheckboxGroup.update(value=[str(i) for i in img_indices], choices=[str(i) for i in img_indices]), 1] +\
|
||||
filter_by_tags.clear_filters(update_filter_and_gallery) +\
|
||||
[batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
|
||||
|
||||
return (
|
||||
[imgs, []]
|
||||
+ [
|
||||
gr.CheckboxGroup.update(
|
||||
value=[str(i) for i in img_indices],
|
||||
choices=[str(i) for i in img_indices],
|
||||
),
|
||||
1,
|
||||
]
|
||||
+ filter_by_tags.clear_filters(update_filter_and_gallery)
|
||||
+ [batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
|
||||
)
|
||||
|
||||
self.btn_load_datasets.click(
|
||||
fn=load_files_from_dir,
|
||||
inputs=[self.tb_img_directory, self.tb_caption_file_ext, self.cb_load_recursive, self.cb_load_caption_from_filename, self.cb_replace_new_line_with_comma, self.rb_use_interrogator, self.dd_intterogator_names, self.cb_use_custom_threshold_booru, self.sl_custom_threshold_booru, self.cb_use_custom_threshold_waifu, self.sl_custom_threshold_waifu, toprow.cb_save_kohya_metadata, toprow.tb_metadata_output],
|
||||
outputs=
|
||||
[dataset_gallery.gl_dataset_images, filter_by_selection.gl_filter_images] +
|
||||
[dataset_gallery.cbg_hidden_dataset_filter, dataset_gallery.nb_hidden_dataset_filter_apply] +
|
||||
o_update_filter_and_gallery
|
||||
inputs=[
|
||||
self.tb_img_directory,
|
||||
self.tb_caption_file_ext,
|
||||
self.cb_load_recursive,
|
||||
self.cb_load_caption_from_filename,
|
||||
self.cb_replace_new_line_with_comma,
|
||||
self.rb_use_interrogator,
|
||||
self.dd_intterogator_names,
|
||||
self.cb_use_custom_threshold_booru,
|
||||
self.sl_custom_threshold_booru,
|
||||
self.cb_use_custom_threshold_waifu,
|
||||
self.sl_custom_threshold_waifu,
|
||||
toprow.cb_save_kohya_metadata,
|
||||
toprow.tb_metadata_output,
|
||||
],
|
||||
outputs=[
|
||||
dataset_gallery.gl_dataset_images,
|
||||
filter_by_selection.gl_filter_images,
|
||||
]
|
||||
+ [
|
||||
dataset_gallery.cbg_hidden_dataset_filter,
|
||||
dataset_gallery.nb_hidden_dataset_filter_apply,
|
||||
]
|
||||
+ o_update_filter_and_gallery,
|
||||
)
|
||||
|
||||
def unload_files():
|
||||
dte_instance.clear()
|
||||
return [
|
||||
[],
|
||||
[]
|
||||
] +\
|
||||
[gr.CheckboxGroup.update(value=[], choices=[]), 1] +\
|
||||
filter_by_tags.clear_filters(update_filter_and_gallery) +\
|
||||
[batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
|
||||
return (
|
||||
[[], []]
|
||||
+ [gr.CheckboxGroup.update(value=[], choices=[]), 1]
|
||||
+ filter_by_tags.clear_filters(update_filter_and_gallery)
|
||||
+ [batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
|
||||
)
|
||||
|
||||
self.btn_unload_datasets.click(
|
||||
fn=unload_files,
|
||||
outputs=
|
||||
[dataset_gallery.gl_dataset_images, filter_by_selection.gl_filter_images] +
|
||||
[dataset_gallery.cbg_hidden_dataset_filter, dataset_gallery.nb_hidden_dataset_filter_apply] +
|
||||
o_update_filter_and_gallery
|
||||
outputs=[
|
||||
dataset_gallery.gl_dataset_images,
|
||||
filter_by_selection.gl_filter_images,
|
||||
]
|
||||
+ [
|
||||
dataset_gallery.cbg_hidden_dataset_filter,
|
||||
dataset_gallery.nb_hidden_dataset_filter_apply,
|
||||
]
|
||||
+ o_update_filter_and_gallery,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -66,10 +66,10 @@ class TagFilterUI():
|
|||
|
||||
self.rb_logic.change(fn=self.rd_logic_changed, inputs=[self.rb_logic], outputs=[self.cbg_tags])
|
||||
for fn, inputs, outputs, _js in self.on_filter_update_callbacks:
|
||||
self.rb_logic.change(fn=fn, inputs=inputs, outputs=outputs, _js=_js)
|
||||
self.rb_logic.change(fn=lambda:None).then(fn=fn, inputs=inputs, outputs=outputs, _js=_js)
|
||||
self.cbg_tags.change(fn=self.cbg_tags_changed, inputs=[self.cbg_tags], outputs=[self.cbg_tags])
|
||||
for fn, inputs, outputs, _js in self.on_filter_update_callbacks:
|
||||
self.cbg_tags.change(fn=fn, inputs=inputs, outputs=outputs, _js=_js)
|
||||
self.cbg_tags.change(fn=lambda:None).then(fn=fn, inputs=inputs, outputs=outputs, _js=_js)
|
||||
|
||||
|
||||
def tb_search_tags_changed(self, tb_search_tags: str):
|
||||
|
|
|
|||
|
|
@ -95,8 +95,7 @@ class BatchEditCaptionsUI(UIBase):
|
|||
fn=apply_edit_tags,
|
||||
inputs=[self.tb_common_tags, self.tb_edit_tags, self.cb_prepend_tags],
|
||||
outputs=o_update_filter_and_gallery
|
||||
)
|
||||
self.btn_apply_edit_tags.click(
|
||||
).then(
|
||||
fn=None,
|
||||
_js='() => dataset_tag_editor_gl_dataset_images_close()'
|
||||
)
|
||||
|
|
@ -124,8 +123,7 @@ class BatchEditCaptionsUI(UIBase):
|
|||
fn=search_and_replace,
|
||||
inputs=[self.tb_sr_search_tags, self.tb_sr_replace_tags, self.rb_sr_replace_target, self.cb_use_regex],
|
||||
outputs=o_update_filter_and_gallery
|
||||
)
|
||||
self.btn_apply_sr_tags.click(
|
||||
).then(
|
||||
fn=None,
|
||||
_js='() => dataset_tag_editor_gl_dataset_images_close()'
|
||||
)
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class EditCaptionOfSelectedImageUI(UIBase):
|
|||
|
||||
with gr.Tab(label='Interrogate Selected Image'):
|
||||
with gr.Row():
|
||||
self.dd_intterogator_names_si = gr.Dropdown(label = 'Interrogator', choices=dte_module.INTERROGATOR_NAMES, value=cfg_edit_selected.use_interrogator_name, interactive=True, multiselect=False)
|
||||
self.dd_intterogator_names_si = gr.Dropdown(label = 'Interrogator', choices=dte_instance.INTERROGATOR_NAMES, value=cfg_edit_selected.use_interrogator_name, interactive=True, multiselect=False)
|
||||
self.btn_interrogate_si = gr.Button(value='Interrogate')
|
||||
with gr.Column():
|
||||
self.tb_interrogate = gr.Textbox(label='Interrogate Result', interactive=True, lines=6, elem_id='dte_interrogate')
|
||||
|
|
@ -89,7 +89,7 @@ class EditCaptionOfSelectedImageUI(UIBase):
|
|||
_js='(a) => dataset_tag_editor_ask_save_change_or_not(a)',
|
||||
inputs=self.nb_hidden_image_index_save_or_not
|
||||
)
|
||||
dataset_gallery.nb_hidden_image_index.change(
|
||||
dataset_gallery.nb_hidden_image_index.change(lambda:None).then(
|
||||
fn=gallery_index_changed,
|
||||
inputs=[dataset_gallery.nb_hidden_image_index, dataset_gallery.nb_hidden_image_index_prev, self.tb_edit_caption, self.cb_copy_caption_automatically, self.cb_ask_save_when_caption_changed],
|
||||
outputs=[self.nb_hidden_image_index_save_or_not] + [self.tb_caption, self.tb_edit_caption] + [self.tb_hidden_edit_caption]
|
||||
|
|
@ -138,7 +138,7 @@ class EditCaptionOfSelectedImageUI(UIBase):
|
|||
return ''
|
||||
threshold_booru = threshold_booru if use_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold
|
||||
threshold_waifu = threshold_waifu if use_threshold_waifu else -1
|
||||
return dte_module.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_waifu)
|
||||
return dte_instance.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_waifu)
|
||||
|
||||
self.btn_interrogate_si.click(
|
||||
fn=interrogate_selected_image,
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class FilterBySelectionUI(UIBase):
|
|||
self.btn_add_image_selection = gr.Button(value='Add selection [Enter]', elem_id='dataset_tag_editor_btn_add_image_selection')
|
||||
self.btn_add_all_displayed_image_selection = gr.Button(value='Add ALL Displayed')
|
||||
|
||||
self.gl_filter_images = gr.Gallery(label='Filter Images', elem_id="dataset_tag_editor_filter_gallery").style(grid=image_columns)
|
||||
self.gl_filter_images = gr.Gallery(label='Filter Images', elem_id="dataset_tag_editor_filter_gallery", columns=image_columns)
|
||||
self.txt_selection = gr.HTML(value=self.get_current_txt_selection())
|
||||
|
||||
with gr.Row():
|
||||
|
|
@ -130,7 +130,7 @@ class FilterBySelectionUI(UIBase):
|
|||
self.path_filter = filters.PathFilter()
|
||||
return clear_image_selection() + update_filter_and_gallery()
|
||||
|
||||
filter_by_tags.btn_clear_all_filters.click(
|
||||
filter_by_tags.btn_clear_all_filters.click(lambda:None).then(
|
||||
fn=clear_image_filter,
|
||||
outputs=
|
||||
[self.gl_filter_images, self.txt_selection, self.nb_hidden_selection_image_index] +
|
||||
|
|
@ -147,8 +147,7 @@ class FilterBySelectionUI(UIBase):
|
|||
self.btn_apply_image_selection_filter.click(
|
||||
fn=apply_image_selection_filter,
|
||||
outputs=o_update_filter_and_gallery
|
||||
)
|
||||
self.btn_apply_image_selection_filter.click(
|
||||
).then(
|
||||
fn=None,
|
||||
_js='() => dataset_tag_editor_gl_dataset_images_close()'
|
||||
)
|
||||
|
|
|
|||
|
|
@ -47,17 +47,17 @@ class MoveOrDeleteFilesUI(UIBase):
|
|||
'outputs' : [self.ta_move_or_delete_target_dataset_num]
|
||||
}
|
||||
|
||||
batch_edit_captions.btn_apply_edit_tags.click(**update_args)
|
||||
batch_edit_captions.btn_apply_edit_tags.click(lambda:None).then(**update_args)
|
||||
|
||||
batch_edit_captions.btn_apply_sr_tags.click(**update_args)
|
||||
batch_edit_captions.btn_apply_sr_tags.click(lambda:None).then(**update_args)
|
||||
|
||||
filter_by_selection.btn_apply_image_selection_filter.click(**update_args)
|
||||
filter_by_selection.btn_apply_image_selection_filter.click(lambda:None).then(**update_args)
|
||||
|
||||
filter_by_tags.btn_clear_tag_filters.click(**update_args)
|
||||
filter_by_tags.btn_clear_tag_filters.click(lambda:None).then(**update_args)
|
||||
|
||||
filter_by_tags.btn_clear_all_filters.click(**update_args)
|
||||
filter_by_tags.btn_clear_all_filters.click(lambda:None).then(**update_args)
|
||||
|
||||
edit_caption_of_selected_image.btn_apply_changes_selected_image.click(**update_args)
|
||||
edit_caption_of_selected_image.btn_apply_changes_selected_image.click(lambda:None).then(**update_args)
|
||||
|
||||
self.rb_move_or_delete_target_data.change(**update_args)
|
||||
|
||||
|
|
@ -84,9 +84,7 @@ class MoveOrDeleteFilesUI(UIBase):
|
|||
fn=move_files,
|
||||
inputs=[self.rb_move_or_delete_target_data, self.cbg_move_or_delete_target_file, self.tb_move_or_delete_caption_ext, self.tb_move_or_delete_destination_dir],
|
||||
outputs=o_update_filter_and_gallery
|
||||
)
|
||||
self.btn_move_or_delete_move_files.click(**update_args)
|
||||
self.btn_move_or_delete_move_files.click(
|
||||
).then(**update_args).then(
|
||||
fn=None,
|
||||
_js='() => dataset_tag_editor_gl_dataset_images_close()'
|
||||
)
|
||||
|
|
@ -114,8 +112,7 @@ class MoveOrDeleteFilesUI(UIBase):
|
|||
inputs=[self.rb_move_or_delete_target_data, self.cbg_move_or_delete_target_file, self.tb_move_or_delete_caption_ext],
|
||||
outputs=o_update_filter_and_gallery
|
||||
)
|
||||
self.btn_move_or_delete_delete_files.click(**update_args)
|
||||
self.btn_move_or_delete_delete_files.click(
|
||||
self.btn_move_or_delete_delete_files.click(**update_args).then(
|
||||
fn=None,
|
||||
_js='() => dataset_tag_editor_gl_dataset_images_close()'
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ __all__ = [
|
|||
'toprow', 'load_dataset', 'dataset_gallery', 'gallery_state', 'filter_by_tags', 'filter_by_selection', 'batch_edit_captions', 'edit_caption_of_selected_image', 'move_or_delete_files'
|
||||
]
|
||||
|
||||
toprow = ToprowUI.get_instance()
|
||||
load_dataset = LoadDatasetUI.get_instance()
|
||||
dataset_gallery = DatasetGalleryUI.get_instance()
|
||||
gallery_state = GalleryStateUI.get_instance()
|
||||
filter_by_tags = FilterByTagsUI.get_instance()
|
||||
filter_by_selection = FilterBySelectionUI.get_instance()
|
||||
batch_edit_captions = BatchEditCaptionsUI.get_instance()
|
||||
edit_caption_of_selected_image = EditCaptionOfSelectedImageUI.get_instance()
|
||||
move_or_delete_files = MoveOrDeleteFilesUI.get_instance()
|
||||
toprow = ToprowUI()
|
||||
load_dataset = LoadDatasetUI()
|
||||
dataset_gallery = DatasetGalleryUI()
|
||||
gallery_state = GalleryStateUI()
|
||||
filter_by_tags = FilterByTagsUI()
|
||||
filter_by_selection = FilterBySelectionUI()
|
||||
batch_edit_captions = BatchEditCaptionsUI()
|
||||
edit_caption_of_selected_image = EditCaptionOfSelectedImageUI()
|
||||
move_or_delete_files = MoveOrDeleteFilesUI()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,52 @@
|
|||
import re
|
||||
from typing import Optional, Generator, Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from modules import shared, lowvram, devices
|
||||
from modules import deepbooru as db
|
||||
|
||||
# Custom tagger classes have to inherit from this class
|
||||
class Tagger:
|
||||
def __enter__(self):
|
||||
lowvram.send_everything_to_cpu()
|
||||
devices.torch_gc()
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
self.stop()
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
# predict tags of one image
|
||||
def predict(self, image: Image.Image, threshold: Optional[float] = None) -> list[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Please implement if you want to use more efficient data loading system
|
||||
# None input will come to check if this function is implemented
|
||||
def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None) -> Generator[list[str], Any, None]:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Visible name in UI
|
||||
def name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_replaced_tag(tag: str):
|
||||
use_spaces = shared.opts.deepbooru_use_spaces
|
||||
use_escape = shared.opts.deepbooru_escape
|
||||
if use_spaces:
|
||||
tag = tag.replace('_', ' ')
|
||||
if use_escape:
|
||||
tag = re.sub(db.re_special, r'\\\1', tag)
|
||||
return tag
|
||||
|
||||
|
||||
def get_arranged_tags(probs: dict[str, float]):
|
||||
return [tag for tag, _ in sorted(probs.items(), key=lambda x: -x[1])]
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
import math
|
||||
|
||||
from PIL import Image
|
||||
from transformers import pipeline
|
||||
import torch
|
||||
|
||||
from modules import devices, shared
|
||||
from scripts.tagger import Tagger
|
||||
|
||||
# brought and modified from https://huggingface.co/spaces/cafeai/cafe_aesthetic_demo/blob/main/app.py
|
||||
|
||||
# I'm not sure if this is really working
|
||||
BATCH_SIZE = 3
|
||||
|
||||
class AestheticShadowV2(Tagger):
|
||||
def load(self):
|
||||
if devices.device.index is None:
|
||||
dev = torch.device(devices.device.type, 0)
|
||||
else:
|
||||
dev = devices.device
|
||||
self.pipe_aesthetic = pipeline("image-classification", "shadowlilac/aesthetic-shadow-v2", device=dev, batch_size=BATCH_SIZE)
|
||||
|
||||
def unload(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.pipe_aesthetic = None
|
||||
devices.torch_gc()
|
||||
|
||||
def start(self):
|
||||
self.load()
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
self.unload()
|
||||
|
||||
def _get_score(self, data):
|
||||
final = {}
|
||||
for d in data:
|
||||
final[d["label"]] = d["score"]
|
||||
hq = final['hq']
|
||||
lq = final['lq']
|
||||
return [f"score_{math.floor((hq + (1 - lq))/2 * 10)}"]
|
||||
|
||||
def predict(self, image: Image.Image, threshold=None):
|
||||
data = self.pipe_aesthetic(image)
|
||||
return self._get_score(data)
|
||||
|
||||
def predict_pipe(self, data: list[Image.Image], threshold=None):
|
||||
if data is None:
|
||||
return
|
||||
for out in self.pipe_aesthetic(data, batch_size=BATCH_SIZE):
|
||||
yield self._get_score(out)
|
||||
|
||||
def name(self):
|
||||
return "aesthetic shadow"
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
import math
|
||||
|
||||
from PIL import Image
|
||||
from transformers import pipeline
|
||||
import torch
|
||||
|
||||
from modules import devices, shared
|
||||
from scripts.tagger import Tagger
|
||||
|
||||
# brought and modified from https://huggingface.co/spaces/cafeai/cafe_aesthetic_demo/blob/main/app.py
|
||||
|
||||
# I'm not sure if this is really working
|
||||
BATCH_SIZE = 8
|
||||
|
||||
class CafeAIAesthetic(Tagger):
|
||||
def load(self):
|
||||
if devices.device.index is None:
|
||||
dev = torch.device(devices.device.type, 0)
|
||||
else:
|
||||
dev = devices.device
|
||||
self.pipe_aesthetic = pipeline("image-classification", "cafeai/cafe_aesthetic", device=dev, batch_size=BATCH_SIZE)
|
||||
|
||||
def unload(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.pipe_aesthetic = None
|
||||
devices.torch_gc()
|
||||
|
||||
def start(self):
|
||||
self.load()
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
self.unload()
|
||||
|
||||
def _get_score(self, data):
|
||||
final = {}
|
||||
for d in data:
|
||||
final[d["label"]] = d["score"]
|
||||
nae = final['not_aesthetic']
|
||||
ae = final['aesthetic']
|
||||
return [f"score_{math.floor((ae + (1 - nae))/2 * 10)}"]
|
||||
|
||||
def predict(self, image: Image.Image, threshold=None):
|
||||
data = self.pipe_aesthetic(image, top_k=2)
|
||||
return self._get_score(data)
|
||||
|
||||
def predict_pipe(self, data: list[Image.Image], threshold=None):
|
||||
if data is None:
|
||||
return
|
||||
for out in self.pipe_aesthetic(data, batch_size=BATCH_SIZE):
|
||||
yield self._get_score(out)
|
||||
|
||||
def name(self):
|
||||
return "cafeai aesthetic classifier"
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
from PIL import Image
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from modules import devices, shared
|
||||
from scripts import model_loader
|
||||
from scripts.paths import paths
|
||||
from scripts.tagger import Tagger
|
||||
|
||||
# brought from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py and modified
|
||||
class Classifier(nn.Module):
|
||||
def __init__(self, input_size):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(self.input_size, 1024),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(1024, 128),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(128, 64),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(64, 16),
|
||||
nn.Linear(16, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
# brought and modified from https://github.com/waifu-diffusion/aesthetic/blob/main/aesthetic.py
|
||||
def image_embeddings(image:Image, model:CLIPModel, processor:CLIPProcessor):
|
||||
inputs = processor(images=image, return_tensors='pt')['pixel_values']
|
||||
inputs = inputs.to(devices.device)
|
||||
result:np.ndarray = model.get_image_features(pixel_values=inputs).cpu().detach().numpy()
|
||||
return (result / np.linalg.norm(result)).squeeze(axis=0)
|
||||
|
||||
|
||||
class ImprovedAestheticPredictor(Tagger):
|
||||
def load(self):
|
||||
MODEL_VERSION = "sac+logos+ava1-l14-linearMSE"
|
||||
file = model_loader.load(
|
||||
model_path=paths.models_path / "aesthetic" / f"{MODEL_VERSION}.pth",
|
||||
model_url=f'https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/{MODEL_VERSION}.pth'
|
||||
)
|
||||
CLIP_REPOS = 'openai/clip-vit-large-patch14'
|
||||
self.model = Classifier(768)
|
||||
self.model.load_state_dict(torch.load(file))
|
||||
self.model = self.model.to(devices.device)
|
||||
self.clip_processor = CLIPProcessor.from_pretrained(CLIP_REPOS)
|
||||
self.clip_model = CLIPModel.from_pretrained(CLIP_REPOS).to(devices.device).eval()
|
||||
|
||||
def unload(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.model = None
|
||||
self.clip_processor = None
|
||||
self.clip_model = None
|
||||
devices.torch_gc()
|
||||
|
||||
def start(self):
|
||||
self.load()
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
self.unload()
|
||||
|
||||
def predict(self, image: Image.Image, threshold=None):
|
||||
image_embeds = image_embeddings(image, self.clip_model, self.clip_processor)
|
||||
prediction:torch.Tensor = self.model(torch.from_numpy(image_embeds).float().to(devices.device))
|
||||
return [f"score_{math.floor(prediction.item())}"]
|
||||
|
||||
def name(self):
|
||||
return "Improved Aesthetic Predictor"
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
from PIL import Image
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from modules import devices, shared
|
||||
from scripts import model_loader
|
||||
from scripts.paths import paths
|
||||
from scripts.tagger import Tagger
|
||||
|
||||
# brought from https://github.com/waifu-diffusion/aesthetic/blob/main/aesthetic.py
|
||||
class Classifier(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, output_size):
|
||||
super(Classifier, self).__init__()
|
||||
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc2 = torch.nn.Linear(hidden_size, hidden_size//2)
|
||||
self.fc3 = torch.nn.Linear(hidden_size//2, output_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x:torch.Tensor):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc3(x)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
||||
# brought and modified from https://github.com/waifu-diffusion/aesthetic/blob/main/aesthetic.py
|
||||
def image_embeddings(image:Image, model:CLIPModel, processor:CLIPProcessor):
|
||||
inputs = processor(images=image, return_tensors='pt')['pixel_values']
|
||||
inputs = inputs.to(devices.device)
|
||||
result:np.ndarray = model.get_image_features(pixel_values=inputs).cpu().detach().numpy()
|
||||
return (result / np.linalg.norm(result)).squeeze(axis=0)
|
||||
|
||||
|
||||
class WaifuAesthetic(Tagger):
|
||||
def load(self):
|
||||
file = model_loader.load(
|
||||
model_path=paths.models_path / "aesthetic" / "aes-B32-v0.pth",
|
||||
model_url='https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/models/aes-B32-v0.pth'
|
||||
)
|
||||
CLIP_REPOS = 'openai/clip-vit-base-patch32'
|
||||
self.model = Classifier(512, 256, 1)
|
||||
self.model.load_state_dict(torch.load(file))
|
||||
self.model = self.model.to(devices.device)
|
||||
self.clip_processor = CLIPProcessor.from_pretrained(CLIP_REPOS)
|
||||
self.clip_model = CLIPModel.from_pretrained(CLIP_REPOS).to(devices.device).eval()
|
||||
|
||||
def unload(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.model = None
|
||||
self.clip_processor = None
|
||||
self.clip_model = None
|
||||
devices.torch_gc()
|
||||
|
||||
def start(self):
|
||||
self.load()
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
self.unload()
|
||||
|
||||
def predict(self, image: Image.Image, threshold=None):
|
||||
image_embeds = image_embeddings(image, self.clip_model, self.clip_processor)
|
||||
prediction:torch.Tensor = self.model(torch.from_numpy(image_embeds).float().to(devices.device))
|
||||
return [f"score_{math.floor(prediction.item()*10)}"]
|
||||
|
||||
def name(self):
|
||||
return "wd aesthetic classifier"
|
||||
Loading…
Reference in New Issue