240 lines
7.6 KiB
Python
240 lines
7.6 KiB
Python
import logging
|
|
|
|
import PIL
|
|
import cv2
|
|
import huggingface_hub
|
|
import numpy as np
|
|
import pandas as pd
|
|
from PIL.Image import Image
|
|
from onnxruntime import InferenceSession
|
|
|
|
from extensions.sd_smartprocess.interrogators.interrogator import Interrogator
|
|
from extensions.sd_smartprocess.model_download import fetch_model
|
|
from extensions.sd_smartprocess.process_params import ProcessParams
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
|
|
MODEL_FILENAME = "model.onnx"
|
|
LABEL_FILENAME = "selected_tags.csv"
|
|
WOLF_PARAMS = {
|
|
"group": "WOLF",
|
|
"threshold": 0.75,
|
|
"char_threshold": 0.75
|
|
}
|
|
|
|
|
|
class MoatInterrogator(Interrogator):
|
|
params = WOLF_PARAMS
|
|
|
|
def __init__(self, params: ProcessParams) -> None:
|
|
super().__init__(params)
|
|
self._setup()
|
|
|
|
def _setup(self):
|
|
model_path = "SmilingWolf/wd-v1-4-moat-tagger-v2"
|
|
self.model = load_model(model_path, self.device)
|
|
|
|
def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
|
|
threshold = params.threshold
|
|
char_threshold = params.char_threshold
|
|
a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold)
|
|
return a
|
|
|
|
|
|
class SwinInterrogator(Interrogator):
|
|
params = WOLF_PARAMS
|
|
|
|
def __init__(self, params: ProcessParams) -> None:
|
|
super().__init__(params)
|
|
self._setup()
|
|
|
|
def _setup(self):
|
|
model_path = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
|
|
self.model = load_model(model_path, self.device)
|
|
|
|
def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
|
|
threshold = params.threshold
|
|
char_threshold = params.char_threshold
|
|
a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold)
|
|
return a
|
|
|
|
|
|
class ConvInterrogator(Interrogator):
|
|
params = WOLF_PARAMS
|
|
|
|
def __init__(self, params: ProcessParams) -> None:
|
|
super().__init__(params)
|
|
self._setup()
|
|
|
|
def _setup(self):
|
|
model_path = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
|
self.model = load_model(model_path, self.device)
|
|
|
|
def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
|
|
threshold = params.threshold
|
|
char_threshold = params.char_threshold
|
|
a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold)
|
|
return a
|
|
|
|
|
|
class Conv2Interrogator(Interrogator):
|
|
params = WOLF_PARAMS
|
|
|
|
def __init__(self, params: ProcessParams) -> None:
|
|
super().__init__(params)
|
|
self._setup()
|
|
|
|
def _setup(self):
|
|
model_path = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
|
|
self.model = load_model(model_path, self.device)
|
|
|
|
def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
|
|
threshold = params.threshold
|
|
char_threshold = params.char_threshold
|
|
a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold)
|
|
return a
|
|
|
|
|
|
class VitInterrogator(Interrogator):
|
|
params = WOLF_PARAMS
|
|
|
|
def __init__(self, params: ProcessParams) -> None:
|
|
super().__init__(params)
|
|
self._setup()
|
|
|
|
def _setup(self):
|
|
model_path = "SmilingWolf/wd-v1-4-vit-tagger-v2"
|
|
self.model = load_model(model_path, self.device)
|
|
|
|
def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
|
|
threshold = params.threshold
|
|
char_threshold = params.char_threshold
|
|
a, c, rating, character_res, general_res = predict(image, self.model, threshold, char_threshold)
|
|
return a
|
|
|
|
|
|
def predict(
|
|
image: Image,
|
|
model,
|
|
general_threshold: float = 0.5,
|
|
character_threshold: float = 0.5
|
|
):
|
|
tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
|
|
|
|
_, height, width, _ = model.get_inputs()[0].shape
|
|
|
|
# Alpha to white
|
|
image = image.convert("RGBA")
|
|
new_image = PIL.Image.new("RGBA", image.size, "WHITE")
|
|
new_image.paste(image, mask=image)
|
|
image = new_image.convert("RGB")
|
|
image = np.asarray(image)
|
|
|
|
# PIL RGB to OpenCV BGR
|
|
image = image[:, :, ::-1]
|
|
|
|
image = make_square(image, height)
|
|
image = smart_resize(image, height)
|
|
image = image.astype(np.float32)
|
|
image = np.expand_dims(image, 0)
|
|
|
|
input_name = model.get_inputs()[0].name
|
|
label_name = model.get_outputs()[0].name
|
|
probs = model.run([label_name], {input_name: image})[0]
|
|
|
|
labels = list(zip(tag_names, probs[0].astype(float)))
|
|
|
|
# First 4 labels are actually ratings: pick one with argmax
|
|
ratings_names = [labels[i] for i in rating_indexes]
|
|
rating = dict(ratings_names)
|
|
|
|
# Then we have general tags: pick any where prediction confidence > threshold
|
|
general_names = [labels[i] for i in general_indexes]
|
|
general_res = [x for x in general_names if x[1] > general_threshold]
|
|
general_res = dict(general_res)
|
|
|
|
# Everything else is characters: pick any where prediction confidence > threshold
|
|
character_names = [labels[i] for i in character_indexes]
|
|
character_res = [x for x in character_names if x[1] > character_threshold]
|
|
character_res = dict(character_res)
|
|
|
|
b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
|
|
a = (
|
|
", ".join(list(b.keys()))
|
|
.replace("_", " ")
|
|
.replace("(", "\(")
|
|
.replace(")", "\)")
|
|
)
|
|
c = ", ".join(list(b.keys()))
|
|
|
|
return a, c, rating, character_res, general_res
|
|
|
|
|
|
def load_model(model_path, device):
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
if device == "cpu":
|
|
providers.pop(0)
|
|
path = fetch_model(model_path, "wolf", True)
|
|
model = InferenceSession(path, providers=providers)
|
|
return model
|
|
|
|
|
|
def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
|
|
if img.endswith(".gif"):
|
|
img = PIL.Image.open(img)
|
|
img = img.convert("RGB")
|
|
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
|
else:
|
|
img = cv2.imread(img, flag)
|
|
return img
|
|
|
|
|
|
def smart_24bit(img):
|
|
if img.dtype is np.dtype(np.uint16):
|
|
img = (img / 257).astype(np.uint8)
|
|
|
|
if len(img.shape) == 2:
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
elif img.shape[2] == 4:
|
|
trans_mask = img[:, :, 3] == 0
|
|
img[trans_mask] = [255, 255, 255, 255]
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
|
return img
|
|
|
|
|
|
def make_square(img, target_size):
|
|
old_size = img.shape[:2]
|
|
desired_size = max(old_size)
|
|
desired_size = max(desired_size, target_size)
|
|
|
|
delta_w = desired_size - old_size[1]
|
|
delta_h = desired_size - old_size[0]
|
|
top, bottom = delta_h // 2, delta_h - (delta_h // 2)
|
|
left, right = delta_w // 2, delta_w - (delta_w // 2)
|
|
|
|
color = [255, 255, 255]
|
|
new_im = cv2.copyMakeBorder(
|
|
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
|
|
)
|
|
return new_im
|
|
|
|
|
|
def smart_resize(img, size):
|
|
# Assumes the image has already gone through make_square
|
|
if img.shape[0] > size:
|
|
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
|
elif img.shape[0] < size:
|
|
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
|
|
return img
|
|
|
|
|
|
def load_labels() -> list[str]:
|
|
path = huggingface_hub.hf_hub_download(CONV2_MODEL_REPO, LABEL_FILENAME)
|
|
df = pd.read_csv(path)
|
|
tag_names = df["name"].tolist()
|
|
rating_indexes = list(np.where(df["category"] == 9)[0])
|
|
general_indexes = list(np.where(df["category"] == 0)[0])
|
|
character_indexes = list(np.where(df["category"] == 4)[0])
|
|
return tag_names, rating_indexes, general_indexes, character_indexes
|