feat: support cl tagger
parent
9a09518786
commit
3c7802fc05
|
|
@ -207,6 +207,9 @@ async def run_interrogate(req: TaggerInterrogateRequest, background_tasks: Backg
|
|||
batch_output_save_json=False,
|
||||
interrogator=interrogator,
|
||||
threshold=req.threshold,
|
||||
character_threshold=req.character_threshold,
|
||||
add_rating_tag=req.add_rating_tag,
|
||||
add_model_tag=req.add_model_tag,
|
||||
additional_tags=req.additional_tags,
|
||||
exclude_tags=req.exclude_tags,
|
||||
sort_by_alphabetical_order=False,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,13 @@ class TaggerInterrogateRequest(BaseModel):
|
|||
ge=0,
|
||||
le=1
|
||||
)
|
||||
character_threshold: float = Field(
|
||||
default=0.6,
|
||||
ge=0,
|
||||
le=1
|
||||
)
|
||||
add_rating_tag: bool = False
|
||||
add_model_tag: bool = False
|
||||
additional_tags: str = ""
|
||||
exclude_tags: str = ""
|
||||
escape_tag: bool = True
|
||||
|
|
|
|||
|
|
@ -14,193 +14,13 @@ from PIL import UnidentifiedImageError
|
|||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from mikazuki.tagger import dbimutils, format
|
||||
from mikazuki.tagger.interrogators.base import Interrogator
|
||||
from mikazuki.tagger.interrogators.wd14 import WaifuDiffusionInterrogator
|
||||
from mikazuki.tagger.interrogators.cl import CLTaggerInterrogator
|
||||
|
||||
tag_escape_pattern = re.compile(r'([\\()])')
|
||||
|
||||
|
||||
class Interrogator:
|
||||
@staticmethod
|
||||
def postprocess_tags(
|
||||
tags: Dict[str, float],
|
||||
|
||||
threshold=0.35,
|
||||
additional_tags: List[str] = [],
|
||||
exclude_tags: List[str] = [],
|
||||
sort_by_alphabetical_order=False,
|
||||
add_confident_as_weight=False,
|
||||
replace_underscore=False,
|
||||
replace_underscore_excludes: List[str] = [],
|
||||
escape_tag=False
|
||||
) -> Dict[str, float]:
|
||||
for t in additional_tags:
|
||||
tags[t] = 1.0
|
||||
|
||||
# those lines are totally not "pythonic" but looks better to me
|
||||
tags = {
|
||||
t: c
|
||||
|
||||
# sort by tag name or confident
|
||||
for t, c in sorted(
|
||||
tags.items(),
|
||||
key=lambda i: i[0 if sort_by_alphabetical_order else 1],
|
||||
reverse=not sort_by_alphabetical_order
|
||||
)
|
||||
|
||||
# filter tags
|
||||
if (
|
||||
c >= threshold
|
||||
and t not in exclude_tags
|
||||
)
|
||||
}
|
||||
|
||||
new_tags = []
|
||||
for tag in list(tags):
|
||||
new_tag = tag
|
||||
|
||||
if replace_underscore and tag not in replace_underscore_excludes:
|
||||
new_tag = new_tag.replace('_', ' ')
|
||||
|
||||
if escape_tag:
|
||||
new_tag = tag_escape_pattern.sub(r'\\\1', new_tag)
|
||||
|
||||
if add_confident_as_weight:
|
||||
new_tag = f'({new_tag}:{tags[tag]})'
|
||||
|
||||
new_tags.append((new_tag, tags[tag]))
|
||||
tags = dict(new_tags)
|
||||
|
||||
return tags
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def load(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def unload(self) -> bool:
|
||||
unloaded = False
|
||||
|
||||
if hasattr(self, 'model') and self.model is not None:
|
||||
del self.model
|
||||
unloaded = True
|
||||
print(f'Unloaded {self.name}')
|
||||
|
||||
if hasattr(self, 'tags'):
|
||||
del self.tags
|
||||
|
||||
return unloaded
|
||||
|
||||
def interrogate(
|
||||
self,
|
||||
image: Image
|
||||
) -> Tuple[
|
||||
Dict[str, float], # rating confidents
|
||||
Dict[str, float] # tag confidents
|
||||
]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class WaifuDiffusionInterrogator(Interrogator):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_path='model.onnx',
|
||||
tags_path='selected_tags.csv',
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(name)
|
||||
self.model_path = model_path
|
||||
self.tags_path = tags_path
|
||||
self.kwargs = kwargs
|
||||
|
||||
def download(self) -> Tuple[os.PathLike, os.PathLike]:
|
||||
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")
|
||||
|
||||
model_path = Path(hf_hub_download(
|
||||
**self.kwargs, filename=self.model_path))
|
||||
tags_path = Path(hf_hub_download(
|
||||
**self.kwargs, filename=self.tags_path))
|
||||
return model_path, tags_path
|
||||
|
||||
def load(self) -> None:
|
||||
model_path, tags_path = self.download()
|
||||
|
||||
# only one of these packages should be installed at a time in any one environment
|
||||
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
|
||||
# TODO: remove old package when the environment changes?
|
||||
# from mikazuki.launch_utils import is_installed, run_pip
|
||||
# if not is_installed('onnxruntime'):
|
||||
# package = os.environ.get(
|
||||
# 'ONNXRUNTIME_PACKAGE',
|
||||
# 'onnxruntime-gpu'
|
||||
# )
|
||||
|
||||
# run_pip(f'install {package}', 'onnxruntime')
|
||||
|
||||
# Load torch to load cuda libs built in torch for onnxruntime, do not delete this.
|
||||
import torch
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
# https://onnxruntime.ai/docs/execution-providers/
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
|
||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
|
||||
self.model = InferenceSession(str(model_path), providers=providers)
|
||||
|
||||
print(f'Loaded {self.name} model from {model_path}')
|
||||
|
||||
self.tags = pd.read_csv(tags_path)
|
||||
|
||||
def interrogate(
|
||||
self,
|
||||
image: Image
|
||||
) -> Tuple[
|
||||
Dict[str, float], # rating confidents
|
||||
Dict[str, float] # tag confidents
|
||||
]:
|
||||
# init model
|
||||
if not hasattr(self, 'model') or self.model is None:
|
||||
self.load()
|
||||
|
||||
# code for converting the image and running the model is taken from the link below
|
||||
# thanks, SmilingWolf!
|
||||
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py
|
||||
|
||||
# convert an image to fit the model
|
||||
_, height, _, _ = self.model.get_inputs()[0].shape
|
||||
|
||||
# alpha to white
|
||||
image = image.convert('RGBA')
|
||||
new_image = Image.new('RGBA', image.size, 'WHITE')
|
||||
new_image.paste(image, mask=image)
|
||||
image = new_image.convert('RGB')
|
||||
image = np.asarray(image)
|
||||
|
||||
# PIL RGB to OpenCV BGR
|
||||
image = image[:, :, ::-1]
|
||||
|
||||
image = dbimutils.make_square(image, height)
|
||||
image = dbimutils.smart_resize(image, height)
|
||||
image = image.astype(np.float32)
|
||||
image = np.expand_dims(image, 0)
|
||||
|
||||
# evaluate model
|
||||
input_name = self.model.get_inputs()[0].name
|
||||
label_name = self.model.get_outputs()[0].name
|
||||
confidents = self.model.run([label_name], {input_name: image})[0]
|
||||
|
||||
tags = self.tags[:][['name']]
|
||||
tags['confidents'] = confidents[0]
|
||||
|
||||
# first 4 items are for rating (general, sensitive, questionable, explicit)
|
||||
ratings = dict(tags[:4].values)
|
||||
|
||||
# rest are regular tags
|
||||
tags = dict(tags[4:].values)
|
||||
|
||||
return ratings, tags
|
||||
|
||||
|
||||
available_interrogators = {
|
||||
'wd-convnext-v3': WaifuDiffusionInterrogator(
|
||||
'wd-convnext-v3',
|
||||
|
|
@ -239,6 +59,12 @@ available_interrogators = {
|
|||
'wd-vit-large-tagger-v3',
|
||||
repo_id='SmilingWolf/wd-vit-large-tagger-v3',
|
||||
),
|
||||
'cl_tagger_1_01': CLTaggerInterrogator(
|
||||
'cl_tagger_1_01',
|
||||
repo_id='cella110n/cl_tagger',
|
||||
model_path='cl_tagger_1_01/model.onnx',
|
||||
tag_mapping_path='cl_tagger_1_01/tag_mapping.json',
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -257,7 +83,13 @@ def on_interrogate(
|
|||
batch_output_save_json: bool,
|
||||
|
||||
interrogator: Interrogator,
|
||||
|
||||
threshold: float,
|
||||
character_threshold: float,
|
||||
|
||||
add_rating_tag: bool,
|
||||
add_model_tag: bool,
|
||||
|
||||
additional_tags: str,
|
||||
exclude_tags: str,
|
||||
sort_by_alphabetical_order: bool,
|
||||
|
|
@ -270,6 +102,9 @@ def on_interrogate(
|
|||
):
|
||||
postprocess_opts = (
|
||||
threshold,
|
||||
character_threshold,
|
||||
add_rating_tag,
|
||||
add_model_tag,
|
||||
split_str(additional_tags),
|
||||
split_str(exclude_tags),
|
||||
sort_by_alphabetical_order,
|
||||
|
|
@ -361,7 +196,7 @@ def on_interrogate(
|
|||
print(f'skipping {path}')
|
||||
continue
|
||||
|
||||
ratings, tags = interrogator.interrogate(image)
|
||||
tags = interrogator.interrogate(image)
|
||||
processed_tags = Interrogator.postprocess_tags(
|
||||
tags,
|
||||
*postprocess_opts
|
||||
|
|
@ -398,7 +233,7 @@ def on_interrogate(
|
|||
|
||||
if batch_output_save_json:
|
||||
output_path.with_suffix('.json').write_text(
|
||||
json.dumps([ratings, tags])
|
||||
json.dumps(tags)
|
||||
)
|
||||
|
||||
print('all done / 识别完成')
|
||||
|
|
|
|||
|
|
@ -0,0 +1,109 @@
|
|||
import re
|
||||
from typing import Dict, List, Tuple
|
||||
from PIL import Image
|
||||
|
||||
tag_escape_pattern = re.compile(r'([\\()])')
|
||||
|
||||
|
||||
class Interrogator:
|
||||
@staticmethod
|
||||
def postprocess_tags(
|
||||
tags: Dict[str, List[Tuple[str, float]]],
|
||||
|
||||
threshold=0.35,
|
||||
character_threshold=0.6,
|
||||
|
||||
add_rating_tag=False,
|
||||
add_model_tag=False,
|
||||
|
||||
additional_tags: List[str] = [],
|
||||
exclude_tags: List[str] = [],
|
||||
sort_by_alphabetical_order=False,
|
||||
add_confident_as_weight=False,
|
||||
replace_underscore=False,
|
||||
replace_underscore_excludes: List[str] = [],
|
||||
escape_tag=False
|
||||
) -> Dict[str, float]:
|
||||
|
||||
ok_tags = {}
|
||||
|
||||
if not add_rating_tag and 'rating' in tags:
|
||||
del tags['rating']
|
||||
|
||||
if not add_model_tag and 'model' in tags:
|
||||
del tags['model']
|
||||
|
||||
if 'character' in tags:
|
||||
for t, c in tags['character']:
|
||||
if c >= character_threshold:
|
||||
ok_tags[t] = c
|
||||
|
||||
del tags['character']
|
||||
|
||||
for t in additional_tags:
|
||||
ok_tags[t] = 1.0
|
||||
|
||||
for category in tags:
|
||||
for t, c in tags[category]:
|
||||
if c >= threshold:
|
||||
ok_tags[t] = c
|
||||
|
||||
for e in exclude_tags:
|
||||
del ok_tags[e]
|
||||
|
||||
if sort_by_alphabetical_order:
|
||||
ok_tags = dict(sorted(ok_tags.items()))
|
||||
# sort tag by confidence
|
||||
else:
|
||||
ok_tags = dict(sorted(ok_tags.items(), key=lambda item: item[1], reverse=True))
|
||||
|
||||
new_tags = []
|
||||
for tag in list(tags):
|
||||
new_tag = tag
|
||||
|
||||
if replace_underscore and tag not in replace_underscore_excludes:
|
||||
new_tag = new_tag.replace('_', ' ')
|
||||
|
||||
if escape_tag:
|
||||
new_tag = tag_escape_pattern.sub(r'\\\1', new_tag)
|
||||
|
||||
if add_confident_as_weight:
|
||||
new_tag = f'({new_tag}:{tags[tag]})'
|
||||
|
||||
new_tags.append((new_tag, tags[tag]))
|
||||
tags = dict(new_tags)
|
||||
|
||||
return tags
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def load(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def unload(self) -> bool:
|
||||
unloaded = False
|
||||
|
||||
if hasattr(self, 'model') and self.model is not None:
|
||||
del self.model
|
||||
unloaded = True
|
||||
print(f'Unloaded {self.name}')
|
||||
|
||||
if hasattr(self, 'tags'):
|
||||
del self.tags
|
||||
|
||||
return unloaded
|
||||
|
||||
def interrogate(
|
||||
self,
|
||||
image: Image
|
||||
) -> Dict[str, List[Tuple[str, float]]]:
|
||||
"""
|
||||
Interrogate the given image and return tags with their confidence scores.
|
||||
:param image: The input image to be interrogated.
|
||||
:return: A dictionary with categories as keys and lists of (tag, confidence)
|
||||
|
||||
categories: "rating", "general", "character", "copyright", "artist", "meta", "quality", "model"
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
|
@ -0,0 +1,268 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from PIL import UnidentifiedImageError
|
||||
from huggingface_hub import hf_hub_download
|
||||
from dataclasses import dataclass
|
||||
from mikazuki.tagger import dbimutils, format
|
||||
from mikazuki.tagger.interrogators.base import Interrogator
|
||||
|
||||
|
||||
@dataclass
|
||||
class LabelData:
|
||||
names: list[str]
|
||||
rating: list[np.int64]
|
||||
general: list[np.int64]
|
||||
artist: list[np.int64]
|
||||
character: list[np.int64]
|
||||
copyright: list[np.int64]
|
||||
meta: list[np.int64]
|
||||
quality: list[np.int64]
|
||||
model: list[np.int64]
|
||||
|
||||
|
||||
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
|
||||
if image.mode not in ["RGB", "RGBA"]:
|
||||
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
|
||||
if image.mode == "RGBA":
|
||||
background = Image.new("RGB", image.size, (255, 255, 255))
|
||||
background.paste(image, mask=image.split()[3])
|
||||
image = background
|
||||
return image
|
||||
|
||||
|
||||
def pil_pad_square(image: Image.Image) -> Image.Image:
|
||||
width, height = image.size
|
||||
if width == height:
|
||||
return image
|
||||
new_size = max(width, height)
|
||||
new_image = Image.new(image.mode, (new_size, new_size), (255, 255, 255)) # Use image.mode
|
||||
paste_position = ((new_size - width) // 2, (new_size - height) // 2)
|
||||
new_image.paste(image, paste_position)
|
||||
return new_image
|
||||
|
||||
|
||||
def get_tags(probs, labels: LabelData):
|
||||
result = {
|
||||
"rating": [],
|
||||
"general": [],
|
||||
"character": [],
|
||||
"copyright": [],
|
||||
"artist": [],
|
||||
"meta": [],
|
||||
"quality": [],
|
||||
"model": []
|
||||
}
|
||||
# Rating (select max)
|
||||
if len(labels.rating) > 0:
|
||||
valid_indices = labels.rating[labels.rating < len(probs)]
|
||||
if len(valid_indices) > 0:
|
||||
rating_probs = probs[valid_indices]
|
||||
if len(rating_probs) > 0:
|
||||
rating_idx_local = np.argmax(rating_probs)
|
||||
rating_idx_global = valid_indices[rating_idx_local]
|
||||
if rating_idx_global < len(labels.names) and labels.names[rating_idx_global] is not None:
|
||||
rating_name = labels.names[rating_idx_global]
|
||||
rating_conf = float(rating_probs[rating_idx_local])
|
||||
result["rating"].append((rating_name, rating_conf))
|
||||
else:
|
||||
print(f"Warning: Invalid global index {rating_idx_global} for rating tag.")
|
||||
else:
|
||||
print("Warning: rating_probs became empty after filtering.")
|
||||
else:
|
||||
print("Warning: No valid indices found for rating tags within probs length.")
|
||||
|
||||
# Quality (select max)
|
||||
if len(labels.quality) > 0:
|
||||
valid_indices = labels.quality[labels.quality < len(probs)]
|
||||
if len(valid_indices) > 0:
|
||||
quality_probs = probs[valid_indices]
|
||||
if len(quality_probs) > 0:
|
||||
quality_idx_local = np.argmax(quality_probs)
|
||||
quality_idx_global = valid_indices[quality_idx_local]
|
||||
if quality_idx_global < len(labels.names) and labels.names[quality_idx_global] is not None:
|
||||
quality_name = labels.names[quality_idx_global]
|
||||
quality_conf = float(quality_probs[quality_idx_local])
|
||||
result["quality"].append((quality_name, quality_conf))
|
||||
else:
|
||||
print(f"Warning: Invalid global index {quality_idx_global} for quality tag.")
|
||||
else:
|
||||
print("Warning: quality_probs became empty after filtering.")
|
||||
else:
|
||||
print("Warning: No valid indices found for quality tags within probs length.")
|
||||
|
||||
# All tags for each category (no threshold)
|
||||
category_map = {
|
||||
"general": labels.general,
|
||||
"character": labels.character,
|
||||
"copyright": labels.copyright,
|
||||
"artist": labels.artist,
|
||||
"meta": labels.meta,
|
||||
"model": labels.model
|
||||
}
|
||||
for category, indices in category_map.items():
|
||||
if len(indices) > 0:
|
||||
valid_indices = indices[(indices < len(probs))]
|
||||
if len(valid_indices) > 0:
|
||||
category_probs = probs[valid_indices]
|
||||
for idx_local, idx_global in enumerate(valid_indices):
|
||||
if idx_global < len(labels.names) and labels.names[idx_global] is not None:
|
||||
result[category].append((labels.names[idx_global], float(category_probs[idx_local])))
|
||||
else:
|
||||
print(f"Warning: Invalid global index {idx_global} for {category} tag.")
|
||||
|
||||
# Sort by probability (descending)
|
||||
for k in result:
|
||||
result[k] = sorted(result[k], key=lambda x: x[1], reverse=True)
|
||||
return result
|
||||
|
||||
|
||||
class CLTaggerInterrogator(Interrogator):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_path='model.onnx',
|
||||
tag_mapping_path='tag_mapping.json',
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(name)
|
||||
self.model_path = model_path
|
||||
self.tag_mapping_path = tag_mapping_path
|
||||
self.kwargs = kwargs
|
||||
|
||||
def download(self) -> Tuple[os.PathLike, os.PathLike]:
|
||||
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")
|
||||
|
||||
model_path = Path(hf_hub_download(
|
||||
**self.kwargs, filename=self.model_path))
|
||||
tag_mapping_path = Path(hf_hub_download(
|
||||
**self.kwargs, filename=self.tag_mapping_path))
|
||||
return model_path, tag_mapping_path
|
||||
|
||||
def load(self) -> None:
|
||||
model_path, tag_mapping_path = self.download()
|
||||
|
||||
import torch
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
|
||||
self.model = InferenceSession(str(model_path), providers=providers)
|
||||
|
||||
print(f'Loaded {self.name} model from {model_path}')
|
||||
|
||||
self.tags = self.load_tag_mapping(tag_mapping_path)
|
||||
|
||||
def load_tag_mapping(self, mapping_path):
|
||||
# Use the implementation from the original app.py as it was confirmed working
|
||||
with open(mapping_path, 'r', encoding='utf-8') as f:
|
||||
tag_mapping_data = json.load(f)
|
||||
# Check format compatibility (can be dict of dicts or dict with idx_to_tag/tag_to_category)
|
||||
if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
|
||||
idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
|
||||
tag_to_category = tag_mapping_data["tag_to_category"]
|
||||
elif isinstance(tag_mapping_data, dict):
|
||||
# Assuming the dict-of-dicts format from previous tests
|
||||
try:
|
||||
tag_mapping_data_int_keys = {int(k): v for k, v in tag_mapping_data.items()}
|
||||
idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data_int_keys.items()}
|
||||
tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data_int_keys.values()}
|
||||
except (KeyError, ValueError) as e:
|
||||
raise ValueError(f"Unsupported tag mapping format (dict): {e}. Expected int keys with 'tag' and 'category'.")
|
||||
else:
|
||||
raise ValueError("Unsupported tag mapping format: Expected a dictionary.")
|
||||
|
||||
names = [None] * (max(idx_to_tag.keys()) + 1)
|
||||
rating, general, artist, character, copyright, meta, quality, model_name = [], [], [], [], [], [], [], []
|
||||
for idx, tag in idx_to_tag.items():
|
||||
if idx >= len(names):
|
||||
names.extend([None] * (idx - len(names) + 1))
|
||||
names[idx] = tag
|
||||
category = tag_to_category.get(tag, 'Unknown') # Handle missing category mapping gracefully
|
||||
idx_int = int(idx)
|
||||
if category == 'Rating':
|
||||
rating.append(idx_int)
|
||||
elif category == 'General':
|
||||
general.append(idx_int)
|
||||
elif category == 'Artist':
|
||||
artist.append(idx_int)
|
||||
elif category == 'Character':
|
||||
character.append(idx_int)
|
||||
elif category == 'Copyright':
|
||||
copyright.append(idx_int)
|
||||
elif category == 'Meta':
|
||||
meta.append(idx_int)
|
||||
elif category == 'Quality':
|
||||
quality.append(idx_int)
|
||||
elif category == 'Model':
|
||||
model_name.append(idx_int)
|
||||
|
||||
return LabelData(names=names, rating=np.array(rating, dtype=np.int64), general=np.array(general, dtype=np.int64), artist=np.array(artist, dtype=np.int64),
|
||||
character=np.array(character, dtype=np.int64), copyright=np.array(copyright, dtype=np.int64), meta=np.array(meta, dtype=np.int64), quality=np.array(quality, dtype=np.int64), model=np.array(model_name, dtype=np.int64)), idx_to_tag, tag_to_category
|
||||
|
||||
def preprocess_image(self, image: Image.Image, target_size=(448, 448)):
|
||||
# Adapted from onnx_predict.py's version
|
||||
image = pil_ensure_rgb(image)
|
||||
image = pil_pad_square(image)
|
||||
image_resized = image.resize(target_size, Image.BICUBIC)
|
||||
img_array = np.array(image_resized, dtype=np.float32) / 255.0
|
||||
img_array = img_array.transpose(2, 0, 1) # HWC -> CHW
|
||||
# Assuming model expects RGB based on original code, no BGR conversion here
|
||||
img_array = img_array[::-1, :, :] # BGR conversion if needed - UNCOMMENTED based on user feedback
|
||||
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
|
||||
std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
|
||||
img_array = (img_array - mean) / std
|
||||
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
|
||||
return image, img_array
|
||||
|
||||
def interrogate(
|
||||
self,
|
||||
image: Image
|
||||
) -> dict[str, list]:
|
||||
|
||||
# init model
|
||||
if not hasattr(self, 'model') or self.model is None:
|
||||
self.load()
|
||||
|
||||
input_name = self.model.get_inputs()[0].name
|
||||
output_name = self.model.get_outputs()[0].name
|
||||
|
||||
original_pil_image, input_tensor = self.preprocess_image(image)
|
||||
input_tensor = input_tensor.astype(np.float32)
|
||||
|
||||
outputs = self.model.run([output_name], {input_name: input_tensor})[0]
|
||||
|
||||
if np.isnan(outputs).any() or np.isinf(outputs).any():
|
||||
print("Warning: NaN or Inf detected in model output. Clamping...")
|
||||
outputs = np.nan_to_num(outputs, nan=0.0, posinf=1.0, neginf=0.0) # Clamp to 0-1 range
|
||||
|
||||
# Apply sigmoid (outputs are likely logits)
|
||||
# Use a stable sigmoid implementation
|
||||
def stable_sigmoid(x):
|
||||
return 1 / (1 + np.exp(-np.clip(x, -30, 30))) # Clip to avoid overflow
|
||||
probs = stable_sigmoid(outputs[0]) # Assuming batch size 1
|
||||
|
||||
predictions = get_tags(probs, self.tags[0]) # g_labels_data
|
||||
# output_tags = []
|
||||
# if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
|
||||
# if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
|
||||
# # Add other categories, respecting order and filtering meta if needed
|
||||
# for category in ["artist", "character", "copyright", "general", "meta", "model"]:
|
||||
# tags_in_category = predictions.get(category, [])
|
||||
# for tag, prob in tags_in_category:
|
||||
# # Basic meta tag filtering for text output
|
||||
# if category == "meta" and any(p in tag.lower() for p in ['id', 'commentary', 'request', 'mismatch']):
|
||||
# continue
|
||||
# output_tags.append(tag.replace("_", " "))
|
||||
# output_text = ", ".join(output_tags)
|
||||
|
||||
print(predictions)
|
||||
return predictions
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
# from https://github.com/toriato/stable-diffusion-webui-wd14-tagger
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from PIL import UnidentifiedImageError
|
||||
from huggingface_hub import hf_hub_download
|
||||
from mikazuki.tagger.interrogators.base import Interrogator
|
||||
from mikazuki.tagger import dbimutils, format
|
||||
|
||||
|
||||
class WaifuDiffusionInterrogator(Interrogator):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_path='model.onnx',
|
||||
tags_path='selected_tags.csv',
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(name)
|
||||
self.model_path = model_path
|
||||
self.tags_path = tags_path
|
||||
self.kwargs = kwargs
|
||||
|
||||
def download(self) -> Tuple[os.PathLike, os.PathLike]:
|
||||
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")
|
||||
|
||||
model_path = Path(hf_hub_download(
|
||||
**self.kwargs, filename=self.model_path))
|
||||
tags_path = Path(hf_hub_download(
|
||||
**self.kwargs, filename=self.tags_path))
|
||||
return model_path, tags_path
|
||||
|
||||
def load(self) -> None:
|
||||
model_path, tags_path = self.download()
|
||||
|
||||
# only one of these packages should be installed at a time in any one environment
|
||||
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
|
||||
# TODO: remove old package when the environment changes?
|
||||
# from mikazuki.launch_utils import is_installed, run_pip
|
||||
# if not is_installed('onnxruntime'):
|
||||
# package = os.environ.get(
|
||||
# 'ONNXRUNTIME_PACKAGE',
|
||||
# 'onnxruntime-gpu'
|
||||
# )
|
||||
|
||||
# run_pip(f'install {package}', 'onnxruntime')
|
||||
|
||||
# Load torch to load cuda libs built in torch for onnxruntime, do not delete this.
|
||||
import torch
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
# https://onnxruntime.ai/docs/execution-providers/
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
|
||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
|
||||
self.model = InferenceSession(str(model_path), providers=providers)
|
||||
|
||||
print(f'Loaded {self.name} model from {model_path}')
|
||||
|
||||
self.tags = pd.read_csv(tags_path)
|
||||
|
||||
def interrogate(
|
||||
self,
|
||||
image: Image
|
||||
) -> Dict[str, List[Tuple[str, float]]]:
|
||||
# init model
|
||||
if not hasattr(self, 'model') or self.model is None:
|
||||
self.load()
|
||||
|
||||
# code for converting the image and running the model is taken from the link below
|
||||
# thanks, SmilingWolf!
|
||||
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py
|
||||
|
||||
# convert an image to fit the model
|
||||
_, height, _, _ = self.model.get_inputs()[0].shape
|
||||
|
||||
# alpha to white
|
||||
image = image.convert('RGBA')
|
||||
new_image = Image.new('RGBA', image.size, 'WHITE')
|
||||
new_image.paste(image, mask=image)
|
||||
image = new_image.convert('RGB')
|
||||
image = np.asarray(image)
|
||||
|
||||
# PIL RGB to OpenCV BGR
|
||||
image = image[:, :, ::-1]
|
||||
|
||||
image = dbimutils.make_square(image, height)
|
||||
image = dbimutils.smart_resize(image, height)
|
||||
image = image.astype(np.float32)
|
||||
image = np.expand_dims(image, 0)
|
||||
|
||||
# evaluate model
|
||||
input_name = self.model.get_inputs()[0].name
|
||||
label_name = self.model.get_outputs()[0].name
|
||||
confidents = self.model.run([label_name], {input_name: image})[0]
|
||||
|
||||
tags = self.tags[:][['name']]
|
||||
tags['confidents'] = confidents[0]
|
||||
|
||||
# first 4 items are for rating (general, sensitive, questionable, explicit)
|
||||
ratings = dict(tags[:4].values)
|
||||
|
||||
# rest are regular tags
|
||||
tags = dict(tags[4:].values)
|
||||
|
||||
result = {
|
||||
"rating": [],
|
||||
"general": [],
|
||||
"character": [],
|
||||
"copyright": [],
|
||||
"artist": [],
|
||||
"meta": [],
|
||||
"quality": [],
|
||||
"model": []
|
||||
}
|
||||
|
||||
for tag, conf in ratings.items():
|
||||
result["rating"].append((tag, conf))
|
||||
|
||||
for tag, conf in tags.items():
|
||||
result["general"].append((tag, conf))
|
||||
|
||||
return result
|
||||
Loading…
Reference in New Issue