mirror of https://github.com/vladmandic/automatic
major refactoring of modules
Signed-off-by: Vladimir Mandic <mandic00@live.com>pull/4013/head
parent
772a5c9ad3
commit
c4d9338d2e
56
.pylintrc
56
.pylintrc
|
|
@ -7,32 +7,27 @@ fail-on=
|
|||
fail-under=10
|
||||
ignore=CVS
|
||||
ignore-paths=/usr/lib/.*$,
|
||||
venv,
|
||||
.git,
|
||||
.ruff_cache,
|
||||
.vscode,
|
||||
modules/apg,
|
||||
modules/consistory,
|
||||
modules/cfgzero,
|
||||
modules/control/proc,
|
||||
modules/control/units,
|
||||
modules/ctrlx,
|
||||
modules/dml,
|
||||
modules/freescale,
|
||||
modules/flash_attn_triton_amd,
|
||||
modules/ggml,
|
||||
modules/hidiffusion,
|
||||
modules/hijack,
|
||||
modules/instantir,
|
||||
modules/hijack/ddpm_edit.py,
|
||||
modules/intel,
|
||||
modules/intel/ipex,
|
||||
modules/intel/openvino,
|
||||
modules/k-diffusion,
|
||||
modules/flex2,
|
||||
modules/ldsr,
|
||||
modules/hidream,
|
||||
modules/meissonic,
|
||||
modules/mod,
|
||||
modules/omnigen,
|
||||
modules/omnigen2,
|
||||
modules/onnx_impl,
|
||||
modules/pag,
|
||||
modules/pixelsmith,
|
||||
modules/postprocess/aurasr_arch.py,
|
||||
modules/prompt_parser_xhinker.py,
|
||||
modules/pulid/eva_clip,
|
||||
modules/ras,
|
||||
modules/rife,
|
||||
modules/schedulers,
|
||||
|
|
@ -40,15 +35,26 @@ ignore-paths=/usr/lib/.*$,
|
|||
modules/teacache,
|
||||
modules/todo,
|
||||
modules/unipc,
|
||||
modules/xadapter,
|
||||
modules/cfgzero,
|
||||
modules/infiniteyou,
|
||||
modules/flash_attn_triton_amd,
|
||||
scripts/softfill.py,
|
||||
pipelines/flex2,
|
||||
pipelines/hidream,
|
||||
pipelines/meissonic,
|
||||
pipelines/omnigen2,
|
||||
pipelines/segmoe,
|
||||
scripts/consistory,
|
||||
scripts/ctrlx,
|
||||
scripts/demofusion,
|
||||
scripts/freescale,
|
||||
scripts/infiniteyou,
|
||||
scripts/instantir,
|
||||
scripts/mod,
|
||||
scripts/pixelsmith,
|
||||
scripts/differential_diffusion.py,
|
||||
scripts/pulid,
|
||||
scripts/xadapter,
|
||||
repositories,
|
||||
extensions-builtin/Lora,
|
||||
extensions-builtin/sd-webui-agent-scheduler,
|
||||
extensions-builtin/sd-extension-chainner/nodes,
|
||||
extensions-builtin/sd-webui-agent-scheduler,
|
||||
extensions-builtin/sdnext-modernui/node_modules,
|
||||
ignore-patterns=.*test*.py$,
|
||||
.*_model.py$,
|
||||
|
|
@ -158,8 +164,8 @@ disable=abstract-method,
|
|||
consider-using-generator,
|
||||
consider-using-get,
|
||||
consider-using-in,
|
||||
consider-using-min-builtin,
|
||||
consider-using-max-builtin,
|
||||
consider-using-min-builtin,
|
||||
consider-using-sys-exit,
|
||||
cyclic-import,
|
||||
dangerous-default-value,
|
||||
|
|
@ -175,6 +181,7 @@ disable=abstract-method,
|
|||
missing-class-docstring,
|
||||
missing-function-docstring,
|
||||
missing-module-docstring,
|
||||
no-else-raise,
|
||||
no-else-return,
|
||||
not-callable,
|
||||
pointless-string-statement,
|
||||
|
|
@ -185,16 +192,17 @@ disable=abstract-method,
|
|||
too-many-instance-attributes,
|
||||
too-many-locals,
|
||||
too-many-nested-blocks,
|
||||
too-many-statements,
|
||||
too-many-positional-arguments,
|
||||
too-many-statements,
|
||||
unidiomatic-typecheck,
|
||||
unknown-option-value,
|
||||
unnecessary-dict-index-lookup,
|
||||
unnecessary-dunder-call,
|
||||
unnecessary-lambda,
|
||||
unnecessary-lambda-assigment,
|
||||
unnecessary-lambda,
|
||||
unused-wildcard-import,
|
||||
use-dict-literal,
|
||||
use-symbolic-message-instead,
|
||||
unknown-option-value,
|
||||
useless-suppression,
|
||||
wrong-import-position,
|
||||
enable=c-extension-no-member
|
||||
|
|
|
|||
57
.ruff.toml
57
.ruff.toml
|
|
@ -3,44 +3,33 @@ exclude = [
|
|||
".git",
|
||||
".ruff_cache",
|
||||
".vscode",
|
||||
"modules/apg",
|
||||
"modules/consistory",
|
||||
|
||||
"modules/cfgzero",
|
||||
"modules/flash_attn_triton_amd",
|
||||
"modules/hidiffusion",
|
||||
"modules/intel/ipex",
|
||||
"modules/k-diffusion",
|
||||
"modules/pag",
|
||||
"modules/schedulers",
|
||||
"modules/teacache",
|
||||
|
||||
"modules/control/proc",
|
||||
"modules/control/units",
|
||||
"modules/freescale",
|
||||
"modules/flex2",
|
||||
"modules/ggml",
|
||||
"modules/hidiffusion",
|
||||
"modules/hijack",
|
||||
"modules/instantir",
|
||||
"modules/intel/ipex",
|
||||
"modules/intel/openvino",
|
||||
"modules/k-diffusion",
|
||||
"modules/ldsr",
|
||||
"modules/meissonic",
|
||||
"modules/mod",
|
||||
"modules/omnigen",
|
||||
"modules/omnigen2",
|
||||
"modules/hidream",
|
||||
"modules/pag",
|
||||
"modules/pixelsmith",
|
||||
"modules/control/units/xs_pipe.py",
|
||||
"modules/postprocess/aurasr_arch.py",
|
||||
"modules/prompt_parser_xhinker.py",
|
||||
"modules/pulid/eva_clip",
|
||||
"modules/ras",
|
||||
"modules/rife",
|
||||
"modules/schedulers",
|
||||
"modules/segmoe",
|
||||
"modules/taesd",
|
||||
"modules/teacache",
|
||||
"modules/todo",
|
||||
"modules/unipc",
|
||||
"modules/cfgzero",
|
||||
"modules/xadapter",
|
||||
"modules/infiniteyou",
|
||||
"modules/flash_attn_triton_amd",
|
||||
"scripts/softfill.py",
|
||||
|
||||
"pipelines/meissonic",
|
||||
"pipelines/omnigen2",
|
||||
"pipelines/segmoe",
|
||||
|
||||
"scripts/xadapter",
|
||||
"scripts/pulid",
|
||||
"scripts/instantir",
|
||||
"scripts/freescale",
|
||||
"scripts/consistory",
|
||||
|
||||
"repositories",
|
||||
|
||||
"extensions-builtin/Lora",
|
||||
"extensions-builtin/sd-extension-chainner/nodes",
|
||||
"extensions-builtin/sd-webui-agent-scheduler",
|
||||
|
|
|
|||
18
CHANGELOG.md
18
CHANGELOG.md
|
|
@ -1,6 +1,6 @@
|
|||
# Change Log for SD.Next
|
||||
|
||||
## Update for 2025-07-02
|
||||
## Update for 2025-07-03
|
||||
|
||||
- **Models**
|
||||
- Add **FLUX.1-Kontext-Dev** inpaint workflow
|
||||
|
|
@ -12,16 +12,20 @@
|
|||
enable in *settings -> compute settings -> sdp options*
|
||||
*note*: SD.Next will use either SageAttention v1 or v2, depending which one is installed
|
||||
until authors provide pre-build wheels for v2, you need to install it manually or SD.Next will auto-install v1
|
||||
- **Core**
|
||||
- override `gradio` installer
|
||||
- major refactoring of requirements and dependencies to unblock `numpy>=2.1.0`
|
||||
- patch `insightface`
|
||||
- patch `k-diffusion`
|
||||
- better handle startup import errors
|
||||
- **Fixes**
|
||||
- allow theme type `None` to be set in config
|
||||
- installer dont cache installed state
|
||||
- fix Cosmos-Predict2 retrying TAESD download
|
||||
- better handle startup import errors
|
||||
- **Refactoring**
|
||||
- override `gradio` installer
|
||||
- major refactoring of requirements and dependencies to unblock `numpy>=2.1.0`
|
||||
- patch `insightface`
|
||||
- patch `k-diffusion`
|
||||
- cleanup `/modules`: move pipeline loaders to `/pipelines` root
|
||||
- cleanup `/modules`: move code folders used by scripts to `/scripts/<script>` folder
|
||||
- cleanup `/modules`: global rename `modules.scripts` to avoid conflict with `/scripts`
|
||||
- stronger lint rules
|
||||
|
||||
## Update for 2025-06-30
|
||||
|
||||
|
|
|
|||
22
TODO.md
22
TODO.md
|
|
@ -8,9 +8,14 @@ Main ToDo list can be found at [GitHub projects](https://github.com/users/vladma
|
|||
|
||||
- Refactor: Move `model_*` stuff into subfolder
|
||||
- Refactor: sampler options
|
||||
- Common repo for `T5` and `CLiP`
|
||||
- Feature: Common repo for `T5` and `CLiP`
|
||||
- Feature: LoRA add OMI format support for SD35/FLUX.1
|
||||
- Feature: Merge FramePack into core
|
||||
- Remove: legacy LoRA loader
|
||||
- Remove: Original backend
|
||||
- Remove: Agent Scheduler
|
||||
- Video: API support
|
||||
- LoRA: add OMI format support for SD35/FLUX.1
|
||||
- ModernUI: Lite vs Expert mode
|
||||
|
||||
### Blocked items
|
||||
|
||||
|
|
@ -66,11 +71,13 @@ Main ToDo list can be found at [GitHub projects](https://github.com/users/vladma
|
|||
|
||||
## Code TODO
|
||||
|
||||
> pnpm lint | grep W0511 | awk -F'TODO ' '{print "- "$NF}' | sed 's/ (fixme)//g'
|
||||
> pnpm lint | grep W0511 | awk -F'TODO ' '{print "- "$NF}' | sed 's/ (fixme)//g' | sort
|
||||
|
||||
- control: support scripts via api
|
||||
- fc: autodetect distilled based on model
|
||||
- fc: autodetect tensor format based on model
|
||||
- flux transformer from-single-file with quant
|
||||
- flux: loader for civitai nf4 models
|
||||
- hypertile: vae breaks when using non-standard sizes
|
||||
- install: enable ROCm for windows when available
|
||||
- loader: load receipe
|
||||
|
|
@ -78,11 +85,12 @@ Main ToDo list can be found at [GitHub projects](https://github.com/users/vladma
|
|||
- lora: add other quantization types
|
||||
- lora: add t5 key support for sd35/f1
|
||||
- lora: maybe force imediate quantization
|
||||
- lora: support pre-quantized flux
|
||||
- model fix: cogview4: balanced offload does not work for GlmModel
|
||||
- model load: add ChromaFillPipeline, ChromaControlPipeline, ChromaImg2ImgPipeline etc when available
|
||||
- model load: chroma transformer from-single-file with quant
|
||||
- model load: force-reloading entire model as loading transformers only leads to massive memory usage
|
||||
- model loader: implement model in-memory caching
|
||||
- model load: implement model in-memory caching
|
||||
- modernui: monkey-patch for missing tabs.select event
|
||||
- modules/lora/lora_extract.py:188:9: W0511: TODO: lora: support pre-quantized flux
|
||||
- nunchaku: batch support
|
||||
- nunchaku: cache-dir for transformer and t5 loader
|
||||
- processing: remove duplicate mask params
|
||||
- resize image: enable full VAE mode for resize-latent
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import network_overrides
|
|||
import lora_convert
|
||||
import torch
|
||||
import diffusers.models.lora
|
||||
from modules import shared, devices, sd_models, sd_models_compile, errors, scripts, files_cache, model_quant
|
||||
from modules import shared, devices, sd_models, sd_models_compile, errors, scripts_manager, files_cache, model_quant
|
||||
|
||||
|
||||
debug = os.environ.get('SD_LORA_DEBUG', None) is not None
|
||||
|
|
@ -580,7 +580,7 @@ def list_available_networks():
|
|||
|
||||
|
||||
def infotext_pasted(infotext, params): # pylint: disable=W0613
|
||||
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
||||
if "AddNet Module 1" in [x[1] for x in scripts_manager.scripts_txt2img.infotext_fields]:
|
||||
return # if the other extension is active, it will handle those fields, no need to do anything
|
||||
added = []
|
||||
for k in params:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Subproject commit c2c0c27939e20303d63c30ea7a5e179697af8eaf
|
||||
Subproject commit c006b67aeabd8bd404de41a446e1a4f6ce4a892f
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from threading import Lock
|
||||
from fastapi.responses import JSONResponse
|
||||
from modules import errors, shared, scripts, ui
|
||||
from modules import errors, shared, scripts_manager, ui
|
||||
from modules.api import models, script, helpers
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ class APIGenerate():
|
|||
|
||||
def post_text2img(self, txt2imgreq: models.ReqTxt2Img):
|
||||
self.prepare_face_module(txt2imgreq)
|
||||
script_runner = scripts.scripts_txt2img
|
||||
script_runner = scripts_manager.scripts_txt2img
|
||||
if not script_runner.scripts:
|
||||
script_runner.initialize_scripts(False)
|
||||
ui.create_ui(None)
|
||||
|
|
@ -113,10 +113,10 @@ class APIGenerate():
|
|||
script_args = script.init_script_args(p, txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
if selectable_scripts is not None:
|
||||
processed = scripts.scripts_txt2img.run(p, *script_args) # Need to pass args as list here
|
||||
processed = scripts_manager.scripts_txt2img.run(p, *script_args) # Need to pass args as list here
|
||||
else:
|
||||
processed = process_images(p)
|
||||
processed = scripts.scripts_txt2img.after(p, processed, *script_args)
|
||||
processed = scripts_manager.scripts_txt2img.after(p, processed, *script_args)
|
||||
p.close()
|
||||
shared.state.end(api=False)
|
||||
if processed is None or processed.images is None or len(processed.images) == 0:
|
||||
|
|
@ -135,7 +135,7 @@ class APIGenerate():
|
|||
mask = img2imgreq.mask
|
||||
if mask:
|
||||
mask = helpers.decode_base64_to_image(mask)
|
||||
script_runner = scripts.scripts_img2img
|
||||
script_runner = scripts_manager.scripts_img2img
|
||||
if not script_runner.scripts:
|
||||
script_runner.initialize_scripts(True)
|
||||
ui.create_ui(None)
|
||||
|
|
@ -165,10 +165,10 @@ class APIGenerate():
|
|||
script_args = script.init_script_args(p, img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
if selectable_scripts is not None:
|
||||
processed = scripts.scripts_img2img.run(p, *script_args) # Need to pass args as list here
|
||||
processed = scripts_manager.scripts_img2img.run(p, *script_args) # Need to pass args as list here
|
||||
else:
|
||||
processed = process_images(p)
|
||||
processed = scripts.scripts_img2img.after(p, processed, *script_args)
|
||||
processed = scripts_manager.scripts_img2img.after(p, processed, *script_args)
|
||||
p.close()
|
||||
shared.state.end(api=False)
|
||||
if processed is None or processed.images is None or len(processed.images) == 0:
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ class APIProcess():
|
|||
seed = processing_helpers.get_fixed_seed(seed)
|
||||
prompt = ''
|
||||
if req.type == 'text':
|
||||
from modules.scripts import scripts_txt2img
|
||||
from modules.scripts_manager import scripts_txt2img
|
||||
model = 'google/gemma-3-1b-it' if req.model is None or len(req.model) < 4 else req.model
|
||||
instance = [s for s in scripts_txt2img.scripts if 'prompt_enhance.py' in s.filename][0]
|
||||
prompt = instance.enhance(
|
||||
|
|
@ -149,7 +149,7 @@ class APIProcess():
|
|||
nsfw=req.nsfw,
|
||||
)
|
||||
elif req.type == 'image':
|
||||
from modules.scripts import scripts_txt2img
|
||||
from modules.scripts_manager import scripts_txt2img
|
||||
model = 'google/gemma-3-4b-it' if req.model is None or len(req.model) < 4 else req.model
|
||||
instance = [s for s in scripts_txt2img.scripts if 'prompt_enhance.py' in s.filename][0]
|
||||
prompt = instance.enhance(
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ from typing import Optional
|
|||
from fastapi.exceptions import HTTPException
|
||||
import gradio as gr
|
||||
from modules.api import models
|
||||
from modules import scripts
|
||||
from modules.errors import log
|
||||
from modules import scripts_manager
|
||||
|
||||
|
||||
def script_name_to_index(name, scripts_list):
|
||||
|
|
@ -30,15 +30,15 @@ def get_selectable_script(script_name, script_runner):
|
|||
|
||||
|
||||
def get_scripts_list():
|
||||
t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
|
||||
i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
|
||||
control = [script.name for script in scripts.scripts_control.scripts if script.name is not None]
|
||||
t2ilist = [script.name for script in scripts_manager.scripts_txt2img.scripts if script.name is not None]
|
||||
i2ilist = [script.name for script in scripts_manager.scripts_img2img.scripts if script.name is not None]
|
||||
control = [script.name for script in scripts_manager.scripts_control.scripts if script.name is not None]
|
||||
return models.ResScripts(txt2img = t2ilist, img2img = i2ilist, control = control)
|
||||
|
||||
|
||||
def get_script_info(script_name: Optional[str] = None):
|
||||
res = []
|
||||
for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts, scripts.scripts_control.scripts]:
|
||||
for script_list in [scripts_manager.scripts_txt2img.scripts, scripts_manager.scripts_img2img.scripts, scripts_manager.scripts_control.scripts]:
|
||||
for script in script_list:
|
||||
if script.api_info is not None and (script_name is None or script_name == script.api_info.name):
|
||||
res.append(script.api_info)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from modules.control.units import xs # VisLearn ControlNet-XS
|
|||
from modules.control.units import lite # Kohya ControlLLLite
|
||||
from modules.control.units import t2iadapter # TencentARC T2I-Adapter
|
||||
from modules.control.units import reference # ControlNet-Reference
|
||||
from modules import devices, shared, errors, processing, images, sd_models, scripts, masking
|
||||
from modules import devices, shared, errors, processing, images, sd_models, scripts_manager, masking
|
||||
from modules.processing_class import StableDiffusionProcessingControl
|
||||
from modules.ui_common import infotext_to_html
|
||||
from modules.api import script
|
||||
|
|
@ -737,10 +737,10 @@ def control_run(state: str = '',
|
|||
if sd_models.get_diffusers_task(pipe) != sd_models.DiffusersTaskType.TEXT_2_IMAGE: # force vae back to gpu if not in txt2img mode
|
||||
sd_models.move_model(pipe.vae, devices.device)
|
||||
|
||||
p.scripts = scripts.scripts_control
|
||||
p.scripts = scripts_manager.scripts_control
|
||||
p.script_args = input_script_args or []
|
||||
if len(p.script_args) == 0:
|
||||
script_runner = scripts.scripts_control
|
||||
script_runner = scripts_manager.scripts_control
|
||||
if not script_runner.scripts:
|
||||
script_runner.initialize_scripts(False)
|
||||
p.script_args = script.init_default_script_args(script_runner)
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class Extension:
|
|||
self.remote = None
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
from modules import scripts
|
||||
from modules import scripts_manager
|
||||
dirpath = os.path.join(self.path, subdir)
|
||||
if not os.path.isdir(dirpath):
|
||||
return []
|
||||
|
|
@ -89,7 +89,7 @@ class Extension:
|
|||
if os.path.isfile(os.path.join(dirpath, "..", ".priority")):
|
||||
with open(os.path.join(dirpath, "..", ".priority"), "r", encoding="utf-8") as f:
|
||||
priority = str(f.read().strip())
|
||||
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
|
||||
res.append(scripts_manager.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
|
||||
if priority != '50':
|
||||
shared.log.debug(f'Extension priority override: {os.path.dirname(dirpath)}:{priority}')
|
||||
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from modules import scripts, processing, shared, images
|
||||
from modules import scripts_manager, processing, shared, images
|
||||
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_FACE_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
class Script(scripts_manager.Script):
|
||||
original_pipeline = None
|
||||
original_prompt_attention = None
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class MetaData():
|
|||
rotary_cos: Optional[torch.Tensor] = None
|
||||
rotary_interleaved: bool = False
|
||||
rotary_conjunction: bool = False
|
||||
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"MetaData(\n"
|
||||
|
|
@ -161,7 +161,7 @@ def generate_varlen_tensor(
|
|||
if batch_size is None:
|
||||
valid_batch_sizes = [bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen]
|
||||
batch_size = random.choice(valid_batch_sizes)
|
||||
|
||||
|
||||
# get seqlens
|
||||
if equal_seqlens:
|
||||
seqlens = torch.full(
|
||||
|
|
@ -241,14 +241,14 @@ def input_helper(
|
|||
TOTAL_SEQLENS_Q = BATCH * N_CTX_Q
|
||||
TOTAL_SEQLENS_K = BATCH * N_CTX_K
|
||||
equal_seqlens=False
|
||||
|
||||
|
||||
# gen tensors
|
||||
# TODO: the gen functions should maybe have different gen modes like random, ones, increasing seqlen
|
||||
q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT)
|
||||
k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT)
|
||||
v, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT)
|
||||
do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q)
|
||||
|
||||
|
||||
# setup metadata
|
||||
if DEBUG_INPUT:
|
||||
sm_scale = 1
|
||||
|
|
@ -369,7 +369,7 @@ def get_shape_from_layout(
|
|||
raise ValueError("cu_seqlens must be provided for varlen (thd) layout")
|
||||
if max_seqlen is None:
|
||||
raise ValueError("max_seqlen must be provided for varlen (thd) layout")
|
||||
|
||||
|
||||
batch, max_seqlen_final, num_heads, head_dim = len(cu_seqlens) - 1, max_seqlen, num_heads, head_dim
|
||||
else:
|
||||
assert False, "Got unsupported layout."
|
||||
|
|
@ -380,7 +380,7 @@ def get_shape_from_layout(
|
|||
def get_shapes_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None):
|
||||
batch_q, seqlen_q, nheads_q, head_size_q = get_shape_from_layout(q, layout, cu_seqlens_q, max_seqlen_q)
|
||||
batch_k, seqlen_k, nheads_k, head_size_k = get_shape_from_layout(k, layout, cu_seqlens_k, max_seqlen_k)
|
||||
|
||||
|
||||
# assert
|
||||
assert batch_q == batch_k
|
||||
assert head_size_q == head_size_k
|
||||
|
|
@ -458,22 +458,22 @@ def write_dropout_mask(x, tensor_name = "tensor"):
|
|||
if True:
|
||||
BLOCK_M = 64
|
||||
BLOCK_N = 64
|
||||
|
||||
|
||||
# Calculate number of blocks in each dimension
|
||||
m_blocks = math.ceil(seqlen_m / BLOCK_M)
|
||||
n_blocks = math.ceil(seqlen_n / BLOCK_N)
|
||||
|
||||
|
||||
# Process each block
|
||||
for m_block in range(m_blocks):
|
||||
# Calculate row range for current block
|
||||
row_start = m_block * BLOCK_M
|
||||
row_end = min(row_start + BLOCK_M, seqlen_m)
|
||||
|
||||
|
||||
for n_block in range(n_blocks):
|
||||
# Calculate column range for current block
|
||||
col_start = n_block * BLOCK_N
|
||||
col_end = min(col_start + BLOCK_N, seqlen_n)
|
||||
|
||||
|
||||
# Extract and write the current block
|
||||
for row_idx in range(row_start, row_end):
|
||||
row_data = dropout_mask[row_idx][col_start:col_end]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from PIL import Image
|
||||
import gradio as gr
|
||||
import gradio.processing_utils
|
||||
from modules import scripts, patches, gr_tempdir
|
||||
from modules import scripts_manager, patches, gr_tempdir
|
||||
|
||||
|
||||
hijacked = False
|
||||
|
|
@ -44,14 +44,14 @@ def add_classes_to_gradio_component(comp):
|
|||
|
||||
def IOComponent_init(self, *args, **kwargs):
|
||||
self.webui_tooltip = kwargs.pop('tooltip', None)
|
||||
if scripts.scripts_current is not None:
|
||||
scripts.scripts_current.before_component(self, **kwargs)
|
||||
scripts.script_callbacks.before_component_callback(self, **kwargs)
|
||||
if scripts_manager.scripts_current is not None:
|
||||
scripts_manager.scripts_current.before_component(self, **kwargs)
|
||||
scripts_manager.script_callbacks.before_component_callback(self, **kwargs)
|
||||
res = original_IOComponent_init(self, *args, **kwargs) # pylint: disable=assignment-from-no-return
|
||||
add_classes_to_gradio_component(self)
|
||||
scripts.script_callbacks.after_component_callback(self, **kwargs)
|
||||
if scripts.scripts_current is not None:
|
||||
scripts.scripts_current.after_component(self, **kwargs)
|
||||
scripts_manager.script_callbacks.after_component_callback(self, **kwargs)
|
||||
if scripts_manager.scripts_current is not None:
|
||||
scripts_manager.scripts_current.after_component(self, **kwargs)
|
||||
return res
|
||||
|
||||
|
||||
|
|
@ -65,14 +65,14 @@ def Block_get_config(self):
|
|||
|
||||
|
||||
def BlockContext_init(self, *args, **kwargs):
|
||||
if scripts.scripts_current is not None:
|
||||
scripts.scripts_current.before_component(self, **kwargs)
|
||||
scripts.script_callbacks.before_component_callback(self, **kwargs)
|
||||
if scripts_manager.scripts_current is not None:
|
||||
scripts_manager.scripts_current.before_component(self, **kwargs)
|
||||
scripts_manager.script_callbacks.before_component_callback(self, **kwargs)
|
||||
res = original_BlockContext_init(self, *args, **kwargs) # pylint: disable=assignment-from-no-return
|
||||
add_classes_to_gradio_component(self)
|
||||
scripts.script_callbacks.after_component_callback(self, **kwargs)
|
||||
if scripts.scripts_current is not None:
|
||||
scripts.scripts_current.after_component(self, **kwargs)
|
||||
scripts_manager.script_callbacks.after_component_callback(self, **kwargs)
|
||||
if scripts_manager.scripts_current is not None:
|
||||
scripts_manager.scripts_current.after_component(self, **kwargs)
|
||||
return res
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,14 +5,14 @@ def isinstance_str(x: object, cls_name: str):
|
|||
"""
|
||||
Checks whether x has any class *named* cls_name in its ancestry.
|
||||
Doesn't require access to the class's implementation.
|
||||
|
||||
|
||||
Useful for patching!
|
||||
"""
|
||||
|
||||
for _cls in x.__class__.__mro__:
|
||||
if _cls.__name__ == cls_name:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -29,4 +29,3 @@ def init_generator(device: torch.device, fallback: torch.Generator=None):
|
|||
return init_generator(torch.device("cpu"))
|
||||
else:
|
||||
return fallback
|
||||
|
||||
|
|
@ -3,8 +3,7 @@ import itertools # SBM Batch frames
|
|||
import numpy as np
|
||||
import filetype
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
||||
import modules.scripts
|
||||
from modules import shared, processing, images
|
||||
from modules import scripts_manager, shared, processing, images
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.ui_common import plaintext_to_html
|
||||
from modules.memstats import memory_stats
|
||||
|
|
@ -100,7 +99,7 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args)
|
|||
|
||||
batch_image_files = batch_image_files * btcrept # List used for naming later.
|
||||
|
||||
processed = modules.scripts.scripts_img2img.run(p, *args)
|
||||
processed = scripts_manager.scripts_img2img.run(p, *args)
|
||||
if processed is None:
|
||||
processed = processing.process_images(p)
|
||||
|
||||
|
|
@ -124,7 +123,7 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args)
|
|||
for k, v in items.items():
|
||||
image.info[k] = v
|
||||
images.save_image(image, path=output_dir, basename=basename, seed=None, prompt=None, extension=ext, info=geninfo, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=image.info, forced_filename=forced_filename)
|
||||
processed = modules.scripts.scripts_img2img.after(p, processed, *args)
|
||||
processed = scripts_manager.scripts_img2img.after(p, processed, *args)
|
||||
shared.log.debug(f'Processed: images={len(batch_image_files)} memory={memory_stats()} batch')
|
||||
|
||||
|
||||
|
|
@ -289,7 +288,7 @@ def img2img(id_task: str, state: str, mode: int,
|
|||
# override
|
||||
override_settings=override_settings,
|
||||
)
|
||||
p.scripts = modules.scripts.scripts_img2img
|
||||
p.scripts = scripts_manager.scripts_img2img
|
||||
p.script_args = args
|
||||
p.state = state
|
||||
if mask:
|
||||
|
|
@ -304,10 +303,10 @@ def img2img(id_task: str, state: str, mode: int,
|
|||
process_batch(p, img2img_batch_files, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
|
||||
processed = processing.Processed(p, [], p.seed, "")
|
||||
else:
|
||||
processed = modules.scripts.scripts_img2img.run(p, *args)
|
||||
processed = scripts_manager.scripts_img2img.run(p, *args)
|
||||
if processed is None:
|
||||
processed = processing.process_images(p)
|
||||
processed = modules.scripts.scripts_img2img.after(p, processed, *args)
|
||||
processed = scripts_manager.scripts_img2img.after(p, processed, *args)
|
||||
p.close()
|
||||
generation_info_js = processed.js() if processed is not None else ''
|
||||
if processed is None:
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ class VQModel(pl.LightningModule):
|
|||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) # noqa: NPY002
|
||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = x.detach()
|
||||
|
|
|
|||
|
|
@ -347,7 +347,7 @@ def sdnq_quantize_model(model, op=None, sd_model=None, do_gc: bool = True, weigh
|
|||
return_device = None
|
||||
|
||||
if getattr(model, "_keep_in_fp32_modules", None) is not None:
|
||||
modules_to_not_convert.extend(model._keep_in_fp32_modules)
|
||||
modules_to_not_convert.extend(model._keep_in_fp32_modules) # pylint: disable=protected-access
|
||||
if model.__class__.__name__ == "ChromaTransformer2DModel":
|
||||
modules_to_not_convert.append("distilled_guidance_layer")
|
||||
|
||||
|
|
|
|||
|
|
@ -70,13 +70,13 @@ def load_t5(name=None, cache_dir=None):
|
|||
|
||||
elif 'int8' in name.lower():
|
||||
from modules.model_quant import create_sdnq_config
|
||||
quantization_config = create_sdnq_config(kwargs=None, allow_sdnq=True, module='any', weights_dtype='int8')
|
||||
quantization_config = create_sdnq_config(kwargs=None, allow=True, module='any', weights_dtype='int8')
|
||||
if quantization_config is not None:
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
elif 'uint4' in name.lower():
|
||||
from modules.model_quant import create_sdnq_config
|
||||
quantization_config = create_sdnq_config(kwargs=None, allow_sdnq=True, module='any', weights_dtype='uint4')
|
||||
quantization_config = create_sdnq_config(kwargs=None, allow=True, module='any', weights_dtype='uint4')
|
||||
if quantization_config is not None:
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import List
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from modules import shared, images, devices, scripts, scripts_postprocessing, infotext
|
||||
from modules import shared, images, devices, scripts_manager, scripts_postprocessing, infotext
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
|
|
@ -70,7 +70,7 @@ def run_postprocessing(extras_mode, image, image_folder: List[tempfile.NamedTemp
|
|||
continue
|
||||
shared.state.textinfo = name
|
||||
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
||||
scripts.scripts_postproc.run(pp, args)
|
||||
scripts_manager.scripts_postproc.run(pp, args)
|
||||
geninfo, items = images.read_info_from_image(image)
|
||||
params = infotext.parse(geninfo)
|
||||
for k, v in items.items():
|
||||
|
|
@ -89,7 +89,7 @@ def run_postprocessing(extras_mode, image, image_folder: List[tempfile.NamedTemp
|
|||
if extras_mode != 2 or show_extras_results:
|
||||
outputs.append(pp.image)
|
||||
image.close()
|
||||
scripts.scripts_postproc.postprocess(processed_images, args)
|
||||
scripts_manager.scripts_postproc.postprocess(processed_images, args)
|
||||
|
||||
devices.torch_gc()
|
||||
return outputs, info, params
|
||||
|
|
@ -98,7 +98,7 @@ def run_postprocessing(extras_mode, image, image_folder: List[tempfile.NamedTemp
|
|||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): #pylint: disable=unused-argument
|
||||
"""old handler for API"""
|
||||
|
||||
args = scripts.scripts_postproc.create_args_for_run({
|
||||
args = scripts_manager.scripts_postproc.create_args_for_run({
|
||||
"Upscale": {
|
||||
"upscale_mode": resize_mode,
|
||||
"upscale_by": upscaling_resize,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import time
|
|||
from contextlib import nullcontext
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
from modules import shared, devices, errors, images, scripts, memstats, lowvram, script_callbacks, extra_networks, detailer, sd_models, sd_checkpoint, sd_vae, processing_helpers, timer, face_restoration, token_merge
|
||||
from modules import shared, devices, errors, images, scripts_manager, memstats, lowvram, script_callbacks, extra_networks, detailer, sd_models, sd_checkpoint, sd_vae, processing_helpers, timer, face_restoration, token_merge
|
||||
from modules.sd_hijack_hypertile import context_hypertile_vae, context_hypertile_unet
|
||||
from modules.processing_class import StableDiffusionProcessing, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, StableDiffusionProcessingControl, StableDiffusionProcessingVideo # pylint: disable=unused-import
|
||||
from modules.processing_info import create_infotext
|
||||
|
|
@ -128,7 +128,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
debug(f'Process images: {vars(p)}')
|
||||
if not hasattr(p.sd_model, 'sd_checkpoint_info'):
|
||||
return None
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner):
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
|
||||
p.scripts.before_process(p)
|
||||
stored_opts = {}
|
||||
for k, v in p.override_settings.copy().items():
|
||||
|
|
@ -290,7 +290,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
process_init(p)
|
||||
if not shared.native and os.path.exists(shared.opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||
modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=False)
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner):
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
|
||||
p.scripts.process(p)
|
||||
|
||||
ema_scope_context = p.sd_model.ema_scope if not shared.native else nullcontext
|
||||
|
|
@ -324,19 +324,19 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n+1) * p.batch_size]
|
||||
p.seeds = p.all_seeds[n * p.batch_size:(n+1) * p.batch_size]
|
||||
p.subseeds = p.all_subseeds[n * p.batch_size:(n+1) * p.batch_size]
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner):
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
|
||||
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||
if len(p.prompts) == 0:
|
||||
break
|
||||
p.prompts, p.network_data = extra_networks.parse_prompts(p.prompts)
|
||||
if not shared.native:
|
||||
extra_networks.activate(p, p.network_data)
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner):
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
|
||||
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||
|
||||
samples = None
|
||||
timer.process.record('init')
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner):
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
|
||||
processed = p.scripts.process_images(p)
|
||||
if processed is not None:
|
||||
samples = processed.images
|
||||
|
|
@ -358,12 +358,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if not shared.native and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
|
||||
lowvram.send_everything_to_cpu()
|
||||
devices.torch_gc()
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner):
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
|
||||
p.scripts.postprocess_batch(p, samples, batch_number=n)
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner):
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
|
||||
p.prompts = p.all_prompts[n * p.batch_size:(n+1) * p.batch_size]
|
||||
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n+1) * p.batch_size]
|
||||
batch_params = scripts.PostprocessBatchListArgs(list(samples))
|
||||
batch_params = scripts_manager.PostprocessBatchListArgs(list(samples))
|
||||
p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
|
||||
samples = batch_params.images
|
||||
|
||||
|
|
@ -402,8 +402,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
|
||||
images.save_image(image_without_cc, path=p.outpath_samples, basename="", seed=p.seeds[i], prompt=p.prompts[i], extension=shared.opts.samples_format, info=info, p=p, suffix="-before-color-correct")
|
||||
image = apply_color_correction(p.color_corrections[i], image)
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner):
|
||||
pp = scripts.PostprocessImageArgs(image)
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
|
||||
pp = scripts_manager.PostprocessImageArgs(image)
|
||||
p.scripts.postprocess_image(p, pp)
|
||||
if pp.image is not None:
|
||||
image = pp.image
|
||||
|
|
@ -496,7 +496,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
index_of_first_image=index_of_first_image,
|
||||
infotexts=infotexts,
|
||||
)
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner) and not (shared.state.interrupted or shared.state.skipped):
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner) and not (shared.state.interrupted or shared.state.skipped):
|
||||
p.scripts.postprocess(p, processed)
|
||||
timer.process.record('post')
|
||||
if not p.disable_extra_networks:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image, ImageOps
|
||||
from modules import shared, devices, images, scripts, masking, sd_samplers, sd_models, processing_helpers
|
||||
from modules import shared, devices, images, scripts_manager, masking, sd_samplers, sd_models, processing_helpers
|
||||
from modules.sd_hijack_hypertile import hypertile_set
|
||||
|
||||
|
||||
|
|
@ -278,7 +278,7 @@ class StableDiffusionProcessing:
|
|||
self.prompt_for_display: str = None
|
||||
|
||||
# scripts
|
||||
self.scripts_value: scripts.ScriptRunner = field(default=None, init=False)
|
||||
self.scripts_value: scripts_manager.ScriptRunner = field(default=None, init=False)
|
||||
self.script_args_value: list = field(default=None, init=False)
|
||||
self.scripts_setup_complete: bool = field(default=False, init=False)
|
||||
self.script_args = script_args
|
||||
|
|
|
|||
|
|
@ -731,7 +731,7 @@ def get_weighted_text_embeddings_sdxl_refiner(
|
|||
|
||||
for z in range(len(neg_weight_tensor_2)):
|
||||
if neg_weight_tensor_2[z] != 1.0:
|
||||
ow = neg_weight_tensor_2[z] - 1
|
||||
# ow = neg_weight_tensor_2[z] - 1
|
||||
# neg_weight = 1 + (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
|
||||
|
||||
# add weight method 1:
|
||||
|
|
@ -1330,7 +1330,6 @@ def get_weighted_text_embeddings_sd3(
|
|||
sd3_neg_prompt_embeds = torch.cat([clip_neg_prompt_embeds, t5_neg_prompt_embeds], dim=-2)
|
||||
|
||||
# padding
|
||||
import torch.nn.functional as F
|
||||
size_diff = sd3_neg_prompt_embeds.size(1) - sd3_prompt_embeds.size(1)
|
||||
# Calculate padding. Format for pad is (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
|
||||
# Since we are padding along the second dimension (axis=1), we need (0, 0, padding_top, padding_bottom, 0, 0)
|
||||
|
|
|
|||
|
|
@ -1,730 +1,2 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
import gradio as gr
|
||||
from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
|
||||
from installer import control_extensions
|
||||
|
||||
|
||||
AlwaysVisible = object()
|
||||
time_component = {}
|
||||
time_setup = {}
|
||||
debug = errors.log.trace if os.environ.get('SD_SCRIPT_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
class PostprocessImageArgs:
|
||||
def __init__(self, image):
|
||||
self.image = image
|
||||
|
||||
|
||||
class PostprocessBatchListArgs:
|
||||
def __init__(self, images):
|
||||
self.images = images
|
||||
|
||||
|
||||
@dataclass
|
||||
class OnComponent:
|
||||
component: gr.blocks.Block
|
||||
|
||||
|
||||
class Script:
|
||||
parent = None
|
||||
name = None
|
||||
filename = None
|
||||
args_from = 0
|
||||
args_to = 0
|
||||
alwayson = False
|
||||
is_txt2img = False
|
||||
is_img2img = False
|
||||
api_info = None
|
||||
group = None
|
||||
infotext_fields = None
|
||||
paste_field_names = None
|
||||
section = None
|
||||
standalone = False
|
||||
on_before_component_elem_id = [] # list of callbacks to be called before a component with an elem_id is created
|
||||
on_after_component_elem_id = [] # list of callbacks to be called after a component with an elem_id is created
|
||||
|
||||
def title(self):
|
||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||
raise NotImplementedError
|
||||
|
||||
def ui(self, is_img2img):
|
||||
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
|
||||
The return value should be an array of all components that are used in processing.
|
||||
Values of those returned components will be passed to run() and process() functions.
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def show(self, is_img2img): # pylint: disable=unused-argument
|
||||
"""
|
||||
is_img2img is True if this function is called for the img2img interface, and False otherwise
|
||||
This function should return:
|
||||
- False if the script should not be shown in UI at all
|
||||
- True if the script should be shown in UI if it's selected in the scripts dropdown
|
||||
- script.AlwaysVisible if the script should be shown in UI at all times
|
||||
"""
|
||||
return True
|
||||
|
||||
def run(self, p, *args):
|
||||
"""
|
||||
This function is called if the script has been selected in the script dropdown.
|
||||
It must do all processing and return the Processed object with results, same as
|
||||
one returned by processing.process_images.
|
||||
Usually the processing is done by calling the processing.process_images function.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def setup(self, p, *args):
|
||||
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
|
||||
args contains all values returned by components from ui().
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_process(self, p, *args):
|
||||
"""
|
||||
This function is called very early during processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process(self, p, *args):
|
||||
"""
|
||||
This function is called before processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process_images(self, p, *args):
|
||||
"""
|
||||
This function is called instead of main processing for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_process_batch(self, p, *args, **kwargs):
|
||||
"""
|
||||
Called before extra networks are parsed from the prompt, so you can add
|
||||
new extra network keywords to the prompt with this callback.
|
||||
**kwargs will have those items:
|
||||
- batch_number - index of current batch, from 0 to number of batches-1
|
||||
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
||||
- seeds - list of seeds for current batch
|
||||
- subseeds - list of subseeds for current batch
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process_batch(self, p, *args, **kwargs):
|
||||
"""
|
||||
Same as process(), but called for every batch.
|
||||
**kwargs will have those items:
|
||||
- batch_number - index of current batch, from 0 to number of batches-1
|
||||
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
||||
- seeds - list of seeds for current batch
|
||||
- subseeds - list of subseeds for current batch
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_batch(self, p, *args, **kwargs):
|
||||
"""
|
||||
Same as process_batch(), but called for every batch after it has been generated.
|
||||
**kwargs will have same items as process_batch, and also:
|
||||
- batch_number - index of current batch, from 0 to number of batches-1
|
||||
- images - torch tensor with all generated images, with values ranging from 0 to 1;
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
|
||||
"""
|
||||
Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
|
||||
This is useful when you want to update the entire batch instead of individual images.
|
||||
You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
|
||||
If the number of images is different from the batch size when returning,
|
||||
then the script has the responsibility to also update the following attributes in the processing object (p):
|
||||
- p.prompts
|
||||
- p.negative_prompts
|
||||
- p.seeds
|
||||
- p.subseeds
|
||||
**kwargs will have same items as process_batch, and also:
|
||||
- batch_number - index of current batch, from 0 to number of batches-1
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
"""
|
||||
Called before a component is created.
|
||||
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
||||
You can return created components in the ui() function to add them to the list of arguments for your processing functions
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
"""
|
||||
Called after a component is created. Same as above.
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def describe(self):
|
||||
"""unused"""
|
||||
return ""
|
||||
|
||||
def elem_id(self, item_id):
|
||||
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
||||
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
||||
return f'script_{self.parent}_{title}_{item_id}'
|
||||
|
||||
|
||||
current_basedir = paths.script_path
|
||||
|
||||
|
||||
def basedir():
|
||||
"""returns the base directory for the current script. For scripts in the main scripts directory,
|
||||
this is the main directory (where webui.py resides), and for scripts in extensions directory
|
||||
(ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
|
||||
"""
|
||||
return current_basedir
|
||||
|
||||
|
||||
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path", "priority"])
|
||||
scripts_data = []
|
||||
postprocessing_scripts_data = []
|
||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||
|
||||
|
||||
def list_scripts(scriptdirname, extension):
|
||||
tmp_list = []
|
||||
base = os.path.join(paths.script_path, scriptdirname)
|
||||
if os.path.exists(base):
|
||||
for filename in sorted(os.listdir(base)):
|
||||
tmp_list.append(ScriptFile(paths.script_path, filename, os.path.join(base, filename), '50'))
|
||||
for ext in extensions.active():
|
||||
tmp_list += ext.list_files(scriptdirname, extension)
|
||||
priority_list = []
|
||||
for script in tmp_list:
|
||||
if os.path.splitext(script.path)[1].lower() == extension and os.path.isfile(script.path):
|
||||
if script.basedir == paths.script_path:
|
||||
priority = '0'
|
||||
elif script.basedir.startswith(os.path.join(paths.script_path, 'scripts')):
|
||||
priority = '1'
|
||||
elif script.basedir.startswith(os.path.join(paths.script_path, 'extensions-builtin')):
|
||||
priority = '2'
|
||||
elif script.basedir.startswith(os.path.join(paths.script_path, 'extensions')):
|
||||
priority = '3'
|
||||
else:
|
||||
priority = '9'
|
||||
if os.path.isfile(os.path.join(base, "..", ".priority")):
|
||||
with open(os.path.join(base, "..", ".priority"), "r", encoding="utf-8") as f:
|
||||
priority = priority + str(f.read().strip())
|
||||
errors.log.debug(f'Script priority override: ${script.name}:{priority}')
|
||||
else:
|
||||
priority = priority + script.priority
|
||||
priority_list.append(ScriptFile(script.basedir, script.filename, script.path, priority))
|
||||
debug(f'Adding script: folder="{script.basedir}" file="{script.filename}" full="{script.path}" priority={priority}')
|
||||
priority_sort = sorted(priority_list, key=lambda item: item.priority + item.path.lower(), reverse=False)
|
||||
return priority_sort
|
||||
|
||||
|
||||
def list_files_with_name(filename):
|
||||
res = []
|
||||
dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
|
||||
for dirpath in dirs:
|
||||
if not os.path.isdir(dirpath):
|
||||
continue
|
||||
path = os.path.join(dirpath, filename)
|
||||
if os.path.isfile(path):
|
||||
res.append(path)
|
||||
return res
|
||||
|
||||
|
||||
def load_scripts():
|
||||
t = timer.Timer()
|
||||
t0 = time.time()
|
||||
global current_basedir # pylint: disable=global-statement
|
||||
scripts_data.clear()
|
||||
postprocessing_scripts_data.clear()
|
||||
script_callbacks.clear_callbacks()
|
||||
scripts_list = list_scripts('scripts', '.py') + list_scripts(os.path.join('modules', 'face'), '.py')
|
||||
scripts_list = sorted(scripts_list, key=lambda item: item.priority + item.path.lower(), reverse=False)
|
||||
syspath = sys.path
|
||||
|
||||
def register_scripts_from_module(module, scriptfile):
|
||||
for script_class in module.__dict__.values():
|
||||
if type(script_class) != type:
|
||||
continue
|
||||
debug(f'Registering script: path="{scriptfile.path}"')
|
||||
if issubclass(script_class, Script):
|
||||
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
||||
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
|
||||
for scriptfile in scripts_list:
|
||||
try:
|
||||
if scriptfile.basedir != paths.script_path:
|
||||
sys.path = [scriptfile.basedir] + sys.path
|
||||
current_basedir = scriptfile.basedir
|
||||
script_module = script_loading.load_module(scriptfile.path)
|
||||
register_scripts_from_module(script_module, scriptfile)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Load script: {scriptfile.filename}')
|
||||
finally:
|
||||
current_basedir = paths.script_path
|
||||
t.record(os.path.basename(scriptfile.basedir) if scriptfile.basedir != paths.script_path else scriptfile.filename)
|
||||
sys.path = syspath
|
||||
|
||||
global scripts_txt2img, scripts_img2img, scripts_control, scripts_postproc # pylint: disable=global-statement
|
||||
scripts_txt2img = ScriptRunner('txt2img')
|
||||
scripts_img2img = ScriptRunner('img2img')
|
||||
scripts_control = ScriptRunner('control')
|
||||
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||
return t, time.time()-t0
|
||||
|
||||
|
||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
return res
|
||||
except Exception as e:
|
||||
errors.display(e, f'Calling script: {filename}/{funcname}')
|
||||
return default
|
||||
|
||||
|
||||
class ScriptSummary:
|
||||
def __init__(self, op):
|
||||
self.start = time.time()
|
||||
self.update = time.time()
|
||||
self.op = op
|
||||
self.time = {}
|
||||
|
||||
def record(self, script):
|
||||
self.update = time.time()
|
||||
self.time[script] = round(time.time() - self.update, 2)
|
||||
|
||||
def report(self):
|
||||
total = sum(self.time.values())
|
||||
if total == 0:
|
||||
return
|
||||
scripts = [f'{k}:{v}' for k, v in self.time.items() if v > 0]
|
||||
errors.log.debug(f'Script: op={self.op} total={total} scripts={scripts}')
|
||||
|
||||
|
||||
class ScriptRunner:
|
||||
def __init__(self, name=''):
|
||||
self.name = name
|
||||
self.scripts = []
|
||||
self.selectable_scripts = []
|
||||
self.alwayson_scripts = []
|
||||
self.auto_processing_scripts = []
|
||||
self.titles = []
|
||||
self.infotext_fields = []
|
||||
self.paste_field_names = []
|
||||
self.script_load_ctr = 0
|
||||
self.is_img2img = False
|
||||
self.inputs = [None]
|
||||
self.time = 0
|
||||
|
||||
def add_script(self, script_class, path, is_img2img, is_control):
|
||||
try:
|
||||
script = script_class()
|
||||
script.filename = path
|
||||
script.is_txt2img = not is_img2img
|
||||
script.is_img2img = is_img2img
|
||||
if is_control: # this is messy but show is a legacy function that is not aware of control tab
|
||||
v1 = script.show(script.is_txt2img)
|
||||
v2 = script.show(script.is_img2img)
|
||||
if v1 == AlwaysVisible or v2 == AlwaysVisible:
|
||||
visibility = AlwaysVisible
|
||||
else:
|
||||
visibility = v1 or v2
|
||||
else:
|
||||
visibility = script.show(script.is_img2img)
|
||||
if visibility == AlwaysVisible:
|
||||
self.scripts.append(script)
|
||||
self.alwayson_scripts.append(script)
|
||||
script.alwayson = True
|
||||
elif visibility:
|
||||
self.scripts.append(script)
|
||||
self.selectable_scripts.append(script)
|
||||
except Exception as e:
|
||||
errors.log.error(f'Script initialize: {path} {e}')
|
||||
|
||||
def initialize_scripts(self, is_img2img=False, is_control=False):
|
||||
from modules import scripts_auto_postprocessing
|
||||
|
||||
self.scripts.clear()
|
||||
self.selectable_scripts.clear()
|
||||
self.alwayson_scripts.clear()
|
||||
self.titles.clear()
|
||||
self.infotext_fields.clear()
|
||||
self.paste_field_names.clear()
|
||||
self.script_load_ctr = 0
|
||||
self.is_img2img = is_img2img
|
||||
self.scripts.clear()
|
||||
self.alwayson_scripts.clear()
|
||||
self.selectable_scripts.clear()
|
||||
self.auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||
|
||||
sorted_scripts = sorted(scripts_data, key=lambda x: x.script_class().title().lower())
|
||||
for script_class, path, _basedir, _script_module in sorted_scripts:
|
||||
self.add_script(script_class, path, is_img2img, is_control)
|
||||
sorted_scripts = sorted(self.auto_processing_scripts, key=lambda x: x.script_class().title().lower())
|
||||
for script_class, path, _basedir, _script_module in sorted_scripts:
|
||||
self.add_script(script_class, path, is_img2img, is_control)
|
||||
|
||||
def prepare_ui(self):
|
||||
self.inputs = [None]
|
||||
|
||||
def setup_ui(self, parent='unknown', accordion=True):
|
||||
import modules.api.models as api_models
|
||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||
|
||||
inputs = []
|
||||
inputs_alwayson = [True]
|
||||
|
||||
def create_script_ui(script: Script, inputs, inputs_alwayson):
|
||||
script.parent = parent
|
||||
script.args_from = len(inputs)
|
||||
script.args_to = len(inputs)
|
||||
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
||||
if controls is None:
|
||||
return
|
||||
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
|
||||
api_args = []
|
||||
for control in controls:
|
||||
debug(f'Script control: parent={script.parent} script="{script.name}" label="{control.label}" type={control} id={control.elem_id}')
|
||||
if hasattr(gr.components, 'IOComponent'):
|
||||
if not isinstance(control, gr.components.IOComponent):
|
||||
errors.log.error(f'Invalid script control: "{script.filename}" control={control}')
|
||||
continue
|
||||
else:
|
||||
if not isinstance(control, gr.components.Component):
|
||||
errors.log.error(f'Invalid script control: "{script.filename}" control={control}')
|
||||
continue
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
arg_info = api_models.ScriptArg(label=control.label or "")
|
||||
for field in ("value", "minimum", "maximum", "step", "choices"):
|
||||
v = getattr(control, field, None)
|
||||
if v is not None:
|
||||
setattr(arg_info, field, v)
|
||||
api_args.append(arg_info)
|
||||
|
||||
script.api_info = api_models.ItemScript(
|
||||
name=script.name,
|
||||
is_img2img=script.is_img2img,
|
||||
is_alwayson=script.alwayson,
|
||||
args=api_args,
|
||||
)
|
||||
if script.infotext_fields is not None:
|
||||
self.infotext_fields += script.infotext_fields
|
||||
if script.paste_field_names is not None:
|
||||
self.paste_field_names += script.paste_field_names
|
||||
inputs += controls
|
||||
inputs_alwayson += [script.alwayson for _ in controls]
|
||||
script.args_to = len(inputs)
|
||||
|
||||
with gr.Row():
|
||||
dropdown = gr.Dropdown(label="Script", elem_id=f'{parent}_script_list', choices=["None"] + self.titles, value="None", type="index")
|
||||
inputs.insert(0, dropdown)
|
||||
|
||||
with gr.Row():
|
||||
for script in self.alwayson_scripts:
|
||||
if not script.standalone:
|
||||
continue
|
||||
if (self.name == 'control') and (script.name not in control_extensions) and (script.title() not in control_extensions):
|
||||
errors.log.debug(f'Script: fn="{script.filename}" type={self.name} skip')
|
||||
continue
|
||||
t0 = time.time()
|
||||
with gr.Group(elem_id=f'{parent}_script_{script.title().lower().replace(" ", "_")}', elem_classes=['group-extension']) as group:
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
script.group = group
|
||||
time_setup[script.title()] = time_setup.get(script.title(), 0) + (time.time()-t0)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Accordion(label="Extensions", elem_id=f'{parent}_script_alwayson') if accordion else gr.Group():
|
||||
for script in self.alwayson_scripts:
|
||||
if script.standalone:
|
||||
continue
|
||||
if (self.name == 'control') and (paths.extensions_dir in script.filename) and (script.title() not in control_extensions):
|
||||
errors.log.debug(f'Script: fn="{script.filename}" type={self.name} skip')
|
||||
continue
|
||||
t0 = time.time()
|
||||
with gr.Group(elem_id=f'{parent}_script_{script.title().lower().replace(" ", "_")}', elem_classes=['group-extension']) as group:
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
script.group = group
|
||||
time_setup[script.title()] = time_setup.get(script.title(), 0) + (time.time()-t0)
|
||||
|
||||
for script in self.selectable_scripts:
|
||||
if (self.name == 'control') and (paths.extensions_dir in script.filename) and (script.title() not in control_extensions):
|
||||
errors.log.debug(f'Script: fn="{script.filename}" type={self.name} skip')
|
||||
continue
|
||||
with gr.Group(elem_id=f'{parent}_script_{script.title().lower().replace(" ", "_")}', elem_classes=['group-scripts'], visible=False) as group:
|
||||
t0 = time.time()
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
time_setup[script.title()] = time_setup.get(script.title(), 0) + (time.time()-t0)
|
||||
script.group = group
|
||||
|
||||
def select_script(script_index):
|
||||
if script_index is None:
|
||||
return [gr.update(visible=False) for script in self.selectable_scripts]
|
||||
selected_script = self.selectable_scripts[script_index - 1] if script_index > 0 else None
|
||||
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
|
||||
|
||||
def init_field(title):
|
||||
if title == 'None': # called when an initial value is set from ui-config.json to show script's UI components
|
||||
return
|
||||
if title not in self.titles:
|
||||
errors.log.error(f'Script not found: {title}')
|
||||
return
|
||||
script_index = self.titles.index(title)
|
||||
self.selectable_scripts[script_index].group.visible = True
|
||||
|
||||
dropdown.init_field = init_field
|
||||
dropdown.change(fn=select_script, inputs=[dropdown], outputs=[script.group for script in self.selectable_scripts if script.group is not None])
|
||||
|
||||
def onload_script_visibility(params):
|
||||
title = params.get('Script', None)
|
||||
if title:
|
||||
title_index = self.titles.index(title)
|
||||
visibility = title_index == self.script_load_ctr
|
||||
self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
|
||||
return gr.update(visible=visibility)
|
||||
else:
|
||||
return gr.update(visible=False)
|
||||
|
||||
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
|
||||
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts if script.group is not None])
|
||||
return inputs
|
||||
|
||||
def run(self, p, *args):
|
||||
s = ScriptSummary('run')
|
||||
script_index = args[0] if len(args) > 0 else 0
|
||||
if script_index == 0:
|
||||
return None
|
||||
script = self.selectable_scripts[script_index-1]
|
||||
if script is None:
|
||||
return None
|
||||
if 'upscale' in script.title():
|
||||
if not hasattr(p, 'init_images') and p.task_args.get('image', None) is not None:
|
||||
p.init_images = p.task_args['image']
|
||||
parsed = []
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from'):
|
||||
parsed = p.per_script_args.get(script.title(), args[script.args_from:script.args_to])
|
||||
if hasattr(script, 'run'):
|
||||
processed = script.run(p, *parsed)
|
||||
else:
|
||||
processed = None
|
||||
errors.log.error(f'Script: file="{script.filename}" no run function defined')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
return processed
|
||||
|
||||
def after(self, p, processed, *args):
|
||||
s = ScriptSummary('after')
|
||||
script_index = args[0] if len(args) > 0 else 0
|
||||
if script_index == 0:
|
||||
return processed
|
||||
script = self.selectable_scripts[script_index-1]
|
||||
if script is None or not hasattr(script, 'after'):
|
||||
return processed
|
||||
parsed = []
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from'):
|
||||
parsed = p.per_script_args.get(script.title(), args[script.args_from:script.args_to])
|
||||
after_processed = script.after(p, processed, *parsed)
|
||||
if after_processed is not None:
|
||||
processed = after_processed
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
return processed
|
||||
|
||||
def before_process(self, p, **kwargs):
|
||||
s = ScriptSummary('before-process')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.before_process(p, *args, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f"Error running before process: {script.filename}")
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def process(self, p, **kwargs):
|
||||
s = ScriptSummary('process')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.process(p, *args, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script process: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def process_images(self, p, **kwargs):
|
||||
s = ScriptSummary('process_images')
|
||||
processed = None
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
_processed = script.process_images(p, *args, **kwargs)
|
||||
if _processed is not None:
|
||||
processed = _processed
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script process images: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
return processed
|
||||
|
||||
def before_process_batch(self, p, **kwargs):
|
||||
s = ScriptSummary('before-process-batch')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.before_process_batch(p, *args, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script before process batch: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def process_batch(self, p, **kwargs):
|
||||
s = ScriptSummary('process-batch')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.process_batch(p, *args, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script process batch: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess(self, p, processed):
|
||||
s = ScriptSummary('postprocess')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.postprocess(p, processed, *args)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script postprocess: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess_batch(self, p, images, **kwargs):
|
||||
s = ScriptSummary('postprocess-batch')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.postprocess_batch(p, *args, images=images, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script before postprocess batch: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
|
||||
s = ScriptSummary('postprocess-batch-list')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.postprocess_batch_list(p, pp, *args, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script before postprocess batch list: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||
s = ScriptSummary('postprocess-image')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.postprocess_image(p, pp, *args)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script postprocess image: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
s = ScriptSummary('before-component')
|
||||
for script in self.scripts:
|
||||
try:
|
||||
script.before_component(component, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script before component: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
s = ScriptSummary('after-component')
|
||||
for script in self.scripts:
|
||||
for elem_id, callback in script.on_after_component_elem_id:
|
||||
if elem_id == kwargs.get("elem_id"):
|
||||
try:
|
||||
callback(OnComponent(component=component))
|
||||
except Exception as e:
|
||||
errors.display(e, f"Running script before_component_elem_id: {script.filename}")
|
||||
try:
|
||||
script.after_component(component, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script after component: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def reload_sources(self, cache):
|
||||
s = ScriptSummary('reload-sources')
|
||||
for si, script in list(enumerate(self.scripts)):
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from'):
|
||||
args_from = script.args_from
|
||||
args_to = script.args_to
|
||||
filename = script.filename
|
||||
module = cache.get(filename, None)
|
||||
if module is None:
|
||||
module = script_loading.load_module(script.filename)
|
||||
cache[filename] = module
|
||||
for script_class in module.__dict__.values():
|
||||
if type(script_class) == type and issubclass(script_class, Script):
|
||||
self.scripts[si] = script_class()
|
||||
self.scripts[si].filename = filename
|
||||
self.scripts[si].args_from = args_from
|
||||
self.scripts[si].args_to = args_to
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
|
||||
scripts_txt2img: ScriptRunner = None
|
||||
scripts_img2img: ScriptRunner = None
|
||||
scripts_control: ScriptRunner = None
|
||||
scripts_current: ScriptRunner = None
|
||||
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
|
||||
reload_scripts = load_scripts # compatibility alias
|
||||
|
||||
|
||||
def reload_script_body_only():
|
||||
cache = {}
|
||||
scripts_txt2img.reload_sources(cache)
|
||||
scripts_img2img.reload_sources(cache)
|
||||
scripts_control.reload_sources(cache)
|
||||
# compatibility with extensions that import scripts directly
|
||||
from modules.scripts_manager import * # noqa: F403
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from modules import scripts, scripts_postprocessing, shared
|
||||
from modules import scripts_manager, scripts_postprocessing, shared
|
||||
|
||||
|
||||
class ScriptPostprocessingForMainUI(scripts.Script):
|
||||
class ScriptPostprocessingForMainUI(scripts_manager.Script):
|
||||
def __init__(self, script_postproc):
|
||||
self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
|
||||
self.postprocessing_controls = None
|
||||
|
|
@ -10,7 +10,7 @@ class ScriptPostprocessingForMainUI(scripts.Script):
|
|||
return self.script.name
|
||||
|
||||
def show(self, is_img2img): # pylint: disable=unused-argument
|
||||
return scripts.AlwaysVisible
|
||||
return scripts_manager.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img): # pylint: disable=unused-argument
|
||||
self.postprocessing_controls = self.script.ui()
|
||||
|
|
@ -28,9 +28,9 @@ class ScriptPostprocessingForMainUI(scripts.Script):
|
|||
def create_auto_preprocessing_script_data():
|
||||
res = []
|
||||
for name in shared.opts.postprocessing_enable_in_main_ui:
|
||||
script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
|
||||
script = next(iter([x for x in scripts_manager.postprocessing_scripts_data if x.script_class.name == name]), None)
|
||||
if script is None:
|
||||
continue
|
||||
constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class()) # pylint: disable=unnecessary-lambda-assignment
|
||||
res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
|
||||
res.append(scripts_manager.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -0,0 +1,730 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
import gradio as gr
|
||||
from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
|
||||
from installer import control_extensions
|
||||
|
||||
|
||||
AlwaysVisible = object()
|
||||
time_component = {}
|
||||
time_setup = {}
|
||||
debug = errors.log.trace if os.environ.get('SD_SCRIPT_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
class PostprocessImageArgs:
|
||||
def __init__(self, image):
|
||||
self.image = image
|
||||
|
||||
|
||||
class PostprocessBatchListArgs:
|
||||
def __init__(self, images):
|
||||
self.images = images
|
||||
|
||||
|
||||
@dataclass
|
||||
class OnComponent:
|
||||
component: gr.blocks.Block
|
||||
|
||||
|
||||
class Script:
|
||||
parent = None
|
||||
name = None
|
||||
filename = None
|
||||
args_from = 0
|
||||
args_to = 0
|
||||
alwayson = False
|
||||
is_txt2img = False
|
||||
is_img2img = False
|
||||
api_info = None
|
||||
group = None
|
||||
infotext_fields = None
|
||||
paste_field_names = None
|
||||
section = None
|
||||
standalone = False
|
||||
on_before_component_elem_id = [] # list of callbacks to be called before a component with an elem_id is created
|
||||
on_after_component_elem_id = [] # list of callbacks to be called after a component with an elem_id is created
|
||||
|
||||
def title(self):
|
||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||
raise NotImplementedError
|
||||
|
||||
def ui(self, is_img2img):
|
||||
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
|
||||
The return value should be an array of all components that are used in processing.
|
||||
Values of those returned components will be passed to run() and process() functions.
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def show(self, is_img2img): # pylint: disable=unused-argument
|
||||
"""
|
||||
is_img2img is True if this function is called for the img2img interface, and False otherwise
|
||||
This function should return:
|
||||
- False if the script should not be shown in UI at all
|
||||
- True if the script should be shown in UI if it's selected in the scripts dropdown
|
||||
- script.AlwaysVisible if the script should be shown in UI at all times
|
||||
"""
|
||||
return True
|
||||
|
||||
def run(self, p, *args):
|
||||
"""
|
||||
This function is called if the script has been selected in the script dropdown.
|
||||
It must do all processing and return the Processed object with results, same as
|
||||
one returned by processing.process_images.
|
||||
Usually the processing is done by calling the processing.process_images function.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def setup(self, p, *args):
|
||||
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
|
||||
args contains all values returned by components from ui().
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_process(self, p, *args):
|
||||
"""
|
||||
This function is called very early during processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process(self, p, *args):
|
||||
"""
|
||||
This function is called before processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process_images(self, p, *args):
|
||||
"""
|
||||
This function is called instead of main processing for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_process_batch(self, p, *args, **kwargs):
|
||||
"""
|
||||
Called before extra networks are parsed from the prompt, so you can add
|
||||
new extra network keywords to the prompt with this callback.
|
||||
**kwargs will have those items:
|
||||
- batch_number - index of current batch, from 0 to number of batches-1
|
||||
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
||||
- seeds - list of seeds for current batch
|
||||
- subseeds - list of subseeds for current batch
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process_batch(self, p, *args, **kwargs):
|
||||
"""
|
||||
Same as process(), but called for every batch.
|
||||
**kwargs will have those items:
|
||||
- batch_number - index of current batch, from 0 to number of batches-1
|
||||
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
||||
- seeds - list of seeds for current batch
|
||||
- subseeds - list of subseeds for current batch
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_batch(self, p, *args, **kwargs):
|
||||
"""
|
||||
Same as process_batch(), but called for every batch after it has been generated.
|
||||
**kwargs will have same items as process_batch, and also:
|
||||
- batch_number - index of current batch, from 0 to number of batches-1
|
||||
- images - torch tensor with all generated images, with values ranging from 0 to 1;
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
|
||||
"""
|
||||
Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
|
||||
This is useful when you want to update the entire batch instead of individual images.
|
||||
You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
|
||||
If the number of images is different from the batch size when returning,
|
||||
then the script has the responsibility to also update the following attributes in the processing object (p):
|
||||
- p.prompts
|
||||
- p.negative_prompts
|
||||
- p.seeds
|
||||
- p.subseeds
|
||||
**kwargs will have same items as process_batch, and also:
|
||||
- batch_number - index of current batch, from 0 to number of batches-1
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
"""
|
||||
Called before a component is created.
|
||||
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
||||
You can return created components in the ui() function to add them to the list of arguments for your processing functions
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
"""
|
||||
Called after a component is created. Same as above.
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def describe(self):
|
||||
"""unused"""
|
||||
return ""
|
||||
|
||||
def elem_id(self, item_id):
|
||||
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
||||
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
||||
return f'script_{self.parent}_{title}_{item_id}'
|
||||
|
||||
|
||||
current_basedir = paths.script_path
|
||||
|
||||
|
||||
def basedir():
|
||||
"""returns the base directory for the current script. For scripts in the main scripts directory,
|
||||
this is the main directory (where webui.py resides), and for scripts in extensions directory
|
||||
(ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
|
||||
"""
|
||||
return current_basedir
|
||||
|
||||
|
||||
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path", "priority"])
|
||||
scripts_data = []
|
||||
postprocessing_scripts_data = []
|
||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||
|
||||
|
||||
def list_scripts(scriptdirname, extension):
|
||||
tmp_list = []
|
||||
base = os.path.join(paths.script_path, scriptdirname)
|
||||
if os.path.exists(base):
|
||||
for filename in sorted(os.listdir(base)):
|
||||
tmp_list.append(ScriptFile(paths.script_path, filename, os.path.join(base, filename), '50'))
|
||||
for ext in extensions.active():
|
||||
tmp_list += ext.list_files(scriptdirname, extension)
|
||||
priority_list = []
|
||||
for script in tmp_list:
|
||||
if os.path.splitext(script.path)[1].lower() == extension and os.path.isfile(script.path):
|
||||
if script.basedir == paths.script_path:
|
||||
priority = '0'
|
||||
elif script.basedir.startswith(os.path.join(paths.script_path, 'scripts')):
|
||||
priority = '1'
|
||||
elif script.basedir.startswith(os.path.join(paths.script_path, 'extensions-builtin')):
|
||||
priority = '2'
|
||||
elif script.basedir.startswith(os.path.join(paths.script_path, 'extensions')):
|
||||
priority = '3'
|
||||
else:
|
||||
priority = '9'
|
||||
if os.path.isfile(os.path.join(base, "..", ".priority")):
|
||||
with open(os.path.join(base, "..", ".priority"), "r", encoding="utf-8") as f:
|
||||
priority = priority + str(f.read().strip())
|
||||
errors.log.debug(f'Script priority override: ${script.name}:{priority}')
|
||||
else:
|
||||
priority = priority + script.priority
|
||||
priority_list.append(ScriptFile(script.basedir, script.filename, script.path, priority))
|
||||
debug(f'Adding script: folder="{script.basedir}" file="{script.filename}" full="{script.path}" priority={priority}')
|
||||
priority_sort = sorted(priority_list, key=lambda item: item.priority + item.path.lower(), reverse=False)
|
||||
return priority_sort
|
||||
|
||||
|
||||
def list_files_with_name(filename):
|
||||
res = []
|
||||
dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
|
||||
for dirpath in dirs:
|
||||
if not os.path.isdir(dirpath):
|
||||
continue
|
||||
path = os.path.join(dirpath, filename)
|
||||
if os.path.isfile(path):
|
||||
res.append(path)
|
||||
return res
|
||||
|
||||
|
||||
def load_scripts():
|
||||
t = timer.Timer()
|
||||
t0 = time.time()
|
||||
global current_basedir # pylint: disable=global-statement
|
||||
scripts_data.clear()
|
||||
postprocessing_scripts_data.clear()
|
||||
script_callbacks.clear_callbacks()
|
||||
scripts_list = list_scripts('scripts', '.py') + list_scripts(os.path.join('modules', 'face'), '.py')
|
||||
scripts_list = sorted(scripts_list, key=lambda item: item.priority + item.path.lower(), reverse=False)
|
||||
syspath = sys.path
|
||||
|
||||
def register_scripts_from_module(module, scriptfile):
|
||||
for script_class in module.__dict__.values():
|
||||
if type(script_class) != type:
|
||||
continue
|
||||
debug(f'Registering script: path="{scriptfile.path}"')
|
||||
if issubclass(script_class, Script):
|
||||
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
||||
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
|
||||
for scriptfile in scripts_list:
|
||||
try:
|
||||
if scriptfile.basedir != paths.script_path:
|
||||
sys.path = [scriptfile.basedir] + sys.path
|
||||
current_basedir = scriptfile.basedir
|
||||
script_module = script_loading.load_module(scriptfile.path)
|
||||
register_scripts_from_module(script_module, scriptfile)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Load script: {scriptfile.filename}')
|
||||
finally:
|
||||
current_basedir = paths.script_path
|
||||
t.record(os.path.basename(scriptfile.basedir) if scriptfile.basedir != paths.script_path else scriptfile.filename)
|
||||
sys.path = syspath
|
||||
|
||||
global scripts_txt2img, scripts_img2img, scripts_control, scripts_postproc # pylint: disable=global-statement
|
||||
scripts_txt2img = ScriptRunner('txt2img')
|
||||
scripts_img2img = ScriptRunner('img2img')
|
||||
scripts_control = ScriptRunner('control')
|
||||
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||
return t, time.time()-t0
|
||||
|
||||
|
||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
return res
|
||||
except Exception as e:
|
||||
errors.display(e, f'Calling script: {filename}/{funcname}')
|
||||
return default
|
||||
|
||||
|
||||
class ScriptSummary:
|
||||
def __init__(self, op):
|
||||
self.start = time.time()
|
||||
self.update = time.time()
|
||||
self.op = op
|
||||
self.time = {}
|
||||
|
||||
def record(self, script):
|
||||
self.update = time.time()
|
||||
self.time[script] = round(time.time() - self.update, 2)
|
||||
|
||||
def report(self):
|
||||
total = sum(self.time.values())
|
||||
if total == 0:
|
||||
return
|
||||
scripts = [f'{k}:{v}' for k, v in self.time.items() if v > 0]
|
||||
errors.log.debug(f'Script: op={self.op} total={total} scripts={scripts}')
|
||||
|
||||
|
||||
class ScriptRunner:
|
||||
def __init__(self, name=''):
|
||||
self.name = name
|
||||
self.scripts = []
|
||||
self.selectable_scripts = []
|
||||
self.alwayson_scripts = []
|
||||
self.auto_processing_scripts = []
|
||||
self.titles = []
|
||||
self.infotext_fields = []
|
||||
self.paste_field_names = []
|
||||
self.script_load_ctr = 0
|
||||
self.is_img2img = False
|
||||
self.inputs = [None]
|
||||
self.time = 0
|
||||
|
||||
def add_script(self, script_class, path, is_img2img, is_control):
|
||||
try:
|
||||
script = script_class()
|
||||
script.filename = path
|
||||
script.is_txt2img = not is_img2img
|
||||
script.is_img2img = is_img2img
|
||||
if is_control: # this is messy but show is a legacy function that is not aware of control tab
|
||||
v1 = script.show(script.is_txt2img)
|
||||
v2 = script.show(script.is_img2img)
|
||||
if v1 == AlwaysVisible or v2 == AlwaysVisible:
|
||||
visibility = AlwaysVisible
|
||||
else:
|
||||
visibility = v1 or v2
|
||||
else:
|
||||
visibility = script.show(script.is_img2img)
|
||||
if visibility == AlwaysVisible:
|
||||
self.scripts.append(script)
|
||||
self.alwayson_scripts.append(script)
|
||||
script.alwayson = True
|
||||
elif visibility:
|
||||
self.scripts.append(script)
|
||||
self.selectable_scripts.append(script)
|
||||
except Exception as e:
|
||||
errors.log.error(f'Script initialize: {path} {e}')
|
||||
|
||||
def initialize_scripts(self, is_img2img=False, is_control=False):
|
||||
from modules import scripts_auto_postprocessing
|
||||
|
||||
self.scripts.clear()
|
||||
self.selectable_scripts.clear()
|
||||
self.alwayson_scripts.clear()
|
||||
self.titles.clear()
|
||||
self.infotext_fields.clear()
|
||||
self.paste_field_names.clear()
|
||||
self.script_load_ctr = 0
|
||||
self.is_img2img = is_img2img
|
||||
self.scripts.clear()
|
||||
self.alwayson_scripts.clear()
|
||||
self.selectable_scripts.clear()
|
||||
self.auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||
|
||||
sorted_scripts = sorted(scripts_data, key=lambda x: x.script_class().title().lower())
|
||||
for script_class, path, _basedir, _script_module in sorted_scripts:
|
||||
self.add_script(script_class, path, is_img2img, is_control)
|
||||
sorted_scripts = sorted(self.auto_processing_scripts, key=lambda x: x.script_class().title().lower())
|
||||
for script_class, path, _basedir, _script_module in sorted_scripts:
|
||||
self.add_script(script_class, path, is_img2img, is_control)
|
||||
|
||||
def prepare_ui(self):
|
||||
self.inputs = [None]
|
||||
|
||||
def setup_ui(self, parent='unknown', accordion=True):
|
||||
import modules.api.models as api_models
|
||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||
|
||||
inputs = []
|
||||
inputs_alwayson = [True]
|
||||
|
||||
def create_script_ui(script: Script, inputs, inputs_alwayson):
|
||||
script.parent = parent
|
||||
script.args_from = len(inputs)
|
||||
script.args_to = len(inputs)
|
||||
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
||||
if controls is None:
|
||||
return
|
||||
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
|
||||
api_args = []
|
||||
for control in controls:
|
||||
debug(f'Script control: parent={script.parent} script="{script.name}" label="{control.label}" type={control} id={control.elem_id}')
|
||||
if hasattr(gr.components, 'IOComponent'):
|
||||
if not isinstance(control, gr.components.IOComponent):
|
||||
errors.log.error(f'Invalid script control: "{script.filename}" control={control}')
|
||||
continue
|
||||
else:
|
||||
if not isinstance(control, gr.components.Component):
|
||||
errors.log.error(f'Invalid script control: "{script.filename}" control={control}')
|
||||
continue
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
arg_info = api_models.ScriptArg(label=control.label or "")
|
||||
for field in ("value", "minimum", "maximum", "step", "choices"):
|
||||
v = getattr(control, field, None)
|
||||
if v is not None:
|
||||
setattr(arg_info, field, v)
|
||||
api_args.append(arg_info)
|
||||
|
||||
script.api_info = api_models.ItemScript(
|
||||
name=script.name,
|
||||
is_img2img=script.is_img2img,
|
||||
is_alwayson=script.alwayson,
|
||||
args=api_args,
|
||||
)
|
||||
if script.infotext_fields is not None:
|
||||
self.infotext_fields += script.infotext_fields
|
||||
if script.paste_field_names is not None:
|
||||
self.paste_field_names += script.paste_field_names
|
||||
inputs += controls
|
||||
inputs_alwayson += [script.alwayson for _ in controls]
|
||||
script.args_to = len(inputs)
|
||||
|
||||
with gr.Row():
|
||||
dropdown = gr.Dropdown(label="Script", elem_id=f'{parent}_script_list', choices=["None"] + self.titles, value="None", type="index")
|
||||
inputs.insert(0, dropdown)
|
||||
|
||||
with gr.Row():
|
||||
for script in self.alwayson_scripts:
|
||||
if not script.standalone:
|
||||
continue
|
||||
if (self.name == 'control') and (script.name not in control_extensions) and (script.title() not in control_extensions):
|
||||
errors.log.debug(f'Script: fn="{script.filename}" type={self.name} skip')
|
||||
continue
|
||||
t0 = time.time()
|
||||
with gr.Group(elem_id=f'{parent}_script_{script.title().lower().replace(" ", "_")}', elem_classes=['group-extension']) as group:
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
script.group = group
|
||||
time_setup[script.title()] = time_setup.get(script.title(), 0) + (time.time()-t0)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Accordion(label="Extensions", elem_id=f'{parent}_script_alwayson') if accordion else gr.Group():
|
||||
for script in self.alwayson_scripts:
|
||||
if script.standalone:
|
||||
continue
|
||||
if (self.name == 'control') and (paths.extensions_dir in script.filename) and (script.title() not in control_extensions):
|
||||
errors.log.debug(f'Script: fn="{script.filename}" type={self.name} skip')
|
||||
continue
|
||||
t0 = time.time()
|
||||
with gr.Group(elem_id=f'{parent}_script_{script.title().lower().replace(" ", "_")}', elem_classes=['group-extension']) as group:
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
script.group = group
|
||||
time_setup[script.title()] = time_setup.get(script.title(), 0) + (time.time()-t0)
|
||||
|
||||
for script in self.selectable_scripts:
|
||||
if (self.name == 'control') and (paths.extensions_dir in script.filename) and (script.title() not in control_extensions):
|
||||
errors.log.debug(f'Script: fn="{script.filename}" type={self.name} skip')
|
||||
continue
|
||||
with gr.Group(elem_id=f'{parent}_script_{script.title().lower().replace(" ", "_")}', elem_classes=['group-scripts'], visible=False) as group:
|
||||
t0 = time.time()
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
time_setup[script.title()] = time_setup.get(script.title(), 0) + (time.time()-t0)
|
||||
script.group = group
|
||||
|
||||
def select_script(script_index):
|
||||
if script_index is None:
|
||||
return [gr.update(visible=False) for script in self.selectable_scripts]
|
||||
selected_script = self.selectable_scripts[script_index - 1] if script_index > 0 else None
|
||||
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
|
||||
|
||||
def init_field(title):
|
||||
if title == 'None': # called when an initial value is set from ui-config.json to show script's UI components
|
||||
return
|
||||
if title not in self.titles:
|
||||
errors.log.error(f'Script not found: {title}')
|
||||
return
|
||||
script_index = self.titles.index(title)
|
||||
self.selectable_scripts[script_index].group.visible = True
|
||||
|
||||
dropdown.init_field = init_field
|
||||
dropdown.change(fn=select_script, inputs=[dropdown], outputs=[script.group for script in self.selectable_scripts if script.group is not None])
|
||||
|
||||
def onload_script_visibility(params):
|
||||
title = params.get('Script', None)
|
||||
if title:
|
||||
title_index = self.titles.index(title)
|
||||
visibility = title_index == self.script_load_ctr
|
||||
self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
|
||||
return gr.update(visible=visibility)
|
||||
else:
|
||||
return gr.update(visible=False)
|
||||
|
||||
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
|
||||
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts if script.group is not None])
|
||||
return inputs
|
||||
|
||||
def run(self, p, *args):
|
||||
s = ScriptSummary('run')
|
||||
script_index = args[0] if len(args) > 0 else 0
|
||||
if script_index == 0:
|
||||
return None
|
||||
script = self.selectable_scripts[script_index-1]
|
||||
if script is None:
|
||||
return None
|
||||
if 'upscale' in script.title():
|
||||
if not hasattr(p, 'init_images') and p.task_args.get('image', None) is not None:
|
||||
p.init_images = p.task_args['image']
|
||||
parsed = []
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from'):
|
||||
parsed = p.per_script_args.get(script.title(), args[script.args_from:script.args_to])
|
||||
if hasattr(script, 'run'):
|
||||
processed = script.run(p, *parsed)
|
||||
else:
|
||||
processed = None
|
||||
errors.log.error(f'Script: file="{script.filename}" no run function defined')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
return processed
|
||||
|
||||
def after(self, p, processed, *args):
|
||||
s = ScriptSummary('after')
|
||||
script_index = args[0] if len(args) > 0 else 0
|
||||
if script_index == 0:
|
||||
return processed
|
||||
script = self.selectable_scripts[script_index-1]
|
||||
if script is None or not hasattr(script, 'after'):
|
||||
return processed
|
||||
parsed = []
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from'):
|
||||
parsed = p.per_script_args.get(script.title(), args[script.args_from:script.args_to])
|
||||
after_processed = script.after(p, processed, *parsed)
|
||||
if after_processed is not None:
|
||||
processed = after_processed
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
return processed
|
||||
|
||||
def before_process(self, p, **kwargs):
|
||||
s = ScriptSummary('before-process')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.before_process(p, *args, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f"Error running before process: {script.filename}")
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def process(self, p, **kwargs):
|
||||
s = ScriptSummary('process')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.process(p, *args, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script process: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def process_images(self, p, **kwargs):
|
||||
s = ScriptSummary('process_images')
|
||||
processed = None
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
_processed = script.process_images(p, *args, **kwargs)
|
||||
if _processed is not None:
|
||||
processed = _processed
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script process images: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
return processed
|
||||
|
||||
def before_process_batch(self, p, **kwargs):
|
||||
s = ScriptSummary('before-process-batch')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.before_process_batch(p, *args, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script before process batch: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def process_batch(self, p, **kwargs):
|
||||
s = ScriptSummary('process-batch')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.process_batch(p, *args, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script process batch: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess(self, p, processed):
|
||||
s = ScriptSummary('postprocess')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.postprocess(p, processed, *args)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script postprocess: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess_batch(self, p, images, **kwargs):
|
||||
s = ScriptSummary('postprocess-batch')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.postprocess_batch(p, *args, images=images, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script before postprocess batch: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
|
||||
s = ScriptSummary('postprocess-batch-list')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.postprocess_batch_list(p, pp, *args, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script before postprocess batch list: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||
s = ScriptSummary('postprocess-image')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from') and (script.args_to > 0) and (script.args_to >= script.args_from):
|
||||
args = p.per_script_args.get(script.title(), p.script_args[script.args_from:script.args_to])
|
||||
script.postprocess_image(p, pp, *args)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script postprocess image: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
s = ScriptSummary('before-component')
|
||||
for script in self.scripts:
|
||||
try:
|
||||
script.before_component(component, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script before component: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
s = ScriptSummary('after-component')
|
||||
for script in self.scripts:
|
||||
for elem_id, callback in script.on_after_component_elem_id:
|
||||
if elem_id == kwargs.get("elem_id"):
|
||||
try:
|
||||
callback(OnComponent(component=component))
|
||||
except Exception as e:
|
||||
errors.display(e, f"Running script before_component_elem_id: {script.filename}")
|
||||
try:
|
||||
script.after_component(component, **kwargs)
|
||||
except Exception as e:
|
||||
errors.display(e, f'Running script after component: {script.filename}')
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def reload_sources(self, cache):
|
||||
s = ScriptSummary('reload-sources')
|
||||
for si, script in list(enumerate(self.scripts)):
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from'):
|
||||
args_from = script.args_from
|
||||
args_to = script.args_to
|
||||
filename = script.filename
|
||||
module = cache.get(filename, None)
|
||||
if module is None:
|
||||
module = script_loading.load_module(script.filename)
|
||||
cache[filename] = module
|
||||
for script_class in module.__dict__.values():
|
||||
if type(script_class) == type and issubclass(script_class, Script):
|
||||
self.scripts[si] = script_class()
|
||||
self.scripts[si].filename = filename
|
||||
self.scripts[si].args_from = args_from
|
||||
self.scripts[si].args_to = args_to
|
||||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
|
||||
scripts_txt2img: ScriptRunner = None
|
||||
scripts_img2img: ScriptRunner = None
|
||||
scripts_control: ScriptRunner = None
|
||||
scripts_current: ScriptRunner = None
|
||||
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
|
||||
reload_scripts = load_scripts # compatibility alias
|
||||
|
||||
|
||||
def reload_script_body_only():
|
||||
cache = {}
|
||||
scripts_txt2img.reload_sources(cache)
|
||||
scripts_img2img.reload_sources(cache)
|
||||
scripts_control.reload_sources(cache)
|
||||
|
|
@ -72,8 +72,8 @@ class ScriptPostprocessingRunner:
|
|||
|
||||
def scripts_in_preferred_order(self):
|
||||
if self.scripts is None:
|
||||
import modules.scripts
|
||||
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
|
||||
import modules.scripts_manager
|
||||
self.initialize_scripts(modules.scripts_manager.postprocessing_scripts_data)
|
||||
scripts_order = shared.opts.postprocessing_operation_order
|
||||
|
||||
def script_score(name):
|
||||
|
|
|
|||
|
|
@ -306,7 +306,7 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='
|
|||
shared.sd_model = None
|
||||
try:
|
||||
if model_type in ['Stable Cascade']: # forced pipeline
|
||||
from modules.model_stablecascade import load_cascade_combined
|
||||
from pipelines.model_stablecascade import load_cascade_combined
|
||||
sd_model = load_cascade_combined(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = True
|
||||
elif model_type in ['InstaFlow']: # forced pipeline
|
||||
|
|
@ -315,77 +315,77 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='
|
|||
sd_model = pipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
|
||||
allow_post_quant = True
|
||||
elif model_type in ['SegMoE']: # forced pipeline
|
||||
from modules.segmoe.segmoe_model import SegMoEPipeline
|
||||
from pipelines.segmoe.segmoe_model import SegMoEPipeline
|
||||
sd_model = SegMoEPipeline(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
|
||||
sd_model = sd_model.pipe # segmoe pipe does its stuff in __init__ and __call__ is the original pipeline
|
||||
allow_post_quant = True
|
||||
shared_items.pipelines['SegMoE'] = SegMoEPipeline
|
||||
elif model_type in ['PixArt Sigma']: # forced pipeline
|
||||
from modules.model_pixart import load_pixart
|
||||
from pipelines.model_pixart import load_pixart
|
||||
sd_model = load_pixart(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['Sana']: # forced pipeline
|
||||
from modules.model_sana import load_sana
|
||||
from pipelines.model_sana import load_sana
|
||||
sd_model = load_sana(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['Lumina-Next']: # forced pipeline
|
||||
from modules.model_lumina import load_lumina
|
||||
from pipelines.model_lumina import load_lumina
|
||||
sd_model = load_lumina(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = True
|
||||
elif model_type in ['Kolors']: # forced pipeline
|
||||
from modules.model_kolors import load_kolors
|
||||
from pipelines.model_kolors import load_kolors
|
||||
sd_model = load_kolors(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = True
|
||||
elif model_type in ['AuraFlow']: # forced pipeline
|
||||
from modules.model_auraflow import load_auraflow
|
||||
from pipelines.model_auraflow import load_auraflow
|
||||
sd_model = load_auraflow(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = True
|
||||
elif model_type in ['FLUX']:
|
||||
from modules.model_flux import load_flux
|
||||
from pipelines.model_flux import load_flux
|
||||
sd_model = load_flux(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['FLEX']:
|
||||
from modules.model_flex import load_flex
|
||||
from pipelines.model_flex import load_flex
|
||||
sd_model = load_flex(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['Chroma']:
|
||||
from modules.model_chroma import load_chroma
|
||||
from pipelines.model_chroma import load_chroma
|
||||
sd_model = load_chroma(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['Lumina 2']:
|
||||
from modules.model_lumina import load_lumina2
|
||||
from pipelines.model_lumina import load_lumina2
|
||||
sd_model = load_lumina2(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['Stable Diffusion 3']:
|
||||
from modules.model_sd3 import load_sd3
|
||||
from pipelines.model_sd3 import load_sd3
|
||||
sd_model = load_sd3(checkpoint_info, cache_dir=shared.opts.diffusers_dir, config=diffusers_load_config.get('config', None))
|
||||
allow_post_quant = False
|
||||
elif model_type in ['CogView 3']: # forced pipeline
|
||||
from modules.model_cogview import load_cogview3
|
||||
from pipelines.model_cogview import load_cogview3
|
||||
sd_model = load_cogview3(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['CogView 4']: # forced pipeline
|
||||
from modules.model_cogview import load_cogview4
|
||||
from pipelines.model_cogview import load_cogview4
|
||||
sd_model = load_cogview4(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['Meissonic']: # forced pipeline
|
||||
from modules.model_meissonic import load_meissonic
|
||||
from pipelines.model_meissonic import load_meissonic
|
||||
sd_model = load_meissonic(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = True
|
||||
elif model_type in ['OmniGen2']: # forced pipeline
|
||||
from modules.model_omnigen2 import load_omnigen2
|
||||
from pipelines.model_omnigen2 import load_omnigen2
|
||||
sd_model = load_omnigen2(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['OmniGen']: # forced pipeline
|
||||
from modules.model_omnigen import load_omnigen
|
||||
from pipelines.model_omnigen import load_omnigen
|
||||
sd_model = load_omnigen(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['HiDream']:
|
||||
from modules.model_hidream import load_hidream
|
||||
from pipelines.model_hidream import load_hidream
|
||||
sd_model = load_hidream(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['Cosmos']:
|
||||
from modules.model_cosmos import load_cosmos_t2i
|
||||
from pipelines.model_cosmos import load_cosmos_t2i
|
||||
sd_model = load_cosmos_t2i(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
except Exception as e:
|
||||
|
|
@ -480,7 +480,7 @@ def load_diffuser_file(model_type, pipeline, checkpoint_info, diffusers_load_con
|
|||
shared.log.debug(f'Load {op}: config="{model_config}"')
|
||||
diffusers_load_config['config'] = model_config
|
||||
if model_type.startswith('Stable Diffusion 3'):
|
||||
from modules.model_sd3 import load_sd3
|
||||
from pipelines.model_sd3 import load_sd3
|
||||
sd_model = load_sd3(checkpoint_info=checkpoint_info, cache_dir=shared.opts.diffusers_dir, config=diffusers_load_config.get('config', None))
|
||||
elif hasattr(pipeline, 'from_single_file'):
|
||||
diffusers.loaders.single_file_utils.CHECKPOINT_KEY_NAMES["clip"] = "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight" # patch for diffusers==0.28.0
|
||||
|
|
@ -1069,7 +1069,7 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model',
|
|||
unload_model_weights(op=op)
|
||||
sd_model = None
|
||||
timer = Timer()
|
||||
# TODO model loader: implement model in-memory caching
|
||||
# TODO model load: implement model in-memory caching
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer) if not shared.native else None
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
timer.record("config")
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def load_unet(model):
|
|||
if shared.opts.sd_unet == loaded_unet or shared.opts.sd_unet in failed_unet:
|
||||
pass
|
||||
elif "StableCascade" in model.__class__.__name__:
|
||||
from modules.model_stablecascade import load_prior
|
||||
from pipelines.model_stablecascade import load_prior
|
||||
prior_unet, prior_text_encoder = load_prior(unet_dict[shared.opts.sd_unet], config_file=config_file)
|
||||
loaded_unet = shared.opts.sd_unet
|
||||
if prior_unet is not None:
|
||||
|
|
@ -38,7 +38,7 @@ def load_unet(model):
|
|||
loaded_unet = shared.opts.sd_unet
|
||||
sd_models.load_diffuser() # TODO model load: force-reloading entire model as loading transformers only leads to massive memory usage
|
||||
"""
|
||||
from modules.model_flux import load_transformer
|
||||
from pipelines.model_flux import load_transformer
|
||||
transformer = load_transformer(unet_dict[shared.opts.sd_unet])
|
||||
if transformer is not None:
|
||||
model.transformer = None
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from .dequantizer import dequantizer_dict
|
|||
from .forward import get_forward_func
|
||||
|
||||
|
||||
def sdnq_quantize_layer(layer, weights_dtype="int8", torch_dtype=None, group_size=0, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, dequantize_fp32=False, quantization_device=None, return_device=None, param_name=None):
|
||||
def sdnq_quantize_layer(layer, weights_dtype="int8", torch_dtype=None, group_size=0, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, dequantize_fp32=False, quantization_device=None, return_device=None, param_name=None): # pylint: disable=unused-argument
|
||||
layer_class_name = layer.__class__.__name__
|
||||
if layer_class_name in allowed_types:
|
||||
is_conv_type = False
|
||||
|
|
|
|||
|
|
@ -61,8 +61,8 @@ onnx_pipelines = {
|
|||
|
||||
|
||||
def postprocessing_scripts():
|
||||
import modules.scripts
|
||||
return modules.scripts.scripts_postproc.scripts
|
||||
import modules.scripts_manager
|
||||
return modules.scripts_manager.scripts_postproc.scripts
|
||||
|
||||
|
||||
def sd_vae_items():
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from modules import shared, processing, scripts
|
||||
from modules import shared, processing, scripts_manager
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.ui_common import plaintext_to_html
|
||||
|
||||
|
|
@ -90,13 +90,13 @@ def txt2img(id_task, state,
|
|||
hdr_boundary=hdr_boundary, hdr_threshold=hdr_threshold, hdr_maximize=hdr_maximize, hdr_max_center=hdr_max_center, hdr_max_boundry=hdr_max_boundry, hdr_color_picker=hdr_color_picker, hdr_tint_ratio=hdr_tint_ratio,
|
||||
override_settings=override_settings,
|
||||
)
|
||||
p.scripts = scripts.scripts_txt2img
|
||||
p.scripts = scripts_manager.scripts_txt2img
|
||||
p.script_args = args
|
||||
p.state = state
|
||||
processed: processing.Processed = scripts.scripts_txt2img.run(p, *args)
|
||||
processed: processing.Processed = scripts_manager.scripts_txt2img.run(p, *args)
|
||||
if processed is None:
|
||||
processed = processing.process_images(p)
|
||||
processed = scripts.scripts_txt2img.after(p, processed, *args)
|
||||
processed = scripts_manager.scripts_txt2img.after(p, processed, *args)
|
||||
p.close()
|
||||
if processed is None:
|
||||
return [], '', '', 'Error: processing failed'
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import mimetypes
|
|||
import gradio as gr
|
||||
import gradio.routes
|
||||
import gradio.utils
|
||||
from modules import errors, timer, gr_hijack, shared, script_callbacks, ui_common, ui_symbols, ui_javascript, ui_sections, generation_parameters_copypaste, call_queue, scripts
|
||||
from modules import errors, timer, gr_hijack, shared, script_callbacks, ui_common, ui_symbols, ui_javascript, ui_sections, generation_parameters_copypaste, call_queue, scripts_manager
|
||||
from modules.paths import script_path, data_path # pylint: disable=unused-import
|
||||
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ def create_ui(startup_timer = None):
|
|||
timer.startup = timer.Timer()
|
||||
ui_javascript.reload_javascript()
|
||||
generation_parameters_copypaste.reset()
|
||||
scripts.scripts_current = None
|
||||
scripts_manager.scripts_current = None
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
from modules import ui_txt2img
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import platform
|
|||
import subprocess
|
||||
from functools import reduce
|
||||
import gradio as gr
|
||||
from modules import call_queue, shared, prompt_parser, ui_sections, ui_symbols, ui_components, generation_parameters_copypaste, images, scripts, script_callbacks, infotext
|
||||
from modules import call_queue, shared, prompt_parser, ui_sections, ui_symbols, ui_components, generation_parameters_copypaste, images, scripts_manager, script_callbacks, infotext
|
||||
|
||||
|
||||
folder_symbol = ui_symbols.folder
|
||||
|
|
@ -304,11 +304,11 @@ def create_output_panel(tabname, preview=True, prompt=None, height=None, transfe
|
|||
)
|
||||
|
||||
if tabname == "txt2img":
|
||||
paste_field_names = scripts.scripts_txt2img.paste_field_names
|
||||
paste_field_names = scripts_manager.scripts_txt2img.paste_field_names
|
||||
elif tabname == "img2img":
|
||||
paste_field_names = scripts.scripts_img2img.paste_field_names
|
||||
paste_field_names = scripts_manager.scripts_img2img.paste_field_names
|
||||
elif tabname == "control":
|
||||
paste_field_names = scripts.scripts_control.paste_field_names
|
||||
paste_field_names = scripts_manager.scripts_control.paste_field_names
|
||||
else:
|
||||
paste_field_names = []
|
||||
debug(f'Paste field: tab={tabname} fields={paste_field_names}')
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from modules.control.units import xs # vislearn ControlNet-XS
|
|||
from modules.control.units import lite # vislearn ControlNet-XS
|
||||
from modules.control.units import t2iadapter # TencentARC T2I-Adapter
|
||||
from modules.control.units import reference # reference pipeline
|
||||
from modules import errors, shared, progress, ui_components, ui_symbols, ui_common, ui_sections, generation_parameters_copypaste, call_queue, scripts, masking, images, processing_vae, timer # pylint: disable=ungrouped-imports
|
||||
from modules import errors, shared, progress, ui_components, ui_symbols, ui_common, ui_sections, generation_parameters_copypaste, call_queue, scripts_manager, masking, images, processing_vae, timer # pylint: disable=ungrouped-imports
|
||||
from modules import ui_control_helpers as helpers
|
||||
|
||||
|
||||
|
|
@ -539,7 +539,7 @@ def create_ui(_blocks: gr.Blocks=None):
|
|||
setting.change(fn=processors.update_settings, inputs=settings, outputs=[])
|
||||
|
||||
with gr.Row(elem_id="control_script_container"):
|
||||
input_script_args = scripts.scripts_current.setup_ui(parent='control', accordion=True)
|
||||
input_script_args = scripts_manager.scripts_current.setup_ui(parent='control', accordion=True)
|
||||
|
||||
# handlers
|
||||
for btn in input_buttons:
|
||||
|
|
@ -709,7 +709,7 @@ def create_ui(_blocks: gr.Blocks=None):
|
|||
# hidden
|
||||
(seed_resize_from_w, "Seed resize from-1"),
|
||||
(seed_resize_from_h, "Seed resize from-2"),
|
||||
*scripts.scripts_control.infotext_fields
|
||||
*scripts_manager.scripts_control.infotext_fields
|
||||
]
|
||||
generation_parameters_copypaste.add_paste_fields("control", input_image, paste_fields, override_settings)
|
||||
bindings = generation_parameters_copypaste.ParamBinding(paste_button=btn_paste, tabname="control", source_text_component=prompt, source_image_component=output_gallery)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from modules import shared, scripts, masking, video # pylint: disable=ungrouped-imports
|
||||
from modules import shared, scripts_manager, masking, video # pylint: disable=ungrouped-imports
|
||||
|
||||
|
||||
gr_height = None
|
||||
|
|
@ -43,8 +43,8 @@ def initialize():
|
|||
os.makedirs(masking.cache_dir, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
scripts.scripts_current = scripts.scripts_control
|
||||
scripts.scripts_control.initialize_scripts(is_img2img=False, is_control=True)
|
||||
scripts_manager.scripts_current = scripts_manager.scripts_control
|
||||
scripts_manager.scripts_control.initialize_scripts(is_img2img=False, is_control=True)
|
||||
|
||||
|
||||
def interrogate():
|
||||
|
|
|
|||
|
|
@ -35,8 +35,8 @@ def process_interrogate(mode, ii_input_files, ii_input_dir, ii_output_dir, *ii_s
|
|||
def create_ui():
|
||||
shared.log.debug('UI initialize: img2img')
|
||||
import modules.img2img # pylint: disable=redefined-outer-name
|
||||
modules.scripts.scripts_current = modules.scripts.scripts_img2img
|
||||
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True, is_control=False)
|
||||
modules.scripts_manager.scripts_current = modules.scripts_manager.scripts_img2img
|
||||
modules.scripts_manager.scripts_img2img.initialize_scripts(is_img2img=True, is_control=False)
|
||||
with gr.Blocks(analytics_enabled=False) as _img2img_interface:
|
||||
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, img2img_submit, img2img_reprocess, img2img_paste, img2img_extra_networks_button, img2img_token_counter, img2img_token_button, img2img_negative_token_counter, img2img_negative_token_button = ui_sections.create_toprow(is_img2img=True, id_part="img2img")
|
||||
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
|
||||
|
|
@ -156,7 +156,7 @@ def create_ui():
|
|||
override_settings = ui_common.create_override_inputs('img2img')
|
||||
|
||||
with gr.Group(elem_id="img2img_script_container"):
|
||||
img2img_script_inputs = modules.scripts.scripts_img2img.setup_ui(parent='img2img', accordion=True)
|
||||
img2img_script_inputs = modules.scripts_manager.scripts_img2img.setup_ui(parent='img2img', accordion=True)
|
||||
|
||||
img2img_gallery, img2img_generation_info, img2img_html_info, _img2img_html_info_formatted, img2img_html_log = ui_common.create_output_panel("img2img", prompt=img2img_prompt)
|
||||
|
||||
|
|
@ -304,7 +304,7 @@ def create_ui():
|
|||
# hidden
|
||||
(seed_resize_from_w, "Seed resize from-1"),
|
||||
(seed_resize_from_h, "Seed resize from-2"),
|
||||
*modules.scripts.scripts_img2img.infotext_fields
|
||||
*modules.scripts_manager.scripts_img2img.infotext_fields
|
||||
]
|
||||
generation_parameters_copypaste.add_paste_fields("img2img", img_init, img2img_paste_fields, override_settings)
|
||||
generation_parameters_copypaste.add_paste_fields("sketch", img_sketch, img2img_paste_fields, override_settings)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import gradio.routes
|
|||
import gradio.utils
|
||||
from modules import shared, theme
|
||||
from modules.paths import script_path, data_path
|
||||
import modules.scripts
|
||||
import modules.scripts_manager
|
||||
|
||||
|
||||
def webpath(fn):
|
||||
|
|
@ -22,12 +22,12 @@ def html_head():
|
|||
script_js = os.path.join(script_path, "javascript", js)
|
||||
head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
|
||||
added = []
|
||||
for script in modules.scripts.list_scripts("javascript", ".js"):
|
||||
for script in modules.scripts_manager.list_scripts("javascript", ".js"):
|
||||
if script.filename in main or script.filename in skip:
|
||||
continue
|
||||
head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
|
||||
added.append(script.path)
|
||||
for script in modules.scripts.list_scripts("javascript", ".mjs"):
|
||||
for script in modules.scripts_manager.list_scripts("javascript", ".mjs"):
|
||||
head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
|
||||
added.append(script.path)
|
||||
added = [a.replace(script_path, '').replace('\\', '/') for a in added]
|
||||
|
|
@ -59,7 +59,7 @@ def html_css(css: str):
|
|||
head = ''
|
||||
if css is not None:
|
||||
head += stylesheet(os.path.join(script_path, 'javascript', css))
|
||||
for cssfile in modules.scripts.list_files_with_name("style.css"):
|
||||
for cssfile in modules.scripts_manager.list_files_with_name("style.css"):
|
||||
if not os.path.isfile(cssfile):
|
||||
continue
|
||||
head += stylesheet(cssfile)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import gradio as gr
|
||||
from modules import scripts, shared, ui_common, postprocessing, call_queue, generation_parameters_copypaste
|
||||
from modules import scripts_manager, shared, ui_common, postprocessing, call_queue, generation_parameters_copypaste
|
||||
|
||||
|
||||
def submit_info(image):
|
||||
|
|
@ -32,7 +32,7 @@ def create_ui():
|
|||
with gr.Row():
|
||||
save_output = gr.Checkbox(label='Save output', value=True, elem_id="extras_save_output")
|
||||
|
||||
script_inputs = scripts.scripts_postproc.setup_ui()
|
||||
script_inputs = scripts_manager.scripts_postproc.setup_ui()
|
||||
with gr.Column():
|
||||
id_part = 'extras'
|
||||
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
|
||||
|
|
@ -56,7 +56,7 @@ def create_ui():
|
|||
tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
|
||||
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
|
||||
extras_image.change(fn=submit_info, inputs=[extras_image], outputs=[html_info_formatted, exif_info, gen_info])
|
||||
extras_image.change(fn=scripts.scripts_postproc.image_changed, inputs=[], outputs=[])
|
||||
extras_image.change(fn=scripts_manager.scripts_postproc.image_changed, inputs=[], outputs=[])
|
||||
submit.click(
|
||||
_js="submit_postprocessing",
|
||||
fn=call_queue.wrap_gradio_gpu_call(submit_process, extra_outputs=[None, ''], name='Postprocess'),
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ def calc_resolution_hires(width, height, hr_scale, hr_resize_x, hr_resize_y, hr_
|
|||
def create_ui():
|
||||
shared.log.debug('UI initialize: txt2img')
|
||||
import modules.txt2img # pylint: disable=redefined-outer-name
|
||||
modules.scripts.scripts_current = modules.scripts.scripts_txt2img
|
||||
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False, is_control=False)
|
||||
modules.scripts_manager.scripts_current = modules.scripts_manager.scripts_txt2img
|
||||
modules.scripts_manager.scripts_txt2img.initialize_scripts(is_img2img=False, is_control=False)
|
||||
with gr.Blocks(analytics_enabled=False) as _txt2img_interface:
|
||||
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, txt2img_submit, txt2img_reprocess, txt2img_paste, txt2img_extra_networks_button, txt2img_token_counter, txt2img_token_button, txt2img_negative_token_counter, txt2img_negative_token_button = ui_sections.create_toprow(is_img2img=False, id_part="txt2img")
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ def create_ui():
|
|||
state = gr.Textbox(value='', visible=False)
|
||||
|
||||
with gr.Group(elem_id="txt2img_script_container"):
|
||||
txt2img_script_inputs = modules.scripts.scripts_txt2img.setup_ui(parent='txt2img', accordion=True)
|
||||
txt2img_script_inputs = modules.scripts_manager.scripts_txt2img.setup_ui(parent='txt2img', accordion=True)
|
||||
|
||||
txt2img_gallery, txt2img_generation_info, txt2img_html_info, _txt2img_html_info_formatted, txt2img_html_log = ui_common.create_output_panel("txt2img", preview=True, prompt=txt2img_prompt)
|
||||
ui_common.connect_reuse_seed(seed, reuse_seed, txt2img_generation_info, is_subseed=False)
|
||||
|
|
@ -158,7 +158,7 @@ def create_ui():
|
|||
# hidden
|
||||
(seed_resize_from_w, "Seed resize from-1"),
|
||||
(seed_resize_from_h, "Seed resize from-2"),
|
||||
*modules.scripts.scripts_txt2img.infotext_fields
|
||||
*modules.scripts_manager.scripts_txt2img.infotext_fields
|
||||
]
|
||||
generation_parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
|
||||
txt2img_bindings = generation_parameters_copypaste.ParamBinding(paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@
|
|||
"localize": "node cli/localize.js",
|
||||
"ruff": ". venv/bin/activate && ruff check",
|
||||
"eslint": "eslint javascript/ extensions-builtin/sdnext-modernui/javascript/",
|
||||
"pylint": ". venv/bin/activate && pylint *.py modules/ extensions-builtin/ | grep -v '^*'",
|
||||
"pylint": ". venv/bin/activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/ | grep -v '^*'",
|
||||
"lint": "npm run eslint && npm run ruff && npm run pylint",
|
||||
"test": "cli/test.sh"
|
||||
},
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ def load_transformer(file_path): # triggered by opts.sd_unet change
|
|||
if transformer is not None:
|
||||
return transformer
|
||||
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=none dtype={devices.dtype}')
|
||||
# TODO chroma transformer from-single-file with quant
|
||||
# TODO model load: chroma transformer from-single-file with quant
|
||||
# shared.log.warning('Load module: type=UNet/Transformer does not support load-time quantization')
|
||||
# transformer = diffusers.ChromaTransformer2DModel.from_single_file(file_path, **diffusers_load_config)
|
||||
if transformer is None:
|
||||
|
|
@ -261,8 +261,8 @@ def load_chroma(checkpoint_info, diffusers_load_config): # triggered by opts.sd_
|
|||
if vae is not None:
|
||||
kwargs['vae'] = vae
|
||||
|
||||
# TODO add ChromaFillPipeline, ChromaControlPipeline, ChromaImg2ImgPipeline etc when available
|
||||
# TODO Chroma will support inpainting *after* its training has finished: https://huggingface.co/lodestones/Chroma/discussions/28#6826dd2ed86f53ff983add5c
|
||||
# TODO model load: add ChromaFillPipeline, ChromaControlPipeline, ChromaImg2ImgPipeline etc when available
|
||||
# Chroma will support inpainting *after* its training has finished: https://huggingface.co/lodestones/Chroma/discussions/28#6826dd2ed86f53ff983add5c
|
||||
cls = diffusers.ChromaPipeline
|
||||
shared.log.debug(f'Load model: type=Chroma cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}')
|
||||
for c in kwargs:
|
||||
|
|
@ -76,6 +76,6 @@ def load_cogview4(checkpoint_info, diffusers_load_config={}):
|
|||
if shared.opts.diffusers_eval:
|
||||
pipe.text_encoder.eval()
|
||||
pipe.transformer.eval()
|
||||
pipe.enable_model_cpu_offload() # TODO cogview4: balanced offload does not work for GlmModel
|
||||
pipe.enable_model_cpu_offload() # TODO model fix: cogview4: balanced offload does not work for GlmModel
|
||||
devices.torch_gc()
|
||||
return pipe
|
||||
|
|
@ -24,9 +24,6 @@ def load_transformer(repo_id, diffusers_load_config={}):
|
|||
elif fn is not None and 'safetensors' in fn.lower():
|
||||
shared.log.debug(f'Load model: type=FLEX transformer="{repo_id}" quant="{model_quant.get_quant(repo_id)}" args={load_args}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(fn, cache_dir=shared.opts.hfcache_dir, **load_args)
|
||||
# elif model_quant.check_nunchaku('Model'):
|
||||
# shared.log.error(f'Load model: type=HiDream transformer="{repo_id}" quant="Nunchaku" unsupported')
|
||||
# transformer = None
|
||||
else:
|
||||
shared.log.debug(f'Load model: type=FLEX transformer="{repo_id}" quant="{model_quant.get_quant_type(quant_args)}" args={load_args}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_pretrained(
|
||||
|
|
@ -71,7 +68,7 @@ def load_flex(checkpoint_info, diffusers_load_config={}):
|
|||
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model')
|
||||
shared.log.debug(f'Load model: type=FLEX model="{checkpoint_info.name}" repo="{repo_id}" offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}')
|
||||
|
||||
from modules.flex2 import Flex2Pipeline
|
||||
from pipelines.flex2 import Flex2Pipeline
|
||||
pipe = Flex2Pipeline.from_pretrained(
|
||||
repo_id,
|
||||
# custom_pipeline=repo_id,
|
||||
|
|
@ -175,7 +175,7 @@ def load_transformer(file_path): # triggered by opts.sd_unet change
|
|||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif 'nf4' in quant: # TODO flux: loader for civitai nf4 models
|
||||
from modules.model_flux_nf4 import load_flux_nf4
|
||||
from pipelines.model_flux_nf4 import load_flux_nf4
|
||||
_transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
|
|
@ -183,7 +183,7 @@ def load_transformer(file_path): # triggered by opts.sd_unet change
|
|||
quant_args = model_quant.create_bnb_config({})
|
||||
if quant_args:
|
||||
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}')
|
||||
from modules.model_flux_nf4 import load_flux_nf4
|
||||
from pipelines.model_flux_nf4 import load_flux_nf4
|
||||
transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False)
|
||||
if transformer is not None:
|
||||
return transformer
|
||||
|
|
@ -275,7 +275,7 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch
|
|||
# load quantized components if any
|
||||
if prequantized == 'nf4':
|
||||
try:
|
||||
from modules.model_flux_nf4 import load_flux_nf4
|
||||
from pipelines.model_flux_nf4 import load_flux_nf4
|
||||
_transformer, _text_encoder = load_flux_nf4(checkpoint_info)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
|
|
@ -102,7 +102,7 @@ def load_hidream(checkpoint_info, diffusers_load_config={}):
|
|||
if 'I1' in repo_id:
|
||||
cls = diffusers.HiDreamImagePipeline
|
||||
elif 'E1' in repo_id:
|
||||
from modules.hidream.pipeline_hidream_image_editing import HiDreamImageEditingPipeline
|
||||
from pipelines.hidream.pipeline_hidream_image_editing import HiDreamImageEditingPipeline
|
||||
cls = HiDreamImageEditingPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["hidream-e1"] = diffusers.HiDreamImagePipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["hidream-e1"] = HiDreamImageEditingPipeline
|
||||
|
|
@ -4,11 +4,11 @@ import diffusers
|
|||
|
||||
def load_meissonic(checkpoint_info, diffusers_load_config={}):
|
||||
from modules import shared, devices, modelloader, sd_models, shared_items
|
||||
from modules.meissonic.transformer import Transformer2DModel as TransformerMeissonic
|
||||
from modules.meissonic.scheduler import Scheduler as MeissonicScheduler
|
||||
from modules.meissonic.pipeline import Pipeline as PipelineMeissonic
|
||||
from modules.meissonic.pipeline_img2img import Img2ImgPipeline as PipelineMeissonicImg2Img
|
||||
from modules.meissonic.pipeline_inpaint import InpaintPipeline as PipelineMeissonicInpaint
|
||||
from pipelines.meissonic.transformer import Transformer2DModel as TransformerMeissonic
|
||||
from pipelines.meissonic.scheduler import Scheduler as MeissonicScheduler
|
||||
from pipelines.meissonic.pipeline import Pipeline as PipelineMeissonic
|
||||
from pipelines.meissonic.pipeline_img2img import Img2ImgPipeline as PipelineMeissonicImg2Img
|
||||
from pipelines.meissonic.pipeline_inpaint import InpaintPipeline as PipelineMeissonicInpaint
|
||||
shared_items.pipelines['Meissonic'] = PipelineMeissonic
|
||||
|
||||
modelloader.hf_login()
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import diffusers
|
||||
from modules import errors, shared, devices, sd_models, model_quant
|
||||
from modules import shared, devices, sd_models, model_quant
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
|
@ -7,7 +7,7 @@ debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None el
|
|||
def load_omnigen2(checkpoint_info, diffusers_load_config={}): # pylint: disable=unused-argument
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info.name)
|
||||
|
||||
from modules.omnigen2 import OmniGen2Pipeline, OmniGen2Transformer2DModel, Qwen2_5_VLForConditionalGeneration
|
||||
from pipelines.omnigen2 import OmniGen2Pipeline, OmniGen2Transformer2DModel, Qwen2_5_VLForConditionalGeneration
|
||||
import diffusers
|
||||
from diffusers import pipelines
|
||||
diffusers.OmniGen2Pipeline = OmniGen2Pipeline # monkey-pathch
|
||||
|
|
@ -2,7 +2,7 @@ import time
|
|||
import gradio as gr
|
||||
import transformers
|
||||
import diffusers
|
||||
from modules import scripts, processing, shared, images, devices, sd_models, sd_checkpoint, model_quant, timer, sd_hijack_te
|
||||
from modules import scripts_manager, processing, shared, images, devices, sd_models, sd_checkpoint, model_quant, timer, sd_hijack_te
|
||||
|
||||
|
||||
repo_id = 'rhymes-ai/Allegro'
|
||||
|
|
@ -19,7 +19,7 @@ def hijack_decode(*args, **kwargs):
|
|||
return res
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
class Script(scripts_manager.Script):
|
||||
def title(self):
|
||||
return 'Video: Allegro (Legacy)'
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
import gradio as gr
|
||||
import diffusers
|
||||
from safetensors.torch import load_file
|
||||
from modules import scripts, processing, shared, devices, sd_models
|
||||
from modules import scripts_manager, processing, shared, devices, sd_models
|
||||
|
||||
|
||||
# config
|
||||
|
|
@ -189,12 +189,12 @@ def set_free_noise(frames):
|
|||
shared.sd_model.enable_free_noise(context_length=context_length, context_stride=context_stride)
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
class Script(scripts_manager.Script):
|
||||
def title(self):
|
||||
return 'Video: AnimateDiff'
|
||||
|
||||
def show(self, is_img2img):
|
||||
# return scripts.AlwaysVisible if shared.native else False
|
||||
# return scripts_manager.AlwaysVisible if shared.native else False
|
||||
return not is_img2img
|
||||
|
||||
|
||||
|
|
@ -231,7 +231,7 @@ class Script(scripts.Script):
|
|||
lora = LORAS[lora_index]
|
||||
set_adapter(adapter)
|
||||
if motion_adapter is None:
|
||||
return
|
||||
return None
|
||||
set_scheduler(p, adapter, override_scheduler)
|
||||
set_lora(p, lora, strength)
|
||||
set_free_init(fi_method, fi_iters, fi_order, fi_spatial, fi_temporal)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import gradio as gr
|
||||
from modules import scripts, processing, shared, sd_models
|
||||
from modules import scripts_manager, processing, shared, sd_models
|
||||
|
||||
|
||||
registered = False
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
class Script(scripts_manager.Script):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.orig_pipe = None
|
||||
|
|
@ -71,6 +71,7 @@ class Script(scripts.Script):
|
|||
shared.log.info(f'APG apply: guidance={p.cfg_scale} momentum={apg.momentum} eta={apg.eta} threshold={apg.threshold} class={shared.sd_model.__class__.__name__}')
|
||||
p.extra_generation_params["APG"] = f'ETA={apg.eta} Momentum={apg.momentum} Threshold={apg.threshold}'
|
||||
# processed = processing.process_images(p)
|
||||
return None
|
||||
|
||||
def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, eta, momentum, threshold): # pylint: disable=arguments-differ, unused-argument
|
||||
from modules import apg
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import gradio as gr
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from modules import shared, scripts, processing, masking
|
||||
from modules import shared, scripts_manager, processing, masking
|
||||
|
||||
"""
|
||||
Automatic Color Inpaint Script for SD.NEXT - SD & SDXL Support
|
||||
|
|
@ -28,7 +28,7 @@ img2img = True
|
|||
|
||||
### Script definition
|
||||
|
||||
class Script(scripts.Script):
|
||||
class Script(scripts_manager.Script):
|
||||
def title(self):
|
||||
return title
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import gradio as gr
|
||||
from modules import scripts, processing, shared, sd_models
|
||||
from modules import scripts_manager, processing, shared, sd_models
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
class Script(scripts_manager.Script):
|
||||
def title(self):
|
||||
return 'BLIP Diffusion: Controllable Generation and Editing'
|
||||
|
||||
|
|
|
|||
|
|
@ -13,14 +13,14 @@ import torch
|
|||
from torchvision import transforms
|
||||
import diffusers
|
||||
import numpy as np
|
||||
from modules import scripts, shared, devices, errors, sd_models, processing
|
||||
from modules import scripts_manager, shared, devices, errors, sd_models, processing
|
||||
from modules.processing_callbacks import diffusers_callback, set_callbacks_p
|
||||
|
||||
|
||||
debug = (os.environ.get('SD_LOAD_DEBUG', None) is not None) or (os.environ.get('SD_PROCESS_DEBUG', None) is not None)
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
class Script(scripts_manager.Script):
|
||||
def title(self):
|
||||
return 'Video: CogVideoX (Legacy)'
|
||||
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ ported to modules/consistory
|
|||
import time
|
||||
import gradio as gr
|
||||
import diffusers
|
||||
from modules import scripts, devices, errors, processing, shared, sd_models, sd_samplers
|
||||
from modules import scripts_manager, devices, errors, processing, shared, sd_models, sd_samplers
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
class Script(scripts_manager.Script):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.anchor_cache_first_stage = None
|
||||
|
|
@ -66,7 +66,7 @@ class Script(scripts.Script):
|
|||
|
||||
def create_model(self):
|
||||
diffusers.models.embeddings.PositionNet = diffusers.models.embeddings.GLIGENTextBoundingboxProjection # patch as renamed in https://github.com/huggingface/diffusers/pull/6244/files
|
||||
import modules.consistory as cs
|
||||
import scripts.consistory as cs
|
||||
if shared.sd_model.__class__.__name__ != 'ConsistoryExtendAttnSDXLPipeline':
|
||||
shared.log.debug('ConsiStory init')
|
||||
t0 = time.time()
|
||||
|
|
@ -128,7 +128,7 @@ class Script(scripts.Script):
|
|||
return concepts, anchors, prompts, alpha, steps, seed
|
||||
|
||||
def create_anchors(self, anchors, concepts, seed, steps, dropout, same, queries, sdsa, injection, alpha):
|
||||
import modules.consistory as cs
|
||||
import scripts.consistory as cs
|
||||
t0 = time.time()
|
||||
if len(anchors) == 0:
|
||||
shared.log.warning('ConsiStory: no anchors')
|
||||
|
|
@ -159,7 +159,7 @@ class Script(scripts.Script):
|
|||
return images
|
||||
|
||||
def create_extra(self, prompt, concepts, seed, steps, dropout, same, queries, sdsa, injection, alpha):
|
||||
import modules.consistory as cs
|
||||
import scripts.consistory as cs
|
||||
t0 = time.time()
|
||||
images = []
|
||||
shared.log.debug(f'ConsiStory extra: concepts={concepts} prompt="{prompt}"')
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@
|
|||
|
||||
import gradio as gr
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from modules import shared, scripts, processing, processing_helpers, sd_models, devices
|
||||
from modules import shared, scripts_manager, processing, processing_helpers, sd_models, devices
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
class Script(scripts_manager.Script):
|
||||
def title(self):
|
||||
return 'Ctrl-X: Controlling Structure and Appearance'
|
||||
|
||||
|
|
@ -44,9 +44,9 @@ class Script(scripts.Script):
|
|||
return None
|
||||
|
||||
import yaml
|
||||
from modules.ctrlx import CtrlXStableDiffusionXLPipeline
|
||||
from modules.ctrlx.sdxl import get_control_config, register_control
|
||||
from modules.ctrlx.utils import get_self_recurrence_schedule
|
||||
from scripts.ctrlx import CtrlXStableDiffusionXLPipeline
|
||||
from scripts.ctrlx.sdxl import get_control_config, register_control
|
||||
from scripts.ctrlx.utils import get_self_recurrence_schedule
|
||||
|
||||
orig_prompt_attention = shared.opts.prompt_attention
|
||||
shared.opts.data['prompt_attention'] = 'fixed'
|
||||
|
|
@ -1,8 +1,7 @@
|
|||
import copy
|
||||
import ast
|
||||
import gradio as gr
|
||||
import modules.scripts as scripts
|
||||
|
||||
from modules import scripts_manager
|
||||
from modules.processing import Processed
|
||||
from modules.shared import opts, cmd_opts, state # pylint: disable=unused-import
|
||||
|
||||
|
|
@ -28,14 +27,15 @@ def exec_with_return(code, module):
|
|||
last_ast = copy.deepcopy(code_ast)
|
||||
last_ast.body = code_ast.body[-1:]
|
||||
|
||||
exec(compile(init_ast, "<ast>", "exec"), module.__dict__)
|
||||
exec(compile(init_ast, "<ast>", "exec"), module.__dict__) # pylint: disable=exec-used
|
||||
if type(last_ast.body[0]) == ast.Expr:
|
||||
return eval(compile(convertExpr2Expression(last_ast.body[0]), "<ast>", "eval"), module.__dict__)
|
||||
return eval(compile(convertExpr2Expression(last_ast.body[0]), "<ast>", "eval"), module.__dict__) # pylint: disable=eval-used
|
||||
else:
|
||||
exec(compile(last_ast, "<ast>", "exec"), module.__dict__)
|
||||
exec(compile(last_ast, "<ast>", "exec"), module.__dict__) # pylint: disable=exec-used
|
||||
return None
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
class Script(scripts_manager.Script):
|
||||
|
||||
def title(self):
|
||||
return "Custom code"
|
||||
|
|
@ -60,7 +60,7 @@ return process_images(p)
|
|||
|
||||
return [code, indent_level]
|
||||
|
||||
def run(self, p, code, indent_level):
|
||||
def run(self, p, code, indent_level): # pylint: disable=arguments-differ
|
||||
assert cmd_opts.allow_code, '--allow-code option must be enabled'
|
||||
|
||||
display_result_data = [[], -1, ""]
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
|
|||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from modules import scripts, processing, shared, sd_models, devices
|
||||
from modules import scripts_manager, processing, shared, sd_models, devices
|
||||
|
||||
|
||||
### Class definition
|
||||
|
|
@ -1219,7 +1219,7 @@ class DemoFusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderM
|
|||
|
||||
### Script definition
|
||||
|
||||
class Script(scripts.Script):
|
||||
class Script(scripts_manager.Script):
|
||||
def title(self):
|
||||
return 'DemoFusion: High-Resolution Image Generation'
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue