diff --git a/modules/scripts_manager.py b/modules/scripts_manager.py index 7672c9cf6..8957a2592 100644 --- a/modules/scripts_manager.py +++ b/modules/scripts_manager.py @@ -5,7 +5,7 @@ import re import sys import time from dataclasses import dataclass -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, Callable, NamedTuple import gradio as gr from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer from modules.logger import log @@ -13,6 +13,10 @@ from installer import control_extensions if TYPE_CHECKING: from types import ModuleType + from modules.api.models import ItemScript + from gradio.blocks import Block + from gradio.components import IOComponent + from modules.processing import Processed, StableDiffusionProcessing AlwaysVisible = object() @@ -33,22 +37,22 @@ class PostprocessBatchListArgs: @dataclass class OnComponent: - component: gr.blocks.Block + component: Block class Script: - parent = None - name = None - filename = None + parent: str | None = None + name: str | None = None + filename: str | None = None args_from = 0 args_to = 0 alwayson = False is_txt2img = False is_img2img = False - api_info = None + api_info: ItemScript | None = None group = None - infotext_fields = None - paste_field_names = None + infotext_fields: list | None = None + paste_field_names: list[str] | None = None section = None standalone = False external = False @@ -62,14 +66,14 @@ class Script: """this function should return the title of the script. This is what will be displayed in the dropdown menu.""" raise NotImplementedError - def ui(self, is_img2img): + def ui(self, is_img2img) -> list[IOComponent]: """this function should create gradio UI elements. See https://gradio.app/docs/#components The return value should be an array of all components that are used in processing. Values of those returned components will be passed to run() and process() functions. """ pass # pylint: disable=unnecessary-pass - def show(self, is_img2img): # pylint: disable=unused-argument + def show(self, is_img2img) -> bool | AlwaysVisible: # pylint: disable=unused-argument """ is_img2img is True if this function is called for the img2img interface, and False otherwise This function should return: @@ -79,7 +83,7 @@ class Script: """ return True - def run(self, p, *args): + def run(self, p: StableDiffusionProcessing, *args): """ This function is called if the script has been selected in the script dropdown. It must do all processing and return the Processed object with results, same as @@ -89,13 +93,13 @@ class Script: """ pass # pylint: disable=unnecessary-pass - def setup(self, p, *args): + def setup(self, p: StableDiffusionProcessing, *args): """For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts. args contains all values returned by components from ui(). """ pass # pylint: disable=unnecessary-pass - def before_process(self, p, *args): + def before_process(self, p: StableDiffusionProcessing, *args): """ This function is called very early during processing begins for AlwaysVisible scripts. You can modify the processing object (p) here, inject hooks, etc. @@ -103,7 +107,7 @@ class Script: """ pass # pylint: disable=unnecessary-pass - def process(self, p, *args): + def process(self, p: StableDiffusionProcessing, *args): """ This function is called before processing begins for AlwaysVisible scripts. You can modify the processing object (p) here, inject hooks, etc. @@ -111,7 +115,7 @@ class Script: """ pass # pylint: disable=unnecessary-pass - def process_images(self, p, *args): + def process_images(self, p: StableDiffusionProcessing, *args): """ This function is called instead of main processing for AlwaysVisible scripts. You can modify the processing object (p) here, inject hooks, etc. @@ -119,7 +123,7 @@ class Script: """ pass # pylint: disable=unnecessary-pass - def before_process_batch(self, p, *args, **kwargs): + def before_process_batch(self, p: StableDiffusionProcessing, *args, **kwargs): """ Called before extra networks are parsed from the prompt, so you can add new extra network keywords to the prompt with this callback. @@ -131,7 +135,7 @@ class Script: """ pass # pylint: disable=unnecessary-pass - def process_batch(self, p, *args, **kwargs): + def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs): """ Same as process(), but called for every batch. **kwargs will have those items: @@ -142,7 +146,7 @@ class Script: """ pass # pylint: disable=unnecessary-pass - def postprocess_batch(self, p, *args, **kwargs): + def postprocess_batch(self, p: StableDiffusionProcessing, *args, **kwargs): """ Same as process_batch(), but called for every batch after it has been generated. **kwargs will have same items as process_batch, and also: @@ -151,13 +155,13 @@ class Script: """ pass # pylint: disable=unnecessary-pass - def postprocess_image(self, p, pp: PostprocessImageArgs, *args): + def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs, *args): """ Called for every image after it has been generated. """ pass # pylint: disable=unnecessary-pass - def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs): + def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, *args, **kwargs): """ Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor. This is useful when you want to update the entire batch instead of individual images. @@ -173,14 +177,14 @@ class Script: """ pass # pylint: disable=unnecessary-pass - def postprocess(self, p, processed, *args): + def postprocess(self, p: StableDiffusionProcessing, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. args contains all values returned by components from ui() """ pass # pylint: disable=unnecessary-pass - def before_component(self, component, **kwargs): + def before_component(self, component: IOComponent, **kwargs): """ Called before a component is created. Use elem_id/label fields of kwargs to figure out which component it is. @@ -189,7 +193,7 @@ class Script: """ pass # pylint: disable=unnecessary-pass - def after_component(self, component, **kwargs): + def after_component(self, component: IOComponent, **kwargs): """ Called after a component is created. Same as above. """ @@ -199,7 +203,7 @@ class Script: """unused""" return "" - def elem_id(self, item_id): + def elem_id(self, item_id: str): """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id""" title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower())) return f'script_{self.parent}_{title}_{item_id}' @@ -234,15 +238,15 @@ scripts_data = [] postprocessing_scripts_data = [] -def list_scripts(scriptdirname, extension): - tmp_list = [] +def list_scripts(scriptdirname: str, extension: str): + tmp_list: list[ScriptFile] = [] base = os.path.join(paths.script_path, scriptdirname) if os.path.exists(base): for filename in sorted(os.listdir(base)): tmp_list.append(ScriptFile(paths.script_path, filename, os.path.join(base, filename), '50')) for ext in extensions.active(): tmp_list += ext.list_files(scriptdirname, extension) - priority_list = [] + priority_list: list[ScriptFile] = [] for script in tmp_list: if os.path.splitext(script.path)[1].lower() == extension and os.path.isfile(script.path): if script.basedir == paths.script_path: @@ -290,7 +294,7 @@ def load_scripts(): scripts_list = sorted(scripts_list, key=lambda item: item.priority + item.path.lower(), reverse=False) syspath = sys.path - def register_scripts_from_module(module, scriptfile): + def register_scripts_from_module(module: ModuleType, scriptfile): for script_class in module.__dict__.values(): if type(script_class) != type: continue @@ -322,7 +326,7 @@ def load_scripts(): return t, time.time()-t0 -def wrap_call(func, filename, funcname, *args, default=None, **kwargs): +def wrap_call(func: Callable, filename: str, funcname, *args, default=None, **kwargs): try: res = func(*args, **kwargs) return res @@ -353,14 +357,14 @@ class ScriptSummary: class ScriptRunner: def __init__(self, name=''): self.name = name - self.scripts = [] - self.selectable_scripts = [] - self.alwayson_scripts = [] - self.auto_processing_scripts = [] + self.scripts: list[Script] = [] + self.selectable_scripts: list[Script] = [] + self.alwayson_scripts: list[Script] = [] + self.auto_processing_scripts: list[ScriptClassData] = [] self.titles = [] self.alwayson_titles = [] self.infotext_fields = [] - self.paste_field_names = [] + self.paste_field_names: list[str] = [] self.script_load_ctr = 0 self.is_img2img = False self.inputs = [None] @@ -574,7 +578,7 @@ class ScriptRunner: self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts if script.group is not None]) return inputs - def run(self, p, *args): + def run(self, p: StableDiffusionProcessing, *args) -> Processed | None: s = ScriptSummary('run') script_index = args[0] if len(args) > 0 else 0 if (script_index is None) or (script_index == 0): @@ -599,7 +603,7 @@ class ScriptRunner: s.report() return processed - def after(self, p, processed, *args): + def after(self, p: StableDiffusionProcessing, processed: Processed, *args): s = ScriptSummary('after') script_index = args[0] if len(args) > 0 else 0 if (script_index is None) or (script_index == 0):