563 lines
19 KiB
Python
563 lines
19 KiB
Python
import os
|
|
from pathlib import Path
|
|
from io import BytesIO
|
|
from hashlib import sha256
|
|
import json
|
|
from pandas import read_csv, read_json
|
|
from PIL import Image
|
|
from typing import Tuple, List, Dict, Callable
|
|
from numpy import asarray, float32, expand_dims
|
|
from tqdm import tqdm
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
from modules import shared
|
|
|
|
from . import dbimutils
|
|
from tagger import settings
|
|
from tagger.uiset import QData, IOData, ItRetTP
|
|
|
|
Its = settings.InterrogatorSettings
|
|
|
|
# select a device to process
|
|
use_cpu = ('all' in shared.cmd_opts.use_cpu) or (
|
|
'interrogate' in shared.cmd_opts.use_cpu)
|
|
|
|
if use_cpu:
|
|
TF_DEVICE_NAME = '/cpu:0'
|
|
else:
|
|
TF_DEVICE_NAME = '/gpu:0'
|
|
|
|
if shared.cmd_opts.device_id is not None:
|
|
try:
|
|
TF_DEVICE_NAME = f'/gpu:{int(shared.cmd_opts.device_id)}'
|
|
except ValueError:
|
|
print('--device-id is not an integer')
|
|
|
|
|
|
def get_file_interrogator_id(bytes, interrogator_name):
|
|
hasher = sha256()
|
|
hasher.update(bytes)
|
|
return str(hasher.hexdigest()) + interrogator_name
|
|
|
|
|
|
def split_str(string: str, separator=',') -> List[str]:
|
|
return [x.strip() for x in string.split(separator) if x]
|
|
|
|
|
|
class Interrogator:
|
|
# the raw input and output.
|
|
input = {
|
|
"cumulative": False,
|
|
"large_query": False,
|
|
"unload_after": False,
|
|
"add": '',
|
|
"keep": '',
|
|
"exclude": '',
|
|
"search": '',
|
|
"replace": '',
|
|
"paths": '',
|
|
"input_glob": '',
|
|
"output_dir": '',
|
|
}
|
|
output = None
|
|
err = {}
|
|
odd_increment = 0
|
|
|
|
@classmethod
|
|
def flip(cls, key):
|
|
def toggle():
|
|
cls.input[key] = not cls.input[key]
|
|
return toggle
|
|
|
|
@classmethod
|
|
def set(cls, key: str) -> Callable[[str], Tuple[str, str]]:
|
|
def setter(val) -> Tuple[str, str]:
|
|
if key in cls.err:
|
|
del cls.err[key]
|
|
err = ''
|
|
if val != cls.input[key]:
|
|
if key == 'input_glob' or key == 'output_dir':
|
|
err = getattr(IOData, "update_" + key)(val)
|
|
if key == 'input_glob' and err == '':
|
|
QData.tags.clear()
|
|
QData.ratings.clear()
|
|
QData.in_db.clear()
|
|
else:
|
|
err = getattr(QData, "update_" + key)(val)
|
|
if err:
|
|
cls.err[key] = err
|
|
else:
|
|
err = ''
|
|
cls.input[key] = val
|
|
return (cls.input[key], err)
|
|
|
|
return setter
|
|
|
|
@classmethod
|
|
def load_image(cls, path: str) -> Image:
|
|
try:
|
|
return Image.open(path)
|
|
except Exception as e:
|
|
# just in case, user has mysterious file...
|
|
print(f'${path} is not supported image type: {e}')
|
|
return None
|
|
|
|
def __init__(self, name: str) -> None:
|
|
self.name = name
|
|
# run_mode 0 is dry run, 1 means run (alternating), 2 means disabled
|
|
self.run_mode = 0 if hasattr(self, "large_batch_interrogate") else 2
|
|
|
|
def load(self):
|
|
raise NotImplementedError()
|
|
|
|
def unload(self) -> bool:
|
|
unloaded = False
|
|
|
|
if hasattr(self, 'model') and self.model is not None:
|
|
del self.model
|
|
unloaded = True
|
|
print(f'Unloaded {self.name}')
|
|
|
|
if hasattr(self, 'tags'):
|
|
del self.tags
|
|
|
|
return unloaded
|
|
|
|
def interrogate_image(self, image: Image) -> ItRetTP:
|
|
sha = IOData.get_bytes_hash(image.tobytes())
|
|
QData.tags.clear()
|
|
QData.ratings.clear()
|
|
if not Interrogator.input["cumulative"]:
|
|
QData.in_db.clear()
|
|
fi_key = sha + self.name
|
|
ct = 0
|
|
QData.for_tags_file.clear()
|
|
|
|
if fi_key in QData.query:
|
|
# this file was already queried for this interrogator.
|
|
QData.single_data(fi_key)
|
|
else:
|
|
# single process
|
|
ct += 1
|
|
data = ('', '', fi_key) + self.interrogate(image)
|
|
# When drag-dropping an image, the path [0] is not known
|
|
if Interrogator.input["unload_after"]:
|
|
self.unload()
|
|
|
|
QData.query[fi_key] = ('', len(QData.query))
|
|
QData.apply_filters(data)
|
|
|
|
for got in QData.in_db.values():
|
|
QData.apply_filters(got)
|
|
|
|
Interrogator.output = QData.finalize(ct)
|
|
return Interrogator.output
|
|
|
|
def batch_interrogate(self) -> ItRetTP:
|
|
QData.tags.clear()
|
|
QData.ratings.clear()
|
|
if not Interrogator.input["cumulative"]:
|
|
QData.in_db.clear()
|
|
|
|
if Interrogator.input["large_query"] is True and self.run_mode < 2:
|
|
# TODO: write specified tags files instead of simple .txt
|
|
image_list = [str(x[0].resolve()) for x in IOData.paths]
|
|
err = self.large_batch_interrogate(image_list, self.run_mode == 0)
|
|
if err:
|
|
return (None, None, None, err)
|
|
|
|
self.run_mode = (self.run_mode + 1) % 2
|
|
Interrogator.output = QData.finalize()
|
|
return Interrogator.output
|
|
|
|
vb = getattr(shared.opts, 'tagger_verbose', True)
|
|
ct = len(QData.query)
|
|
|
|
for i in tqdm(range(len(IOData.paths)), disable=vb, desc='Tags'):
|
|
# if outputpath is '', no tags file will be written
|
|
if len(IOData.paths[i]) == 5:
|
|
path, out_path, output_dir, image_hash, image = IOData.paths[i]
|
|
elif len(IOData.paths[i]) == 4:
|
|
path, out_path, output_dir, image_hash = IOData.paths[i]
|
|
image = Interrogator.load_image(path)
|
|
# should work, we queried before to get the image_hash
|
|
else:
|
|
path, out_path, output_dir = IOData.paths[i]
|
|
image = Interrogator.load_image(path)
|
|
if image is None:
|
|
continue
|
|
|
|
image_hash = IOData.get_bytes_hash(image.tobytes())
|
|
IOData.paths[i].append(image_hash)
|
|
if getattr(shared.opts, 'tagger_store_images', False):
|
|
IOData.paths[i].append(image)
|
|
|
|
if output_dir:
|
|
output_dir.mkdir(0o755, True, True)
|
|
# next iteration we don't need to create the directory
|
|
IOData.paths[i][2] = ''
|
|
|
|
abspath = str(path.absolute())
|
|
fi_key = image_hash + self.name
|
|
|
|
if fi_key in QData.query:
|
|
# this file was already queried for this interrogator.
|
|
index = QData.get_index(fi_key, abspath)
|
|
# this file was already queried and stored
|
|
QData.in_db[index] = (abspath, out_path, '', {}, {})
|
|
else:
|
|
data = (abspath, out_path, fi_key) + self.interrogate(image)
|
|
QData.apply_filters(data)
|
|
QData.had_new = True
|
|
|
|
|
|
if Interrogator.input["unload_after"]:
|
|
self.unload()
|
|
|
|
ct = len(QData.query) - ct
|
|
Interrogator.output = QData.finalize_batch(ct)
|
|
return Interrogator.output
|
|
|
|
def interrogate(
|
|
self,
|
|
image: Image
|
|
) -> Tuple[
|
|
Dict[str, float], # rating confidences
|
|
Dict[str, float] # tag confidences
|
|
]:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class DeepDanbooruInterrogator(Interrogator):
|
|
def __init__(self, name: str, project_path: os.PathLike) -> None:
|
|
super().__init__(name)
|
|
self.project_path = project_path
|
|
|
|
def load(self) -> None:
|
|
print(f'Loading {self.name} from {str(self.project_path)}')
|
|
|
|
# deepdanbooru package is not include in web-sd anymore
|
|
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/c81d440d876dfd2ab3560410f37442ef56fc663
|
|
from launch import is_installed, run_pip
|
|
if not is_installed('deepdanbooru'):
|
|
package = os.environ.get(
|
|
'DEEPDANBOORU_PACKAGE',
|
|
'git+https://github.com/KichangKim/DeepDanbooru.'
|
|
'git@d91a2963bf87c6a770d74894667e9ffa9f6de7ff'
|
|
)
|
|
|
|
run_pip(
|
|
f'install {package} tensorflow tensorflow-io', 'deepdanbooru')
|
|
|
|
import tensorflow as tf
|
|
|
|
# tensorflow maps nearly all vram by default, so we limit this
|
|
# https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
|
|
# TODO: only run on the first run
|
|
for device in tf.config.experimental.list_physical_devices('GPU'):
|
|
try:
|
|
tf.config.experimental.set_memory_growth(device, True)
|
|
except RuntimeError as e:
|
|
print(e)
|
|
|
|
with tf.device(TF_DEVICE_NAME):
|
|
import deepdanbooru.project as ddp
|
|
|
|
self.model = ddp.load_model_from_project(
|
|
project_path=self.project_path,
|
|
compile_model=False
|
|
)
|
|
|
|
print(f'Loaded {self.name} model from {str(self.project_path)}')
|
|
|
|
self.tags = ddp.load_tags_from_project(
|
|
project_path=self.project_path
|
|
)
|
|
|
|
def unload(self) -> bool:
|
|
# unloaded = super().unload()
|
|
|
|
# if unloaded:
|
|
# # tensorflow suck
|
|
# # https://github.com/keras-team/keras/issues/2102
|
|
# import tensorflow as tf
|
|
# tf.keras.backend.clear_session()
|
|
# gc.collect()
|
|
|
|
# return unloaded
|
|
|
|
# There is a bug in Keras where it is not possible to release a model
|
|
# that has been loaded into memory. Downgrading to keras==2.1.6 may
|
|
# solve the issue, but it may cause compatibility issues with other
|
|
# packages. Using subprocess to create a new process may also solve the
|
|
# problem, but it can be too complex (like Automatic1111 did). It seems
|
|
# that for now, the best option is to keep the model in memory, as most
|
|
# users use the Waifu Diffusion model with onnx.
|
|
return False
|
|
|
|
def interrogate(
|
|
self,
|
|
image: Image
|
|
) -> Tuple[
|
|
Dict[str, float], # rating confidences
|
|
Dict[str, float] # tag confidences
|
|
]:
|
|
# init model
|
|
if not hasattr(self, 'model') or self.model is None:
|
|
self.load()
|
|
|
|
import deepdanbooru.data as ddd
|
|
|
|
# convert an image to fit the model
|
|
image_bufs = BytesIO()
|
|
image.save(image_bufs, format='PNG')
|
|
image = ddd.load_image_for_evaluate(
|
|
image_bufs,
|
|
self.model.input_shape[2],
|
|
self.model.input_shape[1]
|
|
)
|
|
|
|
image = image.reshape((1, *image.shape[0:3]))
|
|
|
|
# evaluate model
|
|
result = self.model.predict(image)
|
|
|
|
confidences = result[0].tolist()
|
|
ratings = {}
|
|
tags = {}
|
|
|
|
for i, tag in enumerate(self.tags):
|
|
if tag[:7] != "rating:":
|
|
tags[tag] = confidences[i]
|
|
else:
|
|
ratings[tag[7:]] = confidences[i]
|
|
|
|
return ratings, tags
|
|
|
|
|
|
class WaifuDiffusionInterrogator(Interrogator):
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
model_path='model.onnx',
|
|
tags_path='selected_tags.csv',
|
|
**kwargs
|
|
) -> None:
|
|
super().__init__(name)
|
|
self.model_path = model_path
|
|
self.tags_path = tags_path
|
|
self.kwargs = kwargs
|
|
|
|
def download(self) -> Tuple[os.PathLike, os.PathLike]:
|
|
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")
|
|
|
|
mdir = Path(shared.models_path, 'interrogators')
|
|
model_path = Path(hf_hub_download(**self.kwargs, filename=self.model_path, cache_dir=mdir))
|
|
tags_path = Path(hf_hub_download(**self.kwargs, filename=self.tags_path, cache_dir=mdir))
|
|
|
|
download_model = {
|
|
'name': self.name,
|
|
'model_path': str(model_path),
|
|
'tags_path': str(tags_path),
|
|
}
|
|
mpath = Path(mdir, 'model.json')
|
|
|
|
if not os.path.exists(mdir):
|
|
os.makedir(mdir)
|
|
|
|
elif os.path.exists(mpath):
|
|
with open(mpath, 'r') as f:
|
|
try:
|
|
data = json.load(f)
|
|
data.append(download_model)
|
|
except Exception as e:
|
|
print(f'Adding download_model {mpath} raised {repr(e)}')
|
|
data = [download_model]
|
|
|
|
with open(mpath, 'w') as f:
|
|
json.dump(data, f)
|
|
|
|
return model_path, tags_path
|
|
|
|
def get_model_path(self) -> Tuple[os.PathLike, os.PathLike]:
|
|
model_path = ''
|
|
tags_path = ''
|
|
mpath = Path(shared.models_path, 'interrogators', 'model.json')
|
|
try:
|
|
models = read_json(mpath).to_dict(orient='records')
|
|
i = next(i for i in models if i['name'] == self.name)
|
|
model_path = i['model_path']
|
|
tags_path = i['tags_path']
|
|
except Exception as e:
|
|
print(f'{mpath}: requires a name, model_ and tags_path: {repr(e)}')
|
|
model_path, tags_path = self.download()
|
|
return model_path, tags_path
|
|
|
|
def load(self) -> None:
|
|
if isinstance(self.model_path, str) or isinstance(self.tags_path, str):
|
|
model_path, tags_path = self.download()
|
|
else:
|
|
model_path = self.model_path
|
|
tags_path = self.tags_path
|
|
|
|
# only one of these packages should be installed a time in any one env
|
|
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
|
|
# TODO: remove old package when the environment changes?
|
|
from launch import is_installed, run_pip
|
|
if not is_installed('onnxruntime'):
|
|
package = os.environ.get(
|
|
'ONNXRUNTIME_PACKAGE',
|
|
'onnxruntime-gpu'
|
|
)
|
|
|
|
run_pip(f'install {package}', 'onnxruntime')
|
|
|
|
from onnxruntime import InferenceSession
|
|
|
|
# https://onnxruntime.ai/docs/execution-providers/
|
|
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
if use_cpu:
|
|
providers.pop(0)
|
|
|
|
print(f'Loading {self.name} model from {model_path}, {tags_path}')
|
|
self.model = InferenceSession(str(model_path), providers=providers)
|
|
self.tags = read_csv(tags_path)
|
|
|
|
def interrogate(
|
|
self,
|
|
image: Image
|
|
) -> Tuple[
|
|
Dict[str, float], # rating confidences
|
|
Dict[str, float] # tag confidences
|
|
]:
|
|
# init model
|
|
if not hasattr(self, 'model') or self.model is None:
|
|
self.load()
|
|
|
|
# code for converting the image and running the model is taken from the
|
|
# link below. thanks, SmilingWolf!
|
|
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py
|
|
|
|
# convert an image to fit the model
|
|
_, height, _, _ = self.model.get_inputs()[0].shape
|
|
|
|
# alpha to white
|
|
image = image.convert('RGBA')
|
|
new_image = Image.new('RGBA', image.size, 'WHITE')
|
|
new_image.paste(image, mask=image)
|
|
image = new_image.convert('RGB')
|
|
image = asarray(image)
|
|
|
|
# PIL RGB to OpenCV BGR
|
|
image = image[:, :, ::-1]
|
|
|
|
image = dbimutils.make_square(image, height)
|
|
image = dbimutils.smart_resize(image, height)
|
|
image = image.astype(float32)
|
|
image = expand_dims(image, 0)
|
|
|
|
# evaluate model
|
|
input_name = self.model.get_inputs()[0].name
|
|
label_name = self.model.get_outputs()[0].name
|
|
confidences = self.model.run([label_name], {input_name: image})[0]
|
|
|
|
tags = self.tags[:][['name']]
|
|
tags['confidences'] = confidences[0]
|
|
|
|
# first 4 items are for rating (general, sensitive, questionable,
|
|
# explicit)
|
|
ratings = dict(tags[:4].values)
|
|
|
|
# rest are regular tags
|
|
tags = dict(tags[4:].values)
|
|
|
|
return ratings, tags
|
|
|
|
def large_batch_interrogate(self, images_list, dry_run=True) -> str:
|
|
|
|
# init model
|
|
if not hasattr(self, 'model') or self.model is None:
|
|
self.load()
|
|
|
|
os.environ["TF_XLA_FLAGS"] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
|
|
# Reduce logging
|
|
# os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
|
|
|
|
import tensorflow as tf
|
|
|
|
from tagger.Generator.TFDataReader import DataGenerator
|
|
|
|
# tensorflow maps nearly all vram by default, so we limit this
|
|
# https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
|
|
# TODO: only run on the first run
|
|
gpus = tf.config.experimental.list_physical_devices("GPU")
|
|
if gpus:
|
|
for device in gpus:
|
|
try:
|
|
tf.config.experimental.set_memory_growth(device, True)
|
|
except RuntimeError as e:
|
|
print(e)
|
|
|
|
if dry_run: # dry run
|
|
height, width = 224, 224
|
|
|
|
def process_images(filepaths, images):
|
|
lines = []
|
|
for image_path in filepaths:
|
|
image_path = image_path.numpy().decode("utf-8")
|
|
lines.append(f"{image_path}\n")
|
|
with open("dry_run_read.txt", "a") as f:
|
|
f.writelines(lines)
|
|
|
|
scheduled = [f"{image_path}\n" for image_path in images_list]
|
|
|
|
# Truncate the file from previous runs
|
|
print("updating dry_run_read.txt")
|
|
open("dry_run_read.txt", "w").close()
|
|
with open("dry_run_scheduled.txt", "w") as f:
|
|
f.writelines(scheduled)
|
|
else:
|
|
_, height, width, _ = self.model.inputs[0].shape
|
|
|
|
threshold = QData.threshold
|
|
self.tags["sanitized_name"] = self.tags["name"].map(
|
|
lambda x: x if x in Its.kaomojis else x.replace("_", " ")
|
|
)
|
|
|
|
@tf.function
|
|
def pred_model(x):
|
|
return self.model(x, training=False)
|
|
|
|
def process_images(filepaths, images):
|
|
preds = pred_model(images).numpy()
|
|
|
|
for image_path, pred in zip(filepaths, preds):
|
|
image_path = image_path.numpy().decode("utf-8")
|
|
|
|
self.tags["preds"] = pred
|
|
general_tags = self.tags[self.tags["category"] == 0]
|
|
chosen_tags = general_tags[general_tags["preds"] > threshold]
|
|
chosen_tags = chosen_tags.sort_values(by="preds", ascending=False)
|
|
tags_names = chosen_tags["sanitized_name"]
|
|
|
|
fi_key = image_path.split("/")[-1].split(".")[0] + "_" + self.name
|
|
QData.add_tags = tags_names
|
|
QData.apply_filters((image_path, '', {}, {}), fi_key, False)
|
|
|
|
tags_string = ", ".join(tags_names)
|
|
with open(Path(image_path).with_suffix(".txt"), "w") as f:
|
|
f.write(tags_string)
|
|
|
|
batch_size = getattr(shared.opts, 'tagger_batch_size', 1024)
|
|
generator = DataGenerator(
|
|
file_list=images_list, target_height=height, target_width=width, batch_size=batch_size
|
|
).genDS()
|
|
|
|
orig_add_tags = QData.add_tags
|
|
for filepaths, images in tqdm(generator):
|
|
process_images(filepaths, images)
|
|
QData.add_tag = orig_add_tags
|
|
del os.environ["TF_XLA_FLAGS"]
|
|
return ''
|