Merge pull request #4722 from awsr/various1

typing and typechecks updates
pull/4729/head
Vladimir Mandic 2026-04-02 08:36:41 +02:00 committed by GitHub
commit 668a94141d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 165 additions and 129 deletions

View File

@ -42,7 +42,7 @@ Resut should always be: list[ResGPU]
class ResGPU(BaseModel):
name: str = Field(title="GPU Name")
data: dict = Field(title="Name/Value data")
chart: list[float, float] = Field(title="Exactly two items to place on chart")
chart: tuple[float, float] = Field(title="Exactly two items to place on chart")
"""
if __name__ == '__main__':

View File

@ -484,22 +484,22 @@ else:
FlagsModel = create_model("Flags", __config__=pydantic_config, **flags)
class ResEmbeddings(BaseModel):
loaded: list = Field(default=None, title="loaded", description="List of loaded embeddings")
skipped: list = Field(default=None, title="skipped", description="List of skipped embeddings")
loaded: list = Field(title="loaded", description="List of loaded embeddings")
skipped: list = Field(title="skipped", description="List of skipped embeddings")
class ResMemory(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
class ResScripts(BaseModel):
txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
control: list = Field(default=None, title="Control", description="Titles of scripts (control)")
txt2img: list[str] = Field(title="Txt2img", description="Titles of scripts (txt2img)")
img2img: list[str] = Field(title="Img2img", description="Titles of scripts (img2img)")
control: list[str] = Field(title="Control", description="Titles of scripts (control)")
class ResGPU(BaseModel): # definition of http response
name: str = Field(title="GPU Name", description="GPU device name")
data: dict = Field(title="Name/Value data", description="Key-value pairs of GPU metrics (utilization, temperature, clocks, memory, etc.)")
chart: list[float, float] = Field(title="Exactly two items to place on chart", description="Two numeric values for chart display (e.g., GPU utilization %, VRAM usage %)")
chart: tuple[float, float] = Field(title="Exactly two items to place on chart", description="Two numeric values for chart display (e.g., GPU utilization %, VRAM usage %)")
class ItemLoadedModel(BaseModel):
name: str = Field(title="Model Name", description="Model or component name")

View File

@ -72,7 +72,7 @@ def get_nvml():
"System load": f'GPU {load.gpu}% | VRAM {load.memory}% | Temp {pynvml.nvmlDeviceGetTemperature(dev, 0)}C | Fan {pynvml.nvmlDeviceGetFanSpeed(dev)}%',
'State': get_reason(pynvml.nvmlDeviceGetCurrentClocksThrottleReasons(dev)),
}
chart = [load.memory, load.gpu]
chart = (load.memory, load.gpu)
devices.append({
'name': name,
'data': data,

View File

@ -96,7 +96,7 @@ def get_rocm_smi():
'Throttle reason': str(ThrottleStatus(int(rocm_smi_data[key].get("throttle_status", 0)))),
}
name = rocm_smi_data[key].get('Device Name', 'unknown')
chart = [load["memory"], load["gpu"]]
chart = (load["memory"], load["gpu"])
devices.append({
'name': name,
'data': data,

View File

@ -29,7 +29,7 @@ def get_xpu_smi():
"VRAM usage": f'{round(100 * load["memory"] / total)}% | {load["memory"]} MB used | {total - load["memory"]} MB free | {total} MB total',
"RAM usage": f'{round(100 * ram["used"] / ram["total"])}% | {round(1024 * ram["used"])} MB used | {round(1024 * ram["free"])} MB free | {round(1024 * ram["total"])} MB total',
}
chart = [load["memory"], load["gpu"]]
chart = (load["memory"], load["gpu"])
devices.append({
'name': torch.xpu.get_device_name(),
'data': data,

View File

@ -40,7 +40,7 @@ def ts2utc(timestamp: int) -> datetime:
except Exception:
return "unknown"
def active():
def active() -> list[Extension]:
if shared.opts.disable_all_extensions == "all":
return []
elif shared.opts.disable_all_extensions == "user":
@ -185,12 +185,12 @@ class Extension:
log.error(f"Extension: failed reading data from git repo={self.name}: {ex}")
self.remote = None
def list_files(self, subdir, extension):
from modules import scripts_manager
def list_files(self, subdir: str, extension: str):
from modules.scripts_manager import ScriptFile
dirpath = os.path.join(self.path, subdir)
res: list[ScriptFile] = []
if not os.path.isdir(dirpath):
return []
res = []
return res
for filename in sorted(os.listdir(dirpath)):
if not filename.endswith(".py") and not filename.endswith(".js") and not filename.endswith(".mjs"):
continue
@ -198,7 +198,7 @@ class Extension:
if os.path.isfile(os.path.join(dirpath, "..", ".priority")):
with open(os.path.join(dirpath, "..", ".priority"), encoding="utf-8") as f:
priority = str(f.read().strip())
res.append(scripts_manager.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
res.append(ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
if priority != '50':
log.debug(f'Extension priority override: {os.path.dirname(dirpath)}:{priority}')
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]

View File

@ -1,15 +1,23 @@
import math
from collections import namedtuple
from typing import NamedTuple
import numpy as np
from PIL import Image, ImageFont, ImageDraw
from modules import shared, script_callbacks
from PIL import Image, ImageDraw, ImageFont
from modules import script_callbacks, shared
from modules.logger import log
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
class Grid(NamedTuple):
tiles: list
tile_w: int
tile_h: int
image_w: int
image_h: int
overlap: int
def check_grid_size(imgs):
def check_grid_size(imgs: list[Image.Image] | list[list[Image.Image]] | None):
if imgs is None or len(imgs) == 0:
return False
mp = 0
@ -26,39 +34,36 @@ def check_grid_size(imgs):
return ok
def get_grid_size(imgs, batch_size=1, rows: int | None = None, cols: int | None = None):
if rows and rows > len(imgs):
rows = len(imgs)
if cols and cols > len(imgs):
cols = len(imgs)
def get_grid_size(imgs: list, batch_size=1, rows: int | None = None, cols: int | None = None):
rows_int, cols_int = len(imgs), len(imgs)
if rows is None and cols is None:
if shared.opts.n_rows > 0:
rows = shared.opts.n_rows
cols = math.ceil(len(imgs) / rows)
elif shared.opts.n_rows == 0:
rows = batch_size
cols = math.ceil(len(imgs) / rows)
elif shared.opts.n_cols > 0:
cols = shared.opts.n_cols
rows = math.ceil(len(imgs) / cols)
elif shared.opts.n_cols == 0:
cols = batch_size
rows = math.ceil(len(imgs) / cols)
n_rows, n_cols = shared.opts.n_rows, shared.opts.n_cols
if n_rows >= 0:
rows_int: int = batch_size if n_rows == 0 else n_rows
cols_int = math.ceil(len(imgs) / rows_int)
elif n_cols >= 0:
cols_int: int = batch_size if n_cols == 0 else n_cols
rows_int = math.ceil(len(imgs) / cols_int)
else:
rows = math.floor(math.sqrt(len(imgs)))
while len(imgs) % rows != 0:
rows -= 1
cols = math.ceil(len(imgs) / rows)
elif rows is not None and cols is None:
cols = math.ceil(len(imgs) / rows)
elif rows is None and cols is not None:
rows = math.ceil(len(imgs) / cols)
else:
pass
return rows, cols
rows_int = math.floor(math.sqrt(len(imgs)))
while len(imgs) % rows_int != 0:
rows_int -= 1
cols_int = math.ceil(len(imgs) / rows_int)
return rows_int, cols_int
# Set limits
if rows is not None:
rows_int = max(min(rows, len(imgs)), 1)
if cols is not None:
cols_int = max(min(cols, len(imgs)), 1)
# Calculate
if rows is None:
rows_int = math.ceil(len(imgs) / cols_int)
if cols is None:
cols_int = math.ceil(len(imgs) / rows_int)
return rows_int, cols_int
def image_grid(imgs, batch_size=1, rows: int | None = None, cols: int | None = None):
def image_grid(imgs: list, batch_size=1, rows: int | None = None, cols: int | None = None):
rows, cols = get_grid_size(imgs, batch_size, rows=rows, cols=cols)
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
script_callbacks.image_grid_callback(params)
@ -73,7 +78,7 @@ def image_grid(imgs, batch_size=1, rows: int | None = None, cols: int | None = N
return grid
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
def split_grid(image: Image.Image, tile_w=512, tile_h=512, overlap=64):
w = image.width
h = image.height
non_overlap_width = tile_w - overlap
@ -98,7 +103,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
return grid
def combine_grid(grid):
def combine_grid(grid: Grid):
def make_mask_image(r):
r = r * 255 / grid.overlap
r = r.astype(np.uint8)
@ -127,18 +132,18 @@ class GridAnnotation:
def __init__(self, text='', is_active=True):
self.text = str(text)
self.is_active = is_active
self.size = None
self.size: tuple[int, int] = (10, 10) # Placeholder values
def get_font(fontsize):
def get_font(fontsize: float):
try:
return ImageFont.truetype(shared.opts.font or "javascript/notosans-nerdfont-regular.ttf", fontsize)
except Exception:
return ImageFont.truetype("javascript/notosans-nerdfont-regular.ttf", fontsize)
def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=None):
def wrap(drawing, text, font, line_length):
def draw_grid_annotations(im: Image.Image, width: int, height: int, x_texts: list[list[GridAnnotation]], y_texts: list[list[GridAnnotation]], margin=0, title: list[GridAnnotation] | None = None):
def wrap(drawing: ImageDraw.ImageDraw, text, font, line_length):
lines = ['']
for word in text.split():
line = f'{lines[-1]} {word}'.strip()
@ -148,7 +153,7 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
lines.append(word)
return lines
def draw_texts(drawing: ImageDraw, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
def draw_texts(drawing: ImageDraw.ImageDraw, draw_x: float, draw_y: float, lines, initial_fnt: ImageFont.FreeTypeFont, initial_fontsize: int):
for line in lines:
font = initial_fnt
fontsize = initial_fontsize
@ -185,9 +190,10 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in x_texts]
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in y_texts]
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
title_text_heights = []
title_pad = 0
if title:
title_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in title_texts] # pylint: disable=unsubscriptable-object
title_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in title_texts]
title_pad = 0 if sum(title_text_heights) == 0 else max(title_text_heights) + line_spacing * 2
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + title_pad + margin * (rows-1)), shared.opts.grid_background)
for row in range(rows):
@ -210,7 +216,7 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
return result
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
def draw_prompt_matrix(im: Image.Image, width: int, height: int, all_prompts: list[str], margin=0):
prompts = all_prompts[1:]
boundary = math.ceil(len(prompts) / 2)
prompts_horiz = prompts[:boundary]

View File

@ -1,14 +1,23 @@
from __future__ import annotations
import os
import re
import sys
import time
from collections import namedtuple
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, NamedTuple
import gradio as gr
from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
from modules.logger import log
from installer import control_extensions
if TYPE_CHECKING:
from types import ModuleType
from modules.api.models import ItemScript
from gradio.blocks import Block
from gradio.components import IOComponent
from modules.processing import Processed, StableDiffusionProcessing
AlwaysVisible = object()
time_component = {}
@ -28,22 +37,22 @@ class PostprocessBatchListArgs:
@dataclass
class OnComponent:
component: gr.blocks.Block
component: Block
class Script:
parent = None
name = None
filename = None
parent: str | None = None
name: str | None = None
filename: str | None = None
args_from = 0
args_to = 0
alwayson = False
is_txt2img = False
is_img2img = False
api_info = None
api_info: ItemScript | None = None
group = None
infotext_fields = None
paste_field_names = None
infotext_fields: list | None = None
paste_field_names: list[str] | None = None
section = None
standalone = False
external = False
@ -57,14 +66,14 @@ class Script:
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
raise NotImplementedError
def ui(self, is_img2img):
def ui(self, is_img2img) -> list[IOComponent]:
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
The return value should be an array of all components that are used in processing.
Values of those returned components will be passed to run() and process() functions.
"""
pass # pylint: disable=unnecessary-pass
def show(self, is_img2img): # pylint: disable=unused-argument
def show(self, is_img2img) -> bool | AlwaysVisible: # pylint: disable=unused-argument
"""
is_img2img is True if this function is called for the img2img interface, and False otherwise
This function should return:
@ -74,7 +83,7 @@ class Script:
"""
return True
def run(self, p, *args):
def run(self, p: StableDiffusionProcessing, *args):
"""
This function is called if the script has been selected in the script dropdown.
It must do all processing and return the Processed object with results, same as
@ -84,13 +93,13 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def setup(self, p, *args):
def setup(self, p: StableDiffusionProcessing, *args):
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
args contains all values returned by components from ui().
"""
pass # pylint: disable=unnecessary-pass
def before_process(self, p, *args):
def before_process(self, p: StableDiffusionProcessing, *args):
"""
This function is called very early during processing begins for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc.
@ -98,7 +107,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def process(self, p, *args):
def process(self, p: StableDiffusionProcessing, *args):
"""
This function is called before processing begins for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc.
@ -106,7 +115,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def process_images(self, p, *args):
def process_images(self, p: StableDiffusionProcessing, *args):
"""
This function is called instead of main processing for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc.
@ -114,7 +123,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def before_process_batch(self, p, *args, **kwargs):
def before_process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
"""
Called before extra networks are parsed from the prompt, so you can add
new extra network keywords to the prompt with this callback.
@ -126,7 +135,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def process_batch(self, p, *args, **kwargs):
def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
"""
Same as process(), but called for every batch.
**kwargs will have those items:
@ -137,7 +146,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def postprocess_batch(self, p, *args, **kwargs):
def postprocess_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
"""
Same as process_batch(), but called for every batch after it has been generated.
**kwargs will have same items as process_batch, and also:
@ -146,13 +155,13 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
"""
pass # pylint: disable=unnecessary-pass
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, *args, **kwargs):
"""
Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
This is useful when you want to update the entire batch instead of individual images.
@ -168,14 +177,14 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def postprocess(self, p, processed, *args):
def postprocess(self, p: StableDiffusionProcessing, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
args contains all values returned by components from ui()
"""
pass # pylint: disable=unnecessary-pass
def before_component(self, component, **kwargs):
def before_component(self, component: IOComponent, **kwargs):
"""
Called before a component is created.
Use elem_id/label fields of kwargs to figure out which component it is.
@ -184,7 +193,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def after_component(self, component, **kwargs):
def after_component(self, component: IOComponent, **kwargs):
"""
Called after a component is created. Same as above.
"""
@ -194,7 +203,7 @@ class Script:
"""unused"""
return ""
def elem_id(self, item_id):
def elem_id(self, item_id: str):
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
return f'script_{self.parent}_{title}_{item_id}'
@ -211,21 +220,33 @@ def basedir():
return current_basedir
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path", "priority"])
class ScriptFile(NamedTuple):
basedir: str
filename: str
path: str
priority: str
class ScriptClassData(NamedTuple):
script_class: type[Script] | type[scripts_postprocessing.ScriptPostprocessing]
path: str
basedir: str
module: ModuleType
scripts_data = []
postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def list_scripts(scriptdirname, extension):
tmp_list = []
def list_scripts(scriptdirname: str, extension: str):
tmp_list: list[ScriptFile] = []
base = os.path.join(paths.script_path, scriptdirname)
if os.path.exists(base):
for filename in sorted(os.listdir(base)):
tmp_list.append(ScriptFile(paths.script_path, filename, os.path.join(base, filename), '50'))
for ext in extensions.active():
tmp_list += ext.list_files(scriptdirname, extension)
priority_list = []
priority_list: list[ScriptFile] = []
for script in tmp_list:
if os.path.splitext(script.path)[1].lower() == extension and os.path.isfile(script.path):
if script.basedir == paths.script_path:
@ -250,7 +271,7 @@ def list_scripts(scriptdirname, extension):
return priority_sort
def list_files_with_name(filename):
def list_files_with_name(filename: str):
res = []
dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
for dirpath in dirs:
@ -273,7 +294,7 @@ def load_scripts():
scripts_list = sorted(scripts_list, key=lambda item: item.priority + item.path.lower(), reverse=False)
syspath = sys.path
def register_scripts_from_module(module, scriptfile):
def register_scripts_from_module(module: ModuleType, scriptfile):
for script_class in module.__dict__.values():
if type(script_class) != type:
continue
@ -305,7 +326,7 @@ def load_scripts():
return t, time.time()-t0
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
def wrap_call(func: Callable, filename: str, funcname: str, *args, default=None, **kwargs):
try:
res = func(*args, **kwargs)
return res
@ -315,7 +336,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
class ScriptSummary:
def __init__(self, op):
def __init__(self, op: str):
self.start = time.time()
self.update = time.time()
self.op = op
@ -336,17 +357,17 @@ class ScriptSummary:
class ScriptRunner:
def __init__(self, name=''):
self.name = name
self.scripts = []
self.selectable_scripts = []
self.alwayson_scripts = []
self.auto_processing_scripts = []
self.titles = []
self.alwayson_titles = []
self.infotext_fields = []
self.paste_field_names = []
self.scripts: list[Script] = []
self.selectable_scripts: list[Script] = []
self.alwayson_scripts: list[Script] = []
self.auto_processing_scripts: list[ScriptClassData] = []
self.titles: list[str] = []
self.alwayson_titles: list[str] = []
self.infotext_fields: list[tuple[IOComponent, str]] = []
self.paste_field_names: list[str] = []
self.script_load_ctr = 0
self.is_img2img = False
self.inputs = [None]
self.inputs: list = [None]
self.time = 0
def add_script(self, script_class, path, is_img2img, is_control):
@ -557,7 +578,7 @@ class ScriptRunner:
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts if script.group is not None])
return inputs
def run(self, p, *args):
def run(self, p: StableDiffusionProcessing, *args) -> Processed | None:
s = ScriptSummary('run')
script_index = args[0] if len(args) > 0 else 0
if (script_index is None) or (script_index == 0):
@ -582,7 +603,7 @@ class ScriptRunner:
s.report()
return processed
def after(self, p, processed, *args):
def after(self, p: StableDiffusionProcessing, processed: Processed, *args):
s = ScriptSummary('after')
script_index = args[0] if len(args) > 0 else 0
if (script_index is None) or (script_index == 0):
@ -600,7 +621,7 @@ class ScriptRunner:
s.report()
return processed
def before_process(self, p, **kwargs):
def before_process(self, p: StableDiffusionProcessing, **kwargs):
s = ScriptSummary('before-process')
for script in self.alwayson_scripts:
try:
@ -612,7 +633,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def process(self, p, **kwargs):
def process(self, p: StableDiffusionProcessing, **kwargs):
s = ScriptSummary('process')
for script in self.alwayson_scripts:
try:
@ -624,7 +645,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def process_images(self, p, **kwargs):
def process_images(self, p: StableDiffusionProcessing, **kwargs):
s = ScriptSummary('process_images')
processed = None
for script in self.alwayson_scripts:
@ -640,7 +661,7 @@ class ScriptRunner:
s.report()
return processed
def before_process_batch(self, p, **kwargs):
def before_process_batch(self, p: StableDiffusionProcessing, **kwargs):
s = ScriptSummary('before-process-batch')
for script in self.alwayson_scripts:
try:
@ -652,7 +673,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def process_batch(self, p, **kwargs):
def process_batch(self, p: StableDiffusionProcessing, **kwargs):
s = ScriptSummary('process-batch')
for script in self.alwayson_scripts:
try:
@ -664,7 +685,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def postprocess(self, p, processed):
def postprocess(self, p: StableDiffusionProcessing, processed):
s = ScriptSummary('postprocess')
for script in self.alwayson_scripts:
try:
@ -676,7 +697,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def postprocess_batch(self, p, images, **kwargs):
def postprocess_batch(self, p: StableDiffusionProcessing, images, **kwargs):
s = ScriptSummary('postprocess-batch')
for script in self.alwayson_scripts:
try:
@ -688,7 +709,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, **kwargs):
s = ScriptSummary('postprocess-batch-list')
for script in self.alwayson_scripts:
try:
@ -700,7 +721,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def postprocess_image(self, p, pp: PostprocessImageArgs):
def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs):
s = ScriptSummary('postprocess-image')
for script in self.alwayson_scripts:
try:
@ -712,7 +733,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def before_component(self, component, **kwargs):
def before_component(self, component: IOComponent, **kwargs):
s = ScriptSummary('before-component')
for script in self.scripts:
try:
@ -722,7 +743,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def after_component(self, component, **kwargs):
def after_component(self, component: IOComponent, **kwargs):
s = ScriptSummary('after-component')
for script in self.scripts:
for elem_id, callback in script.on_after_component_elem_id:
@ -738,7 +759,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def reload_sources(self, cache):
def reload_sources(self, cache: dict):
s = ScriptSummary('reload-sources')
for si, script in list(enumerate(self.scripts)):
if hasattr(script, 'args_to') and hasattr(script, 'args_from'):

View File

@ -1418,7 +1418,7 @@ def hf_auth_check(checkpoint_info, force:bool=False):
return False
def save_model(name: str, path: str | None = None, shard: str | None = None, overwrite = False):
def save_model(name: str, path: str | None = None, shard: str = "5GB", overwrite = False):
if (name is None) or len(name.strip()) == 0:
log.error('Save model: invalid model name')
return 'Invalid model name'
@ -1432,6 +1432,8 @@ def save_model(name: str, path: str | None = None, shard: str | None = None, ove
if os.path.exists(model_name) and not overwrite:
log.error(f'Save model: path="{model_name}" exists')
return f'Path exists: {model_name}'
if not shard.strip():
shard = "5GB" # Guard against empty input
try:
t0 = time.time()
save_sdnq_model(

View File

@ -38,8 +38,8 @@ def load_unet_sdxl_nunchaku(repo_id):
def load_unet(model, repo_id: str | None = None):
global loaded_unet # pylint: disable=global-statement
if ("StableDiffusionXLPipeline" in model.__class__.__name__) and (('stable-diffusion-xl-base' in repo_id) or ('sdxl-turbo' in repo_id)):
if model_quant.check_nunchaku('Model'):
if ("StableDiffusionXLPipeline" in model.__class__.__name__) and repo_id is not None and (("stable-diffusion-xl-base" in repo_id) or ("sdxl-turbo" in repo_id)):
if model_quant.check_nunchaku("Model"):
unet = load_unet_sdxl_nunchaku(repo_id)
if unet is not None:
model.unet = unet

View File

@ -1,4 +1,11 @@
from __future__ import annotations
from modules.logger import log
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
repo_id = 'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf'
policy_template = """Provide a safety assessment for the content provided by the user. Determine compliance or violation against our safety policy by reviewing the following policy categories:
Hate:
@ -89,11 +96,11 @@ To provide your assessment use the following json template for each category:
"rationale": str,
}.
"""
model = None
processor = None
model: LlavaOnevisionForConditionalGeneration | None = None
processor: AutoProcessor | None = None
def image_guard(image, policy:str | None=None) -> str:
def image_guard(image, policy:str | None=None):
global model, processor # pylint: disable=global-statement
import json
from installer import install

View File

@ -49,14 +49,14 @@ def create_ui(accordion=True):
# main processing used in both modes
def process(
p: processing.StableDiffusionProcessing=None,
pp: scripts.PostprocessImageArgs=None,
p: processing.StableDiffusionProcessing | None = None,
pp: scripts.PostprocessImageArgs | scripts_postprocessing.PostprocessedImage | None = None,
enabled=True,
lang=False,
policy=False,
banned=False,
metadata=True,
copy=False,
copy=False, # Compatability
score=0.2,
blocks=3,
censor=[],
@ -74,7 +74,7 @@ def process(
nudes = nudenet.detector.censor(image=pp.image, method=method, min_score=score, censor=censor, blocks=blocks, overlay=overlay)
t1 = time.time()
if len(nudes.censored) > 0: # Check if there are any censored areas
if not copy:
if p is None:
pp.image = nudes.output
else:
info = processing.create_infotext(p)
@ -85,7 +85,7 @@ def process(
if metadata and p is not None:
p.extra_generation_params["NudeNet"] = meta
p.extra_generation_params["NSFW"] = nsfw
if metadata and hasattr(pp, 'info'):
if metadata and isinstance(pp, scripts_postprocessing.PostprocessedImage):
pp.info['NudeNet'] = meta
pp.info['NSFW'] = nsfw
log.debug(f'NudeNet detect: {dct} nsfw={nsfw} time={(t1 - t0):.2f}')
@ -118,7 +118,7 @@ def process(
if metadata and p is not None:
p.extra_generation_params["Rating"] = res.get('rating', 'N/A')
p.extra_generation_params["Category"] = res.get('category', 'N/A')
if metadata and hasattr(pp, 'info'):
if metadata and isinstance(pp, scripts_postprocessing.PostprocessedImage):
pp.info["Rating"] = res.get('rating', 'N/A')
pp.info["Category"] = res.get('category', 'N/A')
@ -131,11 +131,11 @@ class ScriptNudeNet(scripts.Script):
def title(self):
return 'NudeNet'
def show(self, _is_img2img):
def show(self, *args, **kwargs):
return scripts.AlwaysVisible
# return signature is array of gradio components
def ui(self, _is_img2img):
def ui(self, *args, **kwargs):
return create_ui(accordion=True)
# triggered by callback