feat: support cl tagger

pull/720/head
Akegarasu 2025-09-01 14:44:24 +08:00
parent 9a09518786
commit 3c7802fc05
No known key found for this signature in database
GPG Key ID: DACA951FEBA569A2
6 changed files with 538 additions and 185 deletions

View File

@ -207,6 +207,9 @@ async def run_interrogate(req: TaggerInterrogateRequest, background_tasks: Backg
batch_output_save_json=False,
interrogator=interrogator,
threshold=req.threshold,
character_threshold=req.character_threshold,
add_rating_tag=req.add_rating_tag,
add_model_tag=req.add_model_tag,
additional_tags=req.additional_tags,
exclude_tags=req.exclude_tags,
sort_by_alphabetical_order=False,

View File

@ -12,6 +12,13 @@ class TaggerInterrogateRequest(BaseModel):
ge=0,
le=1
)
character_threshold: float = Field(
default=0.6,
ge=0,
le=1
)
add_rating_tag: bool = False
add_model_tag: bool = False
additional_tags: str = ""
exclude_tags: str = ""
escape_tag: bool = True

View File

@ -14,193 +14,13 @@ from PIL import UnidentifiedImageError
from huggingface_hub import hf_hub_download
from mikazuki.tagger import dbimutils, format
from mikazuki.tagger.interrogators.base import Interrogator
from mikazuki.tagger.interrogators.wd14 import WaifuDiffusionInterrogator
from mikazuki.tagger.interrogators.cl import CLTaggerInterrogator
tag_escape_pattern = re.compile(r'([\\()])')
class Interrogator:
@staticmethod
def postprocess_tags(
tags: Dict[str, float],
threshold=0.35,
additional_tags: List[str] = [],
exclude_tags: List[str] = [],
sort_by_alphabetical_order=False,
add_confident_as_weight=False,
replace_underscore=False,
replace_underscore_excludes: List[str] = [],
escape_tag=False
) -> Dict[str, float]:
for t in additional_tags:
tags[t] = 1.0
# those lines are totally not "pythonic" but looks better to me
tags = {
t: c
# sort by tag name or confident
for t, c in sorted(
tags.items(),
key=lambda i: i[0 if sort_by_alphabetical_order else 1],
reverse=not sort_by_alphabetical_order
)
# filter tags
if (
c >= threshold
and t not in exclude_tags
)
}
new_tags = []
for tag in list(tags):
new_tag = tag
if replace_underscore and tag not in replace_underscore_excludes:
new_tag = new_tag.replace('_', ' ')
if escape_tag:
new_tag = tag_escape_pattern.sub(r'\\\1', new_tag)
if add_confident_as_weight:
new_tag = f'({new_tag}:{tags[tag]})'
new_tags.append((new_tag, tags[tag]))
tags = dict(new_tags)
return tags
def __init__(self, name: str) -> None:
self.name = name
def load(self):
raise NotImplementedError()
def unload(self) -> bool:
unloaded = False
if hasattr(self, 'model') and self.model is not None:
del self.model
unloaded = True
print(f'Unloaded {self.name}')
if hasattr(self, 'tags'):
del self.tags
return unloaded
def interrogate(
self,
image: Image
) -> Tuple[
Dict[str, float], # rating confidents
Dict[str, float] # tag confidents
]:
raise NotImplementedError()
class WaifuDiffusionInterrogator(Interrogator):
def __init__(
self,
name: str,
model_path='model.onnx',
tags_path='selected_tags.csv',
**kwargs
) -> None:
super().__init__(name)
self.model_path = model_path
self.tags_path = tags_path
self.kwargs = kwargs
def download(self) -> Tuple[os.PathLike, os.PathLike]:
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")
model_path = Path(hf_hub_download(
**self.kwargs, filename=self.model_path))
tags_path = Path(hf_hub_download(
**self.kwargs, filename=self.tags_path))
return model_path, tags_path
def load(self) -> None:
model_path, tags_path = self.download()
# only one of these packages should be installed at a time in any one environment
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
# TODO: remove old package when the environment changes?
# from mikazuki.launch_utils import is_installed, run_pip
# if not is_installed('onnxruntime'):
# package = os.environ.get(
# 'ONNXRUNTIME_PACKAGE',
# 'onnxruntime-gpu'
# )
# run_pip(f'install {package}', 'onnxruntime')
# Load torch to load cuda libs built in torch for onnxruntime, do not delete this.
import torch
from onnxruntime import InferenceSession
# https://onnxruntime.ai/docs/execution-providers/
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self.model = InferenceSession(str(model_path), providers=providers)
print(f'Loaded {self.name} model from {model_path}')
self.tags = pd.read_csv(tags_path)
def interrogate(
self,
image: Image
) -> Tuple[
Dict[str, float], # rating confidents
Dict[str, float] # tag confidents
]:
# init model
if not hasattr(self, 'model') or self.model is None:
self.load()
# code for converting the image and running the model is taken from the link below
# thanks, SmilingWolf!
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py
# convert an image to fit the model
_, height, _, _ = self.model.get_inputs()[0].shape
# alpha to white
image = image.convert('RGBA')
new_image = Image.new('RGBA', image.size, 'WHITE')
new_image.paste(image, mask=image)
image = new_image.convert('RGB')
image = np.asarray(image)
# PIL RGB to OpenCV BGR
image = image[:, :, ::-1]
image = dbimutils.make_square(image, height)
image = dbimutils.smart_resize(image, height)
image = image.astype(np.float32)
image = np.expand_dims(image, 0)
# evaluate model
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
confidents = self.model.run([label_name], {input_name: image})[0]
tags = self.tags[:][['name']]
tags['confidents'] = confidents[0]
# first 4 items are for rating (general, sensitive, questionable, explicit)
ratings = dict(tags[:4].values)
# rest are regular tags
tags = dict(tags[4:].values)
return ratings, tags
available_interrogators = {
'wd-convnext-v3': WaifuDiffusionInterrogator(
'wd-convnext-v3',
@ -239,6 +59,12 @@ available_interrogators = {
'wd-vit-large-tagger-v3',
repo_id='SmilingWolf/wd-vit-large-tagger-v3',
),
'cl_tagger_1_01': CLTaggerInterrogator(
'cl_tagger_1_01',
repo_id='cella110n/cl_tagger',
model_path='cl_tagger_1_01/model.onnx',
tag_mapping_path='cl_tagger_1_01/tag_mapping.json',
),
}
@ -257,7 +83,13 @@ def on_interrogate(
batch_output_save_json: bool,
interrogator: Interrogator,
threshold: float,
character_threshold: float,
add_rating_tag: bool,
add_model_tag: bool,
additional_tags: str,
exclude_tags: str,
sort_by_alphabetical_order: bool,
@ -270,6 +102,9 @@ def on_interrogate(
):
postprocess_opts = (
threshold,
character_threshold,
add_rating_tag,
add_model_tag,
split_str(additional_tags),
split_str(exclude_tags),
sort_by_alphabetical_order,
@ -361,7 +196,7 @@ def on_interrogate(
print(f'skipping {path}')
continue
ratings, tags = interrogator.interrogate(image)
tags = interrogator.interrogate(image)
processed_tags = Interrogator.postprocess_tags(
tags,
*postprocess_opts
@ -398,7 +233,7 @@ def on_interrogate(
if batch_output_save_json:
output_path.with_suffix('.json').write_text(
json.dumps([ratings, tags])
json.dumps(tags)
)
print('all done / 识别完成')

View File

@ -0,0 +1,109 @@
import re
from typing import Dict, List, Tuple
from PIL import Image
tag_escape_pattern = re.compile(r'([\\()])')
class Interrogator:
@staticmethod
def postprocess_tags(
tags: Dict[str, List[Tuple[str, float]]],
threshold=0.35,
character_threshold=0.6,
add_rating_tag=False,
add_model_tag=False,
additional_tags: List[str] = [],
exclude_tags: List[str] = [],
sort_by_alphabetical_order=False,
add_confident_as_weight=False,
replace_underscore=False,
replace_underscore_excludes: List[str] = [],
escape_tag=False
) -> Dict[str, float]:
ok_tags = {}
if not add_rating_tag and 'rating' in tags:
del tags['rating']
if not add_model_tag and 'model' in tags:
del tags['model']
if 'character' in tags:
for t, c in tags['character']:
if c >= character_threshold:
ok_tags[t] = c
del tags['character']
for t in additional_tags:
ok_tags[t] = 1.0
for category in tags:
for t, c in tags[category]:
if c >= threshold:
ok_tags[t] = c
for e in exclude_tags:
del ok_tags[e]
if sort_by_alphabetical_order:
ok_tags = dict(sorted(ok_tags.items()))
# sort tag by confidence
else:
ok_tags = dict(sorted(ok_tags.items(), key=lambda item: item[1], reverse=True))
new_tags = []
for tag in list(tags):
new_tag = tag
if replace_underscore and tag not in replace_underscore_excludes:
new_tag = new_tag.replace('_', ' ')
if escape_tag:
new_tag = tag_escape_pattern.sub(r'\\\1', new_tag)
if add_confident_as_weight:
new_tag = f'({new_tag}:{tags[tag]})'
new_tags.append((new_tag, tags[tag]))
tags = dict(new_tags)
return tags
def __init__(self, name: str) -> None:
self.name = name
def load(self):
raise NotImplementedError()
def unload(self) -> bool:
unloaded = False
if hasattr(self, 'model') and self.model is not None:
del self.model
unloaded = True
print(f'Unloaded {self.name}')
if hasattr(self, 'tags'):
del self.tags
return unloaded
def interrogate(
self,
image: Image
) -> Dict[str, List[Tuple[str, float]]]:
"""
Interrogate the given image and return tags with their confidence scores.
:param image: The input image to be interrogated.
:return: A dictionary with categories as keys and lists of (tag, confidence)
categories: "rating", "general", "character", "copyright", "artist", "meta", "quality", "model"
"""
raise NotImplementedError()

View File

@ -0,0 +1,268 @@
import json
import os
import re
from collections import OrderedDict
from glob import glob
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
from PIL import Image
from PIL import UnidentifiedImageError
from huggingface_hub import hf_hub_download
from dataclasses import dataclass
from mikazuki.tagger import dbimutils, format
from mikazuki.tagger.interrogators.base import Interrogator
@dataclass
class LabelData:
names: list[str]
rating: list[np.int64]
general: list[np.int64]
artist: list[np.int64]
character: list[np.int64]
copyright: list[np.int64]
meta: list[np.int64]
quality: list[np.int64]
model: list[np.int64]
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
if image.mode not in ["RGB", "RGBA"]:
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
if image.mode == "RGBA":
background = Image.new("RGB", image.size, (255, 255, 255))
background.paste(image, mask=image.split()[3])
image = background
return image
def pil_pad_square(image: Image.Image) -> Image.Image:
width, height = image.size
if width == height:
return image
new_size = max(width, height)
new_image = Image.new(image.mode, (new_size, new_size), (255, 255, 255)) # Use image.mode
paste_position = ((new_size - width) // 2, (new_size - height) // 2)
new_image.paste(image, paste_position)
return new_image
def get_tags(probs, labels: LabelData):
result = {
"rating": [],
"general": [],
"character": [],
"copyright": [],
"artist": [],
"meta": [],
"quality": [],
"model": []
}
# Rating (select max)
if len(labels.rating) > 0:
valid_indices = labels.rating[labels.rating < len(probs)]
if len(valid_indices) > 0:
rating_probs = probs[valid_indices]
if len(rating_probs) > 0:
rating_idx_local = np.argmax(rating_probs)
rating_idx_global = valid_indices[rating_idx_local]
if rating_idx_global < len(labels.names) and labels.names[rating_idx_global] is not None:
rating_name = labels.names[rating_idx_global]
rating_conf = float(rating_probs[rating_idx_local])
result["rating"].append((rating_name, rating_conf))
else:
print(f"Warning: Invalid global index {rating_idx_global} for rating tag.")
else:
print("Warning: rating_probs became empty after filtering.")
else:
print("Warning: No valid indices found for rating tags within probs length.")
# Quality (select max)
if len(labels.quality) > 0:
valid_indices = labels.quality[labels.quality < len(probs)]
if len(valid_indices) > 0:
quality_probs = probs[valid_indices]
if len(quality_probs) > 0:
quality_idx_local = np.argmax(quality_probs)
quality_idx_global = valid_indices[quality_idx_local]
if quality_idx_global < len(labels.names) and labels.names[quality_idx_global] is not None:
quality_name = labels.names[quality_idx_global]
quality_conf = float(quality_probs[quality_idx_local])
result["quality"].append((quality_name, quality_conf))
else:
print(f"Warning: Invalid global index {quality_idx_global} for quality tag.")
else:
print("Warning: quality_probs became empty after filtering.")
else:
print("Warning: No valid indices found for quality tags within probs length.")
# All tags for each category (no threshold)
category_map = {
"general": labels.general,
"character": labels.character,
"copyright": labels.copyright,
"artist": labels.artist,
"meta": labels.meta,
"model": labels.model
}
for category, indices in category_map.items():
if len(indices) > 0:
valid_indices = indices[(indices < len(probs))]
if len(valid_indices) > 0:
category_probs = probs[valid_indices]
for idx_local, idx_global in enumerate(valid_indices):
if idx_global < len(labels.names) and labels.names[idx_global] is not None:
result[category].append((labels.names[idx_global], float(category_probs[idx_local])))
else:
print(f"Warning: Invalid global index {idx_global} for {category} tag.")
# Sort by probability (descending)
for k in result:
result[k] = sorted(result[k], key=lambda x: x[1], reverse=True)
return result
class CLTaggerInterrogator(Interrogator):
def __init__(
self,
name: str,
model_path='model.onnx',
tag_mapping_path='tag_mapping.json',
**kwargs
) -> None:
super().__init__(name)
self.model_path = model_path
self.tag_mapping_path = tag_mapping_path
self.kwargs = kwargs
def download(self) -> Tuple[os.PathLike, os.PathLike]:
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")
model_path = Path(hf_hub_download(
**self.kwargs, filename=self.model_path))
tag_mapping_path = Path(hf_hub_download(
**self.kwargs, filename=self.tag_mapping_path))
return model_path, tag_mapping_path
def load(self) -> None:
model_path, tag_mapping_path = self.download()
import torch
from onnxruntime import InferenceSession
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self.model = InferenceSession(str(model_path), providers=providers)
print(f'Loaded {self.name} model from {model_path}')
self.tags = self.load_tag_mapping(tag_mapping_path)
def load_tag_mapping(self, mapping_path):
# Use the implementation from the original app.py as it was confirmed working
with open(mapping_path, 'r', encoding='utf-8') as f:
tag_mapping_data = json.load(f)
# Check format compatibility (can be dict of dicts or dict with idx_to_tag/tag_to_category)
if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
tag_to_category = tag_mapping_data["tag_to_category"]
elif isinstance(tag_mapping_data, dict):
# Assuming the dict-of-dicts format from previous tests
try:
tag_mapping_data_int_keys = {int(k): v for k, v in tag_mapping_data.items()}
idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data_int_keys.items()}
tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data_int_keys.values()}
except (KeyError, ValueError) as e:
raise ValueError(f"Unsupported tag mapping format (dict): {e}. Expected int keys with 'tag' and 'category'.")
else:
raise ValueError("Unsupported tag mapping format: Expected a dictionary.")
names = [None] * (max(idx_to_tag.keys()) + 1)
rating, general, artist, character, copyright, meta, quality, model_name = [], [], [], [], [], [], [], []
for idx, tag in idx_to_tag.items():
if idx >= len(names):
names.extend([None] * (idx - len(names) + 1))
names[idx] = tag
category = tag_to_category.get(tag, 'Unknown') # Handle missing category mapping gracefully
idx_int = int(idx)
if category == 'Rating':
rating.append(idx_int)
elif category == 'General':
general.append(idx_int)
elif category == 'Artist':
artist.append(idx_int)
elif category == 'Character':
character.append(idx_int)
elif category == 'Copyright':
copyright.append(idx_int)
elif category == 'Meta':
meta.append(idx_int)
elif category == 'Quality':
quality.append(idx_int)
elif category == 'Model':
model_name.append(idx_int)
return LabelData(names=names, rating=np.array(rating, dtype=np.int64), general=np.array(general, dtype=np.int64), artist=np.array(artist, dtype=np.int64),
character=np.array(character, dtype=np.int64), copyright=np.array(copyright, dtype=np.int64), meta=np.array(meta, dtype=np.int64), quality=np.array(quality, dtype=np.int64), model=np.array(model_name, dtype=np.int64)), idx_to_tag, tag_to_category
def preprocess_image(self, image: Image.Image, target_size=(448, 448)):
# Adapted from onnx_predict.py's version
image = pil_ensure_rgb(image)
image = pil_pad_square(image)
image_resized = image.resize(target_size, Image.BICUBIC)
img_array = np.array(image_resized, dtype=np.float32) / 255.0
img_array = img_array.transpose(2, 0, 1) # HWC -> CHW
# Assuming model expects RGB based on original code, no BGR conversion here
img_array = img_array[::-1, :, :] # BGR conversion if needed - UNCOMMENTED based on user feedback
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
img_array = (img_array - mean) / std
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
return image, img_array
def interrogate(
self,
image: Image
) -> dict[str, list]:
# init model
if not hasattr(self, 'model') or self.model is None:
self.load()
input_name = self.model.get_inputs()[0].name
output_name = self.model.get_outputs()[0].name
original_pil_image, input_tensor = self.preprocess_image(image)
input_tensor = input_tensor.astype(np.float32)
outputs = self.model.run([output_name], {input_name: input_tensor})[0]
if np.isnan(outputs).any() or np.isinf(outputs).any():
print("Warning: NaN or Inf detected in model output. Clamping...")
outputs = np.nan_to_num(outputs, nan=0.0, posinf=1.0, neginf=0.0) # Clamp to 0-1 range
# Apply sigmoid (outputs are likely logits)
# Use a stable sigmoid implementation
def stable_sigmoid(x):
return 1 / (1 + np.exp(-np.clip(x, -30, 30))) # Clip to avoid overflow
probs = stable_sigmoid(outputs[0]) # Assuming batch size 1
predictions = get_tags(probs, self.tags[0]) # g_labels_data
# output_tags = []
# if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
# if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
# # Add other categories, respecting order and filtering meta if needed
# for category in ["artist", "character", "copyright", "general", "meta", "model"]:
# tags_in_category = predictions.get(category, [])
# for tag, prob in tags_in_category:
# # Basic meta tag filtering for text output
# if category == "meta" and any(p in tag.lower() for p in ['id', 'commentary', 'request', 'mismatch']):
# continue
# output_tags.append(tag.replace("_", " "))
# output_text = ", ".join(output_tags)
print(predictions)
return predictions

View File

@ -0,0 +1,131 @@
# from https://github.com/toriato/stable-diffusion-webui-wd14-tagger
import json
import os
import re
from collections import OrderedDict
from glob import glob
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
from PIL import Image
from PIL import UnidentifiedImageError
from huggingface_hub import hf_hub_download
from mikazuki.tagger.interrogators.base import Interrogator
from mikazuki.tagger import dbimutils, format
class WaifuDiffusionInterrogator(Interrogator):
def __init__(
self,
name: str,
model_path='model.onnx',
tags_path='selected_tags.csv',
**kwargs
) -> None:
super().__init__(name)
self.model_path = model_path
self.tags_path = tags_path
self.kwargs = kwargs
def download(self) -> Tuple[os.PathLike, os.PathLike]:
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")
model_path = Path(hf_hub_download(
**self.kwargs, filename=self.model_path))
tags_path = Path(hf_hub_download(
**self.kwargs, filename=self.tags_path))
return model_path, tags_path
def load(self) -> None:
model_path, tags_path = self.download()
# only one of these packages should be installed at a time in any one environment
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
# TODO: remove old package when the environment changes?
# from mikazuki.launch_utils import is_installed, run_pip
# if not is_installed('onnxruntime'):
# package = os.environ.get(
# 'ONNXRUNTIME_PACKAGE',
# 'onnxruntime-gpu'
# )
# run_pip(f'install {package}', 'onnxruntime')
# Load torch to load cuda libs built in torch for onnxruntime, do not delete this.
import torch
from onnxruntime import InferenceSession
# https://onnxruntime.ai/docs/execution-providers/
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self.model = InferenceSession(str(model_path), providers=providers)
print(f'Loaded {self.name} model from {model_path}')
self.tags = pd.read_csv(tags_path)
def interrogate(
self,
image: Image
) -> Dict[str, List[Tuple[str, float]]]:
# init model
if not hasattr(self, 'model') or self.model is None:
self.load()
# code for converting the image and running the model is taken from the link below
# thanks, SmilingWolf!
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py
# convert an image to fit the model
_, height, _, _ = self.model.get_inputs()[0].shape
# alpha to white
image = image.convert('RGBA')
new_image = Image.new('RGBA', image.size, 'WHITE')
new_image.paste(image, mask=image)
image = new_image.convert('RGB')
image = np.asarray(image)
# PIL RGB to OpenCV BGR
image = image[:, :, ::-1]
image = dbimutils.make_square(image, height)
image = dbimutils.smart_resize(image, height)
image = image.astype(np.float32)
image = np.expand_dims(image, 0)
# evaluate model
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
confidents = self.model.run([label_name], {input_name: image})[0]
tags = self.tags[:][['name']]
tags['confidents'] = confidents[0]
# first 4 items are for rating (general, sensitive, questionable, explicit)
ratings = dict(tags[:4].values)
# rest are regular tags
tags = dict(tags[4:].values)
result = {
"rating": [],
"general": [],
"character": [],
"copyright": [],
"artist": [],
"meta": [],
"quality": [],
"model": []
}
for tag, conf in ratings.items():
result["rating"].append((tag, conf))
for tag, conf in tags.items():
result["general"].append((tag, conf))
return result