53 lines
1.4 KiB
Python
53 lines
1.4 KiB
Python
import re
|
|
from typing import Optional, Generator, Any
|
|
|
|
from PIL import Image
|
|
|
|
from modules import shared, lowvram, devices
|
|
from modules import deepbooru as db
|
|
|
|
# Custom tagger classes have to inherit from this class
|
|
class Tagger:
|
|
def __enter__(self):
|
|
lowvram.send_everything_to_cpu()
|
|
devices.torch_gc()
|
|
self.start()
|
|
return self
|
|
|
|
def __exit__(self, exception_type, exception_value, traceback):
|
|
self.stop()
|
|
pass
|
|
|
|
def start(self):
|
|
pass
|
|
|
|
def stop(self):
|
|
pass
|
|
|
|
# predict tags of one image
|
|
def predict(self, image: Image.Image, threshold: Optional[float] = None) -> list[str]:
|
|
raise NotImplementedError()
|
|
|
|
# Please implement if you want to use more efficient data loading system
|
|
# None input will come to check if this function is implemented
|
|
def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None) -> Generator[list[str], Any, None]:
|
|
raise NotImplementedError()
|
|
|
|
# Visible name in UI
|
|
def name(self):
|
|
raise NotImplementedError()
|
|
|
|
|
|
def get_replaced_tag(tag: str):
|
|
use_spaces = shared.opts.deepbooru_use_spaces
|
|
use_escape = shared.opts.deepbooru_escape
|
|
if use_spaces:
|
|
tag = tag.replace('_', ' ')
|
|
if use_escape:
|
|
tag = re.sub(db.re_special, r'\\\1', tag)
|
|
return tag
|
|
|
|
|
|
def get_arranged_tags(probs: dict[str, float]):
|
|
return [tag for tag, _ in sorted(probs.items(), key=lambda x: -x[1])]
|