modernize typing

pull/4663/head
Vladimir Mandic 2026-02-19 09:15:37 +01:00
parent 7aded79e8a
commit bfe014f5da
222 changed files with 1538 additions and 1444 deletions

View File

@ -45,7 +45,7 @@ TBD
see [docs](https://vladmandic.github.io/sdnext-docs/Python/) for details
- remove hard-dependnecies:
`clip, numba, skimage, torchsde, omegaconf, antlr, patch-ng, patch-ng, astunparse, addict, inflection, jsonmerge, kornia`,
`resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets`
`resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets, imp`
these are now installed on-demand when needed
- refactor: to/from image/tensor logic, thanks @CalamitousFelicitousness
- refactor: switch to `pyproject.toml` for tool configs
@ -57,6 +57,8 @@ TBD
- refactor: unified command line parsing
- refactor: launch use threads to async execute non-critical tasks
- refactor: switch from deprecated `pkg_resources` to `importlib`
- refactor: modernize typing and type annotations
- refactor: improve `pydantic==2.x` compatibility
- update `lint` rules, thanks @awsr
- remove requirements: `clip`, `open-clip`
- update `requirements`

View File

@ -1,4 +1,4 @@
from typing import overload, List, Optional
from typing import overload
import os
import sys
import json
@ -106,10 +106,12 @@ def str_to_bool(val: str | bool | None) -> bool | None:
return val
def install_traceback(suppress: list = []):
def install_traceback(suppress: list = None):
from rich.traceback import install as traceback_install
from rich.pretty import install as pretty_install
if suppress is None:
suppress = []
width = os.environ.get("SD_TRACEWIDTH", console.width if console else None)
if width is not None:
width = int(width)
@ -133,7 +135,7 @@ def setup_logging():
from functools import partial, partialmethod
from logging.handlers import RotatingFileHandler
try:
import rich # pylint: disable=unused-import
pass # pylint: disable=unused-import
except Exception:
log.error('Please restart SD.Next so changes take effect')
sys.exit(1)
@ -187,7 +189,7 @@ def setup_logging():
_Segment = Segment
left = _Segment(" " * self.left, style) if self.left else None
right = [_Segment.line()]
blank_line: Optional[List[Segment]] = None
blank_line: list[Segment] | None = None
if self.top:
blank_line = [_Segment(f'{" " * width}\n', style)]
yield from blank_line * self.top
@ -215,8 +217,10 @@ def setup_logging():
logging.Logger.trace = partialmethod(logging.Logger.log, logging.TRACE)
logging.trace = partial(logging.log, logging.TRACE)
def exception_hook(e: Exception, suppress=[]):
def exception_hook(e: Exception, suppress=None):
from rich.traceback import Traceback
if suppress is None:
suppress = []
tb = Traceback.from_exception(type(e), e, e.__traceback__, show_locals=False, max_frames=16, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=console.width)
# print-to-console, does not get printed-to-file
exc_type, exc_value, exc_traceback = sys.exc_info()
@ -416,7 +420,7 @@ def uninstall(package, quiet = False):
def run(cmd: str, arg: str):
result = subprocess.run(f'"{cmd}" {arg}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
result = subprocess.run(f'"{cmd}" {arg}', shell=True, check=False, env=os.environ, capture_output=True)
txt = result.stdout.decode(encoding="utf8", errors="ignore")
if len(result.stderr) > 0:
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
@ -461,7 +465,7 @@ def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True):
all_args = f'{pip_log}{arg} {env_args}'.strip()
if not quiet:
log.debug(f'Running: {pipCmd}="{all_args}"')
result = subprocess.run(f'"{sys.executable}" -m {pipCmd} {all_args}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
result = subprocess.run(f'"{sys.executable}" -m {pipCmd} {all_args}', shell=True, check=False, env=os.environ, capture_output=True)
txt = result.stdout.decode(encoding="utf8", errors="ignore")
if len(result.stderr) > 0:
if uv and result.returncode != 0:
@ -509,7 +513,7 @@ def git(arg: str, folder: str = None, ignore: bool = False, optional: bool = Fal
git_cmd = os.environ.get('GIT', "git")
if git_cmd != "git":
git_cmd = os.path.abspath(git_cmd)
result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.')
result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, capture_output=True, cwd=folder or '.')
stdout = result.stdout.decode(encoding="utf8", errors="ignore")
if len(result.stderr) > 0:
stdout += ('\n' if len(stdout) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
@ -639,7 +643,11 @@ def get_platform():
# check python version
def check_python(supported_minors=[], experimental_minors=[], reason=None):
def check_python(supported_minors=None, experimental_minors=None, reason=None):
if experimental_minors is None:
experimental_minors = []
if supported_minors is None:
supported_minors = []
if supported_minors is None or len(supported_minors) == 0:
supported_minors = [10, 11, 12, 13]
experimental_minors = [14]
@ -911,8 +919,6 @@ def install_torch_addons():
if 'xformers' in xformers_package:
try:
install(xformers_package, ignore=True, no_deps=True)
import torch # pylint: disable=unused-import
import xformers # pylint: disable=unused-import
except Exception as e:
log.debug(f'xFormers cannot install: {e}')
elif not args.experimental and not args.use_xformers and opts.get('cross_attention_optimization', '') != 'xFormers':
@ -1126,7 +1132,7 @@ def run_extension_installer(folder):
if os.environ.get('PYTHONPATH', None) is not None:
seperator = ';' if sys.platform == 'win32' else ':'
env['PYTHONPATH'] += seperator + os.environ.get('PYTHONPATH', None)
result = subprocess.run(f'"{sys.executable}" "{path_installer}"', shell=True, env=env, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder)
result = subprocess.run(f'"{sys.executable}" "{path_installer}"', shell=True, env=env, check=False, capture_output=True, cwd=folder)
txt = result.stdout.decode(encoding="utf8", errors="ignore")
debug(f'Extension installer: file="{path_installer}" {txt}')
if result.returncode != 0:
@ -1265,7 +1271,7 @@ def ensure_base_requirements():
local_log = logging.getLogger('sdnext.installer')
global setuptools, distutils # pylint: disable=global-statement
# python may ship with incompatible setuptools
subprocess.run(f'"{sys.executable}" -m pip install setuptools=={setuptools_version}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
subprocess.run(f'"{sys.executable}" -m pip install setuptools=={setuptools_version}', shell=True, check=False, env=os.environ, capture_output=True)
# need to delete all references to modules to be able to reload them otherwise python will use cached version
modules = [m for m in sys.modules if m.startswith('setuptools') or m.startswith('distutils')]
for m in modules:
@ -1399,7 +1405,7 @@ def install_requirements():
log.info('Install: verifying requirements')
if args.new:
log.debug('Install: flag=new')
with open('requirements.txt', 'r', encoding='utf8') as f:
with open('requirements.txt', encoding='utf8') as f:
lines = [line.strip() for line in f.readlines() if line.strip() != '' and not line.startswith('#') and line is not None]
for line in lines:
if not installed(line, quiet=True):
@ -1495,20 +1501,20 @@ def get_version(force=False):
t_start = time.time()
if (version is None) or (version.get('branch', 'unknown') == 'unknown') or force:
try:
subprocess.run('git config log.showsignature false', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
subprocess.run('git config log.showsignature false', capture_output=True, shell=True, check=True)
except Exception:
pass
try:
res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', capture_output=True, shell=True, check=True)
ver = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ' '
commit, updated = ver.split(' ')
version['commit'], version['updated'] = commit, updated
except Exception as e:
log.warning(f'Version: where=commit {e}')
try:
res = subprocess.run('git remote get-url origin', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
res = subprocess.run('git remote get-url origin', capture_output=True, shell=True, check=True)
origin = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
res = subprocess.run('git rev-parse --abbrev-ref HEAD', capture_output=True, shell=True, check=True)
branch_name = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
version['url'] = origin.replace('\n', '').removesuffix('.git') + '/tree/' + branch_name.replace('\n', '')
version['branch'] = branch_name.replace('\n', '')
@ -1520,7 +1526,7 @@ def get_version(force=False):
try:
if os.path.exists('extensions-builtin/sdnext-modernui'):
os.chdir('extensions-builtin/sdnext-modernui')
res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
res = subprocess.run('git rev-parse --abbrev-ref HEAD', capture_output=True, shell=True, check=True)
branch_ui = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
branch_ui = 'dev' if 'dev' in branch_ui else 'main'
version['ui'] = branch_ui
@ -1536,7 +1542,7 @@ def get_version(force=False):
version['kanvas'] = 'disabled'
elif os.path.exists('extensions-builtin/sdnext-kanvas'):
os.chdir('extensions-builtin/sdnext-kanvas')
res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
res = subprocess.run('git rev-parse --abbrev-ref HEAD', capture_output=True, shell=True, check=True)
branch_kanvas = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
branch_kanvas = 'dev' if 'dev' in branch_kanvas else 'main'
version['kanvas'] = branch_kanvas
@ -1723,7 +1729,7 @@ def check_timestamp():
ok = True
setup_time = -1
version_time = -1
with open(log_file, 'r', encoding='utf8') as f:
with open(log_file, encoding='utf8') as f:
lines = f.readlines()
for line in lines:
if 'Setup complete without errors' in line:
@ -1752,7 +1758,6 @@ def check_timestamp():
def add_args(parser):
import argparse
group_install = parser.add_argument_group('Install')
group_install.add_argument('--quick', default=os.environ.get("SD_QUICK",False), action='store_true', help="Bypass version checks, default: %(default)s")
group_install.add_argument('--reset', default=os.environ.get("SD_RESET",False), action='store_true', help="Reset main repository to latest version, default: %(default)s")
@ -1832,7 +1837,7 @@ def read_options():
t_start = time.time()
global opts # pylint: disable=global-statement
if os.path.isfile(args.config):
with open(args.config, "r", encoding="utf8") as file:
with open(args.config, encoding="utf8") as file:
try:
opts = json.load(file)
if type(opts) is str:

View File

@ -72,7 +72,7 @@ def get_custom_args():
rec('args')
@lru_cache()
@lru_cache
def commit_hash(): # compatbility function
global stored_commit_hash # pylint: disable=global-statement
if stored_commit_hash is not None:
@ -85,7 +85,7 @@ def commit_hash(): # compatbility function
return stored_commit_hash
@lru_cache()
@lru_cache
def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compatbility function
if desc is not None:
installer.log.info(desc)
@ -94,7 +94,7 @@ def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compat
if result.returncode != 0:
raise RuntimeError(f"""{errdesc or 'Error running command'} Command: {command} Error code: {result.returncode}""")
return ''
result = subprocess.run(command, stdout=subprocess.PIPE, check=False, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
result = subprocess.run(command, capture_output=True, check=False, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
raise RuntimeError(f"""{errdesc or 'Error running command'}: {command} code: {result.returncode}
{result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''}
@ -104,26 +104,26 @@ def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compat
def check_run(command): # compatbility function
result = subprocess.run(command, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
result = subprocess.run(command, check=False, capture_output=True, shell=True)
return result.returncode == 0
@lru_cache()
@lru_cache
def is_installed(pkg): # compatbility function
return installer.installed(pkg)
@lru_cache()
@lru_cache
def repo_dir(name): # compatbility function
return os.path.join(script_path, dir_repos, name)
@lru_cache()
@lru_cache
def run_python(code, desc=None, errdesc=None): # compatbility function
return run(f'"{sys.executable}" -c "{code}"', desc, errdesc)
@lru_cache()
@lru_cache
def run_pip(pkg, desc=None): # compatbility function
forbidden = ['onnxruntime', 'opencv-python']
if desc is None:
@ -136,7 +136,7 @@ def run_pip(pkg, desc=None): # compatbility function
return run(f'"{sys.executable}" -m pip {pkg} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
@lru_cache()
@lru_cache
def check_run_python(code): # compatbility function
return check_run(f'"{sys.executable}" -c "{code}"')

View File

@ -14,7 +14,7 @@
from dataclasses import dataclass
from math import ceil
from typing import Callable, Dict, List, Optional, Union
from collections.abc import Callable
import numpy as np
import PIL
@ -63,11 +63,11 @@ class StableCascadePriorPipelineOutput(BaseOutput):
Text embeddings for the negative prompt.
"""
image_embeddings: Union[torch.Tensor, np.ndarray]
prompt_embeds: Union[torch.Tensor, np.ndarray]
prompt_embeds_pooled: Union[torch.Tensor, np.ndarray]
negative_prompt_embeds: Union[torch.Tensor, np.ndarray]
negative_prompt_embeds_pooled: Union[torch.Tensor, np.ndarray]
image_embeddings: torch.Tensor | np.ndarray
prompt_embeds: torch.Tensor | np.ndarray
prompt_embeds_pooled: torch.Tensor | np.ndarray
negative_prompt_embeds: torch.Tensor | np.ndarray
negative_prompt_embeds_pooled: torch.Tensor | np.ndarray
class StableCascadePriorPipelineAPG(DiffusionPipeline):
@ -109,8 +109,8 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
prior: StableCascadeUNet,
scheduler: DDPMWuerstchenScheduler,
resolution_multiple: float = 42.67,
feature_extractor: Optional[CLIPImageProcessor] = None,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
feature_extractor: CLIPImageProcessor | None = None,
image_encoder: CLIPVisionModelWithProjection | None = None,
) -> None:
super().__init__()
self.register_modules(
@ -151,10 +151,10 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
do_classifier_free_guidance,
prompt=None,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_pooled: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_pooled: Optional[torch.Tensor] = None,
prompt_embeds: torch.Tensor | None = None,
prompt_embeds_pooled: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds_pooled: torch.Tensor | None = None,
):
if prompt_embeds is None:
# get prompt text embeddings
@ -196,7 +196,7 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0)
if negative_prompt_embeds is None and do_classifier_free_guidance:
uncond_tokens: List[str]
uncond_tokens: list[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
@ -367,26 +367,26 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None,
prompt: str | list[str] | None = None,
images: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image] = None,
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 20,
timesteps: List[float] = None,
timesteps: list[float] = None,
guidance_scale: float = 4.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_pooled: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_pooled: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pt",
negative_prompt: str | list[str] | None = None,
prompt_embeds: torch.Tensor | None = None,
prompt_embeds_pooled: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds_pooled: torch.Tensor | None = None,
image_embeds: torch.Tensor | None = None,
num_images_per_prompt: int | None = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
output_type: str | None = "pt",
return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
callback_on_step_end_tensor_inputs: list[str] = None,
):
"""
Function invoked when calling the pipeline for generation.
@ -460,6 +460,8 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
"""
# 0. Define commonly used variables
if callback_on_step_end_tensor_inputs is None:
callback_on_step_end_tensor_inputs = ["latents"]
device = self._execution_device
dtype = next(self.prior.parameters()).dtype
self._guidance_scale = guidance_scale

View File

@ -13,7 +13,8 @@
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any
from collections.abc import Callable
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
@ -28,7 +29,6 @@ from diffusers.utils import USE_PEFT_BACKEND, deprecate, is_invisible_watermark_
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.models.attention_processor import Attention
from modules import apg
if is_invisible_watermark_available():
@ -76,10 +76,10 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
num_inference_steps: int | None = None,
device: str | torch.device | None = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,
**kwargs,
):
"""
@ -217,7 +217,7 @@ class StableDiffusionXLPipelineAPG(
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
add_watermarker: bool | None = None,
):
super().__init__()
@ -248,18 +248,18 @@ class StableDiffusionXLPipelineAPG(
def encode_prompt(
self,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
prompt_2: str | None = None,
device: torch.device | None = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
negative_prompt: str | None = None,
negative_prompt_2: str | None = None,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
pooled_prompt_embeds: torch.Tensor | None = None,
negative_pooled_prompt_embeds: torch.Tensor | None = None,
lora_scale: float | None = None,
clip_skip: int | None = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@ -343,7 +343,7 @@ class StableDiffusionXLPipelineAPG(
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False):
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
@ -396,7 +396,7 @@ class StableDiffusionXLPipelineAPG(
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str]
uncond_tokens: list[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
@ -412,7 +412,7 @@ class StableDiffusionXLPipelineAPG(
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders, strict=False):
if isinstance(self, TextualInversionLoaderMixin):
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
@ -521,7 +521,7 @@ class StableDiffusionXLPipelineAPG(
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers, strict=False
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
@ -793,42 +793,40 @@ class StableDiffusionXLPipelineAPG(
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
prompt: str | list[str] = None,
prompt_2: str | list[str] | None = None,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
sigmas: List[float] = None,
denoising_end: Optional[float] = None,
timesteps: list[int] = None,
sigmas: list[float] = None,
denoising_end: float | None = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
negative_prompt: str | list[str] | None = None,
negative_prompt_2: str | list[str] | None = None,
num_images_per_prompt: int | None = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
output_type: Optional[str] = "pil",
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
pooled_prompt_embeds: torch.Tensor | None = None,
negative_pooled_prompt_embeds: torch.Tensor | None = None,
ip_adapter_image: PipelineImageInput | None = None,
ip_adapter_image_embeds: list[torch.Tensor] | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
cross_attention_kwargs: dict[str, Any] | None = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
negative_original_size: Optional[Tuple[int, int]] = None,
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
original_size: tuple[int, int] | None = None,
crops_coords_top_left: tuple[int, int] = (0, 0),
target_size: tuple[int, int] | None = None,
negative_original_size: tuple[int, int] | None = None,
negative_crops_coords_top_left: tuple[int, int] = (0, 0),
negative_target_size: tuple[int, int] | None = None,
clip_skip: int | None = None,
callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
callback_on_step_end_tensor_inputs: list[str] = None,
**kwargs,
):
r"""
@ -976,6 +974,8 @@ class StableDiffusionXLPipelineAPG(
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if callback_on_step_end_tensor_inputs is None:
callback_on_step_end_tensor_inputs = ["latents"]
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)

View File

@ -13,7 +13,8 @@
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any
from collections.abc import Callable
import torch
from packaging import version
@ -71,10 +72,10 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
num_inference_steps: int | None = None,
device: str | torch.device | None = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,
**kwargs,
):
"""
@ -273,9 +274,9 @@ class StableDiffusionPipelineAPG(
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
lora_scale: float | None = None,
**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
@ -305,10 +306,10 @@ class StableDiffusionPipelineAPG(
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
lora_scale: float | None = None,
clip_skip: int | None = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@ -421,7 +422,7 @@ class StableDiffusionPipelineAPG(
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
uncond_tokens: list[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
@ -520,7 +521,7 @@ class StableDiffusionPipelineAPG(
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers, strict=False
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
@ -748,31 +749,29 @@ class StableDiffusionPipelineAPG(
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
prompt: str | list[str] = None,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
sigmas: List[float] = None,
timesteps: list[int] = None,
sigmas: list[float] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
negative_prompt: str | list[str] | None = None,
num_images_per_prompt: int | None = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
output_type: Optional[str] = "pil",
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
ip_adapter_image: PipelineImageInput | None = None,
ip_adapter_image_embeds: list[torch.Tensor] | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
cross_attention_kwargs: dict[str, Any] | None = None,
guidance_rescale: float = 0.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
clip_skip: int | None = None,
callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
callback_on_step_end_tensor_inputs: list[str] = None,
**kwargs,
):
r"""
@ -861,6 +860,8 @@ class StableDiffusionPipelineAPG(
"not-safe-for-work" (nsfw) content.
"""
if callback_on_step_end_tensor_inputs is None:
callback_on_step_end_tensor_inputs = ["latents"]
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)

View File

@ -1,4 +1,3 @@
from typing import List, Optional
from threading import Lock
from secrets import compare_digest
from fastapi import FastAPI, APIRouter, Depends, Request
@ -19,7 +18,7 @@ class Api:
user, password = auth.split(":")
self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip()
if shared.cmd_opts.auth_file:
with open(shared.cmd_opts.auth_file, 'r', encoding="utf8") as file:
with open(shared.cmd_opts.auth_file, encoding="utf8") as file:
for line in file.readlines():
user, password = line.split(":")
self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip()
@ -41,7 +40,7 @@ class Api:
self.add_api_route("/js", server.get_js, methods=["GET"], auth=False)
# server api
self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str)
self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=list[str])
self.add_api_route("/sdapi/v1/log", server.post_log, methods=["POST"])
self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"])
self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"])
@ -56,7 +55,7 @@ class Api:
self.add_api_route("/sdapi/v1/options", server.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", server.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", server.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu_status, methods=["GET"], response_model=List[models.ResGPU])
self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu_status, methods=["GET"], response_model=list[models.ResGPU])
# core api using locking
self.add_api_route("/sdapi/v1/txt2img", self.generate.post_text2img, methods=["POST"], response_model=models.ResTxt2Img)
@ -71,21 +70,21 @@ class Api:
# api dealing with optional scripts
self.add_api_route("/sdapi/v1/scripts", script.get_scripts_list, methods=["GET"], response_model=models.ResScripts)
self.add_api_route("/sdapi/v1/script-info", script.get_script_info, methods=["GET"], response_model=List[models.ItemScript])
self.add_api_route("/sdapi/v1/script-info", script.get_script_info, methods=["GET"], response_model=list[models.ItemScript])
# enumerator api
self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=List[process.ItemPreprocess])
self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=list[process.ItemPreprocess])
self.add_api_route("/sdapi/v1/masking", self.process.get_mask, methods=["GET"], response_model=process.ItemMask)
self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=List[models.ItemSampler])
self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=List[models.ItemUpscaler])
self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=List[models.ItemModel])
self.add_api_route("/sdapi/v1/controlnets", endpoints.get_controlnets, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/detailers", endpoints.get_detailers, methods=["GET"], response_model=List[models.ItemDetailer])
self.add_api_route("/sdapi/v1/prompt-styles", endpoints.get_prompt_styles, methods=["GET"], response_model=List[models.ItemStyle])
self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=list[models.ItemSampler])
self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=list[models.ItemUpscaler])
self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=list[models.ItemModel])
self.add_api_route("/sdapi/v1/controlnets", endpoints.get_controlnets, methods=["GET"], response_model=list[str])
self.add_api_route("/sdapi/v1/detailers", endpoints.get_detailers, methods=["GET"], response_model=list[models.ItemDetailer])
self.add_api_route("/sdapi/v1/prompt-styles", endpoints.get_prompt_styles, methods=["GET"], response_model=list[models.ItemStyle])
self.add_api_route("/sdapi/v1/embeddings", endpoints.get_embeddings, methods=["GET"], response_model=models.ResEmbeddings)
self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=List[models.ItemVae])
self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=List[models.ItemExtension])
self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=List[models.ItemExtraNetwork])
self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=list[models.ItemVae])
self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=list[models.ItemExtension])
self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=list[models.ItemExtraNetwork])
# functional api
self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo)
@ -96,7 +95,7 @@ class Api:
self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/lock-checkpoint", endpoints.post_lock_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=list[str])
self.add_api_route("/sdapi/v1/latents", endpoints.post_latent_history, methods=["POST"], response_model=int)
self.add_api_route("/sdapi/v1/modules", endpoints.get_modules, methods=["GET"])
self.add_api_route("/sdapi/v1/sampler", endpoints.get_sampler, methods=["GET"], response_model=dict)
@ -146,7 +145,7 @@ class Api:
shared.log.error(f'API authentication: user="{credentials.username}"')
raise HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"})
def get_session_start(self, req: Request, agent: Optional[str] = None):
def get_session_start(self, req: Request, agent: str | None = None):
token = req.cookies.get("access-token") or req.cookies.get("access-token-unsecure")
user = self.app.tokens.get(token) if hasattr(self.app, 'tokens') else None
shared.log.info(f'Browser session: user={user} client={req.client.host} agent={agent}')

View File

@ -25,7 +25,7 @@ Core processing logic is shared between direct and dispatch handlers via
``do_openclip``, ``do_tagger``, and ``do_vqa`` functions to avoid duplication.
"""
from typing import Optional, List, Union, Literal, Annotated
from typing import Literal, Annotated
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
from fastapi.exceptions import HTTPException
from modules import shared
@ -49,21 +49,21 @@ class ReqCaption(BaseModel):
mode: str = Field(default="best", title="Mode", description="Caption mode. 'best': Most thorough analysis, slowest but highest quality. 'fast': Quick caption with minimal flavor terms. 'classic': Standard captioning with balanced quality and speed. 'caption': BLIP caption only, no CLIP flavor matching. 'negative': Generate terms suitable for use as a negative prompt.")
analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed image analysis breakdown (medium, artist, movement, trending, flavor) in addition to caption.")
# Advanced settings (optional per-request overrides)
max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum number of tokens in the generated caption.")
chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up captioning but increase VRAM usage.")
min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum number of descriptive tags (flavors) to keep in the final prompt.")
max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum number of descriptive tags (flavors) to keep in the final prompt.")
flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of the intermediate candidate pool when matching image features to descriptive tags. Higher values may improve quality but are slower.")
num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Number of beams for beam search during caption generation. Higher values search more possibilities but are slower.")
max_length: int | None = Field(default=None, title="Max Length", description="Maximum number of tokens in the generated caption.")
chunk_size: int | None = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up captioning but increase VRAM usage.")
min_flavors: int | None = Field(default=None, title="Min Flavors", description="Minimum number of descriptive tags (flavors) to keep in the final prompt.")
max_flavors: int | None = Field(default=None, title="Max Flavors", description="Maximum number of descriptive tags (flavors) to keep in the final prompt.")
flavor_count: int | None = Field(default=None, title="Intermediates", description="Size of the intermediate candidate pool when matching image features to descriptive tags. Higher values may improve quality but are slower.")
num_beams: int | None = Field(default=None, title="Num Beams", description="Number of beams for beam search during caption generation. Higher values search more possibilities but are slower.")
class ResCaption(BaseModel):
"""Response model for image captioning results."""
caption: Optional[str] = Field(default=None, title="Caption", description="Generated caption/prompt describing the image content and style.")
medium: Optional[str] = Field(default=None, title="Medium", description="Detected artistic medium (e.g., 'oil painting', 'digital art', 'photograph'). Only returned when analyze=True.")
artist: Optional[str] = Field(default=None, title="Artist", description="Detected similar artist style (e.g., 'by greg rutkowski'). Only returned when analyze=True.")
movement: Optional[str] = Field(default=None, title="Movement", description="Detected art movement (e.g., 'art nouveau', 'impressionism'). Only returned when analyze=True.")
trending: Optional[str] = Field(default=None, title="Trending", description="Trending/platform tags (e.g., 'trending on artstation'). Only returned when analyze=True.")
flavor: Optional[str] = Field(default=None, title="Flavor", description="Additional descriptive elements (e.g., 'cinematic lighting', 'highly detailed'). Only returned when analyze=True.")
caption: str | None = Field(default=None, title="Caption", description="Generated caption/prompt describing the image content and style.")
medium: str | None = Field(default=None, title="Medium", description="Detected artistic medium (e.g., 'oil painting', 'digital art', 'photograph'). Only returned when analyze=True.")
artist: str | None = Field(default=None, title="Artist", description="Detected similar artist style (e.g., 'by greg rutkowski'). Only returned when analyze=True.")
movement: str | None = Field(default=None, title="Movement", description="Detected art movement (e.g., 'art nouveau', 'impressionism'). Only returned when analyze=True.")
trending: str | None = Field(default=None, title="Trending", description="Trending/platform tags (e.g., 'trending on artstation'). Only returned when analyze=True.")
flavor: str | None = Field(default=None, title="Flavor", description="Additional descriptive elements (e.g., 'cinematic lighting', 'highly detailed'). Only returned when analyze=True.")
class ReqVQA(BaseModel):
"""Request model for Vision-Language Model (VLM) captioning.
@ -74,32 +74,32 @@ class ReqVQA(BaseModel):
image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string containing the image data.")
model: str = Field(default="Alibaba Qwen 2.5 VL 3B", title="Model", description="Select which model to use for Visual Language tasks. Use GET /sdapi/v1/vqa/models for full list. Models which support thinking mode are indicated in capabilities.")
question: str = Field(default="describe the image", title="Question/Task", description="Task for the model to perform. Common tasks: 'Short Caption', 'Normal Caption', 'Long Caption'. Set to 'Use Prompt' to pass custom text via the prompt field. Florence-2 tasks: 'Object Detection', 'OCR (Read Text)', 'Phrase Grounding', 'Dense Region Caption', 'Region Proposal', 'OCR with Regions'. PromptGen tasks: 'Analyze', 'Generate Tags', 'Mixed Caption'. Moondream tasks: 'Point at...', 'Detect all...', 'Detect Gaze' (Moondream 2 only). Use GET /sdapi/v1/vqa/prompts?model=<name> to list tasks available for a specific model.")
prompt: Optional[str] = Field(default=None, title="Prompt", description="Custom prompt text. Required when question is 'Use Prompt'. For 'Point at...' tasks, specify what to find (e.g., 'the red car'). For 'Detect all...' tasks, specify what to detect (e.g., 'faces').")
prompt: str | None = Field(default=None, title="Prompt", description="Custom prompt text. Required when question is 'Use Prompt'. For 'Point at...' tasks, specify what to find (e.g., 'the red car'). For 'Detect all...' tasks, specify what to detect (e.g., 'faces').")
system: str = Field(default="You are image captioning expert, creative, unbiased and uncensored.", title="System Prompt", description="System prompt controls behavior of the LLM. Processed first and persists throughout conversation. Has highest priority weighting and is always appended at the beginning of the sequence. Use for: Response formatting rules, role definition, style.")
include_annotated: bool = Field(default=False, title="Include Annotated Image", description="If True and the task produces detection results (object detection, point detection, gaze), returns annotated image with bounding boxes/points drawn. Only applicable for detection tasks on models like Florence-2 and Moondream.")
# LLM generation parameters (optional overrides)
max_tokens: Optional[int] = Field(default=None, title="Max Tokens", description="Maximum number of tokens the model can generate in its response. The model is not aware of this limit during generation; it simply sets the hard limit for the length and will forcefully cut off the response when reached.")
temperature: Optional[float] = Field(default=None, title="Temperature", description="Controls randomness in token selection. Lower values (e.g., 0.1) make outputs more focused and deterministic, always choosing high-probability tokens. Higher values (e.g., 0.9) increase creativity and diversity by allowing less probable tokens. Set to 0 for fully deterministic output.")
top_k: Optional[int] = Field(default=None, title="Top-K", description="Limits token selection to the K most likely candidates at each step. Lower values (e.g., 40) make outputs more focused and predictable, while higher values allow more diverse choices. Set to 0 to disable.")
top_p: Optional[float] = Field(default=None, title="Top-P", description="Selects tokens from the smallest set whose cumulative probability exceeds P (e.g., 0.9). Dynamically adapts the number of candidates based on model confidence; fewer options when certain, more when uncertain. Set to 1 to disable.")
num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Maintains multiple candidate paths simultaneously and selects the overall best sequence. More thorough but much slower and less creative than random sampling. Generally not recommended; most modern VLMs perform better with sampling methods. Set to 1 to disable.")
do_sample: Optional[bool] = Field(default=None, title="Use Samplers", description="Enable to use sampling (randomly selecting tokens based on sampling methods like Top-K or Top-P) or disable to use greedy decoding (selecting the most probable token at each step). Enabling makes outputs more diverse and creative but less deterministic.")
thinking_mode: Optional[bool] = Field(default=None, title="Thinking Mode", description="Enables thinking/reasoning, allowing the model to take more time to generate responses. Can lead to more thoughtful and detailed answers but increases response time. Only works with models that support this feature.")
prefill: Optional[str] = Field(default=None, title="Prefill Text", description="Pre-fills the start of the model's response to guide its output format or content by forcing it to continue the prefill text. Prefill is filtered out and does not appear in the final response unless keep_prefill is True. Leave empty to let the model generate from scratch.")
keep_thinking: Optional[bool] = Field(default=None, title="Keep Thinking Trace", description="Include the model's reasoning process in the final output. Useful for understanding how the model arrived at its answer. Only works with models that support thinking mode.")
keep_prefill: Optional[bool] = Field(default=None, title="Keep Prefill", description="Include the prefill text at the beginning of the final output. If disabled, the prefill text used to guide the model is removed from the result.")
max_tokens: int | None = Field(default=None, title="Max Tokens", description="Maximum number of tokens the model can generate in its response. The model is not aware of this limit during generation; it simply sets the hard limit for the length and will forcefully cut off the response when reached.")
temperature: float | None = Field(default=None, title="Temperature", description="Controls randomness in token selection. Lower values (e.g., 0.1) make outputs more focused and deterministic, always choosing high-probability tokens. Higher values (e.g., 0.9) increase creativity and diversity by allowing less probable tokens. Set to 0 for fully deterministic output.")
top_k: int | None = Field(default=None, title="Top-K", description="Limits token selection to the K most likely candidates at each step. Lower values (e.g., 40) make outputs more focused and predictable, while higher values allow more diverse choices. Set to 0 to disable.")
top_p: float | None = Field(default=None, title="Top-P", description="Selects tokens from the smallest set whose cumulative probability exceeds P (e.g., 0.9). Dynamically adapts the number of candidates based on model confidence; fewer options when certain, more when uncertain. Set to 1 to disable.")
num_beams: int | None = Field(default=None, title="Num Beams", description="Maintains multiple candidate paths simultaneously and selects the overall best sequence. More thorough but much slower and less creative than random sampling. Generally not recommended; most modern VLMs perform better with sampling methods. Set to 1 to disable.")
do_sample: bool | None = Field(default=None, title="Use Samplers", description="Enable to use sampling (randomly selecting tokens based on sampling methods like Top-K or Top-P) or disable to use greedy decoding (selecting the most probable token at each step). Enabling makes outputs more diverse and creative but less deterministic.")
thinking_mode: bool | None = Field(default=None, title="Thinking Mode", description="Enables thinking/reasoning, allowing the model to take more time to generate responses. Can lead to more thoughtful and detailed answers but increases response time. Only works with models that support this feature.")
prefill: str | None = Field(default=None, title="Prefill Text", description="Pre-fills the start of the model's response to guide its output format or content by forcing it to continue the prefill text. Prefill is filtered out and does not appear in the final response unless keep_prefill is True. Leave empty to let the model generate from scratch.")
keep_thinking: bool | None = Field(default=None, title="Keep Thinking Trace", description="Include the model's reasoning process in the final output. Useful for understanding how the model arrived at its answer. Only works with models that support thinking mode.")
keep_prefill: bool | None = Field(default=None, title="Keep Prefill", description="Include the prefill text at the beginning of the final output. If disabled, the prefill text used to guide the model is removed from the result.")
class ResVQA(BaseModel):
"""Response model for VLM captioning results."""
answer: Optional[str] = Field(default=None, title="Answer", description="Generated caption, answer, or analysis from the VLM. Format depends on the question/task type.")
annotated_image: Optional[str] = Field(default=None, title="Annotated Image", description="Base64 encoded PNG image with detection results drawn (bounding boxes, points). Only returned when include_annotated=True and the task produces detection results.")
answer: str | None = Field(default=None, title="Answer", description="Generated caption, answer, or analysis from the VLM. Format depends on the question/task type.")
annotated_image: str | None = Field(default=None, title="Annotated Image", description="Base64 encoded PNG image with detection results drawn (bounding boxes, points). Only returned when include_annotated=True and the task produces detection results.")
class ItemVLMModel(BaseModel):
"""VLM model information."""
name: str = Field(title="Name", description="Display name of the model")
repo: str = Field(title="Repository", description="HuggingFace repository ID")
prompts: List[str] = Field(title="Prompts", description="Available prompts/tasks for this model")
capabilities: List[str] = Field(title="Capabilities", description="Model capabilities. Possible values: 'caption' (image captioning), 'vqa' (visual question answering), 'detection' (object/point detection), 'ocr' (text recognition), 'thinking' (reasoning mode support).")
prompts: list[str] = Field(title="Prompts", description="Available prompts/tasks for this model")
capabilities: list[str] = Field(title="Capabilities", description="Model capabilities. Possible values: 'caption' (image captioning), 'vqa' (visual question answering), 'detection' (object/point detection), 'ocr' (text recognition), 'thinking' (reasoning mode support).")
class ResVLMPrompts(BaseModel):
"""Available VLM prompts grouped by category.
@ -107,12 +107,12 @@ class ResVLMPrompts(BaseModel):
When called without ``model`` parameter, returns all prompt categories.
When called with ``model``, returns only the ``available`` field with prompts for that model.
"""
common: Optional[List[str]] = Field(default=None, title="Common", description="Prompts available for all models: Use Prompt, Short/Normal/Long Caption.")
florence: Optional[List[str]] = Field(default=None, title="Florence", description="Florence-2 base model tasks: Phrase Grounding, Object Detection, Dense Region Caption, Region Proposal, OCR (Read Text), OCR with Regions.")
promptgen: Optional[List[str]] = Field(default=None, title="PromptGen", description="MiaoshouAI PromptGen fine-tune tasks: Analyze, Generate Tags, Mixed Caption, Mixed Caption+. Only available on PromptGen models.")
moondream: Optional[List[str]] = Field(default=None, title="Moondream", description="Moondream 2 and 3 tasks: Point at..., Detect all...")
moondream2_only: Optional[List[str]] = Field(default=None, title="Moondream 2 Only", description="Moondream 2 exclusive tasks: Detect Gaze. Not available in Moondream 3.")
available: Optional[List[str]] = Field(default=None, title="Available", description="Populated only when filtering by model. Contains the combined list of prompts available for the specified model.")
common: list[str] | None = Field(default=None, title="Common", description="Prompts available for all models: Use Prompt, Short/Normal/Long Caption.")
florence: list[str] | None = Field(default=None, title="Florence", description="Florence-2 base model tasks: Phrase Grounding, Object Detection, Dense Region Caption, Region Proposal, OCR (Read Text), OCR with Regions.")
promptgen: list[str] | None = Field(default=None, title="PromptGen", description="MiaoshouAI PromptGen fine-tune tasks: Analyze, Generate Tags, Mixed Caption, Mixed Caption+. Only available on PromptGen models.")
moondream: list[str] | None = Field(default=None, title="Moondream", description="Moondream 2 and 3 tasks: Point at..., Detect all...")
moondream2_only: list[str] | None = Field(default=None, title="Moondream 2 Only", description="Moondream 2 exclusive tasks: Detect Gaze. Not available in Moondream 3.")
available: list[str] | None = Field(default=None, title="Available", description="Populated only when filtering by model. Contains the combined list of prompts available for the specified model.")
class ItemTaggerModel(BaseModel):
"""Tagger model information."""
@ -136,7 +136,7 @@ class ReqTagger(BaseModel):
class ResTagger(BaseModel):
"""Response model for image tagging results."""
tags: str = Field(title="Tags", description="Comma-separated list of detected tags")
scores: Optional[dict] = Field(default=None, title="Scores", description="Tag confidence scores (when show_scores=True)")
scores: dict | None = Field(default=None, title="Scores", description="Tag confidence scores (when show_scores=True)")
# =============================================================================
@ -158,12 +158,12 @@ class ReqCaptionOpenCLIP(BaseModel):
blip_model: str = Field(default="blip-large", title="Caption Model", description="BLIP model used to generate the initial image caption.")
mode: str = Field(default="best", title="Mode", description="Caption mode: 'best' (highest quality, slowest), 'fast' (quick, fewer flavors), 'classic' (balanced), 'caption' (BLIP only, no CLIP matching), 'negative' (for negative prompts).")
analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed breakdown (medium, artist, movement, trending, flavor).")
max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum tokens in generated caption.")
chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing flavors.")
min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum descriptive tags to keep.")
max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum descriptive tags to keep.")
flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of intermediate candidate pool.")
num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Beams for beam search during caption generation.")
max_length: int | None = Field(default=None, title="Max Length", description="Maximum tokens in generated caption.")
chunk_size: int | None = Field(default=None, title="Chunk Size", description="Batch size for processing flavors.")
min_flavors: int | None = Field(default=None, title="Min Flavors", description="Minimum descriptive tags to keep.")
max_flavors: int | None = Field(default=None, title="Max Flavors", description="Maximum descriptive tags to keep.")
flavor_count: int | None = Field(default=None, title="Intermediates", description="Size of intermediate candidate pool.")
num_beams: int | None = Field(default=None, title="Num Beams", description="Beams for beam search during caption generation.")
class ReqCaptionTagger(BaseModel):
@ -196,24 +196,24 @@ class ReqCaptionVLM(BaseModel):
image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string.")
model: str = Field(default="Alibaba Qwen 2.5 VL 3B", title="Model", description="VLM model to use. See GET /sdapi/v1/vqa/models for full list.")
question: str = Field(default="describe the image", title="Question/Task", description="Task to perform: 'Short Caption', 'Normal Caption', 'Long Caption', 'Use Prompt' (custom text via prompt field). Model-specific tasks available via GET /sdapi/v1/vqa/prompts.")
prompt: Optional[str] = Field(default=None, title="Prompt", description="Custom prompt text when question is 'Use Prompt'.")
prompt: str | None = Field(default=None, title="Prompt", description="Custom prompt text when question is 'Use Prompt'.")
system: str = Field(default="You are image captioning expert, creative, unbiased and uncensored.", title="System Prompt", description="System prompt for LLM behavior.")
include_annotated: bool = Field(default=False, title="Include Annotated Image", description="Return annotated image for detection tasks.")
max_tokens: Optional[int] = Field(default=None, title="Max Tokens", description="Maximum tokens in response.")
temperature: Optional[float] = Field(default=None, title="Temperature", description="Randomness in token selection (0=deterministic, 0.9=creative).")
top_k: Optional[int] = Field(default=None, title="Top-K", description="Limit to K most likely tokens per step.")
top_p: Optional[float] = Field(default=None, title="Top-P", description="Nucleus sampling threshold.")
num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Beam search width (1=disabled).")
do_sample: Optional[bool] = Field(default=None, title="Use Samplers", description="Enable sampling vs greedy decoding.")
thinking_mode: Optional[bool] = Field(default=None, title="Thinking Mode", description="Enable reasoning mode (supported models only).")
prefill: Optional[str] = Field(default=None, title="Prefill Text", description="Pre-fill response start to guide output.")
keep_thinking: Optional[bool] = Field(default=None, title="Keep Thinking Trace", description="Include reasoning in output.")
keep_prefill: Optional[bool] = Field(default=None, title="Keep Prefill", description="Keep prefill text in final output.")
max_tokens: int | None = Field(default=None, title="Max Tokens", description="Maximum tokens in response.")
temperature: float | None = Field(default=None, title="Temperature", description="Randomness in token selection (0=deterministic, 0.9=creative).")
top_k: int | None = Field(default=None, title="Top-K", description="Limit to K most likely tokens per step.")
top_p: float | None = Field(default=None, title="Top-P", description="Nucleus sampling threshold.")
num_beams: int | None = Field(default=None, title="Num Beams", description="Beam search width (1=disabled).")
do_sample: bool | None = Field(default=None, title="Use Samplers", description="Enable sampling vs greedy decoding.")
thinking_mode: bool | None = Field(default=None, title="Thinking Mode", description="Enable reasoning mode (supported models only).")
prefill: str | None = Field(default=None, title="Prefill Text", description="Pre-fill response start to guide output.")
keep_thinking: bool | None = Field(default=None, title="Keep Thinking Trace", description="Include reasoning in output.")
keep_prefill: bool | None = Field(default=None, title="Keep Prefill", description="Keep prefill text in final output.")
# Discriminated union for the dispatch endpoint
ReqCaptionDispatch = Annotated[
Union[ReqCaptionOpenCLIP, ReqCaptionTagger, ReqCaptionVLM],
ReqCaptionOpenCLIP | ReqCaptionTagger | ReqCaptionVLM,
Field(discriminator="backend")
]
@ -226,18 +226,18 @@ class ResCaptionDispatch(BaseModel):
# Common
backend: str = Field(title="Backend", description="The backend that processed the request: 'openclip', 'tagger', or 'vlm'.")
# OpenCLIP fields
caption: Optional[str] = Field(default=None, title="Caption", description="Generated caption (OpenCLIP backend).")
medium: Optional[str] = Field(default=None, title="Medium", description="Detected artistic medium (OpenCLIP with analyze=True).")
artist: Optional[str] = Field(default=None, title="Artist", description="Detected artist style (OpenCLIP with analyze=True).")
movement: Optional[str] = Field(default=None, title="Movement", description="Detected art movement (OpenCLIP with analyze=True).")
trending: Optional[str] = Field(default=None, title="Trending", description="Trending tags (OpenCLIP with analyze=True).")
flavor: Optional[str] = Field(default=None, title="Flavor", description="Flavor descriptors (OpenCLIP with analyze=True).")
caption: str | None = Field(default=None, title="Caption", description="Generated caption (OpenCLIP backend).")
medium: str | None = Field(default=None, title="Medium", description="Detected artistic medium (OpenCLIP with analyze=True).")
artist: str | None = Field(default=None, title="Artist", description="Detected artist style (OpenCLIP with analyze=True).")
movement: str | None = Field(default=None, title="Movement", description="Detected art movement (OpenCLIP with analyze=True).")
trending: str | None = Field(default=None, title="Trending", description="Trending tags (OpenCLIP with analyze=True).")
flavor: str | None = Field(default=None, title="Flavor", description="Flavor descriptors (OpenCLIP with analyze=True).")
# Tagger fields
tags: Optional[str] = Field(default=None, title="Tags", description="Comma-separated tags (Tagger backend).")
scores: Optional[dict] = Field(default=None, title="Scores", description="Tag confidence scores (Tagger with show_scores=True).")
tags: str | None = Field(default=None, title="Tags", description="Comma-separated tags (Tagger backend).")
scores: dict | None = Field(default=None, title="Scores", description="Tag confidence scores (Tagger with show_scores=True).")
# VLM fields
answer: Optional[str] = Field(default=None, title="Answer", description="VLM response (VLM backend).")
annotated_image: Optional[str] = Field(default=None, title="Annotated Image", description="Base64 annotated image (VLM with include_annotated=True).")
answer: str | None = Field(default=None, title="Answer", description="VLM response (VLM backend).")
annotated_image: str | None = Field(default=None, title="Annotated Image", description="Base64 annotated image (VLM with include_annotated=True).")
# =============================================================================
@ -596,7 +596,7 @@ def get_vqa_models():
return models_list
def get_vqa_prompts(model: Optional[str] = None):
def get_vqa_prompts(model: str | None = None):
"""
List available prompts/tasks for VLM models.
@ -653,11 +653,11 @@ def get_tagger_models():
def register_api():
from modules.shared import api
api.add_api_route("/sdapi/v1/openclip", get_caption, methods=["GET"], response_model=List[str], tags=["Caption"])
api.add_api_route("/sdapi/v1/openclip", get_caption, methods=["GET"], response_model=list[str], tags=["Caption"])
api.add_api_route("/sdapi/v1/caption", post_caption_dispatch, methods=["POST"], response_model=ResCaptionDispatch, tags=["Caption"])
api.add_api_route("/sdapi/v1/openclip", post_caption, methods=["POST"], response_model=ResCaption, tags=["Caption"])
api.add_api_route("/sdapi/v1/vqa", post_vqa, methods=["POST"], response_model=ResVQA, tags=["Caption"])
api.add_api_route("/sdapi/v1/vqa/models", get_vqa_models, methods=["GET"], response_model=List[ItemVLMModel], tags=["Caption"])
api.add_api_route("/sdapi/v1/vqa/models", get_vqa_models, methods=["GET"], response_model=list[ItemVLMModel], tags=["Caption"])
api.add_api_route("/sdapi/v1/vqa/prompts", get_vqa_prompts, methods=["GET"], response_model=ResVLMPrompts, tags=["Caption"])
api.add_api_route("/sdapi/v1/tagger", post_tagger, methods=["POST"], response_model=ResTagger, tags=["Caption"])
api.add_api_route("/sdapi/v1/tagger/models", get_tagger_models, methods=["GET"], response_model=List[ItemTaggerModel], tags=["Caption"])
api.add_api_route("/sdapi/v1/tagger/models", get_tagger_models, methods=["GET"], response_model=list[ItemTaggerModel], tags=["Caption"])

View File

@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Optional
from threading import Lock
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
from modules import errors, shared, processing_helpers
@ -43,9 +43,9 @@ ReqControl = models.create_model_from_signature(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "ip_adapter", "type": Optional[List[models.ItemIPAdapter]], "default": None, "exclude": True},
{"key": "ip_adapter", "type": Optional[list[models.ItemIPAdapter]], "default": None, "exclude": True},
{"key": "face", "type": Optional[models.ItemFace], "default": None, "exclude": True},
{"key": "control", "type": Optional[List[ItemControl]], "default": [], "exclude": True},
{"key": "control", "type": Optional[list[ItemControl]], "default": [], "exclude": True},
{"key": "xyz", "type": Optional[ItemXYZ], "default": None, "exclude": True},
# {"key": "extra", "type": Optional[dict], "default": {}, "exclude": True},
]
@ -55,13 +55,13 @@ if not hasattr(ReqControl, "__config__"):
class ResControl(BaseModel):
images: List[str] = Field(default=None, title="Images", description="")
processed: List[str] = Field(default=None, title="Processed", description="")
images: list[str] = Field(default=None, title="Images", description="")
processed: list[str] = Field(default=None, title="Processed", description="")
params: dict = Field(default={}, title="Settings", description="")
info: str = Field(default="", title="Info", description="")
class APIControl():
class APIControl:
def __init__(self, queue_lock: Lock):
self.queue_lock = queue_lock
self.default_script_arg = []

View File

@ -1,4 +1,3 @@
from typing import Optional
from modules import shared
from modules.api import models, helpers
@ -43,7 +42,7 @@ def get_sd_models():
checkpoints.append(model)
return checkpoints
def get_controlnets(model_type: Optional[str] = None):
def get_controlnets(model_type: str | None = None):
from modules.control.units.controlnet import api_list_models
return api_list_models(model_type)
@ -60,7 +59,7 @@ def get_embeddings():
return models.ResEmbeddings(loaded=[], skipped=[])
return models.ResEmbeddings(loaded=list(db.word_embeddings.keys()), skipped=list(db.skipped_embeddings.keys()))
def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, filename: Optional[str] = None, title: Optional[str] = None, fullname: Optional[str] = None, hash: Optional[str] = None): # pylint: disable=redefined-builtin
def get_extra_networks(page: str | None = None, name: str | None = None, filename: str | None = None, title: str | None = None, fullname: str | None = None, hash: str | None = None): # pylint: disable=redefined-builtin
res = []
for pg in shared.extra_networks:
if page is not None and pg.name != page.lower():

View File

@ -2,7 +2,6 @@ import io
import os
import time
import base64
from typing import List, Union
from urllib.parse import quote, unquote
from fastapi import FastAPI
from fastapi.responses import JSONResponse
@ -52,7 +51,7 @@ class ConnectionManager:
debug(f'Browser WS disconnect: client={ws.client.host}')
self.active.remove(ws)
async def send(self, ws: WebSocket, data: Union[str, dict, bytes]):
async def send(self, ws: WebSocket, data: str | dict | bytes):
# debug(f'Browser WS send: client={ws.client.host} data={type(data)}')
if ws.client_state != WebSocketState.CONNECTED:
return
@ -65,7 +64,7 @@ class ConnectionManager:
else:
debug(f'Browser WS send: client={ws.client.host} data={type(data)} unknown')
async def broadcast(self, data: Union[str, dict, bytes]):
async def broadcast(self, data: str | dict | bytes):
for ws in self.active:
await self.send(ws, data)
@ -206,7 +205,7 @@ def register_api(app: FastAPI): # register api
shared.log.error(f'Gallery: {folder} {e}')
return []
shared.api.add_api_route("/sdapi/v1/browser/folders", get_folders, methods=["GET"], response_model=List[str])
shared.api.add_api_route("/sdapi/v1/browser/folders", get_folders, methods=["GET"], response_model=list[str])
shared.api.add_api_route("/sdapi/v1/browser/thumb", get_thumb, methods=["GET"], response_model=dict)
shared.api.add_api_route("/sdapi/v1/browser/files", ht_files, methods=["GET"], response_model=list)

View File

@ -9,7 +9,7 @@ from modules.paths import resolve_output_path
errors.install()
class APIGenerate():
class APIGenerate:
def __init__(self, queue_lock: Lock):
self.queue_lock = queue_lock
self.default_script_arg_txt2img = []

View File

@ -1,4 +1,3 @@
from typing import List
from fastapi.exceptions import HTTPException
@ -25,5 +24,5 @@ def post_refresh_loras():
def register_api():
from modules.shared import api
api.add_api_route("/sdapi/v1/lora", get_lora, methods=["GET"], response_model=dict)
api.add_api_route("/sdapi/v1/loras", get_loras, methods=["GET"], response_model=List[dict])
api.add_api_route("/sdapi/v1/loras", get_loras, methods=["GET"], response_model=list[dict])
api.add_api_route("/sdapi/v1/refresh-loras", post_refresh_loras, methods=["POST"])

View File

@ -1,7 +1,15 @@
import re
import inspect
from typing import Any, Optional, Dict, List, Type, Callable, Union
from pydantic import BaseModel, Field, create_model # pylint: disable=no-name-in-module
from typing import Any, Optional, Union
from collections.abc import Callable
import pydantic
from pydantic import BaseModel, Field, create_model
try:
from pydantic import ConfigDict
PYDANTIC_V2 = True
except ImportError:
ConfigDict = None
PYDANTIC_V2 = False
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
import modules.shared as shared
@ -41,8 +49,10 @@ class PydanticModelGenerator:
model_name: str = None,
class_instance = None,
additional_fields = None,
exclude_fields: List = [],
exclude_fields: list = None,
):
if exclude_fields is None:
exclude_fields = []
def field_type_generator(_k, v):
field_type = v.annotation
return Optional[field_type]
@ -80,12 +90,15 @@ class PydanticModelGenerator:
def generate_model(self):
model_fields = { d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def }
DynamicModel = create_model(self._model_name, **model_fields)
try:
DynamicModel.__config__.allow_population_by_field_name = True
DynamicModel.__config__.allow_mutation = True
except Exception:
pass
if PYDANTIC_V2:
config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True)
else:
class Config:
arbitrary_types_allowed = True
orm_mode = True
allow_population_by_field_name = True
config = Config
DynamicModel = create_model(self._model_name, __config__=config, **model_fields)
return DynamicModel
### item classes
@ -100,49 +113,49 @@ class ItemVae(BaseModel):
class ItemUpscaler(BaseModel):
name: str = Field(title="Name")
model_name: Optional[str] = Field(title="Model Name")
model_path: Optional[str] = Field(title="Path")
model_url: Optional[str] = Field(title="URL")
scale: Optional[float] = Field(title="Scale")
model_name: str | None = Field(title="Model Name")
model_path: str | None = Field(title="Path")
model_url: str | None = Field(title="URL")
scale: float | None = Field(title="Scale")
class ItemModel(BaseModel):
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
filename: str = Field(title="Filename")
type: str = Field(title="Model type")
sha256: Optional[str] = Field(title="SHA256 hash")
hash: Optional[str] = Field(title="Short hash")
config: Optional[str] = Field(title="Config file")
sha256: str | None = Field(title="SHA256 hash")
hash: str | None = Field(title="Short hash")
config: str | None = Field(title="Config file")
class ItemHypernetwork(BaseModel):
name: str = Field(title="Name")
path: Optional[str] = Field(title="Path")
path: str | None = Field(title="Path")
class ItemDetailer(BaseModel):
name: str = Field(title="Name")
path: Optional[str] = Field(title="Path")
path: str | None = Field(title="Path")
class ItemGAN(BaseModel):
name: str = Field(title="Name")
path: Optional[str] = Field(title="Path")
scale: Optional[int] = Field(title="Scale")
path: str | None = Field(title="Path")
scale: int | None = Field(title="Scale")
class ItemStyle(BaseModel):
name: str = Field(title="Name")
prompt: Optional[str] = Field(title="Prompt")
negative_prompt: Optional[str] = Field(title="Negative Prompt")
extra: Optional[str] = Field(title="Extra")
filename: Optional[str] = Field(title="Filename")
preview: Optional[str] = Field(title="Preview")
prompt: str | None = Field(title="Prompt")
negative_prompt: str | None = Field(title="Negative Prompt")
extra: str | None = Field(title="Extra")
filename: str | None = Field(title="Filename")
preview: str | None = Field(title="Preview")
class ItemExtraNetwork(BaseModel):
name: str = Field(title="Name")
type: str = Field(title="Type")
title: Optional[str] = Field(title="Title")
fullname: Optional[str] = Field(title="Fullname")
filename: Optional[str] = Field(title="Filename")
hash: Optional[str] = Field(title="Hash")
preview: Optional[str] = Field(title="Preview image URL")
title: str | None = Field(title="Title")
fullname: str | None = Field(title="Fullname")
filename: str | None = Field(title="Filename")
hash: str | None = Field(title="Hash")
preview: str | None = Field(title="Preview image URL")
class ItemArtist(BaseModel):
name: str = Field(title="Name")
@ -150,16 +163,16 @@ class ItemArtist(BaseModel):
category: str = Field(title="Category")
class ItemEmbedding(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
step: int | None = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
sd_checkpoint: str | None = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
sd_checkpoint_name: str | None = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class ItemIPAdapter(BaseModel):
adapter: str = Field(title="Adapter", default="Base", description="IP adapter name")
images: List[str] = Field(title="Image", default=[], description="IP adapter input images")
masks: Optional[List[str]] = Field(title="Mask", default=[], description="IP adapter mask images")
images: list[str] = Field(title="Image", default=[], description="IP adapter input images")
masks: list[str] | None = Field(title="Mask", default=[], description="IP adapter mask images")
scale: float = Field(title="Scale", default=0.5, ge=0, le=1, description="IP adapter scale")
start: float = Field(title="Start", default=0.0, ge=0, le=1, description="IP adapter start step")
end: float = Field(title="End", default=1.0, gt=0, le=1, description="IP adapter end step")
@ -183,17 +196,17 @@ class ItemFace(BaseModel):
class ScriptArg(BaseModel):
label: str = Field(default=None, title="Label", description="Name of the argument in UI")
value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
choices: Optional[Any] = Field(default=None, title="Choices", description="Possible values for the argument")
value: Any | None = Field(default=None, title="Value", description="Default value of the argument")
minimum: Any | None = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
maximum: Any | None = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
step: Any | None = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
choices: Any | None = Field(default=None, title="Choices", description="Possible values for the argument")
class ItemScript(BaseModel):
name: str = Field(default=None, title="Name", description="Script name")
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
class ItemExtension(BaseModel):
name: str = Field(title="Name", description="Extension name")
@ -201,13 +214,13 @@ class ItemExtension(BaseModel):
branch: str = Field(default="uknnown", title="Branch", description="Extension Repository Branch")
commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash")
version: str = Field(title="Version", description="Extension Version")
commit_date: Union[str, int] = Field(title="Commit Date", description="Extension Repository Commit Date")
commit_date: str | int = Field(title="Commit Date", description="Extension Repository Commit Date")
enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
class ItemScheduler(BaseModel):
name: str = Field(title="Name", description="Scheduler name")
cls: str = Field(title="Class", description="Scheduler class name")
options: Dict[str, Any] = Field(title="Options", description="Dictionary of scheduler options")
options: dict[str, Any] = Field(title="Options", description="Dictionary of scheduler options")
### request/response classes
@ -223,7 +236,7 @@ ReqTxt2Img = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "ip_adapter", "type": Optional[List[ItemIPAdapter]], "default": None, "exclude": True},
{"key": "ip_adapter", "type": Optional[list[ItemIPAdapter]], "default": None, "exclude": True},
{"key": "face", "type": Optional[ItemFace], "default": None, "exclude": True},
{"key": "extra", "type": Optional[dict], "default": {}, "exclude": True},
]
@ -233,7 +246,7 @@ if not hasattr(ReqTxt2Img, "__config__"):
StableDiffusionTxt2ImgProcessingAPI = ReqTxt2Img
class ResTxt2Img(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated images in base64 format.")
images: list[str] = Field(default=None, title="Image", description="The generated images in base64 format.")
parameters: dict
info: str
@ -253,7 +266,7 @@ ReqImg2Img = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "ip_adapter", "type": Optional[List[ItemIPAdapter]], "default": None, "exclude": True},
{"key": "ip_adapter", "type": Optional[list[ItemIPAdapter]], "default": None, "exclude": True},
{"key": "face_id", "type": Optional[ItemFace], "default": None, "exclude": True},
{"key": "extra", "type": Optional[dict], "default": {}, "exclude": True},
]
@ -263,7 +276,7 @@ if not hasattr(ReqImg2Img, "__config__"):
StableDiffusionImg2ImgProcessingAPI = ReqImg2Img
class ResImg2Img(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated images in base64 format.")
images: list[str] = Field(default=None, title="Image", description="The generated images in base64 format.")
parameters: dict
info: str
@ -289,9 +302,9 @@ class ResProcess(BaseModel):
class ReqPromptEnhance(BaseModel):
prompt: str = Field(title="Prompt", description="Prompt to enhance")
type: str = Field(title="Type", default='text', description="Type of enhancement to perform")
model: Optional[str] = Field(title="Model", default=None, description="Model to use for enhancement")
system_prompt: Optional[str] = Field(title="System prompt", default=None, description="Model system prompt")
image: Optional[str] = Field(title="Image", default=None, description="Image to work on, must be a Base64 string containing the image's data.")
model: str | None = Field(title="Model", default=None, description="Model to use for enhancement")
system_prompt: str | None = Field(title="System prompt", default=None, description="Model system prompt")
image: str | None = Field(title="Image", default=None, description="Image to work on, must be a Base64 string containing the image's data.")
seed: int = Field(title="Seed", default=-1, description="Seed used to generate the prompt")
nsfw: bool = Field(title="NSFW", default=True, description="Should NSFW content be allowed?")
@ -306,10 +319,10 @@ class ResProcessImage(ResProcess):
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
class ReqProcessBatch(ReqProcess):
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ResProcessBatch(ResProcess):
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
class ReqImageInfo(BaseModel):
image: str = Field(title="Image", description="The base64 encoded image")
@ -325,38 +338,38 @@ class ReqGetLog(BaseModel):
class ReqPostLog(BaseModel):
message: Optional[str] = Field(default=None, title="Message", description="The info message to log")
debug: Optional[str] = Field(default=None, title="Debug message", description="The debug message to log")
error: Optional[str] = Field(default=None, title="Error message", description="The error message to log")
message: str | None = Field(default=None, title="Message", description="The info message to log")
debug: str | None = Field(default=None, title="Debug message", description="The debug message to log")
error: str | None = Field(default=None, title="Error message", description="The error message to log")
class ReqHistory(BaseModel):
id: Union[int, str, None] = Field(default=None, title="Task ID", description="Task ID")
id: int | str | None = Field(default=None, title="Task ID", description="Task ID")
class ReqProgress(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
class ResProgress(BaseModel):
id: Union[int, str, None] = Field(title="TaskID", description="Task ID")
id: int | str | None = Field(title="TaskID", description="Task ID")
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
current_image: Optional[str] = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
textinfo: Optional[str] = Field(default=None, title="Info text", description="Info text used by WebUI.")
current_image: str | None = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
textinfo: str | None = Field(default=None, title="Info text", description="Info text used by WebUI.")
class ResHistory(BaseModel):
id: Union[int, str, None] = Field(title="ID", description="Task ID")
id: int | str | None = Field(title="ID", description="Task ID")
job: str = Field(title="Job", description="Job name")
op: str = Field(title="Operation", description="Job state")
timestamp: Union[float, None] = Field(title="Timestamp", description="Job timestamp")
duration: Union[float, None] = Field(title="Duration", description="Job duration")
outputs: List[str] = Field(title="Outputs", description="List of filenames")
timestamp: float | None = Field(title="Timestamp", description="Job timestamp")
duration: float | None = Field(title="Duration", description="Job duration")
outputs: list[str] = Field(title="Outputs", description="List of filenames")
class ResStatus(BaseModel):
status: str = Field(title="Status", description="Current status")
task: str = Field(title="Task", description="Current job")
timestamp: Optional[str] = Field(title="Timestamp", description="Timestamp of the current job")
timestamp: str | None = Field(title="Timestamp", description="Timestamp of the current job")
current: str = Field(title="Task", description="Current job")
id: Union[int, str, None] = Field(title="ID", description="ID of the current task")
id: int | str | None = Field(title="ID", description="ID of the current task")
job: int = Field(title="Job", description="Current job")
jobs: int = Field(title="Jobs", description="Total jobs")
total: int = Field(title="Total Jobs", description="Total jobs")
@ -364,9 +377,9 @@ class ResStatus(BaseModel):
steps: int = Field(title="Steps", description="Total steps")
queued: int = Field(title="Queued", description="Number of queued tasks")
uptime: int = Field(title="Uptime", description="Uptime of the server")
elapsed: Optional[float] = Field(default=None, title="Elapsed time")
eta: Optional[float] = Field(default=None, title="ETA in secs")
progress: Optional[float] = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
elapsed: float | None = Field(default=None, title="Elapsed time")
eta: float | None = Field(default=None, title="ETA in secs")
progress: float | None = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
class ReqLatentHistory(BaseModel):
name: str = Field(title="Name", description="Name of the history item to select")
@ -392,7 +405,15 @@ for key, metadata in shared.opts.data_labels.items():
else:
fields.update({key: (Optional[optType], Field())})
OptionsModel = create_model("Options", **fields)
if PYDANTIC_V2:
config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True)
else:
class Config:
arbitrary_types_allowed = True
orm_mode = True
allow_population_by_field_name = True
config = Config
OptionsModel = create_model("Options", __config__=config, **fields)
flags = {}
_options = vars(shared.parser)['_option_string_actions']
@ -404,7 +425,15 @@ for key in _options:
_type = type(_options[key].default)
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
FlagsModel = create_model("Flags", **flags)
if PYDANTIC_V2:
config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True)
else:
class Config:
arbitrary_types_allowed = True
orm_mode = True
allow_population_by_field_name = True
config = Config
FlagsModel = create_model("Flags", __config__=config, **flags)
class ResEmbeddings(BaseModel):
loaded: list = Field(default=None, title="loaded", description="List of loaded embeddings")
@ -426,9 +455,13 @@ class ResGPU(BaseModel): # definition of http response
# helper function
def create_model_from_signature(func: Callable, model_name: str, base_model: Type[BaseModel] = BaseModel, additional_fields: List = [], exclude_fields: List[str] = []) -> type[BaseModel]:
def create_model_from_signature(func: Callable, model_name: str, base_model: type[BaseModel] = BaseModel, additional_fields: list = None, exclude_fields: list[str] = None) -> type[BaseModel]:
from PIL import Image
if exclude_fields is None:
exclude_fields = []
if additional_fields is None:
additional_fields = []
class Config:
extra = 'allow'
@ -443,13 +476,13 @@ def create_model_from_signature(func: Callable, model_name: str, base_model: Typ
defaults = (...,) * non_default_args + defaults
keyword_only_params = {param: kwonlydefaults.get(param, Any) for param in kwonlyargs}
for k, v in annotations.items():
if v == List[Image.Image]:
annotations[k] = List[str]
if v == list[Image.Image]:
annotations[k] = list[str]
elif v == Image.Image:
annotations[k] = str
elif str(v) == 'typing.List[modules.control.unit.Unit]':
annotations[k] = List[str]
model_fields = {param: (annotations.get(param, Any), default) for param, default in zip(args, defaults)}
annotations[k] = list[str]
model_fields = {param: (annotations.get(param, Any), default) for param, default in zip(args, defaults, strict=False)}
for fld in additional_fields:
model_def = ModelDef(
@ -464,16 +497,21 @@ def create_model_from_signature(func: Callable, model_name: str, base_model: Typ
if fld in model_fields:
del model_fields[fld]
if PYDANTIC_V2:
config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True, extra='allow' if varkw else 'ignore')
else:
class Config:
arbitrary_types_allowed = True
orm_mode = True
allow_population_by_field_name = True
extra = 'allow' if varkw else 'ignore'
config = Config
model = create_model(
model_name,
**model_fields,
**keyword_only_params,
__base__=base_model,
__config__=config,
**model_fields,
**keyword_only_params,
)
try:
model.__config__.allow_population_by_field_name = True
model.__config__.allow_mutation = True
except Exception:
pass
return model

View File

@ -1,4 +1,3 @@
from typing import Optional, List
from threading import Lock
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
from fastapi.responses import JSONResponse
@ -15,7 +14,7 @@ errors.install()
class ReqPreprocess(BaseModel):
image: str = Field(title="Image", description="The base64 encoded image")
model: str = Field(title="Model", description="The model to use for preprocessing")
params: Optional[dict] = Field(default={}, title="Settings", description="Preprocessor settings")
params: dict | None = Field(default={}, title="Settings", description="Preprocessor settings")
class ResPreprocess(BaseModel):
model: str = Field(default='', title="Model", description="The processor model used")
@ -24,20 +23,20 @@ class ResPreprocess(BaseModel):
class ReqMask(BaseModel):
image: str = Field(title="Image", description="The base64 encoded image")
type: str = Field(title="Mask type", description="Type of masking image to return")
mask: Optional[str] = Field(title="Mask", description="If optional maks image is not provided auto-masking will be performed")
model: Optional[str] = Field(title="Model", description="The model to use for preprocessing")
params: Optional[dict] = Field(default={}, title="Settings", description="Preprocessor settings")
mask: str | None = Field(title="Mask", description="If optional maks image is not provided auto-masking will be performed")
model: str | None = Field(title="Model", description="The model to use for preprocessing")
params: dict | None = Field(default={}, title="Settings", description="Preprocessor settings")
class ReqFace(BaseModel):
image: str = Field(title="Image", description="The base64 encoded image")
model: Optional[str] = Field(title="Model", description="The model to use for detection")
model: str | None = Field(title="Model", description="The model to use for detection")
class ResFace(BaseModel):
classes: List[int] = Field(title="Class", description="The class of detected item")
labels: List[str] = Field(title="Label", description="The label of detected item")
boxes: List[List[int]] = Field(title="Box", description="The bounding box of detected item")
images: List[str] = Field(title="Image", description="The base64 encoded images of detected faces")
scores: List[float] = Field(title="Scores", description="The scores of the detected faces")
classes: list[int] = Field(title="Class", description="The class of detected item")
labels: list[str] = Field(title="Label", description="The label of detected item")
boxes: list[list[int]] = Field(title="Box", description="The bounding box of detected item")
images: list[str] = Field(title="Image", description="The base64 encoded images of detected faces")
scores: list[float] = Field(title="Scores", description="The scores of the detected faces")
class ResMask(BaseModel):
mask: str = Field(default='', title="Image", description="The processed image in base64 format")
@ -47,13 +46,13 @@ class ItemPreprocess(BaseModel):
params: dict = Field(title="Params")
class ItemMask(BaseModel):
models: List[str] = Field(title="Models")
colormaps: List[str] = Field(title="Color maps")
models: list[str] = Field(title="Models")
colormaps: list[str] = Field(title="Color maps")
params: dict = Field(title="Params")
types: List[str] = Field(title="Types")
types: list[str] = Field(title="Types")
class APIProcess():
class APIProcess:
def __init__(self, queue_lock: Lock):
self.queue_lock = queue_lock

View File

@ -1,4 +1,3 @@
from typing import Optional
from fastapi.exceptions import HTTPException
import gradio as gr
from modules.api import models
@ -36,7 +35,7 @@ def get_scripts_list():
return models.ResScripts(txt2img = t2ilist, img2img = i2ilist, control = control)
def get_script_info(script_name: Optional[str] = None):
def get_script_info(script_name: str | None = None):
res = []
for script_list in [scripts_manager.scripts_txt2img.scripts, scripts_manager.scripts_img2img.scripts, scripts_manager.scripts_control.scripts]:
for script in script_list:

View File

@ -1,7 +1,6 @@
from typing import List
def xyz_grid_enum(option: str = "") -> List[dict]:
def xyz_grid_enum(option: str = "") -> list[dict]:
from scripts.xyz import xyz_grid_classes # pylint: disable=no-name-in-module
options = []
for x in xyz_grid_classes.axis_options:
@ -23,4 +22,4 @@ def xyz_grid_enum(option: str = "") -> List[dict]:
def register_api():
from modules.shared import api as api_instance
api_instance.add_api_route("/sdapi/v1/xyz-grid", xyz_grid_enum, methods=["GET"], response_model=List[dict])
api_instance.add_api_route("/sdapi/v1/xyz-grid", xyz_grid_enum, methods=["GET"], response_model=list[dict])

View File

@ -1,4 +1,3 @@
from typing import Optional
from functools import wraps
import torch
from modules import rocm
@ -23,7 +22,7 @@ def set_triton_flash_attention(backend: str):
from modules.flash_attn_triton_amd import interface_fa
sdpa_pre_triton_flash_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_triton_flash_atten)
def sdpa_triton_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
def sdpa_triton_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
if scale is None:
scale = query.shape[-1] ** (-0.5)
@ -56,7 +55,7 @@ def set_flex_attention():
sdpa_pre_flex_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_flex_atten)
def sdpa_flex_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: # pylint: disable=unused-argument
def sdpa_flex_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: # pylint: disable=unused-argument
score_mod = None
block_mask = None
if attn_mask is not None:
@ -96,7 +95,7 @@ def set_ck_flash_attention(backend: str, device: torch.device):
from flash_attn import flash_attn_func
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_flash_atten)
def sdpa_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
def sdpa_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
is_unsqueezed = False
if query.dim() == 3:
@ -162,7 +161,7 @@ def set_sage_attention(backend: str, device: torch.device):
sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_sage_atten)
def sdpa_sage_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
def sdpa_sage_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
if (query.shape[-1] in {128, 96, 64}) and (attn_mask is None) and (query.dtype != torch.float32):
if enable_gqa:
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)

View File

@ -373,7 +373,7 @@ class BasicLayer(nn.Module):
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn_mask = attn_mask.masked_fill(attn_mask != 0, (-100.0)).masked_fill(attn_mask == 0, 0.0)
for blk in self.blocks:
blk.H, blk.W = H, W
@ -464,8 +464,8 @@ class SwinTransformer(nn.Module):
patch_size=4,
in_chans=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
depths=None,
num_heads=None,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
@ -479,6 +479,10 @@ class SwinTransformer(nn.Module):
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
use_checkpoint=False):
if num_heads is None:
num_heads = [3, 6, 12, 24]
if depths is None:
depths = [2, 2, 6, 2]
super().__init__()
self.pretrain_img_size = pretrain_img_size
@ -668,8 +672,10 @@ class PositionEmbeddingSine:
class MCLM(nn.Module):
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
super(MCLM, self).__init__()
def __init__(self, d_model, num_heads, pool_ratios=None):
if pool_ratios is None:
pool_ratios = [1, 4, 8]
super().__init__()
self.attention = nn.ModuleList([
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
@ -739,7 +745,7 @@ class MCLM(nn.Module):
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
_g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
outputs_re = []
for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1), strict=False)):
outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
@ -760,8 +766,10 @@ class MCLM(nn.Module):
class MCRM(nn.Module):
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None): # pylint: disable=unused-argument
super(MCRM, self).__init__()
def __init__(self, d_model, num_heads, pool_ratios=None, h=None): # pylint: disable=unused-argument
if pool_ratios is None:
pool_ratios = [4, 8, 16]
super().__init__()
self.attention = nn.ModuleList([
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
@ -1049,7 +1057,7 @@ class BEN_Base(nn.Module):
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Cannot open video: {video_path}")
raise OSError(f"Cannot open video: {video_path}")
original_fps = cap.get(cv2.CAP_PROPFPS)
original_fps = 30 if original_fps == 0 else original_fps
@ -1225,7 +1233,7 @@ def add_audio_to_video(video_without_audio_path, original_video_path, output_pat
'-of', 'csv=p=0',
original_video_path
]
result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False)
result = subprocess.run(probe_command, capture_output=True, text=True, check=False)
# result.stdout is empty if no audio stream found
if not result.stdout.strip():

View File

@ -4,7 +4,7 @@ import threading
import torch
import numpy as np
from PIL import Image
from modules import modelloader, devices, shared
from modules import modelloader, devices, shared, paths
re_special = re.compile(r'([\\()])')
load_lock = threading.Lock()
@ -18,7 +18,7 @@ class DeepDanbooru:
with load_lock:
if self.model is not None:
return
model_path = os.path.join(shared.models_path, "DeepDanbooru")
model_path = os.path.join(paths.models_path, "DeepDanbooru")
shared.log.debug(f'Caption load: module=DeepDanbooru folder="{model_path}"')
files = modelloader.load_models(
model_path=model_path,
@ -96,7 +96,7 @@ class DeepDanbooru:
x = torch.from_numpy(a).to(device=devices.device, dtype=devices.dtype)
y = self.model(x)[0].detach().float().cpu().numpy()
probability_dict = {}
for current, probability in zip(self.model.tags, y):
for current, probability in zip(self.model.tags, y, strict=False):
if probability < general_threshold:
continue
if current.startswith("rating:") and not include_rating:

View File

@ -671,4 +671,4 @@ class DeepDanbooruModel(nn.Module):
def load_state_dict(self, state_dict, **kwargs): # pylint: disable=arguments-differ,unused-argument
self.tags = state_dict.get('tags', [])
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'}) # pylint: disable=R1725
super().load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'}) # pylint: disable=R1725

View File

@ -21,7 +21,7 @@ vl_chat_processor = None
loaded_repo = None
class fake_attrdict():
class fake_attrdict:
class AttrDict(dict): # dot notation access to dictionary attributes
__getattr__ = dict.get
__setattr__ = dict.__setitem__

View File

@ -39,7 +39,7 @@ Extra Options:
"""
@dataclass
class JoyOptions():
class JoyOptions:
repo: str = "fancyfeast/llama-joycaption-alpha-two-hf-llava"
temp: float = 0.5
top_k: float = 10

View File

@ -6,7 +6,6 @@ import os
import math
import json
from pathlib import Path
from typing import Optional
from PIL import Image
import torch
import torch.backends.cuda
@ -126,7 +125,7 @@ class VisionModel(nn.Module):
@staticmethod
def load_model(path: str) -> 'VisionModel':
with open(Path(path) / 'config.json', 'r', encoding='utf8') as f:
with open(Path(path) / 'config.json', encoding='utf8') as f:
config = json.load(f)
from safetensors.torch import load_file
resume = load_file(Path(path) / 'model.safetensors', device='cpu')
@ -244,7 +243,7 @@ class CLIPMlp(nn.Module):
class FastCLIPAttention2(nn.Module):
"""Fast Attention module for CLIP-like. This is NOT a drop-in replacement for CLIPAttention, since it adds additional flexibility. Mainly uses xformers."""
def __init__(self, hidden_size: int, out_dim: int, num_attention_heads: int, out_seq_len: Optional[int] = None, norm_qk: bool = False):
def __init__(self, hidden_size: int, out_dim: int, num_attention_heads: int, out_seq_len: int | None = None, norm_qk: bool = False):
super().__init__()
self.out_seq_len = out_seq_len
self.embed_dim = hidden_size
@ -308,12 +307,12 @@ class FastCLIPEncoderLayer(nn.Module):
self,
hidden_size: int,
num_attention_heads: int,
out_seq_len: Optional[int],
out_seq_len: int | None,
activation_cls = QuickGELUActivation,
use_palm_alt: bool = False,
norm_qk: bool = False,
skip_init: Optional[float] = None,
stochastic_depth: Optional[float] = None,
skip_init: float | None = None,
stochastic_depth: float | None = None,
):
super().__init__()
self.use_palm_alt = use_palm_alt
@ -523,8 +522,8 @@ class CLIPLikeModel(VisionModel):
norm_qk: bool = False,
no_wd_bias: bool = False,
use_gap_head: bool = False,
skip_init: Optional[float] = None,
stochastic_depth: Optional[float] = None,
skip_init: float | None = None,
stochastic_depth: float | None = None,
):
super().__init__(image_size, n_tags)
out_dim = n_tags
@ -939,7 +938,7 @@ class ViT(VisionModel):
stochdepth_rate: float,
use_sine: bool,
loss_type: str,
layerscale_init: Optional[float] = None,
layerscale_init: float | None = None,
head_mean_after: bool = False,
cnn_stem: str = None,
patch_dropout: float = 0.0,
@ -1048,7 +1047,7 @@ def load():
model = VisionModel.load_model(folder)
model.to(dtype=devices.dtype)
model.eval() # required: custom loader, not from_pretrained
with open(os.path.join(folder, 'top_tags.txt'), 'r', encoding='utf8') as f:
with open(os.path.join(folder, 'top_tags.txt'), encoding='utf8') as f:
tags = [line.strip() for line in f.readlines() if line.strip()]
shared.log.info(f'Caption: type=vlm model="JoyTag" repo="{MODEL_REPO}" tags={len(tags)}')
sd_models.move_model(model, devices.device)

View File

@ -330,11 +330,11 @@ def analyze_image(image, clip_model, blip_model):
top_movements = ci.movements.rank(image_features, 5)
top_trendings = ci.trendings.rank(image_features, 5)
top_flavors = ci.flavors.rank(image_features, 5)
medium_ranks = dict(sorted(zip(top_mediums, ci.similarities(image_features, top_mediums)), key=lambda x: x[1], reverse=True))
artist_ranks = dict(sorted(zip(top_artists, ci.similarities(image_features, top_artists)), key=lambda x: x[1], reverse=True))
movement_ranks = dict(sorted(zip(top_movements, ci.similarities(image_features, top_movements)), key=lambda x: x[1], reverse=True))
trending_ranks = dict(sorted(zip(top_trendings, ci.similarities(image_features, top_trendings)), key=lambda x: x[1], reverse=True))
flavor_ranks = dict(sorted(zip(top_flavors, ci.similarities(image_features, top_flavors)), key=lambda x: x[1], reverse=True))
medium_ranks = dict(sorted(zip(top_mediums, ci.similarities(image_features, top_mediums), strict=False), key=lambda x: x[1], reverse=True))
artist_ranks = dict(sorted(zip(top_artists, ci.similarities(image_features, top_artists), strict=False), key=lambda x: x[1], reverse=True))
movement_ranks = dict(sorted(zip(top_movements, ci.similarities(image_features, top_movements), strict=False), key=lambda x: x[1], reverse=True))
trending_ranks = dict(sorted(zip(top_trendings, ci.similarities(image_features, top_trendings), strict=False), key=lambda x: x[1], reverse=True))
flavor_ranks = dict(sorted(zip(top_flavors, ci.similarities(image_features, top_flavors), strict=False), key=lambda x: x[1], reverse=True))
shared.log.debug(f'CLIP analyze: complete time={time.time()-t0:.2f}')
# Format labels as text

View File

@ -708,7 +708,7 @@ class VQA:
debug(f'VQA caption: handler=qwen output_ids_shape={output_ids.shape}')
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
for input_ids, output_ids in zip(inputs.input_ids, output_ids, strict=False)
]
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if debug_enabled:
@ -887,7 +887,7 @@ class VQA:
def _ovis(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
try:
import flash_attn # pylint: disable=unused-import
pass # pylint: disable=unused-import
except Exception:
shared.log.error(f'Caption: vlm="{repo}" flash-attn is not available')
return ''

View File

@ -126,7 +126,7 @@ class WaifuDiffusionTagger:
self.tags = []
self.tag_categories = []
with open(csv_path, 'r', encoding='utf-8') as f:
with open(csv_path, encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
self.tags.append(row['name'])
@ -269,7 +269,7 @@ class WaifuDiffusionTagger:
character_count = 0
rating_count = 0
for i, (tag_name, prob) in enumerate(zip(self.tags, probs)):
for i, (tag_name, prob) in enumerate(zip(self.tags, probs, strict=False)):
category = self.tag_categories[i]
tag_lower = tag_name.lower()

View File

@ -10,7 +10,9 @@ selected_model = None
class CivitModel:
def __init__(self, name, fn, sha = None, meta = {}):
def __init__(self, name, fn, sha = None, meta = None):
if meta is None:
meta = {}
self.name = name
self.file = name
self.id = meta.get('id', 0)

View File

@ -10,7 +10,7 @@ full_html = False
base_models = ['', 'AuraFlow', 'Chroma', 'CogVideoX', 'Flux.1 S', 'Flux.1 D', 'Flux.1 Krea', 'Flux.1 Kontext', 'Flux.2 D', 'HiDream', 'Hunyuan 1', 'Hunyuan Video', 'Illustrious', 'Kolors', 'LTXV', 'Lumina', 'Mochi', 'NoobAI', 'PixArt a', 'PixArt E', 'Pony', 'Pony V7', 'Qwen', 'SD 1.4', 'SD 1.5', 'SD 1.5 LCM', 'SD 1.5 Hyper', 'SD 2.0', 'SD 2.1', 'SDXL 1.0', 'SDXL Lightning', 'SDXL Hyper', 'Wan Video 1.3B t2v', 'Wan Video 14B t2v', 'Wan Video 14B i2v 480p', 'Wan Video 14B i2v 720p', 'Wan Video 2.2 TI2V-5B', 'Wan Video 2.2 I2V-A14B', 'Wan Video 2.2 T2V-A14B', 'Wan Video 2.5 T2V', 'Wan Video 2.5 I2V', 'ZImageTurbo', 'Other']
@dataclass
class ModelImage():
class ModelImage:
def __init__(self, dct: dict):
if isinstance(dct, str):
dct = json.loads(dct)
@ -26,7 +26,7 @@ class ModelImage():
@dataclass
class ModelFile():
class ModelFile:
def __init__(self, dct: dict):
if isinstance(dct, str):
dct = json.loads(dct)
@ -43,7 +43,7 @@ class ModelFile():
@dataclass
class ModelVersion():
class ModelVersion:
def __init__(self, dct: dict):
import bs4
if isinstance(dct, str):
@ -65,7 +65,7 @@ class ModelVersion():
@dataclass
class Model():
class Model:
def __init__(self, dct: dict):
import bs4
if isinstance(dct, str):

View File

@ -129,7 +129,7 @@ def main_args():
def compatibility_args():
# removed args are added here as hidden in fixed format for compatbility reasons
from modules.paths import data_path, models_path
from modules.paths import data_path
group_compat = parser.add_argument_group('Compatibility options')
group_compat.add_argument('--backend', type=str, choices=['diffusers', 'original'], help=argparse.SUPPRESS)
group_compat.add_argument("--allow-code", default=os.environ.get("SD_ALLOWCODE", False), action='store_true', help=argparse.SUPPRESS)

View File

@ -46,11 +46,17 @@ def preprocess_image(
input_mask:Image.Image = None,
input_type:str = 0,
unit_type:str = 'controlnet',
active_process:list = [],
active_model:list = [],
selected_models:list = [],
active_process:list = None,
active_model:list = None,
selected_models:list = None,
has_models:bool = False,
):
if selected_models is None:
selected_models = []
if active_model is None:
active_model = []
if active_process is None:
active_process = []
t0 = time.time()
jobid = shared.state.begin('Preprocess')

View File

@ -161,7 +161,7 @@ def update_settings(*settings):
update(['Depth Pro', 'params', 'color_map'], settings[28])
class Processor():
class Processor:
def __init__(self, processor_id: str = None, resize = True):
self.model = None
self.processor_id = None
@ -268,7 +268,9 @@ class Processor():
display(e, 'Control Processor load')
return f'Processor load filed: {processor_id}'
def __call__(self, image_input: Image, mode: str = 'RGB', width: int = 0, height: int = 0, resize_mode: int = 0, resize_name: str = 'None', scale_tab: int = 1, scale_by: float = 1.0, local_config: dict = {}):
def __call__(self, image_input: Image, mode: str = 'RGB', width: int = 0, height: int = 0, resize_mode: int = 0, resize_name: str = 'None', scale_tab: int = 1, scale_by: float = 1.0, local_config: dict = None):
if local_config is None:
local_config = {}
if self.override is not None:
debug(f'Control Processor: id="{self.processor_id}" override={self.override}')
width = image_input.width if image_input is not None else width

View File

@ -1,6 +1,5 @@
import os
import sys
from typing import List, Union
import cv2
from PIL import Image
from modules.control import util # helper functions
@ -141,12 +140,12 @@ def set_pipe(p, has_models, unit_type, selected_models, active_model, active_str
def check_active(p, unit_type, units):
active_process: List[processors.Processor] = [] # all active preprocessors
active_model: List[Union[controlnet.ControlNet, xs.ControlNetXS, t2iadapter.Adapter]] = [] # all active models
active_strength: List[float] = [] # strength factors for all active models
active_start: List[float] = [] # start step for all active models
active_end: List[float] = [] # end step for all active models
active_units: List[unit.Unit] = [] # all active units
active_process: list[processors.Processor] = [] # all active preprocessors
active_model: list[controlnet.ControlNet | xs.ControlNetXS | t2iadapter.Adapter] = [] # all active models
active_strength: list[float] = [] # strength factors for all active models
active_start: list[float] = [] # start step for all active models
active_end: list[float] = [] # end step for all active models
active_units: list[unit.Unit] = [] # all active units
num_units = 0
for u in units:
if u.type != unit_type:
@ -218,7 +217,7 @@ def check_active(p, unit_type, units):
def check_enabled(p, unit_type, units, active_model, active_strength, active_start, active_end):
has_models = False
selected_models: List[Union[controlnet.ControlNetModel, xs.ControlNetXSModel, t2iadapter.AdapterModel]] = None
selected_models: list[controlnet.ControlNetModel | xs.ControlNetXSModel | t2iadapter.AdapterModel] = None
control_conditioning = None
control_guidance_start = None
control_guidance_end = None
@ -254,7 +253,7 @@ def control_set(kwargs):
p_extra_args[k] = v
def init_units(units: List[unit.Unit]):
def init_units(units: list[unit.Unit]):
for u in units:
if not u.enabled:
continue
@ -271,9 +270,9 @@ def init_units(units: List[unit.Unit]):
def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg
units: List[unit.Unit] = [], inputs: List[Image.Image] = [], inits: List[Image.Image] = [], mask: Image.Image = None, unit_type: str = None, is_generator: bool = True,
units: list[unit.Unit] = None, inputs: list[Image.Image] = None, inits: list[Image.Image] = None, mask: Image.Image = None, unit_type: str = None, is_generator: bool = True,
input_type: int = 0,
prompt: str = '', negative_prompt: str = '', styles: List[str] = [],
prompt: str = '', negative_prompt: str = '', styles: list[str] = None,
steps: int = 20, sampler_index: int = None,
seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1,
guidance_name: str = 'Default', guidance_scale: float = 6.0, guidance_rescale: float = 0.0, guidance_start: float = 0.0, guidance_stop: float = 1.0,
@ -289,11 +288,23 @@ def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg
enable_hr: bool = False, hr_sampler_index: int = None, hr_denoising_strength: float = 0.0, hr_resize_mode: int = 0, hr_resize_context: str = 'None', hr_upscaler: str = None, hr_force: bool = False, hr_second_pass_steps: int = 20,
hr_scale: float = 1.0, hr_resize_x: int = 0, hr_resize_y: int = 0, refiner_steps: int = 5, refiner_start: float = 0.0, refiner_prompt: str = '', refiner_negative: str = '',
video_skip_frames: int = 0, video_type: str = 'None', video_duration: float = 2.0, video_loop: bool = False, video_pad: int = 0, video_interpolate: int = 0,
extra: dict = {},
extra: dict = None,
override_script_name: str = None,
override_script_args = [],
override_script_args = None,
*input_script_args,
):
if override_script_args is None:
override_script_args = []
if extra is None:
extra = {}
if styles is None:
styles = []
if inits is None:
inits = []
if inputs is None:
inputs = []
if units is None:
units = []
global pipe, original_pipeline # pylint: disable=global-statement
if 'refine' in state:
enable_hr = True
@ -303,7 +314,7 @@ def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg
init_units(units)
if inputs is None or (type(inputs) is list and len(inputs) == 0):
inputs = [None]
output_images: List[Image.Image] = [] # output images
output_images: list[Image.Image] = [] # output images
processed_image: Image.Image = None # last processed image
if mask is not None and input_type == 0:
input_type = 1 # inpaint always requires control_image

View File

@ -1,4 +1,3 @@
from typing import Union
from PIL import Image
import gradio as gr
from installer import log
@ -7,7 +6,6 @@ from modules.control.units import controlnet
from modules.control.units import xs
from modules.control.units import lite
from modules.control.units import t2iadapter
from modules.control.units import reference # pylint: disable=unused-import
default_device = None
@ -16,7 +14,7 @@ unit_types = ['t2i adapter', 'controlnet', 'xs', 'lite', 'reference', 'ip']
current = []
class Unit(): # mashup of gradio controls and mapping to actual implementation classes
class Unit: # mashup of gradio controls and mapping to actual implementation classes
def update_choices(self, model_id=None):
name = model_id or self.model_name
if name == 'InstantX Union F1':
@ -57,8 +55,10 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c
control_mode = None,
control_tile = None,
result_txt = None,
extra_controls: list = [],
extra_controls: list = None,
):
if extra_controls is None:
extra_controls = []
self.model_id = model_id
self.process_id = process_id
self.controls = [gr.Label(value=unit_type, visible=False)] # separator
@ -77,7 +77,7 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c
self.process_name = None
self.process: processors.Processor = processors.Processor()
self.adapter: t2iadapter.Adapter = None
self.controlnet: Union[controlnet.ControlNet, xs.ControlNetXS] = None
self.controlnet: controlnet.ControlNet | xs.ControlNetXS = None
# map to input image
self.override: Image = None
# global settings but passed per-unit

View File

@ -104,7 +104,7 @@ def get_gpu_info():
elif torch.cuda.is_available() and torch.version.cuda:
try:
import subprocess
result = subprocess.run('nvidia-smi --query-gpu=driver_version --format=csv,noheader', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
result = subprocess.run('nvidia-smi --query-gpu=driver_version --format=csv,noheader', shell=True, check=False, env=os.environ, capture_output=True)
version = result.stdout.decode(encoding="utf8", errors="ignore").strip()
return version
except Exception:
@ -307,7 +307,7 @@ def set_cuda_tunable():
lines={0}
try:
if os.path.exists(fn):
with open(fn, 'r', encoding='utf8') as f:
with open(fn, encoding='utf8') as f:
lines = sum(1 for _line in f)
except Exception:
pass

View File

@ -1,7 +1,6 @@
from typing import Optional
import torch
class Generator(torch.Generator):
def __init__(self, device: Optional[torch.device] = None):
def __init__(self, device: torch.device | None = None):
super().__init__("cpu")

View File

@ -1,5 +1,6 @@
import platform
from typing import NamedTuple, Callable, Optional
from typing import NamedTuple, Optional
from collections.abc import Callable
import torch
from modules.errors import log
from modules.sd_hijack_utils import CondFunc
@ -86,8 +87,8 @@ def directml_do_hijack():
class OverrideItem(NamedTuple):
value: str
condition: Optional[Callable]
message: Optional[str]
condition: Callable | None
message: str | None
opts_override_table = {

View File

@ -1,5 +1,5 @@
import importlib
from typing import Any, Optional
from typing import Any
import torch
@ -52,7 +52,7 @@ class autocast:
fast_dtype: torch.dtype = torch.float16
prev_fast_dtype: torch.dtype
def __init__(self, dtype: Optional[torch.dtype] = torch.float16):
def __init__(self, dtype: torch.dtype | None = torch.float16):
self.fast_dtype = dtype
def __enter__(self):

View File

@ -1,5 +1,5 @@
# pylint: disable=no-member,no-self-argument,no-method-argument
from typing import Optional, Callable
from collections.abc import Callable
import torch
import torch_directml # pylint: disable=import-error
import modules.dml.amp as amp
@ -9,17 +9,17 @@ from .Generator import Generator
from .device_properties import DeviceProperties
def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
def amd_mem_get_info(device: rDevice | None=None) -> tuple[int, int]:
from .memory_amd import AMDMemoryProvider
return AMDMemoryProvider.mem_get_info(get_device(device).index)
def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
def pdh_mem_get_info(device: rDevice | None=None) -> tuple[int, int]:
mem_info = DirectML.memory_provider.get_memory(get_device(device).index)
return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"])
def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument
def mem_get_info(device: rDevice | None=None) -> tuple[int, int]: # pylint: disable=unused-argument
return (8589934592, 8589934592)
@ -28,7 +28,7 @@ class DirectML:
device = Device
Generator = Generator
context_device: Optional[torch.device] = None
context_device: torch.device | None = None
is_autocast_enabled = False
autocast_gpu_dtype = torch.float16
@ -41,7 +41,7 @@ class DirectML:
def is_directml_device(device: torch.device) -> bool:
return device.type == "privateuseone"
def has_float64_support(device: Optional[rDevice]=None) -> bool:
def has_float64_support(device: rDevice | None=None) -> bool:
return torch_directml.has_float64_support(get_device(device).index)
def device_count() -> int:
@ -53,16 +53,16 @@ class DirectML:
def default_device() -> torch.device:
return torch_directml.device(torch_directml.default_device())
def get_device_string(device: Optional[rDevice]=None) -> str:
def get_device_string(device: rDevice | None=None) -> str:
return f"privateuseone:{get_device(device).index}"
def get_device_name(device: Optional[rDevice]=None) -> str:
def get_device_name(device: rDevice | None=None) -> str:
return torch_directml.device_name(get_device(device).index)
def get_device_properties(device: Optional[rDevice]=None) -> DeviceProperties:
def get_device_properties(device: rDevice | None=None) -> DeviceProperties:
return DeviceProperties(get_device(device))
def memory_stats(device: Optional[rDevice]=None):
def memory_stats(device: rDevice | None=None):
return {
"num_ooms": 0,
"num_alloc_retries": 0,
@ -70,11 +70,11 @@ class DirectML:
mem_get_info: Callable = mem_get_info
def memory_allocated(device: Optional[rDevice]=None) -> int:
def memory_allocated(device: rDevice | None=None) -> int:
return sum(torch_directml.gpu_memory(get_device(device).index)) * (1 << 20)
def max_memory_allocated(device: Optional[rDevice]=None):
def max_memory_allocated(device: rDevice | None=None):
return DirectML.memory_allocated(device) # DirectML does not empty GPU memory
def reset_peak_memory_stats(device: Optional[rDevice]=None):
def reset_peak_memory_stats(device: rDevice | None=None):
return

View File

@ -1,4 +1,3 @@
from typing import Optional
import torch
from .utils import rDevice, get_device
@ -6,11 +5,11 @@ from .utils import rDevice, get_device
class Device:
idx: int
def __enter__(self, device: Optional[rDevice]=None):
def __enter__(self, device: rDevice | None=None):
torch.dml.context_device = get_device(device)
self.idx = torch.dml.context_device.index
def __init__(self, device: Optional[rDevice]=None) -> torch.device: # pylint: disable=return-in-init
def __init__(self, device: rDevice | None=None) -> torch.device: # pylint: disable=return-in-init
self.idx = get_device(device).index
def __exit__(self, t, v, tb):

View File

@ -1,9 +1,8 @@
from typing import Type
import torch
from modules.dml.hijack.utils import catch_nan
def make_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
def make_tome_block(block_class: type[torch.nn.Module]) -> type[torch.nn.Module]:
class ToMeBlock(block_class):
# Save for unpatching later
_parent = block_class

View File

@ -1,4 +1,3 @@
from typing import Optional
import torch
import transformers.models.clip.modeling_clip
@ -22,9 +21,9 @@ def _make_causal_mask(
def CLIPTextEmbeddings_forward(
self: transformers.models.clip.modeling_clip.CLIPTextEmbeddings,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor | None = None,
position_ids: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
) -> torch.Tensor:
from modules.devices import dtype
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]

View File

@ -1,5 +1,5 @@
import torch
from typing import Callable
from collections.abc import Callable
from modules.shared import log, opts

View File

@ -1,6 +1,6 @@
from ctypes import CDLL, POINTER
from ctypes.wintypes import LPCWSTR, LPDWORD, DWORD
from typing import Callable
from collections.abc import Callable
from .structures import PDH_HQUERY, PDH_HCOUNTER, PPDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W
from .defines import PDH_FUNCTION, PZZWSTR, DWORD_PTR

View File

@ -1,9 +1,9 @@
from typing import Optional, Union
from typing import Union
import torch
rDevice = Union[torch.device, int]
def get_device(device: Optional[rDevice]=None) -> torch.device:
def get_device(device: rDevice | None=None) -> torch.device:
if device is None:
device = torch.dml.current_device()
return torch.device(device)

View File

@ -10,13 +10,17 @@ install_traceback()
already_displayed = {}
def install(suppress=[]):
def install(suppress=None):
if suppress is None:
suppress = []
warnings.filterwarnings("ignore", category=UserWarning)
install_traceback(suppress=suppress)
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s')
def display(e: Exception, task: str, suppress=[]):
def display(e: Exception, task: str, suppress=None):
if suppress is None:
suppress = []
if isinstance(e, ErrorLimiterAbort):
return
log.critical(f"{task or 'error'}: {type(e).__name__}")
@ -45,7 +49,9 @@ def run(code, task: str):
display(e, task)
def exception(suppress=[]):
def exception(suppress=None):
if suppress is None:
suppress = []
console = get_console()
console.print_exception(show_locals=False, max_frames=16, extra_lines=2, suppress=suppress, theme="ansi_dark", word_wrap=False, width=min([console.width, 200]))

View File

@ -186,7 +186,7 @@ class Extension:
continue
priority = '50'
if os.path.isfile(os.path.join(dirpath, "..", ".priority")):
with open(os.path.join(dirpath, "..", ".priority"), "r", encoding="utf-8") as f:
with open(os.path.join(dirpath, "..", ".priority"), encoding="utf-8") as f:
priority = str(f.read().strip())
res.append(scripts_manager.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
if priority != '50':

View File

@ -73,8 +73,12 @@ def is_stepwise(en_obj):
return any([len(str(x).split("@")) > 1 for x in all_args]) # noqa C419 # pylint: disable=use-a-generator
def activate(p, extra_network_data=None, step=0, include=[], exclude=[]):
def activate(p, extra_network_data=None, step=0, include=None, exclude=None):
"""call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list"""
if exclude is None:
exclude = []
if include is None:
include = []
if p.disable_extra_networks:
return
extra_network_data = extra_network_data or p.network_data

View File

@ -33,7 +33,7 @@ def run_modelmerger(id_task, **kwargs): # pylint: disable=unused-argument
from installer import install
install('tensordict', quiet=True)
try:
from tensordict import TensorDict # pylint: disable=unused-import
pass # pylint: disable=unused-import
except Exception as e:
shared.log.error(f"Merge: {e}")
return [*[gr.update() for _ in range(4)], "tensordict not available"]

View File

@ -1,4 +1,3 @@
from typing import List
import os
import cv2
import torch
@ -34,7 +33,7 @@ def hijack_load_ip_adapter(self):
def face_id(
p: processing.StableDiffusionProcessing,
app,
source_images: List[Image.Image],
source_images: list[Image.Image],
model: str,
override: bool,
cache: bool,

View File

@ -1,4 +1,3 @@
from typing import List
import os
import cv2
import numpy as np
@ -12,7 +11,7 @@ insightface_app = None
swapper = None
def face_swap(p: processing.StableDiffusionProcessing, app, input_images: List[Image.Image], source_image: Image.Image, cache: bool):
def face_swap(p: processing.StableDiffusionProcessing, app, input_images: list[Image.Image], source_image: Image.Image, cache: bool):
global swapper # pylint: disable=global-statement
if swapper is None:
import insightface.model_zoo

View File

@ -14,7 +14,8 @@
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any
from collections.abc import Callable
import cv2
import numpy as np
@ -544,40 +545,40 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
prompt: str | list[str] = None,
prompt_2: str | list[str] | None = None,
image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
negative_prompt: str | list[str] | None = None,
negative_prompt_2: str | list[str] | None = None,
num_images_per_prompt: int | None = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.FloatTensor | None = None,
prompt_embeds: torch.FloatTensor | None = None,
negative_prompt_embeds: torch.FloatTensor | None = None,
pooled_prompt_embeds: torch.FloatTensor | None = None,
negative_pooled_prompt_embeds: torch.FloatTensor | None = None,
image_embeds: torch.FloatTensor | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
cross_attention_kwargs: dict[str, Any] | None = None,
controlnet_conditioning_scale: float | list[float] = 1.0,
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None,
negative_original_size: Optional[Tuple[int, int]] = None,
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = None,
control_guidance_start: float | list[float] = 0.0,
control_guidance_end: float | list[float] = 1.0,
original_size: tuple[int, int] = None,
crops_coords_top_left: tuple[int, int] = (0, 0),
target_size: tuple[int, int] = None,
negative_original_size: tuple[int, int] | None = None,
negative_crops_coords_top_left: tuple[int, int] = (0, 0),
negative_target_size: tuple[int, int] | None = None,
clip_skip: int | None = None,
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
callback_on_step_end_tensor_inputs: list[str] = None,
**kwargs,
):
r"""
@ -890,7 +891,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
for s, e in zip(control_guidance_start, control_guidance_end, strict=False)
]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
@ -970,7 +971,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
controlnet_added_cond_kwargs = added_cond_kwargs
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i], strict=False)]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):

View File

@ -1,7 +1,8 @@
### original <https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/pipeline.py>
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Union
from collections.abc import Callable
import PIL
import torch
from transformers import CLIPImageProcessor
@ -26,8 +27,8 @@ from modules.face.photomaker_model_v2 import PhotoMakerIDEncoder_CLIPInsightface
PipelineImageInput = Union[
PIL.Image.Image,
torch.FloatTensor,
List[PIL.Image.Image],
List[torch.FloatTensor],
list[PIL.Image.Image],
list[torch.FloatTensor],
]
@ -49,10 +50,10 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
num_inference_steps: int | None = None,
device: str | torch.device | None = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,
**kwargs,
):
"""
@ -110,7 +111,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
@validate_hf_hub_args
def load_photomaker_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
weight_name: str,
subfolder: str = '',
trigger_word: str = 'img',
@ -214,21 +215,21 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
def encode_prompt_with_trigger_word(
self,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
prompt_2: str | None = None,
device: torch.device | None = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
negative_prompt: str | None = None,
negative_prompt_2: str | None = None,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
pooled_prompt_embeds: torch.Tensor | None = None,
negative_pooled_prompt_embeds: torch.Tensor | None = None,
lora_scale: float | None = None,
clip_skip: int | None = None,
### Added args
num_id_images: int = 1,
class_tokens_mask: Optional[torch.LongTensor] = None,
class_tokens_mask: torch.LongTensor | None = None,
):
device = device or self._execution_device
@ -273,7 +274,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): # pylint: disable=redefined-argument-from-local
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False): # pylint: disable=redefined-argument-from-local
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
@ -362,7 +363,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str]
uncond_tokens: list[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
@ -377,7 +378,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): # pylint: disable=redefined-argument-from-local
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders, strict=False): # pylint: disable=redefined-argument-from-local
if isinstance(self, TextualInversionLoaderMixin):
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
@ -444,49 +445,47 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
prompt: str | list[str] = None,
prompt_2: str | list[str] | None = None,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
sigmas: List[float] = None,
denoising_end: Optional[float] = None,
timesteps: list[int] = None,
sigmas: list[float] = None,
denoising_end: float | None = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
negative_prompt: str | list[str] | None = None,
negative_prompt_2: str | list[str] | None = None,
num_images_per_prompt: int | None = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
output_type: Optional[str] = "pil",
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
pooled_prompt_embeds: torch.Tensor | None = None,
negative_pooled_prompt_embeds: torch.Tensor | None = None,
ip_adapter_image: PipelineImageInput | None = None,
ip_adapter_image_embeds: list[torch.Tensor] | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
cross_attention_kwargs: dict[str, Any] | None = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
negative_original_size: Optional[Tuple[int, int]] = None,
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
original_size: tuple[int, int] | None = None,
crops_coords_top_left: tuple[int, int] = (0, 0),
target_size: tuple[int, int] | None = None,
negative_original_size: tuple[int, int] | None = None,
negative_crops_coords_top_left: tuple[int, int] = (0, 0),
negative_target_size: tuple[int, int] | None = None,
clip_skip: int | None = None,
callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
callback_on_step_end_tensor_inputs: list[str] = None,
# Added parameters (for PhotoMaker)
input_id_images: PipelineImageInput = None,
start_merge_step: int = 10,
class_tokens_mask: Optional[torch.LongTensor] = None,
id_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
class_tokens_mask: torch.LongTensor | None = None,
id_embeds: torch.FloatTensor | None = None,
prompt_embeds_text_only: torch.FloatTensor | None = None,
pooled_prompt_embeds_text_only: torch.FloatTensor | None = None,
**kwargs,
):
r"""
@ -512,6 +511,8 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if callback_on_step_end_tensor_inputs is None:
callback_on_step_end_tensor_inputs = ["latents"]
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)

View File

@ -1,4 +1,3 @@
from typing import List
import os
import cv2
import torch
@ -43,8 +42,8 @@ def get_model(model_name: str):
def reswapper(
p: processing.StableDiffusionProcessing,
app,
source_images: List[Image.Image],
target_images: List[Image.Image],
source_images: list[Image.Image],
target_images: list[Image.Image],
model_name: str,
original: bool,
):

View File

@ -6,7 +6,7 @@ import torch.nn.functional as F
class ReSwapperModel(nn.Module):
def __init__(self):
super(ReSwapperModel, self).__init__()
super().__init__()
# self.pad = nn.ReflectionPad2d(3)
# Encoder for target face
@ -87,7 +87,7 @@ class ReSwapperModel(nn.Module):
class StyleBlock(nn.Module):
def __init__(self, in_channels, out_channels, blockIndex):
super(StyleBlock, self).__init__()
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0)
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0)
self.style1 = nn.Linear(512, 2048)

View File

@ -2,7 +2,8 @@ import itertools
import os
from collections import UserDict
from dataclasses import dataclass, field
from typing import Callable, Dict, Iterator, List, Optional, Union
from typing import Union
from collections.abc import Callable, Iterator
from installer import log
@ -10,19 +11,19 @@ do_cache_folders = os.environ.get('SD_NO_CACHE', None) is None
class Directory: # forward declaration
...
FilePathList = List[str]
FilePathList = list[str]
FilePathIterator = Iterator[str]
DirectoryPathList = List[str]
DirectoryPathList = list[str]
DirectoryPathIterator = Iterator[str]
DirectoryList = List[Directory]
DirectoryList = list[Directory]
DirectoryIterator = Iterator[Directory]
DirectoryCollection = Dict[str, Directory]
DirectoryCollection = dict[str, Directory]
ExtensionFilter = Callable
ExtensionList = list[str]
RecursiveType = Union[bool,Callable]
def real_path(directory_path:str) -> Union[str, None]:
def real_path(directory_path:str) -> str | None:
try:
return os.path.abspath(os.path.expanduser(directory_path))
except Exception:
@ -52,7 +53,7 @@ class Directory(Directory): # pylint: disable=E0102
def clear(self) -> None:
self._update(Directory.from_dict({
'path': None,
'mtime': float(),
'mtime': 0.0,
'files': [],
'directories': []
}))
@ -125,7 +126,7 @@ def clean_directory(directory: Directory, /, recursive: RecursiveType=False) ->
return is_clean
def get_directory(directory_or_path: str, /, fetch: bool=True) -> Union[Directory, None]:
def get_directory(directory_or_path: str, /, fetch: bool=True) -> Directory | None:
if isinstance(directory_or_path, Directory):
if directory_or_path.is_directory:
return directory_or_path
@ -143,7 +144,7 @@ def get_directory(directory_or_path: str, /, fetch: bool=True) -> Union[Director
return cache_folders[directory_or_path] if directory_or_path in cache_folders else None
def fetch_directory(directory_path: str) -> Union[Directory, None]:
def fetch_directory(directory_path: str) -> Directory | None:
directory: Directory
for directory in _walk(directory_path, recurse=False):
return directory # The return is intentional, we get a generator, we only need the one
@ -255,7 +256,7 @@ def get_directories(*directory_paths: DirectoryPathList, fetch:bool=True, recurs
return filter(bool, directories)
def directory_files(*directories_or_paths: Union[DirectoryPathList, DirectoryList], recursive: RecursiveType=True) -> FilePathIterator:
def directory_files(*directories_or_paths: DirectoryPathList | DirectoryList, recursive: RecursiveType=True) -> FilePathIterator:
return itertools.chain.from_iterable(
itertools.chain(
directory_object.files,
@ -275,7 +276,7 @@ def directory_files(*directories_or_paths: Union[DirectoryPathList, DirectoryLis
)
def extension_filter(ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None) -> ExtensionFilter:
def extension_filter(ext_filter: ExtensionList | None=None, ext_blacklist: ExtensionList | None=None) -> ExtensionFilter:
if ext_filter:
ext_filter = [*map(str.upper, ext_filter)]
if ext_blacklist:
@ -289,11 +290,11 @@ def not_hidden(filepath: str) -> bool:
return not os.path.basename(filepath).startswith('.')
def filter_files(file_paths: FilePathList, ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None) -> FilePathIterator:
def filter_files(file_paths: FilePathList, ext_filter: ExtensionList | None=None, ext_blacklist: ExtensionList | None=None) -> FilePathIterator:
return filter(extension_filter(ext_filter, ext_blacklist), file_paths)
def list_files(*directory_paths:DirectoryPathList, ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None, recursive:RecursiveType=True) -> FilePathIterator:
def list_files(*directory_paths:DirectoryPathList, ext_filter: ExtensionList | None=None, ext_blacklist: ExtensionList | None=None, recursive:RecursiveType=True) -> FilePathIterator:
return filter_files(itertools.chain.from_iterable(
directory_files(directory, recursive=recursive)
for directory in get_directories(*directory_paths, recursive=recursive)

View File

@ -1,4 +1,3 @@
from typing import Optional, List
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
from fastapi.exceptions import HTTPException
from modules import shared
@ -8,39 +7,39 @@ class ReqFramepack(BaseModel):
variant: str = Field(default=None, title="Model variant", description="Model variant to use")
prompt: str = Field(default=None, title="Prompt", description="Prompt for the model")
init_image: str = Field(default=None, title="Initial image", description="Base64 encoded initial image")
end_image: Optional[str] = Field(default=None, title="End image", description="Base64 encoded end image")
start_weight: Optional[float] = Field(default=1.0, title="Start weight", description="Weight of the initial image")
end_weight: Optional[float] = Field(default=1.0, title="End weight", description="Weight of the end image")
vision_weight: Optional[float] = Field(default=1.0, title="Vision weight", description="Weight of the vision model")
system_prompt: Optional[str] = Field(default=None, title="System prompt", description="System prompt for the model")
optimized_prompt: Optional[bool] = Field(default=True, title="Optimized system prompt", description="Use optimized system prompt for the model")
section_prompt: Optional[str] = Field(default=None, title="Section prompt", description="Prompt for each section")
negative_prompt: Optional[str] = Field(default=None, title="Negative prompt", description="Negative prompt for the model")
styles: Optional[List[str]] = Field(default=None, title="Styles", description="Styles for the model")
seed: Optional[int] = Field(default=None, title="Seed", description="Seed for the model")
resolution: Optional[int] = Field(default=640, title="Resolution", description="Resolution of the image")
duration: Optional[float] = Field(default=4, title="Duration", description="Duration of the video in seconds")
latent_ws: Optional[int] = Field(default=9, title="Latent window size", description="Size of the latent window")
steps: Optional[int] = Field(default=25, title="Video steps", description="Number of steps for the video generation")
cfg_scale: Optional[float] = Field(default=1.0, title="CFG scale", description="CFG scale for the model")
cfg_distilled: Optional[float] = Field(default=10.0, title="Distilled CFG scale", description="Distilled CFG scale for the model")
cfg_rescale: Optional[float] = Field(default=0.0, title="CFG re-scale", description="CFG re-scale for the model")
shift: Optional[float] = Field(default=0, title="Sampler shift", description="Shift for the sampler")
use_teacache: Optional[bool] = Field(default=True, title="Enable TeaCache", description="Use TeaCache for the model")
use_cfgzero: Optional[bool] = Field(default=False, title="Enable CFGZero", description="Use CFGZero for the model")
mp4_fps: Optional[int] = Field(default=30, title="FPS", description="Frames per second for the video")
mp4_codec: Optional[str] = Field(default="libx264", title="Codec", description="Codec for the video")
mp4_sf: Optional[bool] = Field(default=False, title="Save SafeTensors", description="Save SafeTensors for the video")
mp4_video: Optional[bool] = Field(default=True, title="Save Video", description="Save video")
mp4_frames: Optional[bool] = Field(default=False, title="Save Frames", description="Save frames for the video")
mp4_opt: Optional[str] = Field(default="crf:16", title="Options", description="Options for the video codec")
mp4_ext: Optional[str] = Field(default="mp4", title="Format", description="Format for the video")
mp4_interpolate: Optional[int] = Field(default=0, title="Interpolation", description="Interpolation for the video")
attention: Optional[str] = Field(default="Default", title="Attention", description="Attention type for the model")
vae_type: Optional[str] = Field(default="Local", title="VAE", description="VAE type for the model")
vlm_enhance: Optional[bool] = Field(default=False, title="VLM enhance", description="Enable VLM enhance")
vlm_model: Optional[str] = Field(default=None, title="VLM model", description="VLM model to use")
vlm_system_prompt: Optional[str] = Field(default=None, title="VLM system prompt", description="System prompt for the VLM model")
end_image: str | None = Field(default=None, title="End image", description="Base64 encoded end image")
start_weight: float | None = Field(default=1.0, title="Start weight", description="Weight of the initial image")
end_weight: float | None = Field(default=1.0, title="End weight", description="Weight of the end image")
vision_weight: float | None = Field(default=1.0, title="Vision weight", description="Weight of the vision model")
system_prompt: str | None = Field(default=None, title="System prompt", description="System prompt for the model")
optimized_prompt: bool | None = Field(default=True, title="Optimized system prompt", description="Use optimized system prompt for the model")
section_prompt: str | None = Field(default=None, title="Section prompt", description="Prompt for each section")
negative_prompt: str | None = Field(default=None, title="Negative prompt", description="Negative prompt for the model")
styles: list[str] | None = Field(default=None, title="Styles", description="Styles for the model")
seed: int | None = Field(default=None, title="Seed", description="Seed for the model")
resolution: int | None = Field(default=640, title="Resolution", description="Resolution of the image")
duration: float | None = Field(default=4, title="Duration", description="Duration of the video in seconds")
latent_ws: int | None = Field(default=9, title="Latent window size", description="Size of the latent window")
steps: int | None = Field(default=25, title="Video steps", description="Number of steps for the video generation")
cfg_scale: float | None = Field(default=1.0, title="CFG scale", description="CFG scale for the model")
cfg_distilled: float | None = Field(default=10.0, title="Distilled CFG scale", description="Distilled CFG scale for the model")
cfg_rescale: float | None = Field(default=0.0, title="CFG re-scale", description="CFG re-scale for the model")
shift: float | None = Field(default=0, title="Sampler shift", description="Shift for the sampler")
use_teacache: bool | None = Field(default=True, title="Enable TeaCache", description="Use TeaCache for the model")
use_cfgzero: bool | None = Field(default=False, title="Enable CFGZero", description="Use CFGZero for the model")
mp4_fps: int | None = Field(default=30, title="FPS", description="Frames per second for the video")
mp4_codec: str | None = Field(default="libx264", title="Codec", description="Codec for the video")
mp4_sf: bool | None = Field(default=False, title="Save SafeTensors", description="Save SafeTensors for the video")
mp4_video: bool | None = Field(default=True, title="Save Video", description="Save video")
mp4_frames: bool | None = Field(default=False, title="Save Frames", description="Save frames for the video")
mp4_opt: str | None = Field(default="crf:16", title="Options", description="Options for the video codec")
mp4_ext: str | None = Field(default="mp4", title="Format", description="Format for the video")
mp4_interpolate: int | None = Field(default=0, title="Interpolation", description="Interpolation for the video")
attention: str | None = Field(default="Default", title="Attention", description="Attention type for the model")
vae_type: str | None = Field(default="Local", title="VAE", description="VAE type for the model")
vlm_enhance: bool | None = Field(default=False, title="VLM enhance", description="Enable VLM enhance")
vlm_model: str | None = Field(default=None, title="VLM model", description="VLM model to use")
vlm_system_prompt: str | None = Field(default=None, title="VLM system prompt", description="System prompt for the VLM model")
class ResFramepack(BaseModel):

View File

@ -43,8 +43,10 @@ def worker(
mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate,
vae_type,
variant,
metadata:dict={},
metadata:dict=None,
):
if metadata is None:
metadata = {}
timer.process.reset()
memstats.reset_stats()
if stream is None or shared.state.interrupted or shared.state.skipped:

View File

@ -1,4 +1,3 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
@ -251,7 +250,7 @@ class CombinedTimestepTextProjEmbeddings(nn.Module):
class HunyuanVideoAdaNorm(nn.Module):
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
def __init__(self, in_features: int, out_features: int | None = None) -> None:
super().__init__()
out_features = out_features or 2 * in_features
@ -260,7 +259,7 @@ class HunyuanVideoAdaNorm(nn.Module):
def forward(
self, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
temb = self.linear(self.nonlinearity(temb))
gate_msa, gate_mlp = temb.chunk(2, dim=-1)
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
@ -298,7 +297,7 @@ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
norm_hidden_states = self.norm1(hidden_states)
@ -346,7 +345,7 @@ class HunyuanVideoIndividualTokenRefiner(nn.Module):
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_mask: torch.Tensor | None = None,
) -> None:
self_attn_mask = None
if attention_mask is not None:
@ -396,7 +395,7 @@ class HunyuanVideoTokenRefiner(nn.Module):
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
attention_mask: torch.LongTensor | None = None,
) -> torch.Tensor:
if attention_mask is None:
pooled_projections = hidden_states.mean(dim=1)
@ -464,8 +463,8 @@ class AdaLayerNormZero(nn.Module):
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = emb.unsqueeze(-2)
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
@ -487,8 +486,8 @@ class AdaLayerNormZeroSingle(nn.Module):
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = emb.unsqueeze(-2)
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
@ -558,8 +557,8 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@ -636,9 +635,9 @@ class HunyuanVideoTransformerBlock(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
attention_mask: torch.Tensor | None = None,
freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
@ -734,7 +733,7 @@ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterM
text_embed_dim: int = 4096,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56),
rope_axes_dim: tuple[int] = (16, 56, 56),
has_image_proj=False,
image_proj_dim=1152,
has_clean_x_embedder=False,

View File

@ -102,14 +102,14 @@ def just_crop(image, w, h):
def write_to_json(data, file_path):
temp_file_path = file_path + ".tmp"
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
with open(temp_file_path, 'w', encoding='utf-8') as temp_file:
json.dump(data, temp_file, indent=4)
os.replace(temp_file_path, file_path)
return
def read_from_json(file_path):
with open(file_path, 'rt', encoding='utf-8') as file:
with open(file_path, encoding='utf-8') as file:
data = json.load(file)
return data
@ -283,7 +283,7 @@ def add_tensors_with_padding(tensor1, tensor2):
shape1 = tensor1.shape
shape2 = tensor2.shape
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2, strict=False))
padded_tensor1 = torch.zeros(new_shape)
padded_tensor2 = torch.zeros(new_shape)

View File

@ -5,7 +5,7 @@ import os
from PIL import Image
import gradio as gr
from modules import shared, gr_tempdir, script_callbacks, images
from modules.infotext import parse, mapping, quote, unquote # pylint: disable=unused-import
from modules.infotext import parse, mapping # pylint: disable=unused-import
type_of_gr_update = type(gr.update())
@ -259,7 +259,7 @@ def connect_paste(button, local_paste_fields, input_comp, override_settings_comp
from modules.paths import params_path
if prompt is None or len(prompt.strip()) == 0:
if os.path.exists(params_path):
with open(params_path, "r", encoding="utf8") as file:
with open(params_path, encoding="utf8") as file:
prompt = file.read()
shared.log.debug(f'Prompt parse: type="params" prompt="{prompt}"')
else:

View File

@ -1,7 +1,8 @@
# Original: invokeai.backend.quantization.gguf.utils
# Largely based on https://github.com/city96/ComfyUI-GGUF
from typing import Callable, Optional, Union
from typing import Union
from collections.abc import Callable
import gguf
import torch
@ -28,7 +29,7 @@ def get_scale_min(scales: torch.Tensor):
# Legacy Quants #
def dequantize_blocks_Q8_0(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
d, x = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
@ -37,7 +38,7 @@ def dequantize_blocks_Q8_0(
def dequantize_blocks_Q5_1(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@ -58,7 +59,7 @@ def dequantize_blocks_Q5_1(
def dequantize_blocks_Q5_0(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@ -79,7 +80,7 @@ def dequantize_blocks_Q5_0(
def dequantize_blocks_Q4_1(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@ -96,7 +97,7 @@ def dequantize_blocks_Q4_1(
def dequantize_blocks_Q4_0(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@ -111,13 +112,13 @@ def dequantize_blocks_Q4_0(
def dequantize_blocks_BF16(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
def dequantize_blocks_Q6_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@ -147,7 +148,7 @@ def dequantize_blocks_Q6_K(
def dequantize_blocks_Q5_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@ -175,7 +176,7 @@ def dequantize_blocks_Q5_K(
def dequantize_blocks_Q4_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@ -197,7 +198,7 @@ def dequantize_blocks_Q4_K(
def dequantize_blocks_Q3_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@ -232,7 +233,7 @@ def dequantize_blocks_Q3_K(
def dequantize_blocks_Q2_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@ -254,7 +255,7 @@ def dequantize_blocks_Q2_K(
DEQUANTIZE_FUNCTIONS: dict[
gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, Optional[torch.dtype]], torch.Tensor]
gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, torch.dtype | None], torch.Tensor]
] = {
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
@ -270,7 +271,7 @@ DEQUANTIZE_FUNCTIONS: dict[
}
def is_torch_compatible(tensor: Optional[torch.Tensor]):
def is_torch_compatible(tensor: torch.Tensor | None):
return getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES
@ -279,7 +280,7 @@ def is_quantized(tensor: torch.Tensor):
def dequantize(
data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: Optional[torch.dtype] = None
data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: torch.dtype | None = None
):
"""
Dequantize tensor back to usable shape/dtype

View File

@ -9,8 +9,10 @@ import torch
from modules import shared, devices
class Item():
def __init__(self, latent, preview=None, info=None, ops=[]):
class Item:
def __init__(self, latent, preview=None, info=None, ops=None):
if ops is None:
ops = []
self.ts = datetime.datetime.now().replace(microsecond=0)
self.name = self.ts.strftime('%Y-%m-%d %H:%M:%S')
self.latent = latent.detach().clone().to(devices.cpu)
@ -20,7 +22,7 @@ class Item():
self.size = sys.getsizeof(self.latent.storage())
class History():
class History:
def __init__(self):
self.index = -1
self.latents = deque(maxlen=1024)
@ -58,7 +60,9 @@ class History():
return i
return -1
def add(self, latent, preview=None, info=None, ops=[]):
def add(self, latent, preview=None, info=None, ops=None):
if ops is None:
ops = []
shared.state.latent_history += 1
if shared.opts.latent_history == 0:
return

View File

@ -171,7 +171,7 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
calc_img = Image.new("RGB", (1, 1), shared.opts.grid_background)
calc_d = ImageDraw.Draw(calc_img)
title_texts = [title] if title else [[GridAnnotation()]]
for texts, allowed_width in zip(x_texts + y_texts + title_texts, [width] * len(x_texts) + [pad_left] * len(y_texts) + [(width+margin)*cols]):
for texts, allowed_width in zip(x_texts + y_texts + title_texts, [width] * len(x_texts) + [pad_left] * len(y_texts) + [(width+margin)*cols], strict=False):
items = [] + texts
texts.clear()
for line in items:

View File

@ -1,4 +1,3 @@
from typing import Union
import sys
import time
import numpy as np
@ -8,7 +7,7 @@ from modules import shared, upscaler
from modules.image import sharpfin
def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None):
def resize_image(resize_mode: int, im: Image.Image | torch.Tensor, width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None):
upscaler_name = upscaler_name or shared.opts.upscaler_for_img2img
def verify_image(image):
@ -34,7 +33,7 @@ def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width:
im = vae_decode(latents, shared.sd_model, output_type='pil', vae_type='Tiny')[0]
return im
def resize(im: Union[Image.Image, torch.Tensor], w, h):
def resize(im: Image.Image | torch.Tensor, w, h):
w, h = int(w), int(h)
if upscaler_name is None or upscaler_name == "None" or (hasattr(im, 'mode') and im.mode == 'L'):
return sharpfin.resize(im, (w, h), linearize=False) # force for mask

View File

@ -22,7 +22,6 @@ _triton_ok = False
def check_sharpfin():
global _sharpfin_checked, _sharpfin_ok, _triton_ok # pylint: disable=global-statement
if not _sharpfin_checked:
from modules.sharpfin.functional import scale # pylint: disable=unused-import
_sharpfin_ok = True
try:
from modules.sharpfin import TRITON_AVAILABLE

View File

@ -1,10 +1,11 @@
from modules.image.util import flatten, draw_text # pylint: disable=unused-import
from modules.image.save import save_image # pylint: disable=unused-import
from modules.image.convert import to_pil, to_tensor # pylint: disable=unused-import
from modules.image.metadata import read_info_from_image, image_data # pylint: disable=unused-import
from modules.image.resize import resize_image # pylint: disable=unused-import
from modules.image.sharpfin import resize # pylint: disable=unused-import
from modules.image.namegen import FilenameGenerator, get_next_sequence_number # pylint: disable=unused-import
from modules.image.watermark import set_watermark, get_watermark # pylint: disable=unused-import
from modules.image.grid import image_grid, get_grid_size, split_grid, combine_grid, check_grid_size, get_font, draw_grid_annotations, draw_prompt_matrix, GridAnnotation, Grid # pylint: disable=unused-import
from modules.video import save_video # pylint: disable=unused-import
from modules.image.metadata import image_data, read_info_from_image
from modules.image.save import save_image, sanitize_filename_part
from modules.image.resize import resize_image
from modules.image.grid import image_grid, check_grid_size, get_grid_size, draw_grid_annotations, draw_prompt_matrix
__all__ = [
'image_data', 'read_info_from_image',
'save_image', 'sanitize_filename_part',
'resize_image',
'image_grid', 'check_grid_size', 'get_grid_size', 'draw_grid_annotations', 'draw_prompt_matrix'
]

View File

@ -77,7 +77,7 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args)
caption_file = os.path.splitext(image_file)[0] + '.txt'
prompt_type='default'
if os.path.exists(caption_file):
with open(caption_file, 'r', encoding='utf8') as f:
with open(caption_file, encoding='utf8') as f:
p.prompt = f.read()
prompt_type='file'
else:

View File

@ -129,7 +129,7 @@ if __name__ == '__main__':
import sys
if len(sys.argv) > 1:
if os.path.exists(sys.argv[1]):
with open(sys.argv[1], 'r', encoding='utf8') as f:
with open(sys.argv[1], encoding='utf8') as f:
parse(f.read())
else:
parse(sys.argv[1])

View File

@ -1,3 +1,2 @@
# a1111 compatibility module: unused
from modules.infotext import parse as parse_generation_parameters # pylint: disable=unused-import

View File

@ -81,7 +81,9 @@ def warn_once(msg):
warned = True
class OpenVINOGraphModule(torch.nn.Module):
def __init__(self, gm, partition_id, use_python_fusion_cache, model_hash_str: str = None, file_name="", int_inputs=[]):
def __init__(self, gm, partition_id, use_python_fusion_cache, model_hash_str: str = None, file_name="", int_inputs=None):
if int_inputs is None:
int_inputs = []
super().__init__()
self.gm = gm
self.int_inputs = int_inputs
@ -192,7 +194,7 @@ def execute(
elif executor == "strictly_openvino":
return openvino_execute(gm, *args, executor_parameters=executor_parameters, file_name=file_name)
msg = "Received unexpected value for 'executor': {0}. Allowed values are: openvino, strictly_openvino.".format(executor)
msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: openvino, strictly_openvino."
raise ValueError(msg)
@ -373,7 +375,7 @@ def openvino_execute(gm: GraphModule, *args, executor_parameters=None, partition
ov_inputs = []
for arg in flat_args:
if not isinstance(arg, int):
ov_inputs.append((arg.detach().cpu().numpy()))
ov_inputs.append(arg.detach().cpu().numpy())
res = req.infer(ov_inputs, share_inputs=True, share_outputs=True)
@ -423,7 +425,9 @@ def openvino_execute_partitioned(gm: GraphModule, *args, executor_parameters=Non
return shared.compiled_model_state.partitioned_modules[signature][0](*ov_inputs)
def partition_graph(gm: GraphModule, use_python_fusion_cache: bool, model_hash_str: str = None, file_name="", int_inputs=[]):
def partition_graph(gm: GraphModule, use_python_fusion_cache: bool, model_hash_str: str = None, file_name="", int_inputs=None):
if int_inputs is None:
int_inputs = []
for node in gm.graph.nodes:
if node.op == "call_module" and "fused_" in node.name:
openvino_submodule = getattr(gm, node.name)
@ -509,7 +513,7 @@ def openvino_fx(subgraph, example_inputs, options=None):
if os.path.isfile(maybe_fs_cached_name + ".xml") and os.path.isfile(maybe_fs_cached_name + ".bin"):
example_inputs_reordered = []
if (os.path.isfile(maybe_fs_cached_name + ".txt")):
f = open(maybe_fs_cached_name + ".txt", "r")
f = open(maybe_fs_cached_name + ".txt")
for input_data in example_inputs:
shape = f.readline()
if (str(input_data.size()) != shape):
@ -532,7 +536,7 @@ def openvino_fx(subgraph, example_inputs, options=None):
if (shared.compiled_model_state.cn_model != [] and str(shared.compiled_model_state.cn_model) in maybe_fs_cached_name):
args_reordered = []
if (os.path.isfile(maybe_fs_cached_name + ".txt")):
f = open(maybe_fs_cached_name + ".txt", "r")
f = open(maybe_fs_cached_name + ".txt")
for input_data in args:
shape = f.readline()
if (str(input_data.size()) != shape):

View File

@ -288,7 +288,19 @@ def parse_params(p: processing.StableDiffusionProcessing, adapters: list, adapte
return adapter_images, adapter_masks, adapter_scales, adapter_crops, adapter_starts, adapter_ends
def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapter_scales=[1.0], adapter_crops=[False], adapter_starts=[0.0], adapter_ends=[1.0], adapter_images=[]):
def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=None, adapter_scales=None, adapter_crops=None, adapter_starts=None, adapter_ends=None, adapter_images=None):
if adapter_images is None:
adapter_images = []
if adapter_ends is None:
adapter_ends = [1.0]
if adapter_starts is None:
adapter_starts = [0.0]
if adapter_crops is None:
adapter_crops = [False]
if adapter_scales is None:
adapter_scales = [1.0]
if adapter_names is None:
adapter_names = []
global adapters_loaded # pylint: disable=global-statement
# overrides
if hasattr(p, 'ip_adapter_names'):
@ -361,7 +373,7 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt
if adapter_starts[i] > 0:
adapter_scales[i] = 0.00
pipe.set_ip_adapter_scale(adapter_scales if len(adapter_scales) > 1 else adapter_scales[0])
ip_str = [f'{os.path.splitext(adapter)[0]}:{scale}:{start}:{end}:{crop}' for adapter, scale, start, end, crop in zip(adapter_names, adapter_scales, adapter_starts, adapter_ends, adapter_crops)]
ip_str = [f'{os.path.splitext(adapter)[0]}:{scale}:{start}:{end}:{crop}' for adapter, scale, start, end, crop in zip(adapter_names, adapter_scales, adapter_starts, adapter_ends, adapter_crops, strict=False)]
if hasattr(pipe, 'transformer') and 'Nunchaku' in pipe.transformer.__class__.__name__:
if isinstance(repos, str):
sd_models.clear_caches(full=True)

View File

@ -63,7 +63,7 @@ try:
except Exception:
pass
try:
import torch.distributed.distributed_c10d as _c10d # pylint: disable=unused-import,ungrouped-imports
pass # pylint: disable=unused-import,ungrouped-imports
except Exception:
errors.log.warning('Loader: torch is not built with distributed support')
@ -73,7 +73,6 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvisi
torchvision = None
try:
import torchvision # pylint: disable=W0611,C0411
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them # pylint: disable=W0611,C0411
except Exception as e:
errors.log.error(f'Loader: torchvision=={torchvision.__version__ if "torchvision" in sys.modules else None} {e}')
if '_no_nep' in str(e):
@ -100,7 +99,7 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
timer.startup.record("torch")
try:
import bitsandbytes # pylint: disable=W0611,C0411
import bitsandbytes # pylint: disable=unused-import
_bnb = True
except Exception:
_bnb = False
@ -132,7 +131,6 @@ except Exception as e:
errors.log.warning(f'Torch onnxruntime: {e}')
timer.startup.record("onnx")
from fastapi import FastAPI # pylint: disable=W0611,C0411
timer.startup.record("fastapi")
import gradio # pylint: disable=W0611,C0411
@ -161,17 +159,16 @@ except Exception as e:
sys.exit(1)
try:
import pillow_jxl # pylint: disable=W0611,C0411
pass # pylint: disable=W0611,C0411
except Exception:
pass
from PIL import Image # pylint: disable=W0611,C0411
timer.startup.record("pillow")
import cv2 # pylint: disable=W0611,C0411
timer.startup.record("cv2")
class _tqdm_cls():
class _tqdm_cls:
def __call__(self, *args, **kwargs):
bar_format = 'Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + '{desc}' + '\x1b[0m'
return tqdm_lib.tqdm(*args, bar_format=bar_format, ncols=80, colour='#327fba', **kwargs)

View File

@ -28,7 +28,7 @@ def localization_js(current_localization_name):
data = {}
if fn is not None:
try:
with open(fn, "r", encoding="utf8") as file:
with open(fn, encoding="utf8") as file:
data = json.load(file)
except Exception as e:
errors.log.error(f"Error loading localization from {fn}:")

View File

@ -1,4 +1,3 @@
from typing import List
import os
import re
import numpy as np
@ -19,7 +18,7 @@ def get_stepwise(param, step, steps): # from https://github.com/cheald/sd-webui-
return steps[0][0]
steps = [[s[0], s[1] if len(s) == 2 else 1] for s in steps] # Add implicit 1s to any steps which don't have a weight
steps.sort(key=lambda k: k[1]) # Sort by index
steps = [list(v) for v in zip(*steps)]
steps = [list(v) for v in zip(*steps, strict=False)]
return steps
def calculate_weight(m, step, max_steps, step_offset=2):
@ -170,10 +169,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
self.model = None
self.errors = {}
def signature(self, names: List[str], te_multipliers: List, unet_multipliers: List):
return [f'{name}:{te}:{unet}' for name, te, unet in zip(names, te_multipliers, unet_multipliers)]
def signature(self, names: list[str], te_multipliers: list, unet_multipliers: list):
return [f'{name}:{te}:{unet}' for name, te, unet in zip(names, te_multipliers, unet_multipliers, strict=False)]
def changed(self, requested: List[str], include: List[str] = None, exclude: List[str] = None) -> bool:
def changed(self, requested: list[str], include: list[str] = None, exclude: list[str] = None) -> bool:
if shared.opts.lora_force_reload:
debug_log(f'Network check: type=LoRA requested={requested} status=forced')
return True
@ -190,7 +189,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
sd_model.loaded_loras[key] = requested
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=changed')
return True
for req, load in zip(requested, loaded):
for req, load in zip(requested, loaded, strict=False):
if req != load:
sd_model.loaded_loras[key] = requested
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=changed')
@ -198,7 +197,11 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=same')
return False
def activate(self, p, params_list, step=0, include=[], exclude=[]):
def activate(self, p, params_list, step=0, include=None, exclude=None):
if exclude is None:
exclude = []
if include is None:
include = []
self.errors.clear()
if self.active:
if self.model != shared.opts.sd_model_checkpoint: # reset if model changed

View File

@ -1,4 +1,3 @@
from typing import Union
import re
import time
import torch
@ -12,7 +11,7 @@ bnb = None
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, wanted_names: tuple):
def network_backup_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, network_layer_name: str, wanted_names: tuple):
global bnb # pylint: disable=W0603
backup_size = 0
if len(l.loaded_networks) > 0 and network_layer_name is not None and any([net.modules.get(network_layer_name, None) for net in l.loaded_networks]): # noqa: C419 # pylint: disable=R1729
@ -76,7 +75,7 @@ def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n
return backup_size
def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, use_previous: bool = False):
def network_calc_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, network_layer_name: str, use_previous: bool = False):
if shared.opts.diffusers_offload_mode == "none":
try:
self.to(devices.device)
@ -147,7 +146,7 @@ def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.
return batch_updown, batch_ex_bias
def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], model_weights: Union[None, torch.Tensor] = None, lora_weights: torch.Tensor = None, deactivate: bool = False, device: torch.device = None, bias: bool = False):
def network_add_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, model_weights: None | torch.Tensor = None, lora_weights: torch.Tensor = None, deactivate: bool = False, device: torch.device = None, bias: bool = False):
if lora_weights is None:
return
if deactivate:
@ -239,7 +238,7 @@ def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.G
del model_weights, lora_weights, new_weight, weight # required to avoid memory leak
def network_apply_direct(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, deactivate: bool = False, device: torch.device = devices.device):
def network_apply_direct(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, updown: torch.Tensor, ex_bias: torch.Tensor, deactivate: bool = False, device: torch.device = devices.device):
weights_backup = getattr(self, "network_weights_backup", False)
bias_backup = getattr(self, "network_bias_backup", False)
if not isinstance(weights_backup, bool): # remove previous backup if we switched settings
@ -266,7 +265,7 @@ def network_apply_direct(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.
l.timer.apply += time.time() - t0
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, device: torch.device, deactivate: bool = False):
def network_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, updown: torch.Tensor, ex_bias: torch.Tensor, device: torch.device, deactivate: bool = False):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None and bias_backup is None:

View File

@ -1,4 +1,3 @@
from typing import List
import os
from modules.lora import lora_timers
from modules.lora import network_lora, network_hada, network_ia3, network_oft, network_lokr, network_full, network_norm, network_glora
@ -16,6 +15,6 @@ module_types = [
network_norm.ModuleTypeNorm(),
network_glora.ModuleTypeGLora(),
]
loaded_networks: List = [] # no type due to circular import
previously_loaded_networks: List = [] # no type due to circular import
loaded_networks: list = [] # no type due to circular import
previously_loaded_networks: list = [] # no type due to circular import
extra_network_lora = None # initialized in extra_networks.py

View File

@ -1,7 +1,6 @@
import os
import re
import bisect
from typing import Dict
import torch
from modules import shared
@ -23,7 +22,7 @@ re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_compiled = {}
def make_unet_conversion_map() -> Dict[str, str]:
def make_unet_conversion_map() -> dict[str, str]:
unet_conversion_map_layer = []
for i in range(4): # num_blocks is 3 in sdxl
@ -213,10 +212,10 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
ait_sd.update({k: down_weight for k in ait_down_keys})
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0), strict=False)}) # noqa: C416 # pylint: disable=unnecessary-comprehension
else:
# down_weight is chunked to each split
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0), strict=False)}) # noqa: C416 # pylint: disable=unnecessary-comprehension
# up_weight is sparse: only non-zero values are copied to each split
i = 0

View File

@ -1,4 +1,3 @@
from typing import Union
import os
import time
import diffusers
@ -50,7 +49,7 @@ def load_per_module(sd_model: diffusers.DiffusionPipeline, filename: str, adapte
return adapter_name
def load_diffusers(name: str, network_on_disk: network.NetworkOnDisk, lora_scale:float=shared.opts.extra_networks_default_multiplier, lora_module=None) -> Union[network.Network, None]:
def load_diffusers(name: str, network_on_disk: network.NetworkOnDisk, lora_scale:float=shared.opts.extra_networks_default_multiplier, lora_module=None) -> network.Network | None:
t0 = time.time()
name = name.replace(".", "_")
sd_model: diffusers.DiffusionPipeline = getattr(shared.sd_model, "pipe", shared.sd_model)

View File

@ -1,4 +1,3 @@
from typing import Union
import os
import time
import concurrent
@ -39,7 +38,7 @@ def lora_dump(lora, dct):
f.write(line + "\n")
def load_safetensors(name, network_on_disk: network.NetworkOnDisk) -> Union[network.Network, None]:
def load_safetensors(name, network_on_disk: network.NetworkOnDisk) -> network.Network | None:
if not shared.sd_loaded:
return None
@ -241,7 +240,7 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
lora_diffusers.diffuser_scales.clear()
t0 = time.time()
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names, strict=False)):
net = None
if network_on_disk is not None:
shorthash = getattr(network_on_disk, 'shorthash', '').lower()

View File

@ -10,7 +10,7 @@ def load_nunchaku(names, strengths):
global previously_loaded # pylint: disable=global-statement
strengths = [s[0] if isinstance(s, list) else s for s in strengths]
networks = lora_load.gather_networks(names)
networks = [(network, strength) for network, strength in zip(networks, strengths) if network is not None and strength > 0]
networks = [(network, strength) for network, strength in zip(networks, strengths, strict=False) if network is not None and strength > 0]
loras = [(network.filename, strength) for network, strength in networks]
is_changed = loras != previously_loaded
if not is_changed:

View File

@ -1,4 +1,4 @@
class Timer():
class Timer:
list: float = 0
load: float = 0
backup: float = 0

View File

@ -1,6 +1,5 @@
import os
import enum
from typing import Union
from collections import namedtuple
from modules import sd_models, hashes, shared
@ -120,7 +119,7 @@ class NetworkOnDisk:
if self.filename is not None:
fn = os.path.splitext(self.filename)[0] + '.txt'
if os.path.exists(fn):
with open(fn, "r", encoding="utf-8") as file:
with open(fn, encoding="utf-8") as file:
return file.read()
return None
@ -144,7 +143,7 @@ class Network: # LoraModule
class ModuleType:
def create_module(self, net: Network, weights: NetworkWeights) -> Union[Network, None]: # pylint: disable=W0613
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: # pylint: disable=W0613
return None

View File

@ -11,7 +11,11 @@ applied_layers: list[str] = []
default_components = ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'text_encoder_4', 'unet', 'transformer', 'transformer_2']
def network_activate(include=[], exclude=[]):
def network_activate(include=None, exclude=None):
if exclude is None:
exclude = []
if include is None:
include = []
t0 = time.time()
with limit_errors("network_activate"):
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
@ -77,7 +81,11 @@ def network_activate(include=[], exclude=[]):
sd_models.set_diffuser_offload(sd_model, op="model")
def network_deactivate(include=[], exclude=[]):
def network_deactivate(include=None, exclude=None):
if exclude is None:
exclude = []
if include is None:
include = []
if not shared.opts.lora_fuse_native or shared.opts.lora_force_diffusers:
return
if len(l.previously_loaded_networks) == 0:

View File

@ -1,5 +1,4 @@
from types import SimpleNamespace
from typing import List
import os
import sys
import time
@ -235,7 +234,7 @@ def run_segment(input_image: gr.Image, input_mask: np.ndarray):
combined_mask = np.zeros(input_mask.shape, dtype='uint8')
input_mask_size = np.count_nonzero(input_mask)
debug(f'Segment SAM: {vars(opts)}')
for mask, score in zip(outputs['masks'], outputs['scores']):
for mask, score in zip(outputs['masks'], outputs['scores'], strict=False):
mask = mask.astype('uint8')
mask_size = np.count_nonzero(mask)
if mask_size == 0:
@ -561,7 +560,7 @@ def create_segment_ui():
return controls
def bind_controls(image_controls: List[gr.Image], preview_image: gr.Image, output_image: gr.Image):
def bind_controls(image_controls: list[gr.Image], preview_image: gr.Image, output_image: gr.Image):
for image_control in image_controls:
btn_mask.click(run_mask, inputs=[image_control], outputs=[preview_image])
btn_lama.click(run_lama, inputs=[image_control], outputs=[output_image])

View File

@ -2,7 +2,7 @@ from collections import defaultdict
import torch
class MemUsageMonitor():
class MemUsageMonitor:
device = None
disabled = False
opts = None

View File

@ -24,7 +24,7 @@ def get_docker_limit():
if docker_limit is not None:
return docker_limit
try:
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r', encoding='utf8') as f:
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', encoding='utf8') as f:
docker_limit = float(f.read())
except Exception:
docker_limit = sys.float_info.max
@ -145,7 +145,9 @@ class Object:
return f'{self.fn}.{self.name} type={self.type} size={self.size} ref={self.refcount}'
def get_objects(gcl={}, threshold:int=0):
def get_objects(gcl=None, threshold:int=0):
if gcl is None:
gcl = {}
objects = []
seen = []

View File

@ -260,7 +260,9 @@ def calculate_model_hash(state_dict):
return func.hexdigest()
def convert(model_path:str, checkpoint_path:str, metadata:dict={}):
def convert(model_path:str, checkpoint_path:str, metadata:dict=None):
if metadata is None:
metadata = {}
unet_path = os.path.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = os.path.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = os.path.join(model_path, "text_encoder", "model.safetensors")

View File

@ -1,7 +1,6 @@
import os
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Dict, Optional, Tuple, Set
import safetensors.torch
import torch
import modules.memstats
@ -37,7 +36,7 @@ KEY_POSITION_IDS = ".".join(
)
def fix_clip(model: Dict) -> Dict:
def fix_clip(model: dict) -> dict:
if KEY_POSITION_IDS in model.keys():
model[KEY_POSITION_IDS] = torch.tensor(
[list(range(MAX_TOKENS))],
@ -48,7 +47,7 @@ def fix_clip(model: Dict) -> Dict:
return model
def prune_sd_model(model: Dict, keyset: Set) -> Dict:
def prune_sd_model(model: dict, keyset: set) -> dict:
keys = list(model.keys())
for k in keys:
if (
@ -60,7 +59,7 @@ def prune_sd_model(model: Dict, keyset: Set) -> Dict:
return model
def restore_sd_model(original_model: Dict, merged_model: Dict) -> Dict:
def restore_sd_model(original_model: dict, merged_model: dict) -> dict:
for k in original_model:
if k not in merged_model:
merged_model[k] = original_model[k]
@ -72,11 +71,11 @@ def log_vram(txt=""):
def load_thetas(
models: Dict[str, os.PathLike],
models: dict[str, os.PathLike],
prune: bool,
device: torch.device,
precision: str,
) -> Dict:
) -> dict:
from tensordict import TensorDict
thetas = {k: TensorDict.from_dict(read_state_dict(m, "cpu")) for k, m in models.items()}
if prune:
@ -95,7 +94,7 @@ def load_thetas(
def merge_models(
models: Dict[str, os.PathLike],
models: dict[str, os.PathLike],
merge_mode: str,
precision: str = "fp16",
weights_clip: bool = False,
@ -104,7 +103,7 @@ def merge_models(
prune: bool = False,
threads: int = 4,
**kwargs,
) -> Dict:
) -> dict:
thetas = load_thetas(models, prune, device, precision)
# log.info(f'Merge start: models={models.values()} precision={precision} clip={weights_clip} rebasin={re_basin} prune={prune} threads={threads}')
weight_matcher = WeightClass(thetas["model_a"], **kwargs)
@ -136,13 +135,13 @@ def merge_models(
def un_prune_model(
merged: Dict,
thetas: Dict,
models: Dict,
merged: dict,
thetas: dict,
models: dict,
device: torch.device,
prune: bool,
precision: str,
) -> Dict:
) -> dict:
if prune:
log.info("Merge restoring pruned keys")
del thetas
@ -180,7 +179,7 @@ def un_prune_model(
def simple_merge(
thetas: Dict[str, Dict],
thetas: dict[str, dict],
weight_matcher: WeightClass,
merge_mode: str,
precision: str = "fp16",
@ -188,7 +187,7 @@ def simple_merge(
device: torch.device = None,
work_device: torch.device = None,
threads: int = 4,
) -> Dict:
) -> dict:
futures = []
import rich.progress as p
with p.Progress(p.TextColumn('[cyan]{task.description}'), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TextColumn('[cyan]keys={task.fields[keys]}'), console=console) as progress:
@ -227,7 +226,7 @@ def simple_merge(
def rebasin_merge(
thetas: Dict[str, os.PathLike],
thetas: dict[str, os.PathLike],
weight_matcher: WeightClass,
merge_mode: str,
precision: str = "fp16",
@ -306,14 +305,14 @@ def simple_merge_key(progress, task, key, thetas, *args, **kwargs):
def merge_key( # pylint: disable=inconsistent-return-statements
key: str,
thetas: Dict,
thetas: dict,
weight_matcher: WeightClass,
merge_mode: str,
precision: str = "fp16",
weights_clip: bool = False,
device: torch.device = None,
work_device: torch.device = None,
) -> Optional[Tuple[str, Dict]]:
) -> tuple[str, dict] | None:
if work_device is None:
work_device = device
@ -376,11 +375,11 @@ def merge_key_context(*args, **kwargs):
def get_merge_method_args(
current_bases: Dict,
thetas: Dict,
current_bases: dict,
thetas: dict,
key: str,
work_device: torch.device,
) -> Dict:
) -> dict:
merge_method_args = {
"a": thetas["model_a"][key].to(work_device),
"b": thetas["model_b"][key].to(work_device),

View File

@ -1,5 +1,4 @@
import math
from typing import Tuple
import torch
from torch import Tensor
@ -151,7 +150,7 @@ def kth_abs_value(a: Tensor, k: int) -> Tensor:
return torch.kthvalue(torch.abs(a.float()), k)[0]
def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]:
def ratio_to_region(width: float, offset: float, n: int) -> tuple[int, int, bool]:
if width < 0:
offset += width
width = -width
@ -233,7 +232,7 @@ def ties_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: flo
delta_filters = (signs == final_sign).float()
res = torch.zeros_like(c, device=c.device)
for delta_filter, delta in zip(delta_filters, deltas):
for delta_filter, delta in zip(delta_filters, deltas, strict=False):
res += delta_filter * delta
param_count = torch.sum(delta_filters, dim=0)

View File

@ -206,7 +206,7 @@ def test_model(pipe: diffusers.StableDiffusionXLPipeline, fn: str, **kwargs):
if not test.generate:
return
try:
generator = torch.Generator(devices.device).manual_seed(int(4242))
generator = torch.Generator(devices.device).manual_seed(4242)
args = {
'prompt': test.prompt,
'negative_prompt': test.negative,
@ -278,7 +278,7 @@ def save_model(pipe: diffusers.StableDiffusionXLPipeline):
yield msg(f'pretrained={folder}')
shared.log.info(f'Modules merge save: type=sdxl diffusers="{folder}"')
pipe.save_pretrained(folder, safe_serialization=True, push_to_hub=False)
with open(os.path.join(folder, 'vae', 'config.json'), 'r', encoding='utf8') as f:
with open(os.path.join(folder, 'vae', 'config.json'), encoding='utf8') as f:
vae_config = json.load(f)
vae_config['force_upcast'] = False
vae_config['scaling_factor'] = 0.13025

View File

@ -53,8 +53,6 @@ def install_nunchaku():
import os
import sys
import platform
import importlib
import importlib.metadata
import torch
python_ver = f'{sys.version_info.major}{sys.version_info.minor}'
if python_ver not in ['311', '312', '313']:

View File

@ -220,5 +220,13 @@ class Shared(sys.modules[__name__].__class__):
model_type = 'unknown'
return model_type
@property
def console(self):
try:
from installer import get_console
return get_console()
except ImportError:
return None
model_data = ModelData()

View File

@ -4,13 +4,12 @@ import time
import shutil
import importlib
import contextlib
from typing import Dict
from urllib.parse import urlparse
import huggingface_hub as hf
from installer import install, log
from modules import shared, errors, files_cache
from modules.upscaler import Upscaler
from modules.paths import script_path, models_path
from modules import paths
loggedin = None
@ -55,7 +54,7 @@ def hf_login(token=None):
return True
def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config: Dict[str, str] = None, token = None, variant = None, revision = None, mirror = None, custom_pipeline = None):
def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config: dict[str, str] = None, token = None, variant = None, revision = None, mirror = None, custom_pipeline = None):
if hub_id is None or len(hub_id) == 0:
return None
from diffusers import DiffusionPipeline
@ -117,7 +116,7 @@ def load_diffusers_models(clear=True):
# t0 = time.time()
place = shared.opts.diffusers_dir
if place is None or len(place) == 0 or not os.path.isdir(place):
place = os.path.join(models_path, 'Diffusers')
place = os.path.join(paths.models_path, 'Diffusers')
if clear:
diffuser_repos.clear()
already_found = []
@ -382,25 +381,25 @@ def cleanup_models():
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
# somehow auto-register and just do these things...
root_path = script_path
src_path = models_path
dest_path = os.path.join(models_path, "Stable-diffusion")
root_path = paths.script_path
src_path = paths.models_path
dest_path = os.path.join(paths.models_path, "Stable-diffusion")
# move_files(src_path, dest_path, ".ckpt")
# move_files(src_path, dest_path, ".safetensors")
src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
dest_path = os.path.join(paths.models_path, "ESRGAN")
move_files(src_path, dest_path)
src_path = os.path.join(models_path, "BSRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
src_path = os.path.join(paths.models_path, "BSRGAN")
dest_path = os.path.join(paths.models_path, "ESRGAN")
move_files(src_path, dest_path, ".pth")
src_path = os.path.join(root_path, "SwinIR")
dest_path = os.path.join(models_path, "SwinIR")
dest_path = os.path.join(paths.models_path, "SwinIR")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
dest_path = os.path.join(models_path, "LDSR")
dest_path = os.path.join(paths.models_path, "LDSR")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "SCUNet")
dest_path = os.path.join(models_path, "SCUNet")
dest_path = os.path.join(paths.models_path, "SCUNet")
move_files(src_path, dest_path)
@ -430,7 +429,7 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced, so we'll try to import any _model.py files before looking in __subclasses__
t0 = time.time()
modules_dir = os.path.join(shared.script_path, "modules", "postprocess")
modules_dir = os.path.join(paths.script_path, "modules", "postprocess")
for file in os.listdir(modules_dir):
if "_model.py" in file:
model_name = file.replace("_model.py", "")

View File

@ -28,7 +28,7 @@ def stat(fn: str):
return size, mtime
class Module():
class Module:
name: str = ''
cls: str = None
device: str = None
@ -61,7 +61,7 @@ class Module():
return s
class Model():
class Model:
name: str = ''
fn: str = ''
type: str = ''

View File

@ -1,17 +1,18 @@
import os
from typing import Type, Callable, TypeVar, Dict, Any
from typing import TypeVar, Any
from collections.abc import Callable
import torch
import diffusers
from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPTextModelWithProjection
class ENVStore:
__DESERIALIZER: Dict[Type, Callable[[str,], Any]] = {
__DESERIALIZER: dict[type, Callable[[str,], Any]] = {
bool: lambda x: bool(int(x)),
int: int,
str: lambda x: x,
}
__SERIALIZER: Dict[Type, Callable[[Any,], str]] = {
__SERIALIZER: dict[type, Callable[[Any,], str]] = {
bool: lambda x: str(int(x)),
int: str,
str: lambda x: x,
@ -89,7 +90,7 @@ def get_loader_arguments(no_variant: bool = False):
T = TypeVar("T")
def from_pretrained(cls: Type[T], pretrained_model_name_or_path: os.PathLike, *args, no_variant: bool = False, **kwargs) -> T:
def from_pretrained(cls: type[T], pretrained_model_name_or_path: os.PathLike, *args, no_variant: bool = False, **kwargs) -> T:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if pretrained_model_name_or_path.endswith(".onnx"):
cls = diffusers.OnnxRuntimeModel

View File

@ -16,7 +16,7 @@ except Exception as e:
class DynamicSessionOptions(ort.SessionOptions):
config: Optional[Dict] = None
config: dict | None = None
def __init__(self):
super().__init__()
@ -28,7 +28,7 @@ class DynamicSessionOptions(ort.SessionOptions):
return sess_options.copy()
return DynamicSessionOptions()
def enable_static_dims(self, config: Dict):
def enable_static_dims(self, config: dict):
self.config = config
self.add_free_dimension_override_by_name("unet_sample_batch", config["hidden_batch_size"])
self.add_free_dimension_override_by_name("unet_sample_channels", 4)
@ -103,9 +103,9 @@ class OnnxRuntimeModel(TorchCompatibleModule, diffusers.OnnxRuntimeModel):
class VAEConfig:
DEFAULTS = { "scaling_factor": 0.18215 }
config: Dict
config: dict
def __init__(self, config: Dict):
def __init__(self, config: dict):
self.config = config
def __getattr__(self, key):

View File

@ -1,6 +1,5 @@
import sys
from enum import Enum
from typing import Tuple, List
from installer import log
from modules import devices
@ -33,7 +32,7 @@ TORCH_DEVICE_TO_EP = {
try:
import onnxruntime as ort
available_execution_providers: List[ExecutionProvider] = ort.get_available_providers()
available_execution_providers: list[ExecutionProvider] = ort.get_available_providers()
except Exception as e:
log.error(f'ONNX import error: {e}')
available_execution_providers = []
@ -90,7 +89,7 @@ def get_execution_provider_options():
return execution_provider_options
def get_provider() -> Tuple:
def get_provider() -> tuple:
from modules.shared import opts
return (opts.onnx_execution_provider, get_execution_provider_options(),)

View File

@ -103,12 +103,12 @@ class OnnxRawPipeline(PipelineBase):
path: os.PathLike
original_filename: str
constructor: Type[PipelineBase]
init_dict: Dict[str, Tuple[str]] = {}
constructor: type[PipelineBase]
init_dict: dict[str, tuple[str]] = {}
default_scheduler: Any = None # for Img2Img
def __init__(self, constructor: Type[PipelineBase], path: os.PathLike): # pylint: disable=super-init-not-called
def __init__(self, constructor: type[PipelineBase], path: os.PathLike): # pylint: disable=super-init-not-called
self._is_sdxl = check_pipeline_sdxl(constructor)
self.from_diffusers_cache = check_diffusers_cache(path)
self.path = path
@ -150,7 +150,7 @@ class OnnxRawPipeline(PipelineBase):
pipeline.scheduler = self.default_scheduler
return pipeline
def convert(self, submodels: List[str], in_dir: os.PathLike, out_dir: os.PathLike):
def convert(self, submodels: list[str], in_dir: os.PathLike, out_dir: os.PathLike):
install('onnx') # may not be installed yet, this performs check and installs as needed
import onnx
shutil.rmtree("cache", ignore_errors=True)
@ -218,7 +218,7 @@ class OnnxRawPipeline(PipelineBase):
with open(os.path.join(out_dir, "model_index.json"), 'w', encoding="utf-8") as file:
json.dump(model_index, file)
def run_olive(self, submodels: List[str], in_dir: os.PathLike, out_dir: os.PathLike):
def run_olive(self, submodels: list[str], in_dir: os.PathLike, out_dir: os.PathLike):
from olive.model import ONNXModelHandler
from olive.workflows import run as run_workflows
@ -235,8 +235,8 @@ class OnnxRawPipeline(PipelineBase):
for submodel in submodels:
log.info(f"\nProcessing {submodel}")
with open(os.path.join(sd_configs_path, "olive", 'sdxl' if self._is_sdxl else 'sd', f"{submodel}.json"), "r", encoding="utf-8") as config_file:
olive_config: Dict[str, Dict[str, Dict]] = json.load(config_file)
with open(os.path.join(sd_configs_path, "olive", 'sdxl' if self._is_sdxl else 'sd', f"{submodel}.json"), encoding="utf-8") as config_file:
olive_config: dict[str, dict[str, dict]] = json.load(config_file)
for flow in olive_config["pass_flows"]:
for i in range(len(flow)):
@ -257,7 +257,7 @@ class OnnxRawPipeline(PipelineBase):
run_workflows(olive_config)
with open(os.path.join("footprints", f"{submodel}_{EP_TO_NAME[shared.opts.onnx_execution_provider]}_footprints.json"), "r", encoding="utf-8") as footprint_file:
with open(os.path.join("footprints", f"{submodel}_{EP_TO_NAME[shared.opts.onnx_execution_provider]}_footprints.json"), encoding="utf-8") as footprint_file:
footprints = json.load(footprint_file)
processor_final_pass_footprint = None
for _, footprint in footprints.items():

View File

@ -1,5 +1,6 @@
import inspect
from typing import Union, Optional, Callable, List, Any
from typing import Any
from collections.abc import Callable
import numpy as np
import torch
import diffusers
@ -33,20 +34,20 @@ class OnnxStableDiffusionImg2ImgPipeline(diffusers.OnnxStableDiffusionImg2ImgPip
def __call__(
self,
prompt: Union[str, List[str]],
prompt: str | list[str],
image: PipelineImageInput = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
num_inference_steps: int | None = 50,
guidance_scale: float | None = 7.5,
negative_prompt: str | list[str] | None = None,
num_images_per_prompt: int | None = 1,
eta: float | None = 0.0,
generator: torch.Generator | list[torch.Generator] | None = None,
prompt_embeds: np.ndarray | None = None,
negative_prompt_embeds: np.ndarray | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback: Callable[[int, int, np.ndarray], None] | None = None,
callback_steps: int = 1,
):
# check inputs. Raise error if not correct

Some files were not shown because too many files have changed in this diff Show More