stable-diffusion-webui-data.../scripts/tagger.py

53 lines
1.4 KiB
Python

import re
from typing import Optional, Generator, Any
from PIL import Image
from modules import shared, lowvram, devices
from modules import deepbooru as db
# Custom tagger classes have to inherit from this class
class Tagger:
def __enter__(self):
lowvram.send_everything_to_cpu()
devices.torch_gc()
self.start()
return self
def __exit__(self, exception_type, exception_value, traceback):
self.stop()
pass
def start(self):
pass
def stop(self):
pass
# predict tags of one image
def predict(self, image: Image.Image, threshold: Optional[float] = None) -> list[str]:
raise NotImplementedError()
# Please implement if you want to use more efficient data loading system
# None input will come to check if this function is implemented
def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None) -> Generator[list[str], Any, None]:
raise NotImplementedError()
# Visible name in UI
def name(self):
raise NotImplementedError()
def get_replaced_tag(tag: str):
use_spaces = shared.opts.deepbooru_use_spaces
use_escape = shared.opts.deepbooru_escape
if use_spaces:
tag = tag.replace('_', ' ')
if use_escape:
tag = re.sub(db.re_special, r'\\\1', tag)
return tag
def get_arranged_tags(probs: dict[str, float]):
return [tag for tag, _ in sorted(probs.items(), key=lambda x: -x[1])]