253 lines
7.4 KiB
Python
253 lines
7.4 KiB
Python
# Borrowed from https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/master/tagger/interrogator.py
|
|
|
|
import os
|
|
import re
|
|
import sys
|
|
import traceback
|
|
from collections import namedtuple
|
|
from pathlib import Path
|
|
from typing import Tuple, Dict
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
import open_clip
|
|
from PIL import Image
|
|
from huggingface_hub import hf_hub_download
|
|
from torchvision import transforms
|
|
from torchvision.transforms.functional import InterpolationMode
|
|
|
|
import modules.deepbooru
|
|
import modules.shared as shared
|
|
from extensions.sd_smartprocess import dbimutils
|
|
from modules import devices, paths, lowvram, modelloader
|
|
from modules import images
|
|
from modules.deepbooru import re_special as tag_escape_pattern
|
|
|
|
blip_image_eval_size = 384
|
|
clip_model_name = 'ViT-L/14'
|
|
|
|
Category = namedtuple("Category", ["name", "topn", "items"])
|
|
|
|
re_topn = re.compile(r"\.top(\d+)\.")
|
|
|
|
use_cpu = shared.cmd_opts.use_cpu == 'all' or shared.cmd_opts.use_cpu == 'interrogate'
|
|
onyx_providers = []
|
|
if use_cpu:
|
|
tf_device_name = '/cpu:0'
|
|
onyx_providers = ['CPUExecutionProvider']
|
|
else:
|
|
tf_device_name = '/gpu:0'
|
|
onyx_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
|
|
if shared.cmd_opts.device_id is not None:
|
|
try:
|
|
tf_device_name = f'/gpu:{int(shared.cmd_opts.device_id)}'
|
|
except ValueError:
|
|
print('--device-id is not a integer')
|
|
|
|
|
|
class Interrogator:
|
|
@staticmethod
|
|
def postprocess_tags(
|
|
tags: Dict[str, float],
|
|
threshold=0.35,
|
|
additional_tags=None,
|
|
exclude_tags=None,
|
|
sort_by_alphabetical_order=False,
|
|
add_confident_as_weight=False,
|
|
replace_underscore=False,
|
|
replace_underscore_excludes=None,
|
|
escape_tag=False
|
|
) -> Dict[str, float]:
|
|
|
|
if replace_underscore_excludes is None:
|
|
replace_underscore_excludes = []
|
|
if exclude_tags is None:
|
|
exclude_tags = []
|
|
if additional_tags is None:
|
|
additional_tags = []
|
|
tags = {
|
|
**{t: 1.0 for t in additional_tags},
|
|
**tags
|
|
}
|
|
|
|
# those lines are totally not "pythonic" but looks better to me
|
|
tags = {
|
|
t: c
|
|
|
|
# sort by tag name or confident
|
|
for t, c in sorted(
|
|
tags.items(),
|
|
key=lambda i: i[0 if sort_by_alphabetical_order else 1],
|
|
reverse=not sort_by_alphabetical_order
|
|
)
|
|
|
|
# filter tags
|
|
if (
|
|
c >= threshold
|
|
and t not in exclude_tags
|
|
)
|
|
}
|
|
|
|
for tag in list(tags):
|
|
new_tag = tag
|
|
|
|
if replace_underscore and tag not in replace_underscore_excludes:
|
|
new_tag = new_tag.replace('_', ' ')
|
|
|
|
if escape_tag:
|
|
new_tag = tag_escape_pattern.sub(r'\\\1', new_tag)
|
|
|
|
if add_confident_as_weight:
|
|
new_tag = f'({new_tag}:{tags[tag]})'
|
|
|
|
if new_tag != tag:
|
|
tags[new_tag] = tags.pop(tag)
|
|
|
|
return tags
|
|
|
|
def interrogate(
|
|
self,
|
|
image: Image
|
|
) -> Tuple[
|
|
Dict[str, float], # rating confidence
|
|
Dict[str, float] # tag confidence
|
|
]:
|
|
pass
|
|
|
|
|
|
re_special = re.compile(r'([\\()])')
|
|
|
|
|
|
class BooruInterrogator(Interrogator):
|
|
def __init__(self) -> None:
|
|
self.tags = None
|
|
self.booru = modules.deepbooru.DeepDanbooru()
|
|
self.booru.start()
|
|
self.model = self.booru.model
|
|
|
|
def unload(self):
|
|
self.booru.stop()
|
|
|
|
def interrogate(self, pil_image) -> Dict[str, float]:
|
|
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
|
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
|
|
|
with torch.no_grad(), devices.autocast():
|
|
x = torch.from_numpy(a).to(devices.device)
|
|
y = self.model(x)[0].detach().cpu().numpy()
|
|
|
|
probability_dict = {}
|
|
|
|
for tag, probability in zip(self.model.tags, y):
|
|
if tag.startswith("rating:"):
|
|
continue
|
|
|
|
probability_dict[tag] = probability
|
|
|
|
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
|
|
|
|
output = {}
|
|
for tag in tags:
|
|
probability = probability_dict[tag]
|
|
tag_outformat = tag
|
|
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
|
output[tag_outformat] = probability
|
|
|
|
return output
|
|
|
|
|
|
class WaifuDiffusionInterrogator(Interrogator):
|
|
def __init__(
|
|
self,
|
|
repo='SmilingWolf/wd-v1-4-vit-tagger',
|
|
model_path='model.onnx',
|
|
tags_path='selected_tags.csv'
|
|
) -> None:
|
|
self.tags = None
|
|
self.model = None
|
|
self.repo = repo
|
|
self.model_path = model_path
|
|
self.tags_path = tags_path
|
|
self.load()
|
|
|
|
def download(self) -> Tuple[os.PathLike, os.PathLike]:
|
|
print(f'Loading Waifu Diffusion tagger model file from {self.repo}')
|
|
|
|
model_path = Path(hf_hub_download(self.repo, filename=self.model_path))
|
|
tags_path = Path(hf_hub_download(self.repo, filename=self.tags_path))
|
|
return model_path, tags_path
|
|
|
|
def load(self) -> None:
|
|
model_path, tags_path = self.download()
|
|
from launch import is_installed, run_pip
|
|
if not is_installed('onnxruntime'):
|
|
package_name = 'onnxruntime-gpu'
|
|
|
|
if use_cpu or not torch.cuda.is_available():
|
|
package_name = 'onnxruntime'
|
|
|
|
package = os.environ.get(
|
|
'ONNXRUNTIME_PACKAGE',
|
|
package_name
|
|
)
|
|
|
|
run_pip(f'install {package}', package_name)
|
|
|
|
from onnxruntime import InferenceSession
|
|
|
|
self.model = InferenceSession(str(model_path), providers=onyx_providers)
|
|
|
|
print(f'Loaded Waifu Diffusion tagger model from {model_path}')
|
|
self.tags = pd.read_csv(tags_path)
|
|
|
|
def unload(self):
|
|
if self.model is not None:
|
|
del self.model
|
|
|
|
def interrogate(
|
|
self,
|
|
image: Image
|
|
) -> Tuple[
|
|
Dict[str, float], # rating confidence
|
|
Dict[str, float] # tag confidence
|
|
]:
|
|
# code for converting the image and running the model is taken from the link below
|
|
# thanks, SmilingWolf!
|
|
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py
|
|
|
|
# convert an image to fit the model
|
|
_, height, _, _ = self.model.get_inputs()[0].shape
|
|
|
|
# alpha to white
|
|
image = image.convert('RGBA')
|
|
new_image = Image.new('RGBA', image.size, 'WHITE')
|
|
new_image.paste(image, mask=image)
|
|
image = new_image.convert('RGB')
|
|
image = np.asarray(image)
|
|
|
|
# PIL RGB to OpenCV BGR
|
|
image = image[:, :, ::-1]
|
|
|
|
image = dbimutils.make_square(image, height)
|
|
image = dbimutils.smart_resize(image, height)
|
|
image = image.astype(np.float32)
|
|
image = np.expand_dims(image, 0)
|
|
|
|
# evaluate model
|
|
input_name = self.model.get_inputs()[0].name
|
|
label_name = self.model.get_outputs()[0].name
|
|
confidence = self.model.run([label_name], {input_name: image})[0]
|
|
|
|
tags = self.tags[:][['name']]
|
|
tags['confidence'] = confidence[0]
|
|
|
|
# first 4 items are for rating (general, sensitive, questionable, explicit)
|
|
ratings = dict(tags[:4].values)
|
|
|
|
# rest are regular tags
|
|
tags = dict(tags[4:].values)
|
|
|
|
return ratings, tags
|