mirror of https://github.com/vladmandic/automatic
commit
668a94141d
|
|
@ -42,7 +42,7 @@ Resut should always be: list[ResGPU]
|
|||
class ResGPU(BaseModel):
|
||||
name: str = Field(title="GPU Name")
|
||||
data: dict = Field(title="Name/Value data")
|
||||
chart: list[float, float] = Field(title="Exactly two items to place on chart")
|
||||
chart: tuple[float, float] = Field(title="Exactly two items to place on chart")
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -484,22 +484,22 @@ else:
|
|||
FlagsModel = create_model("Flags", __config__=pydantic_config, **flags)
|
||||
|
||||
class ResEmbeddings(BaseModel):
|
||||
loaded: list = Field(default=None, title="loaded", description="List of loaded embeddings")
|
||||
skipped: list = Field(default=None, title="skipped", description="List of skipped embeddings")
|
||||
loaded: list = Field(title="loaded", description="List of loaded embeddings")
|
||||
skipped: list = Field(title="skipped", description="List of skipped embeddings")
|
||||
|
||||
class ResMemory(BaseModel):
|
||||
ram: dict = Field(title="RAM", description="System memory stats")
|
||||
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
||||
|
||||
class ResScripts(BaseModel):
|
||||
txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
|
||||
img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
|
||||
control: list = Field(default=None, title="Control", description="Titles of scripts (control)")
|
||||
txt2img: list[str] = Field(title="Txt2img", description="Titles of scripts (txt2img)")
|
||||
img2img: list[str] = Field(title="Img2img", description="Titles of scripts (img2img)")
|
||||
control: list[str] = Field(title="Control", description="Titles of scripts (control)")
|
||||
|
||||
class ResGPU(BaseModel): # definition of http response
|
||||
name: str = Field(title="GPU Name", description="GPU device name")
|
||||
data: dict = Field(title="Name/Value data", description="Key-value pairs of GPU metrics (utilization, temperature, clocks, memory, etc.)")
|
||||
chart: list[float, float] = Field(title="Exactly two items to place on chart", description="Two numeric values for chart display (e.g., GPU utilization %, VRAM usage %)")
|
||||
chart: tuple[float, float] = Field(title="Exactly two items to place on chart", description="Two numeric values for chart display (e.g., GPU utilization %, VRAM usage %)")
|
||||
|
||||
class ItemLoadedModel(BaseModel):
|
||||
name: str = Field(title="Model Name", description="Model or component name")
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ def get_nvml():
|
|||
"System load": f'GPU {load.gpu}% | VRAM {load.memory}% | Temp {pynvml.nvmlDeviceGetTemperature(dev, 0)}C | Fan {pynvml.nvmlDeviceGetFanSpeed(dev)}%',
|
||||
'State': get_reason(pynvml.nvmlDeviceGetCurrentClocksThrottleReasons(dev)),
|
||||
}
|
||||
chart = [load.memory, load.gpu]
|
||||
chart = (load.memory, load.gpu)
|
||||
devices.append({
|
||||
'name': name,
|
||||
'data': data,
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ def get_rocm_smi():
|
|||
'Throttle reason': str(ThrottleStatus(int(rocm_smi_data[key].get("throttle_status", 0)))),
|
||||
}
|
||||
name = rocm_smi_data[key].get('Device Name', 'unknown')
|
||||
chart = [load["memory"], load["gpu"]]
|
||||
chart = (load["memory"], load["gpu"])
|
||||
devices.append({
|
||||
'name': name,
|
||||
'data': data,
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def get_xpu_smi():
|
|||
"VRAM usage": f'{round(100 * load["memory"] / total)}% | {load["memory"]} MB used | {total - load["memory"]} MB free | {total} MB total',
|
||||
"RAM usage": f'{round(100 * ram["used"] / ram["total"])}% | {round(1024 * ram["used"])} MB used | {round(1024 * ram["free"])} MB free | {round(1024 * ram["total"])} MB total',
|
||||
}
|
||||
chart = [load["memory"], load["gpu"]]
|
||||
chart = (load["memory"], load["gpu"])
|
||||
devices.append({
|
||||
'name': torch.xpu.get_device_name(),
|
||||
'data': data,
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def ts2utc(timestamp: int) -> datetime:
|
|||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
def active():
|
||||
def active() -> list[Extension]:
|
||||
if shared.opts.disable_all_extensions == "all":
|
||||
return []
|
||||
elif shared.opts.disable_all_extensions == "user":
|
||||
|
|
@ -185,12 +185,12 @@ class Extension:
|
|||
log.error(f"Extension: failed reading data from git repo={self.name}: {ex}")
|
||||
self.remote = None
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
from modules import scripts_manager
|
||||
def list_files(self, subdir: str, extension: str):
|
||||
from modules.scripts_manager import ScriptFile
|
||||
dirpath = os.path.join(self.path, subdir)
|
||||
res: list[ScriptFile] = []
|
||||
if not os.path.isdir(dirpath):
|
||||
return []
|
||||
res = []
|
||||
return res
|
||||
for filename in sorted(os.listdir(dirpath)):
|
||||
if not filename.endswith(".py") and not filename.endswith(".js") and not filename.endswith(".mjs"):
|
||||
continue
|
||||
|
|
@ -198,7 +198,7 @@ class Extension:
|
|||
if os.path.isfile(os.path.join(dirpath, "..", ".priority")):
|
||||
with open(os.path.join(dirpath, "..", ".priority"), encoding="utf-8") as f:
|
||||
priority = str(f.read().strip())
|
||||
res.append(scripts_manager.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
|
||||
res.append(ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
|
||||
if priority != '50':
|
||||
log.debug(f'Extension priority override: {os.path.dirname(dirpath)}:{priority}')
|
||||
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
|
|
|||
|
|
@ -1,15 +1,23 @@
|
|||
import math
|
||||
from collections import namedtuple
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFont, ImageDraw
|
||||
from modules import shared, script_callbacks
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from modules import script_callbacks, shared
|
||||
from modules.logger import log
|
||||
|
||||
|
||||
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
||||
class Grid(NamedTuple):
|
||||
tiles: list
|
||||
tile_w: int
|
||||
tile_h: int
|
||||
image_w: int
|
||||
image_h: int
|
||||
overlap: int
|
||||
|
||||
|
||||
def check_grid_size(imgs):
|
||||
def check_grid_size(imgs: list[Image.Image] | list[list[Image.Image]] | None):
|
||||
if imgs is None or len(imgs) == 0:
|
||||
return False
|
||||
mp = 0
|
||||
|
|
@ -26,39 +34,36 @@ def check_grid_size(imgs):
|
|||
return ok
|
||||
|
||||
|
||||
def get_grid_size(imgs, batch_size=1, rows: int | None = None, cols: int | None = None):
|
||||
if rows and rows > len(imgs):
|
||||
rows = len(imgs)
|
||||
if cols and cols > len(imgs):
|
||||
cols = len(imgs)
|
||||
def get_grid_size(imgs: list, batch_size=1, rows: int | None = None, cols: int | None = None):
|
||||
rows_int, cols_int = len(imgs), len(imgs)
|
||||
if rows is None and cols is None:
|
||||
if shared.opts.n_rows > 0:
|
||||
rows = shared.opts.n_rows
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
elif shared.opts.n_rows == 0:
|
||||
rows = batch_size
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
elif shared.opts.n_cols > 0:
|
||||
cols = shared.opts.n_cols
|
||||
rows = math.ceil(len(imgs) / cols)
|
||||
elif shared.opts.n_cols == 0:
|
||||
cols = batch_size
|
||||
rows = math.ceil(len(imgs) / cols)
|
||||
n_rows, n_cols = shared.opts.n_rows, shared.opts.n_cols
|
||||
if n_rows >= 0:
|
||||
rows_int: int = batch_size if n_rows == 0 else n_rows
|
||||
cols_int = math.ceil(len(imgs) / rows_int)
|
||||
elif n_cols >= 0:
|
||||
cols_int: int = batch_size if n_cols == 0 else n_cols
|
||||
rows_int = math.ceil(len(imgs) / cols_int)
|
||||
else:
|
||||
rows = math.floor(math.sqrt(len(imgs)))
|
||||
while len(imgs) % rows != 0:
|
||||
rows -= 1
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
elif rows is not None and cols is None:
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
elif rows is None and cols is not None:
|
||||
rows = math.ceil(len(imgs) / cols)
|
||||
else:
|
||||
pass
|
||||
return rows, cols
|
||||
rows_int = math.floor(math.sqrt(len(imgs)))
|
||||
while len(imgs) % rows_int != 0:
|
||||
rows_int -= 1
|
||||
cols_int = math.ceil(len(imgs) / rows_int)
|
||||
return rows_int, cols_int
|
||||
# Set limits
|
||||
if rows is not None:
|
||||
rows_int = max(min(rows, len(imgs)), 1)
|
||||
if cols is not None:
|
||||
cols_int = max(min(cols, len(imgs)), 1)
|
||||
# Calculate
|
||||
if rows is None:
|
||||
rows_int = math.ceil(len(imgs) / cols_int)
|
||||
if cols is None:
|
||||
cols_int = math.ceil(len(imgs) / rows_int)
|
||||
return rows_int, cols_int
|
||||
|
||||
|
||||
def image_grid(imgs, batch_size=1, rows: int | None = None, cols: int | None = None):
|
||||
def image_grid(imgs: list, batch_size=1, rows: int | None = None, cols: int | None = None):
|
||||
rows, cols = get_grid_size(imgs, batch_size, rows=rows, cols=cols)
|
||||
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
|
||||
script_callbacks.image_grid_callback(params)
|
||||
|
|
@ -73,7 +78,7 @@ def image_grid(imgs, batch_size=1, rows: int | None = None, cols: int | None = N
|
|||
return grid
|
||||
|
||||
|
||||
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
||||
def split_grid(image: Image.Image, tile_w=512, tile_h=512, overlap=64):
|
||||
w = image.width
|
||||
h = image.height
|
||||
non_overlap_width = tile_w - overlap
|
||||
|
|
@ -98,7 +103,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
|||
return grid
|
||||
|
||||
|
||||
def combine_grid(grid):
|
||||
def combine_grid(grid: Grid):
|
||||
def make_mask_image(r):
|
||||
r = r * 255 / grid.overlap
|
||||
r = r.astype(np.uint8)
|
||||
|
|
@ -127,18 +132,18 @@ class GridAnnotation:
|
|||
def __init__(self, text='', is_active=True):
|
||||
self.text = str(text)
|
||||
self.is_active = is_active
|
||||
self.size = None
|
||||
self.size: tuple[int, int] = (10, 10) # Placeholder values
|
||||
|
||||
|
||||
def get_font(fontsize):
|
||||
def get_font(fontsize: float):
|
||||
try:
|
||||
return ImageFont.truetype(shared.opts.font or "javascript/notosans-nerdfont-regular.ttf", fontsize)
|
||||
except Exception:
|
||||
return ImageFont.truetype("javascript/notosans-nerdfont-regular.ttf", fontsize)
|
||||
|
||||
|
||||
def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=None):
|
||||
def wrap(drawing, text, font, line_length):
|
||||
def draw_grid_annotations(im: Image.Image, width: int, height: int, x_texts: list[list[GridAnnotation]], y_texts: list[list[GridAnnotation]], margin=0, title: list[GridAnnotation] | None = None):
|
||||
def wrap(drawing: ImageDraw.ImageDraw, text, font, line_length):
|
||||
lines = ['']
|
||||
for word in text.split():
|
||||
line = f'{lines[-1]} {word}'.strip()
|
||||
|
|
@ -148,7 +153,7 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
|
|||
lines.append(word)
|
||||
return lines
|
||||
|
||||
def draw_texts(drawing: ImageDraw, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
||||
def draw_texts(drawing: ImageDraw.ImageDraw, draw_x: float, draw_y: float, lines, initial_fnt: ImageFont.FreeTypeFont, initial_fontsize: int):
|
||||
for line in lines:
|
||||
font = initial_fnt
|
||||
fontsize = initial_fontsize
|
||||
|
|
@ -185,9 +190,10 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
|
|||
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in x_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in y_texts]
|
||||
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
||||
title_text_heights = []
|
||||
title_pad = 0
|
||||
if title:
|
||||
title_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in title_texts] # pylint: disable=unsubscriptable-object
|
||||
title_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in title_texts]
|
||||
title_pad = 0 if sum(title_text_heights) == 0 else max(title_text_heights) + line_spacing * 2
|
||||
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + title_pad + margin * (rows-1)), shared.opts.grid_background)
|
||||
for row in range(rows):
|
||||
|
|
@ -210,7 +216,7 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
|
|||
return result
|
||||
|
||||
|
||||
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
||||
def draw_prompt_matrix(im: Image.Image, width: int, height: int, all_prompts: list[str], margin=0):
|
||||
prompts = all_prompts[1:]
|
||||
boundary = math.ceil(len(prompts) / 2)
|
||||
prompts_horiz = prompts[:boundary]
|
||||
|
|
|
|||
|
|
@ -1,14 +1,23 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
import gradio as gr
|
||||
from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
|
||||
from modules.logger import log
|
||||
from installer import control_extensions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
from modules.api.models import ItemScript
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import IOComponent
|
||||
from modules.processing import Processed, StableDiffusionProcessing
|
||||
|
||||
|
||||
AlwaysVisible = object()
|
||||
time_component = {}
|
||||
|
|
@ -28,22 +37,22 @@ class PostprocessBatchListArgs:
|
|||
|
||||
@dataclass
|
||||
class OnComponent:
|
||||
component: gr.blocks.Block
|
||||
component: Block
|
||||
|
||||
|
||||
class Script:
|
||||
parent = None
|
||||
name = None
|
||||
filename = None
|
||||
parent: str | None = None
|
||||
name: str | None = None
|
||||
filename: str | None = None
|
||||
args_from = 0
|
||||
args_to = 0
|
||||
alwayson = False
|
||||
is_txt2img = False
|
||||
is_img2img = False
|
||||
api_info = None
|
||||
api_info: ItemScript | None = None
|
||||
group = None
|
||||
infotext_fields = None
|
||||
paste_field_names = None
|
||||
infotext_fields: list | None = None
|
||||
paste_field_names: list[str] | None = None
|
||||
section = None
|
||||
standalone = False
|
||||
external = False
|
||||
|
|
@ -57,14 +66,14 @@ class Script:
|
|||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||
raise NotImplementedError
|
||||
|
||||
def ui(self, is_img2img):
|
||||
def ui(self, is_img2img) -> list[IOComponent]:
|
||||
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
|
||||
The return value should be an array of all components that are used in processing.
|
||||
Values of those returned components will be passed to run() and process() functions.
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def show(self, is_img2img): # pylint: disable=unused-argument
|
||||
def show(self, is_img2img) -> bool | AlwaysVisible: # pylint: disable=unused-argument
|
||||
"""
|
||||
is_img2img is True if this function is called for the img2img interface, and False otherwise
|
||||
This function should return:
|
||||
|
|
@ -74,7 +83,7 @@ class Script:
|
|||
"""
|
||||
return True
|
||||
|
||||
def run(self, p, *args):
|
||||
def run(self, p: StableDiffusionProcessing, *args):
|
||||
"""
|
||||
This function is called if the script has been selected in the script dropdown.
|
||||
It must do all processing and return the Processed object with results, same as
|
||||
|
|
@ -84,13 +93,13 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def setup(self, p, *args):
|
||||
def setup(self, p: StableDiffusionProcessing, *args):
|
||||
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
|
||||
args contains all values returned by components from ui().
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_process(self, p, *args):
|
||||
def before_process(self, p: StableDiffusionProcessing, *args):
|
||||
"""
|
||||
This function is called very early during processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
|
|
@ -98,7 +107,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process(self, p, *args):
|
||||
def process(self, p: StableDiffusionProcessing, *args):
|
||||
"""
|
||||
This function is called before processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
|
|
@ -106,7 +115,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process_images(self, p, *args):
|
||||
def process_images(self, p: StableDiffusionProcessing, *args):
|
||||
"""
|
||||
This function is called instead of main processing for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
|
|
@ -114,7 +123,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_process_batch(self, p, *args, **kwargs):
|
||||
def before_process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
|
||||
"""
|
||||
Called before extra networks are parsed from the prompt, so you can add
|
||||
new extra network keywords to the prompt with this callback.
|
||||
|
|
@ -126,7 +135,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process_batch(self, p, *args, **kwargs):
|
||||
def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
|
||||
"""
|
||||
Same as process(), but called for every batch.
|
||||
**kwargs will have those items:
|
||||
|
|
@ -137,7 +146,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_batch(self, p, *args, **kwargs):
|
||||
def postprocess_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
|
||||
"""
|
||||
Same as process_batch(), but called for every batch after it has been generated.
|
||||
**kwargs will have same items as process_batch, and also:
|
||||
|
|
@ -146,13 +155,13 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||
def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
|
||||
def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, *args, **kwargs):
|
||||
"""
|
||||
Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
|
||||
This is useful when you want to update the entire batch instead of individual images.
|
||||
|
|
@ -168,14 +177,14 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
def postprocess(self, p: StableDiffusionProcessing, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
def before_component(self, component: IOComponent, **kwargs):
|
||||
"""
|
||||
Called before a component is created.
|
||||
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||
|
|
@ -184,7 +193,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
def after_component(self, component: IOComponent, **kwargs):
|
||||
"""
|
||||
Called after a component is created. Same as above.
|
||||
"""
|
||||
|
|
@ -194,7 +203,7 @@ class Script:
|
|||
"""unused"""
|
||||
return ""
|
||||
|
||||
def elem_id(self, item_id):
|
||||
def elem_id(self, item_id: str):
|
||||
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
||||
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
||||
return f'script_{self.parent}_{title}_{item_id}'
|
||||
|
|
@ -211,21 +220,33 @@ def basedir():
|
|||
return current_basedir
|
||||
|
||||
|
||||
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path", "priority"])
|
||||
class ScriptFile(NamedTuple):
|
||||
basedir: str
|
||||
filename: str
|
||||
path: str
|
||||
priority: str
|
||||
|
||||
|
||||
class ScriptClassData(NamedTuple):
|
||||
script_class: type[Script] | type[scripts_postprocessing.ScriptPostprocessing]
|
||||
path: str
|
||||
basedir: str
|
||||
module: ModuleType
|
||||
|
||||
|
||||
scripts_data = []
|
||||
postprocessing_scripts_data = []
|
||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||
|
||||
|
||||
def list_scripts(scriptdirname, extension):
|
||||
tmp_list = []
|
||||
def list_scripts(scriptdirname: str, extension: str):
|
||||
tmp_list: list[ScriptFile] = []
|
||||
base = os.path.join(paths.script_path, scriptdirname)
|
||||
if os.path.exists(base):
|
||||
for filename in sorted(os.listdir(base)):
|
||||
tmp_list.append(ScriptFile(paths.script_path, filename, os.path.join(base, filename), '50'))
|
||||
for ext in extensions.active():
|
||||
tmp_list += ext.list_files(scriptdirname, extension)
|
||||
priority_list = []
|
||||
priority_list: list[ScriptFile] = []
|
||||
for script in tmp_list:
|
||||
if os.path.splitext(script.path)[1].lower() == extension and os.path.isfile(script.path):
|
||||
if script.basedir == paths.script_path:
|
||||
|
|
@ -250,7 +271,7 @@ def list_scripts(scriptdirname, extension):
|
|||
return priority_sort
|
||||
|
||||
|
||||
def list_files_with_name(filename):
|
||||
def list_files_with_name(filename: str):
|
||||
res = []
|
||||
dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
|
||||
for dirpath in dirs:
|
||||
|
|
@ -273,7 +294,7 @@ def load_scripts():
|
|||
scripts_list = sorted(scripts_list, key=lambda item: item.priority + item.path.lower(), reverse=False)
|
||||
syspath = sys.path
|
||||
|
||||
def register_scripts_from_module(module, scriptfile):
|
||||
def register_scripts_from_module(module: ModuleType, scriptfile):
|
||||
for script_class in module.__dict__.values():
|
||||
if type(script_class) != type:
|
||||
continue
|
||||
|
|
@ -305,7 +326,7 @@ def load_scripts():
|
|||
return t, time.time()-t0
|
||||
|
||||
|
||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||
def wrap_call(func: Callable, filename: str, funcname: str, *args, default=None, **kwargs):
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
return res
|
||||
|
|
@ -315,7 +336,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
|||
|
||||
|
||||
class ScriptSummary:
|
||||
def __init__(self, op):
|
||||
def __init__(self, op: str):
|
||||
self.start = time.time()
|
||||
self.update = time.time()
|
||||
self.op = op
|
||||
|
|
@ -336,17 +357,17 @@ class ScriptSummary:
|
|||
class ScriptRunner:
|
||||
def __init__(self, name=''):
|
||||
self.name = name
|
||||
self.scripts = []
|
||||
self.selectable_scripts = []
|
||||
self.alwayson_scripts = []
|
||||
self.auto_processing_scripts = []
|
||||
self.titles = []
|
||||
self.alwayson_titles = []
|
||||
self.infotext_fields = []
|
||||
self.paste_field_names = []
|
||||
self.scripts: list[Script] = []
|
||||
self.selectable_scripts: list[Script] = []
|
||||
self.alwayson_scripts: list[Script] = []
|
||||
self.auto_processing_scripts: list[ScriptClassData] = []
|
||||
self.titles: list[str] = []
|
||||
self.alwayson_titles: list[str] = []
|
||||
self.infotext_fields: list[tuple[IOComponent, str]] = []
|
||||
self.paste_field_names: list[str] = []
|
||||
self.script_load_ctr = 0
|
||||
self.is_img2img = False
|
||||
self.inputs = [None]
|
||||
self.inputs: list = [None]
|
||||
self.time = 0
|
||||
|
||||
def add_script(self, script_class, path, is_img2img, is_control):
|
||||
|
|
@ -557,7 +578,7 @@ class ScriptRunner:
|
|||
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts if script.group is not None])
|
||||
return inputs
|
||||
|
||||
def run(self, p, *args):
|
||||
def run(self, p: StableDiffusionProcessing, *args) -> Processed | None:
|
||||
s = ScriptSummary('run')
|
||||
script_index = args[0] if len(args) > 0 else 0
|
||||
if (script_index is None) or (script_index == 0):
|
||||
|
|
@ -582,7 +603,7 @@ class ScriptRunner:
|
|||
s.report()
|
||||
return processed
|
||||
|
||||
def after(self, p, processed, *args):
|
||||
def after(self, p: StableDiffusionProcessing, processed: Processed, *args):
|
||||
s = ScriptSummary('after')
|
||||
script_index = args[0] if len(args) > 0 else 0
|
||||
if (script_index is None) or (script_index == 0):
|
||||
|
|
@ -600,7 +621,7 @@ class ScriptRunner:
|
|||
s.report()
|
||||
return processed
|
||||
|
||||
def before_process(self, p, **kwargs):
|
||||
def before_process(self, p: StableDiffusionProcessing, **kwargs):
|
||||
s = ScriptSummary('before-process')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -612,7 +633,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def process(self, p, **kwargs):
|
||||
def process(self, p: StableDiffusionProcessing, **kwargs):
|
||||
s = ScriptSummary('process')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -624,7 +645,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def process_images(self, p, **kwargs):
|
||||
def process_images(self, p: StableDiffusionProcessing, **kwargs):
|
||||
s = ScriptSummary('process_images')
|
||||
processed = None
|
||||
for script in self.alwayson_scripts:
|
||||
|
|
@ -640,7 +661,7 @@ class ScriptRunner:
|
|||
s.report()
|
||||
return processed
|
||||
|
||||
def before_process_batch(self, p, **kwargs):
|
||||
def before_process_batch(self, p: StableDiffusionProcessing, **kwargs):
|
||||
s = ScriptSummary('before-process-batch')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -652,7 +673,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def process_batch(self, p, **kwargs):
|
||||
def process_batch(self, p: StableDiffusionProcessing, **kwargs):
|
||||
s = ScriptSummary('process-batch')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -664,7 +685,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess(self, p, processed):
|
||||
def postprocess(self, p: StableDiffusionProcessing, processed):
|
||||
s = ScriptSummary('postprocess')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -676,7 +697,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess_batch(self, p, images, **kwargs):
|
||||
def postprocess_batch(self, p: StableDiffusionProcessing, images, **kwargs):
|
||||
s = ScriptSummary('postprocess-batch')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -688,7 +709,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
|
||||
def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, **kwargs):
|
||||
s = ScriptSummary('postprocess-batch-list')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -700,7 +721,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||
def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs):
|
||||
s = ScriptSummary('postprocess-image')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -712,7 +733,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
def before_component(self, component: IOComponent, **kwargs):
|
||||
s = ScriptSummary('before-component')
|
||||
for script in self.scripts:
|
||||
try:
|
||||
|
|
@ -722,7 +743,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
def after_component(self, component: IOComponent, **kwargs):
|
||||
s = ScriptSummary('after-component')
|
||||
for script in self.scripts:
|
||||
for elem_id, callback in script.on_after_component_elem_id:
|
||||
|
|
@ -738,7 +759,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def reload_sources(self, cache):
|
||||
def reload_sources(self, cache: dict):
|
||||
s = ScriptSummary('reload-sources')
|
||||
for si, script in list(enumerate(self.scripts)):
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from'):
|
||||
|
|
|
|||
|
|
@ -1418,7 +1418,7 @@ def hf_auth_check(checkpoint_info, force:bool=False):
|
|||
return False
|
||||
|
||||
|
||||
def save_model(name: str, path: str | None = None, shard: str | None = None, overwrite = False):
|
||||
def save_model(name: str, path: str | None = None, shard: str = "5GB", overwrite = False):
|
||||
if (name is None) or len(name.strip()) == 0:
|
||||
log.error('Save model: invalid model name')
|
||||
return 'Invalid model name'
|
||||
|
|
@ -1432,6 +1432,8 @@ def save_model(name: str, path: str | None = None, shard: str | None = None, ove
|
|||
if os.path.exists(model_name) and not overwrite:
|
||||
log.error(f'Save model: path="{model_name}" exists')
|
||||
return f'Path exists: {model_name}'
|
||||
if not shard.strip():
|
||||
shard = "5GB" # Guard against empty input
|
||||
try:
|
||||
t0 = time.time()
|
||||
save_sdnq_model(
|
||||
|
|
|
|||
|
|
@ -38,8 +38,8 @@ def load_unet_sdxl_nunchaku(repo_id):
|
|||
def load_unet(model, repo_id: str | None = None):
|
||||
global loaded_unet # pylint: disable=global-statement
|
||||
|
||||
if ("StableDiffusionXLPipeline" in model.__class__.__name__) and (('stable-diffusion-xl-base' in repo_id) or ('sdxl-turbo' in repo_id)):
|
||||
if model_quant.check_nunchaku('Model'):
|
||||
if ("StableDiffusionXLPipeline" in model.__class__.__name__) and repo_id is not None and (("stable-diffusion-xl-base" in repo_id) or ("sdxl-turbo" in repo_id)):
|
||||
if model_quant.check_nunchaku("Model"):
|
||||
unet = load_unet_sdxl_nunchaku(repo_id)
|
||||
if unet is not None:
|
||||
model.unet = unet
|
||||
|
|
|
|||
|
|
@ -1,4 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from modules.logger import log
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
|
||||
|
||||
repo_id = 'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf'
|
||||
policy_template = """Provide a safety assessment for the content provided by the user. Determine compliance or violation against our safety policy by reviewing the following policy categories:
|
||||
Hate:
|
||||
|
|
@ -89,11 +96,11 @@ To provide your assessment use the following json template for each category:
|
|||
"rationale": str,
|
||||
}.
|
||||
"""
|
||||
model = None
|
||||
processor = None
|
||||
model: LlavaOnevisionForConditionalGeneration | None = None
|
||||
processor: AutoProcessor | None = None
|
||||
|
||||
|
||||
def image_guard(image, policy:str | None=None) -> str:
|
||||
def image_guard(image, policy:str | None=None):
|
||||
global model, processor # pylint: disable=global-statement
|
||||
import json
|
||||
from installer import install
|
||||
|
|
|
|||
|
|
@ -49,14 +49,14 @@ def create_ui(accordion=True):
|
|||
|
||||
# main processing used in both modes
|
||||
def process(
|
||||
p: processing.StableDiffusionProcessing=None,
|
||||
pp: scripts.PostprocessImageArgs=None,
|
||||
p: processing.StableDiffusionProcessing | None = None,
|
||||
pp: scripts.PostprocessImageArgs | scripts_postprocessing.PostprocessedImage | None = None,
|
||||
enabled=True,
|
||||
lang=False,
|
||||
policy=False,
|
||||
banned=False,
|
||||
metadata=True,
|
||||
copy=False,
|
||||
copy=False, # Compatability
|
||||
score=0.2,
|
||||
blocks=3,
|
||||
censor=[],
|
||||
|
|
@ -74,7 +74,7 @@ def process(
|
|||
nudes = nudenet.detector.censor(image=pp.image, method=method, min_score=score, censor=censor, blocks=blocks, overlay=overlay)
|
||||
t1 = time.time()
|
||||
if len(nudes.censored) > 0: # Check if there are any censored areas
|
||||
if not copy:
|
||||
if p is None:
|
||||
pp.image = nudes.output
|
||||
else:
|
||||
info = processing.create_infotext(p)
|
||||
|
|
@ -85,7 +85,7 @@ def process(
|
|||
if metadata and p is not None:
|
||||
p.extra_generation_params["NudeNet"] = meta
|
||||
p.extra_generation_params["NSFW"] = nsfw
|
||||
if metadata and hasattr(pp, 'info'):
|
||||
if metadata and isinstance(pp, scripts_postprocessing.PostprocessedImage):
|
||||
pp.info['NudeNet'] = meta
|
||||
pp.info['NSFW'] = nsfw
|
||||
log.debug(f'NudeNet detect: {dct} nsfw={nsfw} time={(t1 - t0):.2f}')
|
||||
|
|
@ -118,7 +118,7 @@ def process(
|
|||
if metadata and p is not None:
|
||||
p.extra_generation_params["Rating"] = res.get('rating', 'N/A')
|
||||
p.extra_generation_params["Category"] = res.get('category', 'N/A')
|
||||
if metadata and hasattr(pp, 'info'):
|
||||
if metadata and isinstance(pp, scripts_postprocessing.PostprocessedImage):
|
||||
pp.info["Rating"] = res.get('rating', 'N/A')
|
||||
pp.info["Category"] = res.get('category', 'N/A')
|
||||
|
||||
|
|
@ -131,11 +131,11 @@ class ScriptNudeNet(scripts.Script):
|
|||
def title(self):
|
||||
return 'NudeNet'
|
||||
|
||||
def show(self, _is_img2img):
|
||||
def show(self, *args, **kwargs):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
# return signature is array of gradio components
|
||||
def ui(self, _is_img2img):
|
||||
def ui(self, *args, **kwargs):
|
||||
return create_ui(accordion=True)
|
||||
|
||||
# triggered by callback
|
||||
|
|
|
|||
Loading…
Reference in New Issue