mirror of https://github.com/vladmandic/automatic
Typing updates
parent
8ef8074467
commit
b2b6fdf9d5
|
|
@ -5,7 +5,7 @@ import re
|
|||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
import gradio as gr
|
||||
from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
|
||||
from modules.logger import log
|
||||
|
|
@ -13,6 +13,10 @@ from installer import control_extensions
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
from modules.api.models import ItemScript
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import IOComponent
|
||||
from modules.processing import Processed, StableDiffusionProcessing
|
||||
|
||||
|
||||
AlwaysVisible = object()
|
||||
|
|
@ -33,22 +37,22 @@ class PostprocessBatchListArgs:
|
|||
|
||||
@dataclass
|
||||
class OnComponent:
|
||||
component: gr.blocks.Block
|
||||
component: Block
|
||||
|
||||
|
||||
class Script:
|
||||
parent = None
|
||||
name = None
|
||||
filename = None
|
||||
parent: str | None = None
|
||||
name: str | None = None
|
||||
filename: str | None = None
|
||||
args_from = 0
|
||||
args_to = 0
|
||||
alwayson = False
|
||||
is_txt2img = False
|
||||
is_img2img = False
|
||||
api_info = None
|
||||
api_info: ItemScript | None = None
|
||||
group = None
|
||||
infotext_fields = None
|
||||
paste_field_names = None
|
||||
infotext_fields: list | None = None
|
||||
paste_field_names: list[str] | None = None
|
||||
section = None
|
||||
standalone = False
|
||||
external = False
|
||||
|
|
@ -62,14 +66,14 @@ class Script:
|
|||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||
raise NotImplementedError
|
||||
|
||||
def ui(self, is_img2img):
|
||||
def ui(self, is_img2img) -> list[IOComponent]:
|
||||
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
|
||||
The return value should be an array of all components that are used in processing.
|
||||
Values of those returned components will be passed to run() and process() functions.
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def show(self, is_img2img): # pylint: disable=unused-argument
|
||||
def show(self, is_img2img) -> bool | AlwaysVisible: # pylint: disable=unused-argument
|
||||
"""
|
||||
is_img2img is True if this function is called for the img2img interface, and False otherwise
|
||||
This function should return:
|
||||
|
|
@ -79,7 +83,7 @@ class Script:
|
|||
"""
|
||||
return True
|
||||
|
||||
def run(self, p, *args):
|
||||
def run(self, p: StableDiffusionProcessing, *args):
|
||||
"""
|
||||
This function is called if the script has been selected in the script dropdown.
|
||||
It must do all processing and return the Processed object with results, same as
|
||||
|
|
@ -89,13 +93,13 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def setup(self, p, *args):
|
||||
def setup(self, p: StableDiffusionProcessing, *args):
|
||||
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
|
||||
args contains all values returned by components from ui().
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_process(self, p, *args):
|
||||
def before_process(self, p: StableDiffusionProcessing, *args):
|
||||
"""
|
||||
This function is called very early during processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
|
|
@ -103,7 +107,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process(self, p, *args):
|
||||
def process(self, p: StableDiffusionProcessing, *args):
|
||||
"""
|
||||
This function is called before processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
|
|
@ -111,7 +115,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process_images(self, p, *args):
|
||||
def process_images(self, p: StableDiffusionProcessing, *args):
|
||||
"""
|
||||
This function is called instead of main processing for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
|
|
@ -119,7 +123,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_process_batch(self, p, *args, **kwargs):
|
||||
def before_process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
|
||||
"""
|
||||
Called before extra networks are parsed from the prompt, so you can add
|
||||
new extra network keywords to the prompt with this callback.
|
||||
|
|
@ -131,7 +135,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process_batch(self, p, *args, **kwargs):
|
||||
def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
|
||||
"""
|
||||
Same as process(), but called for every batch.
|
||||
**kwargs will have those items:
|
||||
|
|
@ -142,7 +146,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_batch(self, p, *args, **kwargs):
|
||||
def postprocess_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
|
||||
"""
|
||||
Same as process_batch(), but called for every batch after it has been generated.
|
||||
**kwargs will have same items as process_batch, and also:
|
||||
|
|
@ -151,13 +155,13 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||
def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
|
||||
def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, *args, **kwargs):
|
||||
"""
|
||||
Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
|
||||
This is useful when you want to update the entire batch instead of individual images.
|
||||
|
|
@ -173,14 +177,14 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
def postprocess(self, p: StableDiffusionProcessing, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
def before_component(self, component: IOComponent, **kwargs):
|
||||
"""
|
||||
Called before a component is created.
|
||||
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||
|
|
@ -189,7 +193,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
def after_component(self, component: IOComponent, **kwargs):
|
||||
"""
|
||||
Called after a component is created. Same as above.
|
||||
"""
|
||||
|
|
@ -199,7 +203,7 @@ class Script:
|
|||
"""unused"""
|
||||
return ""
|
||||
|
||||
def elem_id(self, item_id):
|
||||
def elem_id(self, item_id: str):
|
||||
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
||||
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
||||
return f'script_{self.parent}_{title}_{item_id}'
|
||||
|
|
@ -234,15 +238,15 @@ scripts_data = []
|
|||
postprocessing_scripts_data = []
|
||||
|
||||
|
||||
def list_scripts(scriptdirname, extension):
|
||||
tmp_list = []
|
||||
def list_scripts(scriptdirname: str, extension: str):
|
||||
tmp_list: list[ScriptFile] = []
|
||||
base = os.path.join(paths.script_path, scriptdirname)
|
||||
if os.path.exists(base):
|
||||
for filename in sorted(os.listdir(base)):
|
||||
tmp_list.append(ScriptFile(paths.script_path, filename, os.path.join(base, filename), '50'))
|
||||
for ext in extensions.active():
|
||||
tmp_list += ext.list_files(scriptdirname, extension)
|
||||
priority_list = []
|
||||
priority_list: list[ScriptFile] = []
|
||||
for script in tmp_list:
|
||||
if os.path.splitext(script.path)[1].lower() == extension and os.path.isfile(script.path):
|
||||
if script.basedir == paths.script_path:
|
||||
|
|
@ -290,7 +294,7 @@ def load_scripts():
|
|||
scripts_list = sorted(scripts_list, key=lambda item: item.priority + item.path.lower(), reverse=False)
|
||||
syspath = sys.path
|
||||
|
||||
def register_scripts_from_module(module, scriptfile):
|
||||
def register_scripts_from_module(module: ModuleType, scriptfile):
|
||||
for script_class in module.__dict__.values():
|
||||
if type(script_class) != type:
|
||||
continue
|
||||
|
|
@ -322,7 +326,7 @@ def load_scripts():
|
|||
return t, time.time()-t0
|
||||
|
||||
|
||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||
def wrap_call(func: Callable, filename: str, funcname, *args, default=None, **kwargs):
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
return res
|
||||
|
|
@ -353,14 +357,14 @@ class ScriptSummary:
|
|||
class ScriptRunner:
|
||||
def __init__(self, name=''):
|
||||
self.name = name
|
||||
self.scripts = []
|
||||
self.selectable_scripts = []
|
||||
self.alwayson_scripts = []
|
||||
self.auto_processing_scripts = []
|
||||
self.scripts: list[Script] = []
|
||||
self.selectable_scripts: list[Script] = []
|
||||
self.alwayson_scripts: list[Script] = []
|
||||
self.auto_processing_scripts: list[ScriptClassData] = []
|
||||
self.titles = []
|
||||
self.alwayson_titles = []
|
||||
self.infotext_fields = []
|
||||
self.paste_field_names = []
|
||||
self.paste_field_names: list[str] = []
|
||||
self.script_load_ctr = 0
|
||||
self.is_img2img = False
|
||||
self.inputs = [None]
|
||||
|
|
@ -574,7 +578,7 @@ class ScriptRunner:
|
|||
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts if script.group is not None])
|
||||
return inputs
|
||||
|
||||
def run(self, p, *args):
|
||||
def run(self, p: StableDiffusionProcessing, *args) -> Processed | None:
|
||||
s = ScriptSummary('run')
|
||||
script_index = args[0] if len(args) > 0 else 0
|
||||
if (script_index is None) or (script_index == 0):
|
||||
|
|
@ -599,7 +603,7 @@ class ScriptRunner:
|
|||
s.report()
|
||||
return processed
|
||||
|
||||
def after(self, p, processed, *args):
|
||||
def after(self, p: StableDiffusionProcessing, processed: Processed, *args):
|
||||
s = ScriptSummary('after')
|
||||
script_index = args[0] if len(args) > 0 else 0
|
||||
if (script_index is None) or (script_index == 0):
|
||||
|
|
|
|||
Loading…
Reference in New Issue