sd_smartprocess/interrogators/interrogator.py

115 lines
3.4 KiB
Python

# Borrowed from https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/master/tagger/interrogator.py
import importlib
import pkgutil
import re
import sys
from abc import abstractmethod
import torch
from PIL import Image
import extensions.sd_smartprocess.interrogators as interrogators
from extensions.sd_smartprocess.process_params import ProcessParams
@abstractmethod
class Interrogator:
def __init__(self, params: ProcessParams) -> None:
self.params = params
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
registry = InterrogatorRegistry()
registry.register(self)
def interrogate(self, image: Image, params: ProcessParams, unload: bool = False) -> str:
raise NotImplementedError
def unload(self):
if self.model:
try:
self.model = self.model.to("cpu")
except:
pass
def load(self):
if self.model:
try:
self.model = self.model.to(self.device)
except:
pass
re_special = re.compile(r'([\\()])')
class InterrogatorRegistry:
_instance = None # Class variable to hold the singleton instance
def __new__(cls):
if cls._instance is None:
cls._instance = super(InterrogatorRegistry, cls).__new__(cls)
cls._instance._init()
return cls._instance
def _init(self):
self.interrogators = {}
def register(self, interrogator: Interrogator):
self.interrogators[interrogator.__class__.__name__] = interrogator
def get_interrogator(self, interrogator_name: str) -> Interrogator:
return self.interrogators[interrogator_name]
def get_interrogators(self):
return self.interrogators
def unload(self):
for interrogator in self.interrogators.values():
interrogator.unload()
def load(self):
for interrogator in self.interrogators.values():
interrogator.load()
@staticmethod
def list_interrogators():
# Import all modules in the extensions.sd_smartprocess.interrogators package
package = interrogators
params_dict = {}
for importer, modname, ispkg in pkgutil.iter_modules(package.__path__, package.__name__ + '.'):
try:
importlib.import_module(modname)
except:
continue
# Find all subclasses of Interrogator globally
interrogator_dict = {}
for cls in Interrogator.__subclasses__():
# Try to get the params attribute from the class
params = getattr(cls, "params", {})
interrogator_dict[cls.__name__] = params
return interrogator_dict
@staticmethod
def get_all_interrogators():
# Import all modules in the extensions.sd_smartprocess.interrogators package
package = interrogators
params_dict = {}
for importer, modname, ispkg in pkgutil.iter_modules(package.__path__, package.__name__ + '.'):
try:
importlib.import_module(modname)
except:
continue
# Find all subclasses of Interrogator globally
interrogator_dict = {}
for cls in Interrogator.__subclasses__():
# Try to get the params attribute from the class
params = getattr(cls, "params", {})
interrogator_dict[cls.__name__] = cls
return interrogator_dict