diff --git a/modules/scripts_manager.py b/modules/scripts_manager.py index e9381f53b..7672c9cf6 100644 --- a/modules/scripts_manager.py +++ b/modules/scripts_manager.py @@ -1,14 +1,19 @@ +from __future__ import annotations + import os import re import sys import time -from collections import namedtuple from dataclasses import dataclass +from typing import TYPE_CHECKING, NamedTuple import gradio as gr from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer from modules.logger import log from installer import control_extensions +if TYPE_CHECKING: + from types import ModuleType + AlwaysVisible = object() time_component = {} @@ -211,10 +216,22 @@ def basedir(): return current_basedir -ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path", "priority"]) +class ScriptFile(NamedTuple): + basedir: str + filename: str + path: str + priority: str + + +class ScriptClassData(NamedTuple): + script_class: type[Script] | type[scripts_postprocessing.ScriptPostprocessing] + path: str + basedir: str + module: ModuleType + + scripts_data = [] postprocessing_scripts_data = [] -ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) def list_scripts(scriptdirname, extension): diff --git a/scripts/nudenet/imageguard.py b/scripts/nudenet/imageguard.py index 068d21133..c92dd43ee 100644 --- a/scripts/nudenet/imageguard.py +++ b/scripts/nudenet/imageguard.py @@ -1,4 +1,11 @@ +from __future__ import annotations + from modules.logger import log +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration + repo_id = 'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf' policy_template = """Provide a safety assessment for the content provided by the user. Determine compliance or violation against our safety policy by reviewing the following policy categories: Hate: @@ -89,11 +96,11 @@ To provide your assessment use the following json template for each category: "rationale": str, }. """ -model = None -processor = None +model: LlavaOnevisionForConditionalGeneration | None = None +processor: AutoProcessor | None = None -def image_guard(image, policy:str | None=None) -> str: +def image_guard(image, policy:str | None=None): global model, processor # pylint: disable=global-statement import json from installer import install diff --git a/scripts/nudenet_ext.py b/scripts/nudenet_ext.py index 3a735ba32..971007cb5 100644 --- a/scripts/nudenet_ext.py +++ b/scripts/nudenet_ext.py @@ -1,6 +1,7 @@ import time import gradio as gr -from modules import scripts, scripts_postprocessing, processing, images +from modules import scripts, processing, images +from modules.scripts_postprocessing import PostprocessedImage, ScriptPostprocessing from scripts.nudenet import nudenet # pylint: disable=no-name-in-module from scripts.nudenet import langdetect # pylint: disable=no-name-in-module from scripts.nudenet import imageguard # pylint: disable=no-name-in-module @@ -49,14 +50,14 @@ def create_ui(accordion=True): # main processing used in both modes def process( - p: processing.StableDiffusionProcessing=None, - pp: scripts.PostprocessImageArgs=None, + p: processing.StableDiffusionProcessing | None = None, + pp: scripts.PostprocessImageArgs | PostprocessedImage | None = None, enabled=True, lang=False, policy=False, banned=False, metadata=True, - copy=False, + copy=False, # Compatability score=0.2, blocks=3, censor=[], @@ -74,7 +75,7 @@ def process( nudes = nudenet.detector.censor(image=pp.image, method=method, min_score=score, censor=censor, blocks=blocks, overlay=overlay) t1 = time.time() if len(nudes.censored) > 0: # Check if there are any censored areas - if not copy: + if p is None: pp.image = nudes.output else: info = processing.create_infotext(p) @@ -85,7 +86,7 @@ def process( if metadata and p is not None: p.extra_generation_params["NudeNet"] = meta p.extra_generation_params["NSFW"] = nsfw - if metadata and hasattr(pp, 'info'): + if metadata and isinstance(pp, PostprocessedImage): pp.info['NudeNet'] = meta pp.info['NSFW'] = nsfw log.debug(f'NudeNet detect: {dct} nsfw={nsfw} time={(t1 - t0):.2f}') @@ -118,7 +119,7 @@ def process( if metadata and p is not None: p.extra_generation_params["Rating"] = res.get('rating', 'N/A') p.extra_generation_params["Category"] = res.get('category', 'N/A') - if metadata and hasattr(pp, 'info'): + if metadata and isinstance(pp, PostprocessedImage): pp.info["Rating"] = res.get('rating', 'N/A') pp.info["Category"] = res.get('category', 'N/A') @@ -131,11 +132,11 @@ class ScriptNudeNet(scripts.Script): def title(self): return 'NudeNet' - def show(self, _is_img2img): + def show(self, *args, **kwargs): return scripts.AlwaysVisible # return signature is array of gradio components - def ui(self, _is_img2img): + def ui(self, *args, **kwargs): return create_ui(accordion=True) # triggered by callback @@ -148,7 +149,7 @@ class ScriptNudeNet(scripts.Script): # defines postprocessing script for dual-mode usage -class ScriptPostprocessingNudeNet(scripts_postprocessing.ScriptPostprocessing): +class ScriptPostprocessingNudeNet(ScriptPostprocessing): name = 'NudeNet' order = 10000 @@ -158,5 +159,5 @@ class ScriptPostprocessingNudeNet(scripts_postprocessing.ScriptPostprocessing): return { 'enabled': enabled, 'lang': lang, 'policy': policy, 'banned': banned, 'metadata': metadata, 'copy': copy, 'score': score, 'blocks': blocks, 'censor': censor, 'method': method, 'overlay': overlay, 'allowed': allowed, 'alphabet': alphabet, 'words': words} # triggered by callback - def process(self, pp: scripts_postprocessing.PostprocessedImage, enabled, lang, policy, banned, metadata, copy, score, blocks, censor, method, overlay, allowed, alphabet, words): # pylint: disable=arguments-differ + def process(self, pp: PostprocessedImage, enabled, lang, policy, banned, metadata, copy, score, blocks, censor, method, overlay, allowed, alphabet, words): # pylint: disable=arguments-differ process(None, pp, enabled, lang, policy, banned, metadata, copy, score, blocks, censor, method, overlay, allowed, alphabet, words)