Typing updates

pull/4722/head
awsr 2026-03-28 17:15:48 -07:00
parent 8ef8074467
commit b2b6fdf9d5
No known key found for this signature in database
1 changed files with 40 additions and 36 deletions

View File

@ -5,7 +5,7 @@ import re
import sys
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, Callable, NamedTuple
import gradio as gr
from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
from modules.logger import log
@ -13,6 +13,10 @@ from installer import control_extensions
if TYPE_CHECKING:
from types import ModuleType
from modules.api.models import ItemScript
from gradio.blocks import Block
from gradio.components import IOComponent
from modules.processing import Processed, StableDiffusionProcessing
AlwaysVisible = object()
@ -33,22 +37,22 @@ class PostprocessBatchListArgs:
@dataclass
class OnComponent:
component: gr.blocks.Block
component: Block
class Script:
parent = None
name = None
filename = None
parent: str | None = None
name: str | None = None
filename: str | None = None
args_from = 0
args_to = 0
alwayson = False
is_txt2img = False
is_img2img = False
api_info = None
api_info: ItemScript | None = None
group = None
infotext_fields = None
paste_field_names = None
infotext_fields: list | None = None
paste_field_names: list[str] | None = None
section = None
standalone = False
external = False
@ -62,14 +66,14 @@ class Script:
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
raise NotImplementedError
def ui(self, is_img2img):
def ui(self, is_img2img) -> list[IOComponent]:
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
The return value should be an array of all components that are used in processing.
Values of those returned components will be passed to run() and process() functions.
"""
pass # pylint: disable=unnecessary-pass
def show(self, is_img2img): # pylint: disable=unused-argument
def show(self, is_img2img) -> bool | AlwaysVisible: # pylint: disable=unused-argument
"""
is_img2img is True if this function is called for the img2img interface, and False otherwise
This function should return:
@ -79,7 +83,7 @@ class Script:
"""
return True
def run(self, p, *args):
def run(self, p: StableDiffusionProcessing, *args):
"""
This function is called if the script has been selected in the script dropdown.
It must do all processing and return the Processed object with results, same as
@ -89,13 +93,13 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def setup(self, p, *args):
def setup(self, p: StableDiffusionProcessing, *args):
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
args contains all values returned by components from ui().
"""
pass # pylint: disable=unnecessary-pass
def before_process(self, p, *args):
def before_process(self, p: StableDiffusionProcessing, *args):
"""
This function is called very early during processing begins for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc.
@ -103,7 +107,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def process(self, p, *args):
def process(self, p: StableDiffusionProcessing, *args):
"""
This function is called before processing begins for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc.
@ -111,7 +115,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def process_images(self, p, *args):
def process_images(self, p: StableDiffusionProcessing, *args):
"""
This function is called instead of main processing for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc.
@ -119,7 +123,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def before_process_batch(self, p, *args, **kwargs):
def before_process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
"""
Called before extra networks are parsed from the prompt, so you can add
new extra network keywords to the prompt with this callback.
@ -131,7 +135,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def process_batch(self, p, *args, **kwargs):
def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
"""
Same as process(), but called for every batch.
**kwargs will have those items:
@ -142,7 +146,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def postprocess_batch(self, p, *args, **kwargs):
def postprocess_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
"""
Same as process_batch(), but called for every batch after it has been generated.
**kwargs will have same items as process_batch, and also:
@ -151,13 +155,13 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
"""
pass # pylint: disable=unnecessary-pass
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, *args, **kwargs):
"""
Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
This is useful when you want to update the entire batch instead of individual images.
@ -173,14 +177,14 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def postprocess(self, p, processed, *args):
def postprocess(self, p: StableDiffusionProcessing, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
args contains all values returned by components from ui()
"""
pass # pylint: disable=unnecessary-pass
def before_component(self, component, **kwargs):
def before_component(self, component: IOComponent, **kwargs):
"""
Called before a component is created.
Use elem_id/label fields of kwargs to figure out which component it is.
@ -189,7 +193,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def after_component(self, component, **kwargs):
def after_component(self, component: IOComponent, **kwargs):
"""
Called after a component is created. Same as above.
"""
@ -199,7 +203,7 @@ class Script:
"""unused"""
return ""
def elem_id(self, item_id):
def elem_id(self, item_id: str):
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
return f'script_{self.parent}_{title}_{item_id}'
@ -234,15 +238,15 @@ scripts_data = []
postprocessing_scripts_data = []
def list_scripts(scriptdirname, extension):
tmp_list = []
def list_scripts(scriptdirname: str, extension: str):
tmp_list: list[ScriptFile] = []
base = os.path.join(paths.script_path, scriptdirname)
if os.path.exists(base):
for filename in sorted(os.listdir(base)):
tmp_list.append(ScriptFile(paths.script_path, filename, os.path.join(base, filename), '50'))
for ext in extensions.active():
tmp_list += ext.list_files(scriptdirname, extension)
priority_list = []
priority_list: list[ScriptFile] = []
for script in tmp_list:
if os.path.splitext(script.path)[1].lower() == extension and os.path.isfile(script.path):
if script.basedir == paths.script_path:
@ -290,7 +294,7 @@ def load_scripts():
scripts_list = sorted(scripts_list, key=lambda item: item.priority + item.path.lower(), reverse=False)
syspath = sys.path
def register_scripts_from_module(module, scriptfile):
def register_scripts_from_module(module: ModuleType, scriptfile):
for script_class in module.__dict__.values():
if type(script_class) != type:
continue
@ -322,7 +326,7 @@ def load_scripts():
return t, time.time()-t0
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
def wrap_call(func: Callable, filename: str, funcname, *args, default=None, **kwargs):
try:
res = func(*args, **kwargs)
return res
@ -353,14 +357,14 @@ class ScriptSummary:
class ScriptRunner:
def __init__(self, name=''):
self.name = name
self.scripts = []
self.selectable_scripts = []
self.alwayson_scripts = []
self.auto_processing_scripts = []
self.scripts: list[Script] = []
self.selectable_scripts: list[Script] = []
self.alwayson_scripts: list[Script] = []
self.auto_processing_scripts: list[ScriptClassData] = []
self.titles = []
self.alwayson_titles = []
self.infotext_fields = []
self.paste_field_names = []
self.paste_field_names: list[str] = []
self.script_load_ctr = 0
self.is_img2img = False
self.inputs = [None]
@ -574,7 +578,7 @@ class ScriptRunner:
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts if script.group is not None])
return inputs
def run(self, p, *args):
def run(self, p: StableDiffusionProcessing, *args) -> Processed | None:
s = ScriptSummary('run')
script_index = args[0] if len(args) > 0 else 0
if (script_index is None) or (script_index == 0):
@ -599,7 +603,7 @@ class ScriptRunner:
s.report()
return processed
def after(self, p, processed, *args):
def after(self, p: StableDiffusionProcessing, processed: Processed, *args):
s = ScriptSummary('after')
script_index = args[0] if len(args) > 0 else 0
if (script_index is None) or (script_index == 0):