mirror of https://github.com/vladmandic/automatic
Upgrade to typed NamedTuples
parent
1e9bef8d56
commit
8ef8074467
|
|
@ -1,14 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
import gradio as gr
|
||||
from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
|
||||
from modules.logger import log
|
||||
from installer import control_extensions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
AlwaysVisible = object()
|
||||
time_component = {}
|
||||
|
|
@ -211,10 +216,22 @@ def basedir():
|
|||
return current_basedir
|
||||
|
||||
|
||||
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path", "priority"])
|
||||
class ScriptFile(NamedTuple):
|
||||
basedir: str
|
||||
filename: str
|
||||
path: str
|
||||
priority: str
|
||||
|
||||
|
||||
class ScriptClassData(NamedTuple):
|
||||
script_class: type[Script] | type[scripts_postprocessing.ScriptPostprocessing]
|
||||
path: str
|
||||
basedir: str
|
||||
module: ModuleType
|
||||
|
||||
|
||||
scripts_data = []
|
||||
postprocessing_scripts_data = []
|
||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||
|
||||
|
||||
def list_scripts(scriptdirname, extension):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from modules.logger import log
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
|
||||
|
||||
repo_id = 'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf'
|
||||
policy_template = """Provide a safety assessment for the content provided by the user. Determine compliance or violation against our safety policy by reviewing the following policy categories:
|
||||
Hate:
|
||||
|
|
@ -89,11 +96,11 @@ To provide your assessment use the following json template for each category:
|
|||
"rationale": str,
|
||||
}.
|
||||
"""
|
||||
model = None
|
||||
processor = None
|
||||
model: LlavaOnevisionForConditionalGeneration | None = None
|
||||
processor: AutoProcessor | None = None
|
||||
|
||||
|
||||
def image_guard(image, policy:str | None=None) -> str:
|
||||
def image_guard(image, policy:str | None=None):
|
||||
global model, processor # pylint: disable=global-statement
|
||||
import json
|
||||
from installer import install
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import time
|
||||
import gradio as gr
|
||||
from modules import scripts, scripts_postprocessing, processing, images
|
||||
from modules import scripts, processing, images
|
||||
from modules.scripts_postprocessing import PostprocessedImage, ScriptPostprocessing
|
||||
from scripts.nudenet import nudenet # pylint: disable=no-name-in-module
|
||||
from scripts.nudenet import langdetect # pylint: disable=no-name-in-module
|
||||
from scripts.nudenet import imageguard # pylint: disable=no-name-in-module
|
||||
|
|
@ -49,14 +50,14 @@ def create_ui(accordion=True):
|
|||
|
||||
# main processing used in both modes
|
||||
def process(
|
||||
p: processing.StableDiffusionProcessing=None,
|
||||
pp: scripts.PostprocessImageArgs=None,
|
||||
p: processing.StableDiffusionProcessing | None = None,
|
||||
pp: scripts.PostprocessImageArgs | PostprocessedImage | None = None,
|
||||
enabled=True,
|
||||
lang=False,
|
||||
policy=False,
|
||||
banned=False,
|
||||
metadata=True,
|
||||
copy=False,
|
||||
copy=False, # Compatability
|
||||
score=0.2,
|
||||
blocks=3,
|
||||
censor=[],
|
||||
|
|
@ -74,7 +75,7 @@ def process(
|
|||
nudes = nudenet.detector.censor(image=pp.image, method=method, min_score=score, censor=censor, blocks=blocks, overlay=overlay)
|
||||
t1 = time.time()
|
||||
if len(nudes.censored) > 0: # Check if there are any censored areas
|
||||
if not copy:
|
||||
if p is None:
|
||||
pp.image = nudes.output
|
||||
else:
|
||||
info = processing.create_infotext(p)
|
||||
|
|
@ -85,7 +86,7 @@ def process(
|
|||
if metadata and p is not None:
|
||||
p.extra_generation_params["NudeNet"] = meta
|
||||
p.extra_generation_params["NSFW"] = nsfw
|
||||
if metadata and hasattr(pp, 'info'):
|
||||
if metadata and isinstance(pp, PostprocessedImage):
|
||||
pp.info['NudeNet'] = meta
|
||||
pp.info['NSFW'] = nsfw
|
||||
log.debug(f'NudeNet detect: {dct} nsfw={nsfw} time={(t1 - t0):.2f}')
|
||||
|
|
@ -118,7 +119,7 @@ def process(
|
|||
if metadata and p is not None:
|
||||
p.extra_generation_params["Rating"] = res.get('rating', 'N/A')
|
||||
p.extra_generation_params["Category"] = res.get('category', 'N/A')
|
||||
if metadata and hasattr(pp, 'info'):
|
||||
if metadata and isinstance(pp, PostprocessedImage):
|
||||
pp.info["Rating"] = res.get('rating', 'N/A')
|
||||
pp.info["Category"] = res.get('category', 'N/A')
|
||||
|
||||
|
|
@ -131,11 +132,11 @@ class ScriptNudeNet(scripts.Script):
|
|||
def title(self):
|
||||
return 'NudeNet'
|
||||
|
||||
def show(self, _is_img2img):
|
||||
def show(self, *args, **kwargs):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
# return signature is array of gradio components
|
||||
def ui(self, _is_img2img):
|
||||
def ui(self, *args, **kwargs):
|
||||
return create_ui(accordion=True)
|
||||
|
||||
# triggered by callback
|
||||
|
|
@ -148,7 +149,7 @@ class ScriptNudeNet(scripts.Script):
|
|||
|
||||
|
||||
# defines postprocessing script for dual-mode usage
|
||||
class ScriptPostprocessingNudeNet(scripts_postprocessing.ScriptPostprocessing):
|
||||
class ScriptPostprocessingNudeNet(ScriptPostprocessing):
|
||||
name = 'NudeNet'
|
||||
order = 10000
|
||||
|
||||
|
|
@ -158,5 +159,5 @@ class ScriptPostprocessingNudeNet(scripts_postprocessing.ScriptPostprocessing):
|
|||
return { 'enabled': enabled, 'lang': lang, 'policy': policy, 'banned': banned, 'metadata': metadata, 'copy': copy, 'score': score, 'blocks': blocks, 'censor': censor, 'method': method, 'overlay': overlay, 'allowed': allowed, 'alphabet': alphabet, 'words': words}
|
||||
|
||||
# triggered by callback
|
||||
def process(self, pp: scripts_postprocessing.PostprocessedImage, enabled, lang, policy, banned, metadata, copy, score, blocks, censor, method, overlay, allowed, alphabet, words): # pylint: disable=arguments-differ
|
||||
def process(self, pp: PostprocessedImage, enabled, lang, policy, banned, metadata, copy, score, blocks, censor, method, overlay, allowed, alphabet, words): # pylint: disable=arguments-differ
|
||||
process(None, pp, enabled, lang, policy, banned, metadata, copy, score, blocks, censor, method, overlay, allowed, alphabet, words)
|
||||
|
|
|
|||
Loading…
Reference in New Issue