diff --git a/CHANGELOG.md b/CHANGELOG.md
index 16541b7c2..b813f4591 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -45,7 +45,7 @@ TBD
see [docs](https://vladmandic.github.io/sdnext-docs/Python/) for details
- remove hard-dependnecies:
`clip, numba, skimage, torchsde, omegaconf, antlr, patch-ng, patch-ng, astunparse, addict, inflection, jsonmerge, kornia`,
- `resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets`
+ `resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets, imp`
these are now installed on-demand when needed
- refactor: to/from image/tensor logic, thanks @CalamitousFelicitousness
- refactor: switch to `pyproject.toml` for tool configs
@@ -57,6 +57,8 @@ TBD
- refactor: unified command line parsing
- refactor: launch use threads to async execute non-critical tasks
- refactor: switch from deprecated `pkg_resources` to `importlib`
+ - refactor: modernize typing and type annotations
+ - refactor: improve `pydantic==2.x` compatibility
- update `lint` rules, thanks @awsr
- remove requirements: `clip`, `open-clip`
- update `requirements`
diff --git a/installer.py b/installer.py
index fb44e4053..e8e2fbe72 100644
--- a/installer.py
+++ b/installer.py
@@ -1,4 +1,4 @@
-from typing import overload, List, Optional
+from typing import overload
import os
import sys
import json
@@ -106,10 +106,12 @@ def str_to_bool(val: str | bool | None) -> bool | None:
return val
-def install_traceback(suppress: list = []):
+def install_traceback(suppress: list = None):
from rich.traceback import install as traceback_install
from rich.pretty import install as pretty_install
+ if suppress is None:
+ suppress = []
width = os.environ.get("SD_TRACEWIDTH", console.width if console else None)
if width is not None:
width = int(width)
@@ -133,7 +135,7 @@ def setup_logging():
from functools import partial, partialmethod
from logging.handlers import RotatingFileHandler
try:
- import rich # pylint: disable=unused-import
+ pass # pylint: disable=unused-import
except Exception:
log.error('Please restart SD.Next so changes take effect')
sys.exit(1)
@@ -187,7 +189,7 @@ def setup_logging():
_Segment = Segment
left = _Segment(" " * self.left, style) if self.left else None
right = [_Segment.line()]
- blank_line: Optional[List[Segment]] = None
+ blank_line: list[Segment] | None = None
if self.top:
blank_line = [_Segment(f'{" " * width}\n', style)]
yield from blank_line * self.top
@@ -215,8 +217,10 @@ def setup_logging():
logging.Logger.trace = partialmethod(logging.Logger.log, logging.TRACE)
logging.trace = partial(logging.log, logging.TRACE)
- def exception_hook(e: Exception, suppress=[]):
+ def exception_hook(e: Exception, suppress=None):
from rich.traceback import Traceback
+ if suppress is None:
+ suppress = []
tb = Traceback.from_exception(type(e), e, e.__traceback__, show_locals=False, max_frames=16, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=console.width)
# print-to-console, does not get printed-to-file
exc_type, exc_value, exc_traceback = sys.exc_info()
@@ -416,7 +420,7 @@ def uninstall(package, quiet = False):
def run(cmd: str, arg: str):
- result = subprocess.run(f'"{cmd}" {arg}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ result = subprocess.run(f'"{cmd}" {arg}', shell=True, check=False, env=os.environ, capture_output=True)
txt = result.stdout.decode(encoding="utf8", errors="ignore")
if len(result.stderr) > 0:
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
@@ -461,7 +465,7 @@ def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True):
all_args = f'{pip_log}{arg} {env_args}'.strip()
if not quiet:
log.debug(f'Running: {pipCmd}="{all_args}"')
- result = subprocess.run(f'"{sys.executable}" -m {pipCmd} {all_args}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ result = subprocess.run(f'"{sys.executable}" -m {pipCmd} {all_args}', shell=True, check=False, env=os.environ, capture_output=True)
txt = result.stdout.decode(encoding="utf8", errors="ignore")
if len(result.stderr) > 0:
if uv and result.returncode != 0:
@@ -509,7 +513,7 @@ def git(arg: str, folder: str = None, ignore: bool = False, optional: bool = Fal
git_cmd = os.environ.get('GIT', "git")
if git_cmd != "git":
git_cmd = os.path.abspath(git_cmd)
- result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.')
+ result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, capture_output=True, cwd=folder or '.')
stdout = result.stdout.decode(encoding="utf8", errors="ignore")
if len(result.stderr) > 0:
stdout += ('\n' if len(stdout) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
@@ -639,7 +643,11 @@ def get_platform():
# check python version
-def check_python(supported_minors=[], experimental_minors=[], reason=None):
+def check_python(supported_minors=None, experimental_minors=None, reason=None):
+ if experimental_minors is None:
+ experimental_minors = []
+ if supported_minors is None:
+ supported_minors = []
if supported_minors is None or len(supported_minors) == 0:
supported_minors = [10, 11, 12, 13]
experimental_minors = [14]
@@ -911,8 +919,6 @@ def install_torch_addons():
if 'xformers' in xformers_package:
try:
install(xformers_package, ignore=True, no_deps=True)
- import torch # pylint: disable=unused-import
- import xformers # pylint: disable=unused-import
except Exception as e:
log.debug(f'xFormers cannot install: {e}')
elif not args.experimental and not args.use_xformers and opts.get('cross_attention_optimization', '') != 'xFormers':
@@ -1126,7 +1132,7 @@ def run_extension_installer(folder):
if os.environ.get('PYTHONPATH', None) is not None:
seperator = ';' if sys.platform == 'win32' else ':'
env['PYTHONPATH'] += seperator + os.environ.get('PYTHONPATH', None)
- result = subprocess.run(f'"{sys.executable}" "{path_installer}"', shell=True, env=env, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder)
+ result = subprocess.run(f'"{sys.executable}" "{path_installer}"', shell=True, env=env, check=False, capture_output=True, cwd=folder)
txt = result.stdout.decode(encoding="utf8", errors="ignore")
debug(f'Extension installer: file="{path_installer}" {txt}')
if result.returncode != 0:
@@ -1265,7 +1271,7 @@ def ensure_base_requirements():
local_log = logging.getLogger('sdnext.installer')
global setuptools, distutils # pylint: disable=global-statement
# python may ship with incompatible setuptools
- subprocess.run(f'"{sys.executable}" -m pip install setuptools=={setuptools_version}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ subprocess.run(f'"{sys.executable}" -m pip install setuptools=={setuptools_version}', shell=True, check=False, env=os.environ, capture_output=True)
# need to delete all references to modules to be able to reload them otherwise python will use cached version
modules = [m for m in sys.modules if m.startswith('setuptools') or m.startswith('distutils')]
for m in modules:
@@ -1399,7 +1405,7 @@ def install_requirements():
log.info('Install: verifying requirements')
if args.new:
log.debug('Install: flag=new')
- with open('requirements.txt', 'r', encoding='utf8') as f:
+ with open('requirements.txt', encoding='utf8') as f:
lines = [line.strip() for line in f.readlines() if line.strip() != '' and not line.startswith('#') and line is not None]
for line in lines:
if not installed(line, quiet=True):
@@ -1495,20 +1501,20 @@ def get_version(force=False):
t_start = time.time()
if (version is None) or (version.get('branch', 'unknown') == 'unknown') or force:
try:
- subprocess.run('git config log.showsignature false', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
+ subprocess.run('git config log.showsignature false', capture_output=True, shell=True, check=True)
except Exception:
pass
try:
- res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
+ res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', capture_output=True, shell=True, check=True)
ver = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ' '
commit, updated = ver.split(' ')
version['commit'], version['updated'] = commit, updated
except Exception as e:
log.warning(f'Version: where=commit {e}')
try:
- res = subprocess.run('git remote get-url origin', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
+ res = subprocess.run('git remote get-url origin', capture_output=True, shell=True, check=True)
origin = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
- res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
+ res = subprocess.run('git rev-parse --abbrev-ref HEAD', capture_output=True, shell=True, check=True)
branch_name = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
version['url'] = origin.replace('\n', '').removesuffix('.git') + '/tree/' + branch_name.replace('\n', '')
version['branch'] = branch_name.replace('\n', '')
@@ -1520,7 +1526,7 @@ def get_version(force=False):
try:
if os.path.exists('extensions-builtin/sdnext-modernui'):
os.chdir('extensions-builtin/sdnext-modernui')
- res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
+ res = subprocess.run('git rev-parse --abbrev-ref HEAD', capture_output=True, shell=True, check=True)
branch_ui = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
branch_ui = 'dev' if 'dev' in branch_ui else 'main'
version['ui'] = branch_ui
@@ -1536,7 +1542,7 @@ def get_version(force=False):
version['kanvas'] = 'disabled'
elif os.path.exists('extensions-builtin/sdnext-kanvas'):
os.chdir('extensions-builtin/sdnext-kanvas')
- res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
+ res = subprocess.run('git rev-parse --abbrev-ref HEAD', capture_output=True, shell=True, check=True)
branch_kanvas = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
branch_kanvas = 'dev' if 'dev' in branch_kanvas else 'main'
version['kanvas'] = branch_kanvas
@@ -1723,7 +1729,7 @@ def check_timestamp():
ok = True
setup_time = -1
version_time = -1
- with open(log_file, 'r', encoding='utf8') as f:
+ with open(log_file, encoding='utf8') as f:
lines = f.readlines()
for line in lines:
if 'Setup complete without errors' in line:
@@ -1752,7 +1758,6 @@ def check_timestamp():
def add_args(parser):
- import argparse
group_install = parser.add_argument_group('Install')
group_install.add_argument('--quick', default=os.environ.get("SD_QUICK",False), action='store_true', help="Bypass version checks, default: %(default)s")
group_install.add_argument('--reset', default=os.environ.get("SD_RESET",False), action='store_true', help="Reset main repository to latest version, default: %(default)s")
@@ -1832,7 +1837,7 @@ def read_options():
t_start = time.time()
global opts # pylint: disable=global-statement
if os.path.isfile(args.config):
- with open(args.config, "r", encoding="utf8") as file:
+ with open(args.config, encoding="utf8") as file:
try:
opts = json.load(file)
if type(opts) is str:
diff --git a/launch.py b/launch.py
index e51de7510..3ba43ef5b 100755
--- a/launch.py
+++ b/launch.py
@@ -72,7 +72,7 @@ def get_custom_args():
rec('args')
-@lru_cache()
+@lru_cache
def commit_hash(): # compatbility function
global stored_commit_hash # pylint: disable=global-statement
if stored_commit_hash is not None:
@@ -85,7 +85,7 @@ def commit_hash(): # compatbility function
return stored_commit_hash
-@lru_cache()
+@lru_cache
def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compatbility function
if desc is not None:
installer.log.info(desc)
@@ -94,7 +94,7 @@ def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compat
if result.returncode != 0:
raise RuntimeError(f"""{errdesc or 'Error running command'} Command: {command} Error code: {result.returncode}""")
return ''
- result = subprocess.run(command, stdout=subprocess.PIPE, check=False, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
+ result = subprocess.run(command, capture_output=True, check=False, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
raise RuntimeError(f"""{errdesc or 'Error running command'}: {command} code: {result.returncode}
{result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''}
@@ -104,26 +104,26 @@ def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compat
def check_run(command): # compatbility function
- result = subprocess.run(command, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
+ result = subprocess.run(command, check=False, capture_output=True, shell=True)
return result.returncode == 0
-@lru_cache()
+@lru_cache
def is_installed(pkg): # compatbility function
return installer.installed(pkg)
-@lru_cache()
+@lru_cache
def repo_dir(name): # compatbility function
return os.path.join(script_path, dir_repos, name)
-@lru_cache()
+@lru_cache
def run_python(code, desc=None, errdesc=None): # compatbility function
return run(f'"{sys.executable}" -c "{code}"', desc, errdesc)
-@lru_cache()
+@lru_cache
def run_pip(pkg, desc=None): # compatbility function
forbidden = ['onnxruntime', 'opencv-python']
if desc is None:
@@ -136,7 +136,7 @@ def run_pip(pkg, desc=None): # compatbility function
return run(f'"{sys.executable}" -m pip {pkg} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
-@lru_cache()
+@lru_cache
def check_run_python(code): # compatbility function
return check_run(f'"{sys.executable}" -c "{code}"')
diff --git a/modules/apg/pipeline_stable_cascade_prior_apg.py b/modules/apg/pipeline_stable_cascade_prior_apg.py
index 0e311ad4e..6ffaf8089 100644
--- a/modules/apg/pipeline_stable_cascade_prior_apg.py
+++ b/modules/apg/pipeline_stable_cascade_prior_apg.py
@@ -14,7 +14,7 @@
from dataclasses import dataclass
from math import ceil
-from typing import Callable, Dict, List, Optional, Union
+from collections.abc import Callable
import numpy as np
import PIL
@@ -63,11 +63,11 @@ class StableCascadePriorPipelineOutput(BaseOutput):
Text embeddings for the negative prompt.
"""
- image_embeddings: Union[torch.Tensor, np.ndarray]
- prompt_embeds: Union[torch.Tensor, np.ndarray]
- prompt_embeds_pooled: Union[torch.Tensor, np.ndarray]
- negative_prompt_embeds: Union[torch.Tensor, np.ndarray]
- negative_prompt_embeds_pooled: Union[torch.Tensor, np.ndarray]
+ image_embeddings: torch.Tensor | np.ndarray
+ prompt_embeds: torch.Tensor | np.ndarray
+ prompt_embeds_pooled: torch.Tensor | np.ndarray
+ negative_prompt_embeds: torch.Tensor | np.ndarray
+ negative_prompt_embeds_pooled: torch.Tensor | np.ndarray
class StableCascadePriorPipelineAPG(DiffusionPipeline):
@@ -109,8 +109,8 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
prior: StableCascadeUNet,
scheduler: DDPMWuerstchenScheduler,
resolution_multiple: float = 42.67,
- feature_extractor: Optional[CLIPImageProcessor] = None,
- image_encoder: Optional[CLIPVisionModelWithProjection] = None,
+ feature_extractor: CLIPImageProcessor | None = None,
+ image_encoder: CLIPVisionModelWithProjection | None = None,
) -> None:
super().__init__()
self.register_modules(
@@ -151,10 +151,10 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
do_classifier_free_guidance,
prompt=None,
negative_prompt=None,
- prompt_embeds: Optional[torch.Tensor] = None,
- prompt_embeds_pooled: Optional[torch.Tensor] = None,
- negative_prompt_embeds: Optional[torch.Tensor] = None,
- negative_prompt_embeds_pooled: Optional[torch.Tensor] = None,
+ prompt_embeds: torch.Tensor | None = None,
+ prompt_embeds_pooled: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds_pooled: torch.Tensor | None = None,
):
if prompt_embeds is None:
# get prompt text embeddings
@@ -196,7 +196,7 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0)
if negative_prompt_embeds is None and do_classifier_free_guidance:
- uncond_tokens: List[str]
+ uncond_tokens: list[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
@@ -367,26 +367,26 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
- prompt: Optional[Union[str, List[str]]] = None,
- images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None,
+ prompt: str | list[str] | None = None,
+ images: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image] = None,
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 20,
- timesteps: List[float] = None,
+ timesteps: list[float] = None,
guidance_scale: float = 4.0,
- negative_prompt: Optional[Union[str, List[str]]] = None,
- prompt_embeds: Optional[torch.Tensor] = None,
- prompt_embeds_pooled: Optional[torch.Tensor] = None,
- negative_prompt_embeds: Optional[torch.Tensor] = None,
- negative_prompt_embeds_pooled: Optional[torch.Tensor] = None,
- image_embeds: Optional[torch.Tensor] = None,
- num_images_per_prompt: Optional[int] = 1,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- latents: Optional[torch.Tensor] = None,
- output_type: Optional[str] = "pt",
+ negative_prompt: str | list[str] | None = None,
+ prompt_embeds: torch.Tensor | None = None,
+ prompt_embeds_pooled: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds_pooled: torch.Tensor | None = None,
+ image_embeds: torch.Tensor | None = None,
+ num_images_per_prompt: int | None = 1,
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ output_type: str | None = "pt",
return_dict: bool = True,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ callback_on_step_end: Callable[[int, int, dict], None] | None = None,
+ callback_on_step_end_tensor_inputs: list[str] = None,
):
"""
Function invoked when calling the pipeline for generation.
@@ -460,6 +460,8 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
"""
# 0. Define commonly used variables
+ if callback_on_step_end_tensor_inputs is None:
+ callback_on_step_end_tensor_inputs = ["latents"]
device = self._execution_device
dtype = next(self.prior.parameters()).dtype
self._guidance_scale = guidance_scale
diff --git a/modules/apg/pipeline_stable_diffision_xl_apg.py b/modules/apg/pipeline_stable_diffision_xl_apg.py
index 3371877fd..b09ae242f 100644
--- a/modules/apg/pipeline_stable_diffision_xl_apg.py
+++ b/modules/apg/pipeline_stable_diffision_xl_apg.py
@@ -13,7 +13,8 @@
# limitations under the License.
import inspect
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any
+from collections.abc import Callable
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
@@ -28,7 +29,6 @@ from diffusers.utils import USE_PEFT_BACKEND, deprecate, is_invisible_watermark_
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
-from diffusers.models.attention_processor import Attention
from modules import apg
if is_invisible_watermark_available():
@@ -76,10 +76,10 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- sigmas: Optional[List[float]] = None,
+ num_inference_steps: int | None = None,
+ device: str | torch.device | None = None,
+ timesteps: list[int] | None = None,
+ sigmas: list[float] | None = None,
**kwargs,
):
"""
@@ -217,7 +217,7 @@ class StableDiffusionXLPipelineAPG(
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
force_zeros_for_empty_prompt: bool = True,
- add_watermarker: Optional[bool] = None,
+ add_watermarker: bool | None = None,
):
super().__init__()
@@ -248,18 +248,18 @@ class StableDiffusionXLPipelineAPG(
def encode_prompt(
self,
prompt: str,
- prompt_2: Optional[str] = None,
- device: Optional[torch.device] = None,
+ prompt_2: str | None = None,
+ device: torch.device | None = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
- negative_prompt: Optional[str] = None,
- negative_prompt_2: Optional[str] = None,
- prompt_embeds: Optional[torch.Tensor] = None,
- negative_prompt_embeds: Optional[torch.Tensor] = None,
- pooled_prompt_embeds: Optional[torch.Tensor] = None,
- negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
- lora_scale: Optional[float] = None,
- clip_skip: Optional[int] = None,
+ negative_prompt: str | None = None,
+ negative_prompt_2: str | None = None,
+ prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ pooled_prompt_embeds: torch.Tensor | None = None,
+ negative_pooled_prompt_embeds: torch.Tensor | None = None,
+ lora_scale: float | None = None,
+ clip_skip: int | None = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -343,7 +343,7 @@ class StableDiffusionXLPipelineAPG(
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
- for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False):
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
@@ -396,7 +396,7 @@ class StableDiffusionXLPipelineAPG(
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
- uncond_tokens: List[str]
+ uncond_tokens: list[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
@@ -412,7 +412,7 @@ class StableDiffusionXLPipelineAPG(
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
- for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders, strict=False):
if isinstance(self, TextualInversionLoaderMixin):
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
@@ -521,7 +521,7 @@ class StableDiffusionXLPipelineAPG(
)
for single_ip_adapter_image, image_proj_layer in zip(
- ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers, strict=False
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
@@ -793,42 +793,40 @@ class StableDiffusionXLPipelineAPG(
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
- prompt: Union[str, List[str]] = None,
- prompt_2: Optional[Union[str, List[str]]] = None,
- height: Optional[int] = None,
- width: Optional[int] = None,
+ prompt: str | list[str] = None,
+ prompt_2: str | list[str] | None = None,
+ height: int | None = None,
+ width: int | None = None,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
- sigmas: List[float] = None,
- denoising_end: Optional[float] = None,
+ timesteps: list[int] = None,
+ sigmas: list[float] = None,
+ denoising_end: float | None = None,
guidance_scale: float = 5.0,
- negative_prompt: Optional[Union[str, List[str]]] = None,
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
- num_images_per_prompt: Optional[int] = 1,
+ negative_prompt: str | list[str] | None = None,
+ negative_prompt_2: str | list[str] | None = None,
+ num_images_per_prompt: int | None = 1,
eta: float = 0.0,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- latents: Optional[torch.Tensor] = None,
- prompt_embeds: Optional[torch.Tensor] = None,
- negative_prompt_embeds: Optional[torch.Tensor] = None,
- pooled_prompt_embeds: Optional[torch.Tensor] = None,
- negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
- ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
- output_type: Optional[str] = "pil",
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ pooled_prompt_embeds: torch.Tensor | None = None,
+ negative_pooled_prompt_embeds: torch.Tensor | None = None,
+ ip_adapter_image: PipelineImageInput | None = None,
+ ip_adapter_image_embeds: list[torch.Tensor] | None = None,
+ output_type: str | None = "pil",
return_dict: bool = True,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ cross_attention_kwargs: dict[str, Any] | None = None,
guidance_rescale: float = 0.0,
- original_size: Optional[Tuple[int, int]] = None,
- crops_coords_top_left: Tuple[int, int] = (0, 0),
- target_size: Optional[Tuple[int, int]] = None,
- negative_original_size: Optional[Tuple[int, int]] = None,
- negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
- negative_target_size: Optional[Tuple[int, int]] = None,
- clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
- ] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ original_size: tuple[int, int] | None = None,
+ crops_coords_top_left: tuple[int, int] = (0, 0),
+ target_size: tuple[int, int] | None = None,
+ negative_original_size: tuple[int, int] | None = None,
+ negative_crops_coords_top_left: tuple[int, int] = (0, 0),
+ negative_target_size: tuple[int, int] | None = None,
+ clip_skip: int | None = None,
+ callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
+ callback_on_step_end_tensor_inputs: list[str] = None,
**kwargs,
):
r"""
@@ -976,6 +974,8 @@ class StableDiffusionXLPipelineAPG(
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
+ if callback_on_step_end_tensor_inputs is None:
+ callback_on_step_end_tensor_inputs = ["latents"]
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
diff --git a/modules/apg/pipeline_stable_diffusion_apg.py b/modules/apg/pipeline_stable_diffusion_apg.py
index 6eb6bae90..ae1eb26e0 100644
--- a/modules/apg/pipeline_stable_diffusion_apg.py
+++ b/modules/apg/pipeline_stable_diffusion_apg.py
@@ -13,7 +13,8 @@
# limitations under the License.
import inspect
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Any
+from collections.abc import Callable
import torch
from packaging import version
@@ -71,10 +72,10 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
def retrieve_timesteps(
scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- sigmas: Optional[List[float]] = None,
+ num_inference_steps: int | None = None,
+ device: str | torch.device | None = None,
+ timesteps: list[int] | None = None,
+ sigmas: list[float] | None = None,
**kwargs,
):
"""
@@ -273,9 +274,9 @@ class StableDiffusionPipelineAPG(
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
- prompt_embeds: Optional[torch.Tensor] = None,
- negative_prompt_embeds: Optional[torch.Tensor] = None,
- lora_scale: Optional[float] = None,
+ prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ lora_scale: float | None = None,
**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
@@ -305,10 +306,10 @@ class StableDiffusionPipelineAPG(
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
- prompt_embeds: Optional[torch.Tensor] = None,
- negative_prompt_embeds: Optional[torch.Tensor] = None,
- lora_scale: Optional[float] = None,
- clip_skip: Optional[int] = None,
+ prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ lora_scale: float | None = None,
+ clip_skip: int | None = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -421,7 +422,7 @@ class StableDiffusionPipelineAPG(
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
- uncond_tokens: List[str]
+ uncond_tokens: list[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
@@ -520,7 +521,7 @@ class StableDiffusionPipelineAPG(
)
for single_ip_adapter_image, image_proj_layer in zip(
- ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers, strict=False
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
@@ -748,31 +749,29 @@ class StableDiffusionPipelineAPG(
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
- prompt: Union[str, List[str]] = None,
- height: Optional[int] = None,
- width: Optional[int] = None,
+ prompt: str | list[str] = None,
+ height: int | None = None,
+ width: int | None = None,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
- sigmas: List[float] = None,
+ timesteps: list[int] = None,
+ sigmas: list[float] = None,
guidance_scale: float = 7.5,
- negative_prompt: Optional[Union[str, List[str]]] = None,
- num_images_per_prompt: Optional[int] = 1,
+ negative_prompt: str | list[str] | None = None,
+ num_images_per_prompt: int | None = 1,
eta: float = 0.0,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- latents: Optional[torch.Tensor] = None,
- prompt_embeds: Optional[torch.Tensor] = None,
- negative_prompt_embeds: Optional[torch.Tensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
- ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
- output_type: Optional[str] = "pil",
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ ip_adapter_image: PipelineImageInput | None = None,
+ ip_adapter_image_embeds: list[torch.Tensor] | None = None,
+ output_type: str | None = "pil",
return_dict: bool = True,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ cross_attention_kwargs: dict[str, Any] | None = None,
guidance_rescale: float = 0.0,
- clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
- ] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ clip_skip: int | None = None,
+ callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
+ callback_on_step_end_tensor_inputs: list[str] = None,
**kwargs,
):
r"""
@@ -861,6 +860,8 @@ class StableDiffusionPipelineAPG(
"not-safe-for-work" (nsfw) content.
"""
+ if callback_on_step_end_tensor_inputs is None:
+ callback_on_step_end_tensor_inputs = ["latents"]
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
diff --git a/modules/api/api.py b/modules/api/api.py
index 9ba641b23..669362d7d 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,4 +1,3 @@
-from typing import List, Optional
from threading import Lock
from secrets import compare_digest
from fastapi import FastAPI, APIRouter, Depends, Request
@@ -19,7 +18,7 @@ class Api:
user, password = auth.split(":")
self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip()
if shared.cmd_opts.auth_file:
- with open(shared.cmd_opts.auth_file, 'r', encoding="utf8") as file:
+ with open(shared.cmd_opts.auth_file, encoding="utf8") as file:
for line in file.readlines():
user, password = line.split(":")
self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip()
@@ -41,7 +40,7 @@ class Api:
self.add_api_route("/js", server.get_js, methods=["GET"], auth=False)
# server api
self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str)
- self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=List[str])
+ self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=list[str])
self.add_api_route("/sdapi/v1/log", server.post_log, methods=["POST"])
self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"])
self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"])
@@ -56,7 +55,7 @@ class Api:
self.add_api_route("/sdapi/v1/options", server.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", server.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", server.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
- self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu_status, methods=["GET"], response_model=List[models.ResGPU])
+ self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu_status, methods=["GET"], response_model=list[models.ResGPU])
# core api using locking
self.add_api_route("/sdapi/v1/txt2img", self.generate.post_text2img, methods=["POST"], response_model=models.ResTxt2Img)
@@ -71,21 +70,21 @@ class Api:
# api dealing with optional scripts
self.add_api_route("/sdapi/v1/scripts", script.get_scripts_list, methods=["GET"], response_model=models.ResScripts)
- self.add_api_route("/sdapi/v1/script-info", script.get_script_info, methods=["GET"], response_model=List[models.ItemScript])
+ self.add_api_route("/sdapi/v1/script-info", script.get_script_info, methods=["GET"], response_model=list[models.ItemScript])
# enumerator api
- self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=List[process.ItemPreprocess])
+ self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=list[process.ItemPreprocess])
self.add_api_route("/sdapi/v1/masking", self.process.get_mask, methods=["GET"], response_model=process.ItemMask)
- self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=List[models.ItemSampler])
- self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=List[models.ItemUpscaler])
- self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=List[models.ItemModel])
- self.add_api_route("/sdapi/v1/controlnets", endpoints.get_controlnets, methods=["GET"], response_model=List[str])
- self.add_api_route("/sdapi/v1/detailers", endpoints.get_detailers, methods=["GET"], response_model=List[models.ItemDetailer])
- self.add_api_route("/sdapi/v1/prompt-styles", endpoints.get_prompt_styles, methods=["GET"], response_model=List[models.ItemStyle])
+ self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=list[models.ItemSampler])
+ self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=list[models.ItemUpscaler])
+ self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=list[models.ItemModel])
+ self.add_api_route("/sdapi/v1/controlnets", endpoints.get_controlnets, methods=["GET"], response_model=list[str])
+ self.add_api_route("/sdapi/v1/detailers", endpoints.get_detailers, methods=["GET"], response_model=list[models.ItemDetailer])
+ self.add_api_route("/sdapi/v1/prompt-styles", endpoints.get_prompt_styles, methods=["GET"], response_model=list[models.ItemStyle])
self.add_api_route("/sdapi/v1/embeddings", endpoints.get_embeddings, methods=["GET"], response_model=models.ResEmbeddings)
- self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=List[models.ItemVae])
- self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=List[models.ItemExtension])
- self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=List[models.ItemExtraNetwork])
+ self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=list[models.ItemVae])
+ self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=list[models.ItemExtension])
+ self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=list[models.ItemExtraNetwork])
# functional api
self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo)
@@ -96,7 +95,7 @@ class Api:
self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/lock-checkpoint", endpoints.post_lock_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"])
- self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=List[str])
+ self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=list[str])
self.add_api_route("/sdapi/v1/latents", endpoints.post_latent_history, methods=["POST"], response_model=int)
self.add_api_route("/sdapi/v1/modules", endpoints.get_modules, methods=["GET"])
self.add_api_route("/sdapi/v1/sampler", endpoints.get_sampler, methods=["GET"], response_model=dict)
@@ -146,7 +145,7 @@ class Api:
shared.log.error(f'API authentication: user="{credentials.username}"')
raise HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"})
- def get_session_start(self, req: Request, agent: Optional[str] = None):
+ def get_session_start(self, req: Request, agent: str | None = None):
token = req.cookies.get("access-token") or req.cookies.get("access-token-unsecure")
user = self.app.tokens.get(token) if hasattr(self.app, 'tokens') else None
shared.log.info(f'Browser session: user={user} client={req.client.host} agent={agent}')
diff --git a/modules/api/caption.py b/modules/api/caption.py
index ad164889b..82110da74 100644
--- a/modules/api/caption.py
+++ b/modules/api/caption.py
@@ -25,7 +25,7 @@ Core processing logic is shared between direct and dispatch handlers via
``do_openclip``, ``do_tagger``, and ``do_vqa`` functions to avoid duplication.
"""
-from typing import Optional, List, Union, Literal, Annotated
+from typing import Literal, Annotated
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
from fastapi.exceptions import HTTPException
from modules import shared
@@ -49,21 +49,21 @@ class ReqCaption(BaseModel):
mode: str = Field(default="best", title="Mode", description="Caption mode. 'best': Most thorough analysis, slowest but highest quality. 'fast': Quick caption with minimal flavor terms. 'classic': Standard captioning with balanced quality and speed. 'caption': BLIP caption only, no CLIP flavor matching. 'negative': Generate terms suitable for use as a negative prompt.")
analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed image analysis breakdown (medium, artist, movement, trending, flavor) in addition to caption.")
# Advanced settings (optional per-request overrides)
- max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum number of tokens in the generated caption.")
- chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up captioning but increase VRAM usage.")
- min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum number of descriptive tags (flavors) to keep in the final prompt.")
- max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum number of descriptive tags (flavors) to keep in the final prompt.")
- flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of the intermediate candidate pool when matching image features to descriptive tags. Higher values may improve quality but are slower.")
- num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Number of beams for beam search during caption generation. Higher values search more possibilities but are slower.")
+ max_length: int | None = Field(default=None, title="Max Length", description="Maximum number of tokens in the generated caption.")
+ chunk_size: int | None = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up captioning but increase VRAM usage.")
+ min_flavors: int | None = Field(default=None, title="Min Flavors", description="Minimum number of descriptive tags (flavors) to keep in the final prompt.")
+ max_flavors: int | None = Field(default=None, title="Max Flavors", description="Maximum number of descriptive tags (flavors) to keep in the final prompt.")
+ flavor_count: int | None = Field(default=None, title="Intermediates", description="Size of the intermediate candidate pool when matching image features to descriptive tags. Higher values may improve quality but are slower.")
+ num_beams: int | None = Field(default=None, title="Num Beams", description="Number of beams for beam search during caption generation. Higher values search more possibilities but are slower.")
class ResCaption(BaseModel):
"""Response model for image captioning results."""
- caption: Optional[str] = Field(default=None, title="Caption", description="Generated caption/prompt describing the image content and style.")
- medium: Optional[str] = Field(default=None, title="Medium", description="Detected artistic medium (e.g., 'oil painting', 'digital art', 'photograph'). Only returned when analyze=True.")
- artist: Optional[str] = Field(default=None, title="Artist", description="Detected similar artist style (e.g., 'by greg rutkowski'). Only returned when analyze=True.")
- movement: Optional[str] = Field(default=None, title="Movement", description="Detected art movement (e.g., 'art nouveau', 'impressionism'). Only returned when analyze=True.")
- trending: Optional[str] = Field(default=None, title="Trending", description="Trending/platform tags (e.g., 'trending on artstation'). Only returned when analyze=True.")
- flavor: Optional[str] = Field(default=None, title="Flavor", description="Additional descriptive elements (e.g., 'cinematic lighting', 'highly detailed'). Only returned when analyze=True.")
+ caption: str | None = Field(default=None, title="Caption", description="Generated caption/prompt describing the image content and style.")
+ medium: str | None = Field(default=None, title="Medium", description="Detected artistic medium (e.g., 'oil painting', 'digital art', 'photograph'). Only returned when analyze=True.")
+ artist: str | None = Field(default=None, title="Artist", description="Detected similar artist style (e.g., 'by greg rutkowski'). Only returned when analyze=True.")
+ movement: str | None = Field(default=None, title="Movement", description="Detected art movement (e.g., 'art nouveau', 'impressionism'). Only returned when analyze=True.")
+ trending: str | None = Field(default=None, title="Trending", description="Trending/platform tags (e.g., 'trending on artstation'). Only returned when analyze=True.")
+ flavor: str | None = Field(default=None, title="Flavor", description="Additional descriptive elements (e.g., 'cinematic lighting', 'highly detailed'). Only returned when analyze=True.")
class ReqVQA(BaseModel):
"""Request model for Vision-Language Model (VLM) captioning.
@@ -74,32 +74,32 @@ class ReqVQA(BaseModel):
image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string containing the image data.")
model: str = Field(default="Alibaba Qwen 2.5 VL 3B", title="Model", description="Select which model to use for Visual Language tasks. Use GET /sdapi/v1/vqa/models for full list. Models which support thinking mode are indicated in capabilities.")
question: str = Field(default="describe the image", title="Question/Task", description="Task for the model to perform. Common tasks: 'Short Caption', 'Normal Caption', 'Long Caption'. Set to 'Use Prompt' to pass custom text via the prompt field. Florence-2 tasks: 'Object Detection', 'OCR (Read Text)', 'Phrase Grounding', 'Dense Region Caption', 'Region Proposal', 'OCR with Regions'. PromptGen tasks: 'Analyze', 'Generate Tags', 'Mixed Caption'. Moondream tasks: 'Point at...', 'Detect all...', 'Detect Gaze' (Moondream 2 only). Use GET /sdapi/v1/vqa/prompts?model= to list tasks available for a specific model.")
- prompt: Optional[str] = Field(default=None, title="Prompt", description="Custom prompt text. Required when question is 'Use Prompt'. For 'Point at...' tasks, specify what to find (e.g., 'the red car'). For 'Detect all...' tasks, specify what to detect (e.g., 'faces').")
+ prompt: str | None = Field(default=None, title="Prompt", description="Custom prompt text. Required when question is 'Use Prompt'. For 'Point at...' tasks, specify what to find (e.g., 'the red car'). For 'Detect all...' tasks, specify what to detect (e.g., 'faces').")
system: str = Field(default="You are image captioning expert, creative, unbiased and uncensored.", title="System Prompt", description="System prompt controls behavior of the LLM. Processed first and persists throughout conversation. Has highest priority weighting and is always appended at the beginning of the sequence. Use for: Response formatting rules, role definition, style.")
include_annotated: bool = Field(default=False, title="Include Annotated Image", description="If True and the task produces detection results (object detection, point detection, gaze), returns annotated image with bounding boxes/points drawn. Only applicable for detection tasks on models like Florence-2 and Moondream.")
# LLM generation parameters (optional overrides)
- max_tokens: Optional[int] = Field(default=None, title="Max Tokens", description="Maximum number of tokens the model can generate in its response. The model is not aware of this limit during generation; it simply sets the hard limit for the length and will forcefully cut off the response when reached.")
- temperature: Optional[float] = Field(default=None, title="Temperature", description="Controls randomness in token selection. Lower values (e.g., 0.1) make outputs more focused and deterministic, always choosing high-probability tokens. Higher values (e.g., 0.9) increase creativity and diversity by allowing less probable tokens. Set to 0 for fully deterministic output.")
- top_k: Optional[int] = Field(default=None, title="Top-K", description="Limits token selection to the K most likely candidates at each step. Lower values (e.g., 40) make outputs more focused and predictable, while higher values allow more diverse choices. Set to 0 to disable.")
- top_p: Optional[float] = Field(default=None, title="Top-P", description="Selects tokens from the smallest set whose cumulative probability exceeds P (e.g., 0.9). Dynamically adapts the number of candidates based on model confidence; fewer options when certain, more when uncertain. Set to 1 to disable.")
- num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Maintains multiple candidate paths simultaneously and selects the overall best sequence. More thorough but much slower and less creative than random sampling. Generally not recommended; most modern VLMs perform better with sampling methods. Set to 1 to disable.")
- do_sample: Optional[bool] = Field(default=None, title="Use Samplers", description="Enable to use sampling (randomly selecting tokens based on sampling methods like Top-K or Top-P) or disable to use greedy decoding (selecting the most probable token at each step). Enabling makes outputs more diverse and creative but less deterministic.")
- thinking_mode: Optional[bool] = Field(default=None, title="Thinking Mode", description="Enables thinking/reasoning, allowing the model to take more time to generate responses. Can lead to more thoughtful and detailed answers but increases response time. Only works with models that support this feature.")
- prefill: Optional[str] = Field(default=None, title="Prefill Text", description="Pre-fills the start of the model's response to guide its output format or content by forcing it to continue the prefill text. Prefill is filtered out and does not appear in the final response unless keep_prefill is True. Leave empty to let the model generate from scratch.")
- keep_thinking: Optional[bool] = Field(default=None, title="Keep Thinking Trace", description="Include the model's reasoning process in the final output. Useful for understanding how the model arrived at its answer. Only works with models that support thinking mode.")
- keep_prefill: Optional[bool] = Field(default=None, title="Keep Prefill", description="Include the prefill text at the beginning of the final output. If disabled, the prefill text used to guide the model is removed from the result.")
+ max_tokens: int | None = Field(default=None, title="Max Tokens", description="Maximum number of tokens the model can generate in its response. The model is not aware of this limit during generation; it simply sets the hard limit for the length and will forcefully cut off the response when reached.")
+ temperature: float | None = Field(default=None, title="Temperature", description="Controls randomness in token selection. Lower values (e.g., 0.1) make outputs more focused and deterministic, always choosing high-probability tokens. Higher values (e.g., 0.9) increase creativity and diversity by allowing less probable tokens. Set to 0 for fully deterministic output.")
+ top_k: int | None = Field(default=None, title="Top-K", description="Limits token selection to the K most likely candidates at each step. Lower values (e.g., 40) make outputs more focused and predictable, while higher values allow more diverse choices. Set to 0 to disable.")
+ top_p: float | None = Field(default=None, title="Top-P", description="Selects tokens from the smallest set whose cumulative probability exceeds P (e.g., 0.9). Dynamically adapts the number of candidates based on model confidence; fewer options when certain, more when uncertain. Set to 1 to disable.")
+ num_beams: int | None = Field(default=None, title="Num Beams", description="Maintains multiple candidate paths simultaneously and selects the overall best sequence. More thorough but much slower and less creative than random sampling. Generally not recommended; most modern VLMs perform better with sampling methods. Set to 1 to disable.")
+ do_sample: bool | None = Field(default=None, title="Use Samplers", description="Enable to use sampling (randomly selecting tokens based on sampling methods like Top-K or Top-P) or disable to use greedy decoding (selecting the most probable token at each step). Enabling makes outputs more diverse and creative but less deterministic.")
+ thinking_mode: bool | None = Field(default=None, title="Thinking Mode", description="Enables thinking/reasoning, allowing the model to take more time to generate responses. Can lead to more thoughtful and detailed answers but increases response time. Only works with models that support this feature.")
+ prefill: str | None = Field(default=None, title="Prefill Text", description="Pre-fills the start of the model's response to guide its output format or content by forcing it to continue the prefill text. Prefill is filtered out and does not appear in the final response unless keep_prefill is True. Leave empty to let the model generate from scratch.")
+ keep_thinking: bool | None = Field(default=None, title="Keep Thinking Trace", description="Include the model's reasoning process in the final output. Useful for understanding how the model arrived at its answer. Only works with models that support thinking mode.")
+ keep_prefill: bool | None = Field(default=None, title="Keep Prefill", description="Include the prefill text at the beginning of the final output. If disabled, the prefill text used to guide the model is removed from the result.")
class ResVQA(BaseModel):
"""Response model for VLM captioning results."""
- answer: Optional[str] = Field(default=None, title="Answer", description="Generated caption, answer, or analysis from the VLM. Format depends on the question/task type.")
- annotated_image: Optional[str] = Field(default=None, title="Annotated Image", description="Base64 encoded PNG image with detection results drawn (bounding boxes, points). Only returned when include_annotated=True and the task produces detection results.")
+ answer: str | None = Field(default=None, title="Answer", description="Generated caption, answer, or analysis from the VLM. Format depends on the question/task type.")
+ annotated_image: str | None = Field(default=None, title="Annotated Image", description="Base64 encoded PNG image with detection results drawn (bounding boxes, points). Only returned when include_annotated=True and the task produces detection results.")
class ItemVLMModel(BaseModel):
"""VLM model information."""
name: str = Field(title="Name", description="Display name of the model")
repo: str = Field(title="Repository", description="HuggingFace repository ID")
- prompts: List[str] = Field(title="Prompts", description="Available prompts/tasks for this model")
- capabilities: List[str] = Field(title="Capabilities", description="Model capabilities. Possible values: 'caption' (image captioning), 'vqa' (visual question answering), 'detection' (object/point detection), 'ocr' (text recognition), 'thinking' (reasoning mode support).")
+ prompts: list[str] = Field(title="Prompts", description="Available prompts/tasks for this model")
+ capabilities: list[str] = Field(title="Capabilities", description="Model capabilities. Possible values: 'caption' (image captioning), 'vqa' (visual question answering), 'detection' (object/point detection), 'ocr' (text recognition), 'thinking' (reasoning mode support).")
class ResVLMPrompts(BaseModel):
"""Available VLM prompts grouped by category.
@@ -107,12 +107,12 @@ class ResVLMPrompts(BaseModel):
When called without ``model`` parameter, returns all prompt categories.
When called with ``model``, returns only the ``available`` field with prompts for that model.
"""
- common: Optional[List[str]] = Field(default=None, title="Common", description="Prompts available for all models: Use Prompt, Short/Normal/Long Caption.")
- florence: Optional[List[str]] = Field(default=None, title="Florence", description="Florence-2 base model tasks: Phrase Grounding, Object Detection, Dense Region Caption, Region Proposal, OCR (Read Text), OCR with Regions.")
- promptgen: Optional[List[str]] = Field(default=None, title="PromptGen", description="MiaoshouAI PromptGen fine-tune tasks: Analyze, Generate Tags, Mixed Caption, Mixed Caption+. Only available on PromptGen models.")
- moondream: Optional[List[str]] = Field(default=None, title="Moondream", description="Moondream 2 and 3 tasks: Point at..., Detect all...")
- moondream2_only: Optional[List[str]] = Field(default=None, title="Moondream 2 Only", description="Moondream 2 exclusive tasks: Detect Gaze. Not available in Moondream 3.")
- available: Optional[List[str]] = Field(default=None, title="Available", description="Populated only when filtering by model. Contains the combined list of prompts available for the specified model.")
+ common: list[str] | None = Field(default=None, title="Common", description="Prompts available for all models: Use Prompt, Short/Normal/Long Caption.")
+ florence: list[str] | None = Field(default=None, title="Florence", description="Florence-2 base model tasks: Phrase Grounding, Object Detection, Dense Region Caption, Region Proposal, OCR (Read Text), OCR with Regions.")
+ promptgen: list[str] | None = Field(default=None, title="PromptGen", description="MiaoshouAI PromptGen fine-tune tasks: Analyze, Generate Tags, Mixed Caption, Mixed Caption+. Only available on PromptGen models.")
+ moondream: list[str] | None = Field(default=None, title="Moondream", description="Moondream 2 and 3 tasks: Point at..., Detect all...")
+ moondream2_only: list[str] | None = Field(default=None, title="Moondream 2 Only", description="Moondream 2 exclusive tasks: Detect Gaze. Not available in Moondream 3.")
+ available: list[str] | None = Field(default=None, title="Available", description="Populated only when filtering by model. Contains the combined list of prompts available for the specified model.")
class ItemTaggerModel(BaseModel):
"""Tagger model information."""
@@ -136,7 +136,7 @@ class ReqTagger(BaseModel):
class ResTagger(BaseModel):
"""Response model for image tagging results."""
tags: str = Field(title="Tags", description="Comma-separated list of detected tags")
- scores: Optional[dict] = Field(default=None, title="Scores", description="Tag confidence scores (when show_scores=True)")
+ scores: dict | None = Field(default=None, title="Scores", description="Tag confidence scores (when show_scores=True)")
# =============================================================================
@@ -158,12 +158,12 @@ class ReqCaptionOpenCLIP(BaseModel):
blip_model: str = Field(default="blip-large", title="Caption Model", description="BLIP model used to generate the initial image caption.")
mode: str = Field(default="best", title="Mode", description="Caption mode: 'best' (highest quality, slowest), 'fast' (quick, fewer flavors), 'classic' (balanced), 'caption' (BLIP only, no CLIP matching), 'negative' (for negative prompts).")
analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed breakdown (medium, artist, movement, trending, flavor).")
- max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum tokens in generated caption.")
- chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing flavors.")
- min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum descriptive tags to keep.")
- max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum descriptive tags to keep.")
- flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of intermediate candidate pool.")
- num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Beams for beam search during caption generation.")
+ max_length: int | None = Field(default=None, title="Max Length", description="Maximum tokens in generated caption.")
+ chunk_size: int | None = Field(default=None, title="Chunk Size", description="Batch size for processing flavors.")
+ min_flavors: int | None = Field(default=None, title="Min Flavors", description="Minimum descriptive tags to keep.")
+ max_flavors: int | None = Field(default=None, title="Max Flavors", description="Maximum descriptive tags to keep.")
+ flavor_count: int | None = Field(default=None, title="Intermediates", description="Size of intermediate candidate pool.")
+ num_beams: int | None = Field(default=None, title="Num Beams", description="Beams for beam search during caption generation.")
class ReqCaptionTagger(BaseModel):
@@ -196,24 +196,24 @@ class ReqCaptionVLM(BaseModel):
image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string.")
model: str = Field(default="Alibaba Qwen 2.5 VL 3B", title="Model", description="VLM model to use. See GET /sdapi/v1/vqa/models for full list.")
question: str = Field(default="describe the image", title="Question/Task", description="Task to perform: 'Short Caption', 'Normal Caption', 'Long Caption', 'Use Prompt' (custom text via prompt field). Model-specific tasks available via GET /sdapi/v1/vqa/prompts.")
- prompt: Optional[str] = Field(default=None, title="Prompt", description="Custom prompt text when question is 'Use Prompt'.")
+ prompt: str | None = Field(default=None, title="Prompt", description="Custom prompt text when question is 'Use Prompt'.")
system: str = Field(default="You are image captioning expert, creative, unbiased and uncensored.", title="System Prompt", description="System prompt for LLM behavior.")
include_annotated: bool = Field(default=False, title="Include Annotated Image", description="Return annotated image for detection tasks.")
- max_tokens: Optional[int] = Field(default=None, title="Max Tokens", description="Maximum tokens in response.")
- temperature: Optional[float] = Field(default=None, title="Temperature", description="Randomness in token selection (0=deterministic, 0.9=creative).")
- top_k: Optional[int] = Field(default=None, title="Top-K", description="Limit to K most likely tokens per step.")
- top_p: Optional[float] = Field(default=None, title="Top-P", description="Nucleus sampling threshold.")
- num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Beam search width (1=disabled).")
- do_sample: Optional[bool] = Field(default=None, title="Use Samplers", description="Enable sampling vs greedy decoding.")
- thinking_mode: Optional[bool] = Field(default=None, title="Thinking Mode", description="Enable reasoning mode (supported models only).")
- prefill: Optional[str] = Field(default=None, title="Prefill Text", description="Pre-fill response start to guide output.")
- keep_thinking: Optional[bool] = Field(default=None, title="Keep Thinking Trace", description="Include reasoning in output.")
- keep_prefill: Optional[bool] = Field(default=None, title="Keep Prefill", description="Keep prefill text in final output.")
+ max_tokens: int | None = Field(default=None, title="Max Tokens", description="Maximum tokens in response.")
+ temperature: float | None = Field(default=None, title="Temperature", description="Randomness in token selection (0=deterministic, 0.9=creative).")
+ top_k: int | None = Field(default=None, title="Top-K", description="Limit to K most likely tokens per step.")
+ top_p: float | None = Field(default=None, title="Top-P", description="Nucleus sampling threshold.")
+ num_beams: int | None = Field(default=None, title="Num Beams", description="Beam search width (1=disabled).")
+ do_sample: bool | None = Field(default=None, title="Use Samplers", description="Enable sampling vs greedy decoding.")
+ thinking_mode: bool | None = Field(default=None, title="Thinking Mode", description="Enable reasoning mode (supported models only).")
+ prefill: str | None = Field(default=None, title="Prefill Text", description="Pre-fill response start to guide output.")
+ keep_thinking: bool | None = Field(default=None, title="Keep Thinking Trace", description="Include reasoning in output.")
+ keep_prefill: bool | None = Field(default=None, title="Keep Prefill", description="Keep prefill text in final output.")
# Discriminated union for the dispatch endpoint
ReqCaptionDispatch = Annotated[
- Union[ReqCaptionOpenCLIP, ReqCaptionTagger, ReqCaptionVLM],
+ ReqCaptionOpenCLIP | ReqCaptionTagger | ReqCaptionVLM,
Field(discriminator="backend")
]
@@ -226,18 +226,18 @@ class ResCaptionDispatch(BaseModel):
# Common
backend: str = Field(title="Backend", description="The backend that processed the request: 'openclip', 'tagger', or 'vlm'.")
# OpenCLIP fields
- caption: Optional[str] = Field(default=None, title="Caption", description="Generated caption (OpenCLIP backend).")
- medium: Optional[str] = Field(default=None, title="Medium", description="Detected artistic medium (OpenCLIP with analyze=True).")
- artist: Optional[str] = Field(default=None, title="Artist", description="Detected artist style (OpenCLIP with analyze=True).")
- movement: Optional[str] = Field(default=None, title="Movement", description="Detected art movement (OpenCLIP with analyze=True).")
- trending: Optional[str] = Field(default=None, title="Trending", description="Trending tags (OpenCLIP with analyze=True).")
- flavor: Optional[str] = Field(default=None, title="Flavor", description="Flavor descriptors (OpenCLIP with analyze=True).")
+ caption: str | None = Field(default=None, title="Caption", description="Generated caption (OpenCLIP backend).")
+ medium: str | None = Field(default=None, title="Medium", description="Detected artistic medium (OpenCLIP with analyze=True).")
+ artist: str | None = Field(default=None, title="Artist", description="Detected artist style (OpenCLIP with analyze=True).")
+ movement: str | None = Field(default=None, title="Movement", description="Detected art movement (OpenCLIP with analyze=True).")
+ trending: str | None = Field(default=None, title="Trending", description="Trending tags (OpenCLIP with analyze=True).")
+ flavor: str | None = Field(default=None, title="Flavor", description="Flavor descriptors (OpenCLIP with analyze=True).")
# Tagger fields
- tags: Optional[str] = Field(default=None, title="Tags", description="Comma-separated tags (Tagger backend).")
- scores: Optional[dict] = Field(default=None, title="Scores", description="Tag confidence scores (Tagger with show_scores=True).")
+ tags: str | None = Field(default=None, title="Tags", description="Comma-separated tags (Tagger backend).")
+ scores: dict | None = Field(default=None, title="Scores", description="Tag confidence scores (Tagger with show_scores=True).")
# VLM fields
- answer: Optional[str] = Field(default=None, title="Answer", description="VLM response (VLM backend).")
- annotated_image: Optional[str] = Field(default=None, title="Annotated Image", description="Base64 annotated image (VLM with include_annotated=True).")
+ answer: str | None = Field(default=None, title="Answer", description="VLM response (VLM backend).")
+ annotated_image: str | None = Field(default=None, title="Annotated Image", description="Base64 annotated image (VLM with include_annotated=True).")
# =============================================================================
@@ -596,7 +596,7 @@ def get_vqa_models():
return models_list
-def get_vqa_prompts(model: Optional[str] = None):
+def get_vqa_prompts(model: str | None = None):
"""
List available prompts/tasks for VLM models.
@@ -653,11 +653,11 @@ def get_tagger_models():
def register_api():
from modules.shared import api
- api.add_api_route("/sdapi/v1/openclip", get_caption, methods=["GET"], response_model=List[str], tags=["Caption"])
+ api.add_api_route("/sdapi/v1/openclip", get_caption, methods=["GET"], response_model=list[str], tags=["Caption"])
api.add_api_route("/sdapi/v1/caption", post_caption_dispatch, methods=["POST"], response_model=ResCaptionDispatch, tags=["Caption"])
api.add_api_route("/sdapi/v1/openclip", post_caption, methods=["POST"], response_model=ResCaption, tags=["Caption"])
api.add_api_route("/sdapi/v1/vqa", post_vqa, methods=["POST"], response_model=ResVQA, tags=["Caption"])
- api.add_api_route("/sdapi/v1/vqa/models", get_vqa_models, methods=["GET"], response_model=List[ItemVLMModel], tags=["Caption"])
+ api.add_api_route("/sdapi/v1/vqa/models", get_vqa_models, methods=["GET"], response_model=list[ItemVLMModel], tags=["Caption"])
api.add_api_route("/sdapi/v1/vqa/prompts", get_vqa_prompts, methods=["GET"], response_model=ResVLMPrompts, tags=["Caption"])
api.add_api_route("/sdapi/v1/tagger", post_tagger, methods=["POST"], response_model=ResTagger, tags=["Caption"])
- api.add_api_route("/sdapi/v1/tagger/models", get_tagger_models, methods=["GET"], response_model=List[ItemTaggerModel], tags=["Caption"])
+ api.add_api_route("/sdapi/v1/tagger/models", get_tagger_models, methods=["GET"], response_model=list[ItemTaggerModel], tags=["Caption"])
diff --git a/modules/api/control.py b/modules/api/control.py
index 63b7636dd..5f5882e04 100644
--- a/modules/api/control.py
+++ b/modules/api/control.py
@@ -1,4 +1,4 @@
-from typing import Optional, List
+from typing import Optional
from threading import Lock
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
from modules import errors, shared, processing_helpers
@@ -43,9 +43,9 @@ ReqControl = models.create_model_from_signature(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
- {"key": "ip_adapter", "type": Optional[List[models.ItemIPAdapter]], "default": None, "exclude": True},
+ {"key": "ip_adapter", "type": Optional[list[models.ItemIPAdapter]], "default": None, "exclude": True},
{"key": "face", "type": Optional[models.ItemFace], "default": None, "exclude": True},
- {"key": "control", "type": Optional[List[ItemControl]], "default": [], "exclude": True},
+ {"key": "control", "type": Optional[list[ItemControl]], "default": [], "exclude": True},
{"key": "xyz", "type": Optional[ItemXYZ], "default": None, "exclude": True},
# {"key": "extra", "type": Optional[dict], "default": {}, "exclude": True},
]
@@ -55,13 +55,13 @@ if not hasattr(ReqControl, "__config__"):
class ResControl(BaseModel):
- images: List[str] = Field(default=None, title="Images", description="")
- processed: List[str] = Field(default=None, title="Processed", description="")
+ images: list[str] = Field(default=None, title="Images", description="")
+ processed: list[str] = Field(default=None, title="Processed", description="")
params: dict = Field(default={}, title="Settings", description="")
info: str = Field(default="", title="Info", description="")
-class APIControl():
+class APIControl:
def __init__(self, queue_lock: Lock):
self.queue_lock = queue_lock
self.default_script_arg = []
diff --git a/modules/api/endpoints.py b/modules/api/endpoints.py
index 0ad141b48..543756906 100644
--- a/modules/api/endpoints.py
+++ b/modules/api/endpoints.py
@@ -1,4 +1,3 @@
-from typing import Optional
from modules import shared
from modules.api import models, helpers
@@ -43,7 +42,7 @@ def get_sd_models():
checkpoints.append(model)
return checkpoints
-def get_controlnets(model_type: Optional[str] = None):
+def get_controlnets(model_type: str | None = None):
from modules.control.units.controlnet import api_list_models
return api_list_models(model_type)
@@ -60,7 +59,7 @@ def get_embeddings():
return models.ResEmbeddings(loaded=[], skipped=[])
return models.ResEmbeddings(loaded=list(db.word_embeddings.keys()), skipped=list(db.skipped_embeddings.keys()))
-def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, filename: Optional[str] = None, title: Optional[str] = None, fullname: Optional[str] = None, hash: Optional[str] = None): # pylint: disable=redefined-builtin
+def get_extra_networks(page: str | None = None, name: str | None = None, filename: str | None = None, title: str | None = None, fullname: str | None = None, hash: str | None = None): # pylint: disable=redefined-builtin
res = []
for pg in shared.extra_networks:
if page is not None and pg.name != page.lower():
diff --git a/modules/api/gallery.py b/modules/api/gallery.py
index e4dc8ba0b..ac6cf93e1 100644
--- a/modules/api/gallery.py
+++ b/modules/api/gallery.py
@@ -2,7 +2,6 @@ import io
import os
import time
import base64
-from typing import List, Union
from urllib.parse import quote, unquote
from fastapi import FastAPI
from fastapi.responses import JSONResponse
@@ -52,7 +51,7 @@ class ConnectionManager:
debug(f'Browser WS disconnect: client={ws.client.host}')
self.active.remove(ws)
- async def send(self, ws: WebSocket, data: Union[str, dict, bytes]):
+ async def send(self, ws: WebSocket, data: str | dict | bytes):
# debug(f'Browser WS send: client={ws.client.host} data={type(data)}')
if ws.client_state != WebSocketState.CONNECTED:
return
@@ -65,7 +64,7 @@ class ConnectionManager:
else:
debug(f'Browser WS send: client={ws.client.host} data={type(data)} unknown')
- async def broadcast(self, data: Union[str, dict, bytes]):
+ async def broadcast(self, data: str | dict | bytes):
for ws in self.active:
await self.send(ws, data)
@@ -206,7 +205,7 @@ def register_api(app: FastAPI): # register api
shared.log.error(f'Gallery: {folder} {e}')
return []
- shared.api.add_api_route("/sdapi/v1/browser/folders", get_folders, methods=["GET"], response_model=List[str])
+ shared.api.add_api_route("/sdapi/v1/browser/folders", get_folders, methods=["GET"], response_model=list[str])
shared.api.add_api_route("/sdapi/v1/browser/thumb", get_thumb, methods=["GET"], response_model=dict)
shared.api.add_api_route("/sdapi/v1/browser/files", ht_files, methods=["GET"], response_model=list)
diff --git a/modules/api/generate.py b/modules/api/generate.py
index 102b15f2c..9fd1d2bf3 100644
--- a/modules/api/generate.py
+++ b/modules/api/generate.py
@@ -9,7 +9,7 @@ from modules.paths import resolve_output_path
errors.install()
-class APIGenerate():
+class APIGenerate:
def __init__(self, queue_lock: Lock):
self.queue_lock = queue_lock
self.default_script_arg_txt2img = []
diff --git a/modules/api/loras.py b/modules/api/loras.py
index c192ec62d..8a9d7a594 100644
--- a/modules/api/loras.py
+++ b/modules/api/loras.py
@@ -1,4 +1,3 @@
-from typing import List
from fastapi.exceptions import HTTPException
@@ -25,5 +24,5 @@ def post_refresh_loras():
def register_api():
from modules.shared import api
api.add_api_route("/sdapi/v1/lora", get_lora, methods=["GET"], response_model=dict)
- api.add_api_route("/sdapi/v1/loras", get_loras, methods=["GET"], response_model=List[dict])
+ api.add_api_route("/sdapi/v1/loras", get_loras, methods=["GET"], response_model=list[dict])
api.add_api_route("/sdapi/v1/refresh-loras", post_refresh_loras, methods=["POST"])
diff --git a/modules/api/models.py b/modules/api/models.py
index b1425d267..7305fc7bf 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,7 +1,15 @@
import re
import inspect
-from typing import Any, Optional, Dict, List, Type, Callable, Union
-from pydantic import BaseModel, Field, create_model # pylint: disable=no-name-in-module
+from typing import Any, Optional, Union
+from collections.abc import Callable
+import pydantic
+from pydantic import BaseModel, Field, create_model
+try:
+ from pydantic import ConfigDict
+ PYDANTIC_V2 = True
+except ImportError:
+ ConfigDict = None
+ PYDANTIC_V2 = False
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
import modules.shared as shared
@@ -41,8 +49,10 @@ class PydanticModelGenerator:
model_name: str = None,
class_instance = None,
additional_fields = None,
- exclude_fields: List = [],
+ exclude_fields: list = None,
):
+ if exclude_fields is None:
+ exclude_fields = []
def field_type_generator(_k, v):
field_type = v.annotation
return Optional[field_type]
@@ -80,12 +90,15 @@ class PydanticModelGenerator:
def generate_model(self):
model_fields = { d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def }
- DynamicModel = create_model(self._model_name, **model_fields)
- try:
- DynamicModel.__config__.allow_population_by_field_name = True
- DynamicModel.__config__.allow_mutation = True
- except Exception:
- pass
+ if PYDANTIC_V2:
+ config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True)
+ else:
+ class Config:
+ arbitrary_types_allowed = True
+ orm_mode = True
+ allow_population_by_field_name = True
+ config = Config
+ DynamicModel = create_model(self._model_name, __config__=config, **model_fields)
return DynamicModel
### item classes
@@ -100,49 +113,49 @@ class ItemVae(BaseModel):
class ItemUpscaler(BaseModel):
name: str = Field(title="Name")
- model_name: Optional[str] = Field(title="Model Name")
- model_path: Optional[str] = Field(title="Path")
- model_url: Optional[str] = Field(title="URL")
- scale: Optional[float] = Field(title="Scale")
+ model_name: str | None = Field(title="Model Name")
+ model_path: str | None = Field(title="Path")
+ model_url: str | None = Field(title="URL")
+ scale: float | None = Field(title="Scale")
class ItemModel(BaseModel):
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
filename: str = Field(title="Filename")
type: str = Field(title="Model type")
- sha256: Optional[str] = Field(title="SHA256 hash")
- hash: Optional[str] = Field(title="Short hash")
- config: Optional[str] = Field(title="Config file")
+ sha256: str | None = Field(title="SHA256 hash")
+ hash: str | None = Field(title="Short hash")
+ config: str | None = Field(title="Config file")
class ItemHypernetwork(BaseModel):
name: str = Field(title="Name")
- path: Optional[str] = Field(title="Path")
+ path: str | None = Field(title="Path")
class ItemDetailer(BaseModel):
name: str = Field(title="Name")
- path: Optional[str] = Field(title="Path")
+ path: str | None = Field(title="Path")
class ItemGAN(BaseModel):
name: str = Field(title="Name")
- path: Optional[str] = Field(title="Path")
- scale: Optional[int] = Field(title="Scale")
+ path: str | None = Field(title="Path")
+ scale: int | None = Field(title="Scale")
class ItemStyle(BaseModel):
name: str = Field(title="Name")
- prompt: Optional[str] = Field(title="Prompt")
- negative_prompt: Optional[str] = Field(title="Negative Prompt")
- extra: Optional[str] = Field(title="Extra")
- filename: Optional[str] = Field(title="Filename")
- preview: Optional[str] = Field(title="Preview")
+ prompt: str | None = Field(title="Prompt")
+ negative_prompt: str | None = Field(title="Negative Prompt")
+ extra: str | None = Field(title="Extra")
+ filename: str | None = Field(title="Filename")
+ preview: str | None = Field(title="Preview")
class ItemExtraNetwork(BaseModel):
name: str = Field(title="Name")
type: str = Field(title="Type")
- title: Optional[str] = Field(title="Title")
- fullname: Optional[str] = Field(title="Fullname")
- filename: Optional[str] = Field(title="Filename")
- hash: Optional[str] = Field(title="Hash")
- preview: Optional[str] = Field(title="Preview image URL")
+ title: str | None = Field(title="Title")
+ fullname: str | None = Field(title="Fullname")
+ filename: str | None = Field(title="Filename")
+ hash: str | None = Field(title="Hash")
+ preview: str | None = Field(title="Preview image URL")
class ItemArtist(BaseModel):
name: str = Field(title="Name")
@@ -150,16 +163,16 @@ class ItemArtist(BaseModel):
category: str = Field(title="Category")
class ItemEmbedding(BaseModel):
- step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
- sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
- sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
+ step: int | None = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
+ sd_checkpoint: str | None = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
+ sd_checkpoint_name: str | None = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class ItemIPAdapter(BaseModel):
adapter: str = Field(title="Adapter", default="Base", description="IP adapter name")
- images: List[str] = Field(title="Image", default=[], description="IP adapter input images")
- masks: Optional[List[str]] = Field(title="Mask", default=[], description="IP adapter mask images")
+ images: list[str] = Field(title="Image", default=[], description="IP adapter input images")
+ masks: list[str] | None = Field(title="Mask", default=[], description="IP adapter mask images")
scale: float = Field(title="Scale", default=0.5, ge=0, le=1, description="IP adapter scale")
start: float = Field(title="Start", default=0.0, ge=0, le=1, description="IP adapter start step")
end: float = Field(title="End", default=1.0, gt=0, le=1, description="IP adapter end step")
@@ -183,17 +196,17 @@ class ItemFace(BaseModel):
class ScriptArg(BaseModel):
label: str = Field(default=None, title="Label", description="Name of the argument in UI")
- value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
- minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
- maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
- step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
- choices: Optional[Any] = Field(default=None, title="Choices", description="Possible values for the argument")
+ value: Any | None = Field(default=None, title="Value", description="Default value of the argument")
+ minimum: Any | None = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
+ maximum: Any | None = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
+ step: Any | None = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
+ choices: Any | None = Field(default=None, title="Choices", description="Possible values for the argument")
class ItemScript(BaseModel):
name: str = Field(default=None, title="Name", description="Script name")
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
- args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
+ args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
class ItemExtension(BaseModel):
name: str = Field(title="Name", description="Extension name")
@@ -201,13 +214,13 @@ class ItemExtension(BaseModel):
branch: str = Field(default="uknnown", title="Branch", description="Extension Repository Branch")
commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash")
version: str = Field(title="Version", description="Extension Version")
- commit_date: Union[str, int] = Field(title="Commit Date", description="Extension Repository Commit Date")
+ commit_date: str | int = Field(title="Commit Date", description="Extension Repository Commit Date")
enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
class ItemScheduler(BaseModel):
name: str = Field(title="Name", description="Scheduler name")
cls: str = Field(title="Class", description="Scheduler class name")
- options: Dict[str, Any] = Field(title="Options", description="Dictionary of scheduler options")
+ options: dict[str, Any] = Field(title="Options", description="Dictionary of scheduler options")
### request/response classes
@@ -223,7 +236,7 @@ ReqTxt2Img = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
- {"key": "ip_adapter", "type": Optional[List[ItemIPAdapter]], "default": None, "exclude": True},
+ {"key": "ip_adapter", "type": Optional[list[ItemIPAdapter]], "default": None, "exclude": True},
{"key": "face", "type": Optional[ItemFace], "default": None, "exclude": True},
{"key": "extra", "type": Optional[dict], "default": {}, "exclude": True},
]
@@ -233,7 +246,7 @@ if not hasattr(ReqTxt2Img, "__config__"):
StableDiffusionTxt2ImgProcessingAPI = ReqTxt2Img
class ResTxt2Img(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated images in base64 format.")
+ images: list[str] = Field(default=None, title="Image", description="The generated images in base64 format.")
parameters: dict
info: str
@@ -253,7 +266,7 @@ ReqImg2Img = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
- {"key": "ip_adapter", "type": Optional[List[ItemIPAdapter]], "default": None, "exclude": True},
+ {"key": "ip_adapter", "type": Optional[list[ItemIPAdapter]], "default": None, "exclude": True},
{"key": "face_id", "type": Optional[ItemFace], "default": None, "exclude": True},
{"key": "extra", "type": Optional[dict], "default": {}, "exclude": True},
]
@@ -263,7 +276,7 @@ if not hasattr(ReqImg2Img, "__config__"):
StableDiffusionImg2ImgProcessingAPI = ReqImg2Img
class ResImg2Img(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated images in base64 format.")
+ images: list[str] = Field(default=None, title="Image", description="The generated images in base64 format.")
parameters: dict
info: str
@@ -289,9 +302,9 @@ class ResProcess(BaseModel):
class ReqPromptEnhance(BaseModel):
prompt: str = Field(title="Prompt", description="Prompt to enhance")
type: str = Field(title="Type", default='text', description="Type of enhancement to perform")
- model: Optional[str] = Field(title="Model", default=None, description="Model to use for enhancement")
- system_prompt: Optional[str] = Field(title="System prompt", default=None, description="Model system prompt")
- image: Optional[str] = Field(title="Image", default=None, description="Image to work on, must be a Base64 string containing the image's data.")
+ model: str | None = Field(title="Model", default=None, description="Model to use for enhancement")
+ system_prompt: str | None = Field(title="System prompt", default=None, description="Model system prompt")
+ image: str | None = Field(title="Image", default=None, description="Image to work on, must be a Base64 string containing the image's data.")
seed: int = Field(title="Seed", default=-1, description="Seed used to generate the prompt")
nsfw: bool = Field(title="NSFW", default=True, description="Should NSFW content be allowed?")
@@ -306,10 +319,10 @@ class ResProcessImage(ResProcess):
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
class ReqProcessBatch(ReqProcess):
- imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+ imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ResProcessBatch(ResProcess):
- images: List[str] = Field(title="Images", description="The generated images in base64 format.")
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
class ReqImageInfo(BaseModel):
image: str = Field(title="Image", description="The base64 encoded image")
@@ -325,38 +338,38 @@ class ReqGetLog(BaseModel):
class ReqPostLog(BaseModel):
- message: Optional[str] = Field(default=None, title="Message", description="The info message to log")
- debug: Optional[str] = Field(default=None, title="Debug message", description="The debug message to log")
- error: Optional[str] = Field(default=None, title="Error message", description="The error message to log")
+ message: str | None = Field(default=None, title="Message", description="The info message to log")
+ debug: str | None = Field(default=None, title="Debug message", description="The debug message to log")
+ error: str | None = Field(default=None, title="Error message", description="The error message to log")
class ReqHistory(BaseModel):
- id: Union[int, str, None] = Field(default=None, title="Task ID", description="Task ID")
+ id: int | str | None = Field(default=None, title="Task ID", description="Task ID")
class ReqProgress(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
class ResProgress(BaseModel):
- id: Union[int, str, None] = Field(title="TaskID", description="Task ID")
+ id: int | str | None = Field(title="TaskID", description="Task ID")
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
- current_image: Optional[str] = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
- textinfo: Optional[str] = Field(default=None, title="Info text", description="Info text used by WebUI.")
+ current_image: str | None = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
+ textinfo: str | None = Field(default=None, title="Info text", description="Info text used by WebUI.")
class ResHistory(BaseModel):
- id: Union[int, str, None] = Field(title="ID", description="Task ID")
+ id: int | str | None = Field(title="ID", description="Task ID")
job: str = Field(title="Job", description="Job name")
op: str = Field(title="Operation", description="Job state")
- timestamp: Union[float, None] = Field(title="Timestamp", description="Job timestamp")
- duration: Union[float, None] = Field(title="Duration", description="Job duration")
- outputs: List[str] = Field(title="Outputs", description="List of filenames")
+ timestamp: float | None = Field(title="Timestamp", description="Job timestamp")
+ duration: float | None = Field(title="Duration", description="Job duration")
+ outputs: list[str] = Field(title="Outputs", description="List of filenames")
class ResStatus(BaseModel):
status: str = Field(title="Status", description="Current status")
task: str = Field(title="Task", description="Current job")
- timestamp: Optional[str] = Field(title="Timestamp", description="Timestamp of the current job")
+ timestamp: str | None = Field(title="Timestamp", description="Timestamp of the current job")
current: str = Field(title="Task", description="Current job")
- id: Union[int, str, None] = Field(title="ID", description="ID of the current task")
+ id: int | str | None = Field(title="ID", description="ID of the current task")
job: int = Field(title="Job", description="Current job")
jobs: int = Field(title="Jobs", description="Total jobs")
total: int = Field(title="Total Jobs", description="Total jobs")
@@ -364,9 +377,9 @@ class ResStatus(BaseModel):
steps: int = Field(title="Steps", description="Total steps")
queued: int = Field(title="Queued", description="Number of queued tasks")
uptime: int = Field(title="Uptime", description="Uptime of the server")
- elapsed: Optional[float] = Field(default=None, title="Elapsed time")
- eta: Optional[float] = Field(default=None, title="ETA in secs")
- progress: Optional[float] = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
+ elapsed: float | None = Field(default=None, title="Elapsed time")
+ eta: float | None = Field(default=None, title="ETA in secs")
+ progress: float | None = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
class ReqLatentHistory(BaseModel):
name: str = Field(title="Name", description="Name of the history item to select")
@@ -392,7 +405,15 @@ for key, metadata in shared.opts.data_labels.items():
else:
fields.update({key: (Optional[optType], Field())})
-OptionsModel = create_model("Options", **fields)
+if PYDANTIC_V2:
+ config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True)
+else:
+ class Config:
+ arbitrary_types_allowed = True
+ orm_mode = True
+ allow_population_by_field_name = True
+ config = Config
+OptionsModel = create_model("Options", __config__=config, **fields)
flags = {}
_options = vars(shared.parser)['_option_string_actions']
@@ -404,7 +425,15 @@ for key in _options:
_type = type(_options[key].default)
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
-FlagsModel = create_model("Flags", **flags)
+if PYDANTIC_V2:
+ config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True)
+else:
+ class Config:
+ arbitrary_types_allowed = True
+ orm_mode = True
+ allow_population_by_field_name = True
+ config = Config
+FlagsModel = create_model("Flags", __config__=config, **flags)
class ResEmbeddings(BaseModel):
loaded: list = Field(default=None, title="loaded", description="List of loaded embeddings")
@@ -426,9 +455,13 @@ class ResGPU(BaseModel): # definition of http response
# helper function
-def create_model_from_signature(func: Callable, model_name: str, base_model: Type[BaseModel] = BaseModel, additional_fields: List = [], exclude_fields: List[str] = []) -> type[BaseModel]:
+def create_model_from_signature(func: Callable, model_name: str, base_model: type[BaseModel] = BaseModel, additional_fields: list = None, exclude_fields: list[str] = None) -> type[BaseModel]:
from PIL import Image
+ if exclude_fields is None:
+ exclude_fields = []
+ if additional_fields is None:
+ additional_fields = []
class Config:
extra = 'allow'
@@ -443,13 +476,13 @@ def create_model_from_signature(func: Callable, model_name: str, base_model: Typ
defaults = (...,) * non_default_args + defaults
keyword_only_params = {param: kwonlydefaults.get(param, Any) for param in kwonlyargs}
for k, v in annotations.items():
- if v == List[Image.Image]:
- annotations[k] = List[str]
+ if v == list[Image.Image]:
+ annotations[k] = list[str]
elif v == Image.Image:
annotations[k] = str
elif str(v) == 'typing.List[modules.control.unit.Unit]':
- annotations[k] = List[str]
- model_fields = {param: (annotations.get(param, Any), default) for param, default in zip(args, defaults)}
+ annotations[k] = list[str]
+ model_fields = {param: (annotations.get(param, Any), default) for param, default in zip(args, defaults, strict=False)}
for fld in additional_fields:
model_def = ModelDef(
@@ -464,16 +497,21 @@ def create_model_from_signature(func: Callable, model_name: str, base_model: Typ
if fld in model_fields:
del model_fields[fld]
+ if PYDANTIC_V2:
+ config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True, extra='allow' if varkw else 'ignore')
+ else:
+ class Config:
+ arbitrary_types_allowed = True
+ orm_mode = True
+ allow_population_by_field_name = True
+ extra = 'allow' if varkw else 'ignore'
+ config = Config
+
model = create_model(
model_name,
- **model_fields,
- **keyword_only_params,
__base__=base_model,
__config__=config,
+ **model_fields,
+ **keyword_only_params,
)
- try:
- model.__config__.allow_population_by_field_name = True
- model.__config__.allow_mutation = True
- except Exception:
- pass
return model
diff --git a/modules/api/process.py b/modules/api/process.py
index c106a18e2..c3d102252 100644
--- a/modules/api/process.py
+++ b/modules/api/process.py
@@ -1,4 +1,3 @@
-from typing import Optional, List
from threading import Lock
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
from fastapi.responses import JSONResponse
@@ -15,7 +14,7 @@ errors.install()
class ReqPreprocess(BaseModel):
image: str = Field(title="Image", description="The base64 encoded image")
model: str = Field(title="Model", description="The model to use for preprocessing")
- params: Optional[dict] = Field(default={}, title="Settings", description="Preprocessor settings")
+ params: dict | None = Field(default={}, title="Settings", description="Preprocessor settings")
class ResPreprocess(BaseModel):
model: str = Field(default='', title="Model", description="The processor model used")
@@ -24,20 +23,20 @@ class ResPreprocess(BaseModel):
class ReqMask(BaseModel):
image: str = Field(title="Image", description="The base64 encoded image")
type: str = Field(title="Mask type", description="Type of masking image to return")
- mask: Optional[str] = Field(title="Mask", description="If optional maks image is not provided auto-masking will be performed")
- model: Optional[str] = Field(title="Model", description="The model to use for preprocessing")
- params: Optional[dict] = Field(default={}, title="Settings", description="Preprocessor settings")
+ mask: str | None = Field(title="Mask", description="If optional maks image is not provided auto-masking will be performed")
+ model: str | None = Field(title="Model", description="The model to use for preprocessing")
+ params: dict | None = Field(default={}, title="Settings", description="Preprocessor settings")
class ReqFace(BaseModel):
image: str = Field(title="Image", description="The base64 encoded image")
- model: Optional[str] = Field(title="Model", description="The model to use for detection")
+ model: str | None = Field(title="Model", description="The model to use for detection")
class ResFace(BaseModel):
- classes: List[int] = Field(title="Class", description="The class of detected item")
- labels: List[str] = Field(title="Label", description="The label of detected item")
- boxes: List[List[int]] = Field(title="Box", description="The bounding box of detected item")
- images: List[str] = Field(title="Image", description="The base64 encoded images of detected faces")
- scores: List[float] = Field(title="Scores", description="The scores of the detected faces")
+ classes: list[int] = Field(title="Class", description="The class of detected item")
+ labels: list[str] = Field(title="Label", description="The label of detected item")
+ boxes: list[list[int]] = Field(title="Box", description="The bounding box of detected item")
+ images: list[str] = Field(title="Image", description="The base64 encoded images of detected faces")
+ scores: list[float] = Field(title="Scores", description="The scores of the detected faces")
class ResMask(BaseModel):
mask: str = Field(default='', title="Image", description="The processed image in base64 format")
@@ -47,13 +46,13 @@ class ItemPreprocess(BaseModel):
params: dict = Field(title="Params")
class ItemMask(BaseModel):
- models: List[str] = Field(title="Models")
- colormaps: List[str] = Field(title="Color maps")
+ models: list[str] = Field(title="Models")
+ colormaps: list[str] = Field(title="Color maps")
params: dict = Field(title="Params")
- types: List[str] = Field(title="Types")
+ types: list[str] = Field(title="Types")
-class APIProcess():
+class APIProcess:
def __init__(self, queue_lock: Lock):
self.queue_lock = queue_lock
diff --git a/modules/api/script.py b/modules/api/script.py
index 064d26143..ce8df9339 100644
--- a/modules/api/script.py
+++ b/modules/api/script.py
@@ -1,4 +1,3 @@
-from typing import Optional
from fastapi.exceptions import HTTPException
import gradio as gr
from modules.api import models
@@ -36,7 +35,7 @@ def get_scripts_list():
return models.ResScripts(txt2img = t2ilist, img2img = i2ilist, control = control)
-def get_script_info(script_name: Optional[str] = None):
+def get_script_info(script_name: str | None = None):
res = []
for script_list in [scripts_manager.scripts_txt2img.scripts, scripts_manager.scripts_img2img.scripts, scripts_manager.scripts_control.scripts]:
for script in script_list:
diff --git a/modules/api/xyz_grid.py b/modules/api/xyz_grid.py
index 569ae98b0..ec230c324 100644
--- a/modules/api/xyz_grid.py
+++ b/modules/api/xyz_grid.py
@@ -1,7 +1,6 @@
-from typing import List
-def xyz_grid_enum(option: str = "") -> List[dict]:
+def xyz_grid_enum(option: str = "") -> list[dict]:
from scripts.xyz import xyz_grid_classes # pylint: disable=no-name-in-module
options = []
for x in xyz_grid_classes.axis_options:
@@ -23,4 +22,4 @@ def xyz_grid_enum(option: str = "") -> List[dict]:
def register_api():
from modules.shared import api as api_instance
- api_instance.add_api_route("/sdapi/v1/xyz-grid", xyz_grid_enum, methods=["GET"], response_model=List[dict])
+ api_instance.add_api_route("/sdapi/v1/xyz-grid", xyz_grid_enum, methods=["GET"], response_model=list[dict])
diff --git a/modules/attention.py b/modules/attention.py
index a0a29bfb1..6490f6abb 100644
--- a/modules/attention.py
+++ b/modules/attention.py
@@ -1,4 +1,3 @@
-from typing import Optional
from functools import wraps
import torch
from modules import rocm
@@ -23,7 +22,7 @@ def set_triton_flash_attention(backend: str):
from modules.flash_attn_triton_amd import interface_fa
sdpa_pre_triton_flash_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_triton_flash_atten)
- def sdpa_triton_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
+ def sdpa_triton_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
if scale is None:
scale = query.shape[-1] ** (-0.5)
@@ -56,7 +55,7 @@ def set_flex_attention():
sdpa_pre_flex_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_flex_atten)
- def sdpa_flex_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: # pylint: disable=unused-argument
+ def sdpa_flex_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: # pylint: disable=unused-argument
score_mod = None
block_mask = None
if attn_mask is not None:
@@ -96,7 +95,7 @@ def set_ck_flash_attention(backend: str, device: torch.device):
from flash_attn import flash_attn_func
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_flash_atten)
- def sdpa_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
+ def sdpa_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
is_unsqueezed = False
if query.dim() == 3:
@@ -162,7 +161,7 @@ def set_sage_attention(backend: str, device: torch.device):
sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_sage_atten)
- def sdpa_sage_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
+ def sdpa_sage_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
if (query.shape[-1] in {128, 96, 64}) and (attn_mask is None) and (query.dtype != torch.float32):
if enable_gqa:
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
diff --git a/modules/ben2/ben2_model.py b/modules/ben2/ben2_model.py
index fe38c571e..2eccfc8ce 100644
--- a/modules/ben2/ben2_model.py
+++ b/modules/ben2/ben2_model.py
@@ -373,7 +373,7 @@ class BasicLayer(nn.Module):
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, (-100.0)).masked_fill(attn_mask == 0, 0.0)
for blk in self.blocks:
blk.H, blk.W = H, W
@@ -464,8 +464,8 @@ class SwinTransformer(nn.Module):
patch_size=4,
in_chans=3,
embed_dim=96,
- depths=[2, 2, 6, 2],
- num_heads=[3, 6, 12, 24],
+ depths=None,
+ num_heads=None,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
@@ -479,6 +479,10 @@ class SwinTransformer(nn.Module):
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
use_checkpoint=False):
+ if num_heads is None:
+ num_heads = [3, 6, 12, 24]
+ if depths is None:
+ depths = [2, 2, 6, 2]
super().__init__()
self.pretrain_img_size = pretrain_img_size
@@ -668,8 +672,10 @@ class PositionEmbeddingSine:
class MCLM(nn.Module):
- def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
- super(MCLM, self).__init__()
+ def __init__(self, d_model, num_heads, pool_ratios=None):
+ if pool_ratios is None:
+ pool_ratios = [1, 4, 8]
+ super().__init__()
self.attention = nn.ModuleList([
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
@@ -739,7 +745,7 @@ class MCLM(nn.Module):
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
_g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
outputs_re = []
- for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
+ for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1), strict=False)):
outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
@@ -760,8 +766,10 @@ class MCLM(nn.Module):
class MCRM(nn.Module):
- def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None): # pylint: disable=unused-argument
- super(MCRM, self).__init__()
+ def __init__(self, d_model, num_heads, pool_ratios=None, h=None): # pylint: disable=unused-argument
+ if pool_ratios is None:
+ pool_ratios = [4, 8, 16]
+ super().__init__()
self.attention = nn.ModuleList([
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
@@ -1049,7 +1057,7 @@ class BEN_Base(nn.Module):
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
- raise IOError(f"Cannot open video: {video_path}")
+ raise OSError(f"Cannot open video: {video_path}")
original_fps = cap.get(cv2.CAP_PROPFPS)
original_fps = 30 if original_fps == 0 else original_fps
@@ -1225,7 +1233,7 @@ def add_audio_to_video(video_without_audio_path, original_video_path, output_pat
'-of', 'csv=p=0',
original_video_path
]
- result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False)
+ result = subprocess.run(probe_command, capture_output=True, text=True, check=False)
# result.stdout is empty if no audio stream found
if not result.stdout.strip():
diff --git a/modules/caption/deepbooru.py b/modules/caption/deepbooru.py
index d8324e428..5162a0dc8 100644
--- a/modules/caption/deepbooru.py
+++ b/modules/caption/deepbooru.py
@@ -4,7 +4,7 @@ import threading
import torch
import numpy as np
from PIL import Image
-from modules import modelloader, devices, shared
+from modules import modelloader, devices, shared, paths
re_special = re.compile(r'([\\()])')
load_lock = threading.Lock()
@@ -18,7 +18,7 @@ class DeepDanbooru:
with load_lock:
if self.model is not None:
return
- model_path = os.path.join(shared.models_path, "DeepDanbooru")
+ model_path = os.path.join(paths.models_path, "DeepDanbooru")
shared.log.debug(f'Caption load: module=DeepDanbooru folder="{model_path}"')
files = modelloader.load_models(
model_path=model_path,
@@ -96,7 +96,7 @@ class DeepDanbooru:
x = torch.from_numpy(a).to(device=devices.device, dtype=devices.dtype)
y = self.model(x)[0].detach().float().cpu().numpy()
probability_dict = {}
- for current, probability in zip(self.model.tags, y):
+ for current, probability in zip(self.model.tags, y, strict=False):
if probability < general_threshold:
continue
if current.startswith("rating:") and not include_rating:
diff --git a/modules/caption/deepbooru_model.py b/modules/caption/deepbooru_model.py
index 2963385c3..9489182ab 100644
--- a/modules/caption/deepbooru_model.py
+++ b/modules/caption/deepbooru_model.py
@@ -671,4 +671,4 @@ class DeepDanbooruModel(nn.Module):
def load_state_dict(self, state_dict, **kwargs): # pylint: disable=arguments-differ,unused-argument
self.tags = state_dict.get('tags', [])
- super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'}) # pylint: disable=R1725
+ super().load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'}) # pylint: disable=R1725
diff --git a/modules/caption/deepseek.py b/modules/caption/deepseek.py
index e7d3eac0c..44bff3fb6 100644
--- a/modules/caption/deepseek.py
+++ b/modules/caption/deepseek.py
@@ -21,7 +21,7 @@ vl_chat_processor = None
loaded_repo = None
-class fake_attrdict():
+class fake_attrdict:
class AttrDict(dict): # dot notation access to dictionary attributes
__getattr__ = dict.get
__setattr__ = dict.__setitem__
diff --git a/modules/caption/joycaption.py b/modules/caption/joycaption.py
index 3fc990ba8..99f4ef677 100644
--- a/modules/caption/joycaption.py
+++ b/modules/caption/joycaption.py
@@ -39,7 +39,7 @@ Extra Options:
"""
@dataclass
-class JoyOptions():
+class JoyOptions:
repo: str = "fancyfeast/llama-joycaption-alpha-two-hf-llava"
temp: float = 0.5
top_k: float = 10
diff --git a/modules/caption/joytag.py b/modules/caption/joytag.py
index efae5bbca..042f87529 100644
--- a/modules/caption/joytag.py
+++ b/modules/caption/joytag.py
@@ -6,7 +6,6 @@ import os
import math
import json
from pathlib import Path
-from typing import Optional
from PIL import Image
import torch
import torch.backends.cuda
@@ -126,7 +125,7 @@ class VisionModel(nn.Module):
@staticmethod
def load_model(path: str) -> 'VisionModel':
- with open(Path(path) / 'config.json', 'r', encoding='utf8') as f:
+ with open(Path(path) / 'config.json', encoding='utf8') as f:
config = json.load(f)
from safetensors.torch import load_file
resume = load_file(Path(path) / 'model.safetensors', device='cpu')
@@ -244,7 +243,7 @@ class CLIPMlp(nn.Module):
class FastCLIPAttention2(nn.Module):
"""Fast Attention module for CLIP-like. This is NOT a drop-in replacement for CLIPAttention, since it adds additional flexibility. Mainly uses xformers."""
- def __init__(self, hidden_size: int, out_dim: int, num_attention_heads: int, out_seq_len: Optional[int] = None, norm_qk: bool = False):
+ def __init__(self, hidden_size: int, out_dim: int, num_attention_heads: int, out_seq_len: int | None = None, norm_qk: bool = False):
super().__init__()
self.out_seq_len = out_seq_len
self.embed_dim = hidden_size
@@ -308,12 +307,12 @@ class FastCLIPEncoderLayer(nn.Module):
self,
hidden_size: int,
num_attention_heads: int,
- out_seq_len: Optional[int],
+ out_seq_len: int | None,
activation_cls = QuickGELUActivation,
use_palm_alt: bool = False,
norm_qk: bool = False,
- skip_init: Optional[float] = None,
- stochastic_depth: Optional[float] = None,
+ skip_init: float | None = None,
+ stochastic_depth: float | None = None,
):
super().__init__()
self.use_palm_alt = use_palm_alt
@@ -523,8 +522,8 @@ class CLIPLikeModel(VisionModel):
norm_qk: bool = False,
no_wd_bias: bool = False,
use_gap_head: bool = False,
- skip_init: Optional[float] = None,
- stochastic_depth: Optional[float] = None,
+ skip_init: float | None = None,
+ stochastic_depth: float | None = None,
):
super().__init__(image_size, n_tags)
out_dim = n_tags
@@ -939,7 +938,7 @@ class ViT(VisionModel):
stochdepth_rate: float,
use_sine: bool,
loss_type: str,
- layerscale_init: Optional[float] = None,
+ layerscale_init: float | None = None,
head_mean_after: bool = False,
cnn_stem: str = None,
patch_dropout: float = 0.0,
@@ -1048,7 +1047,7 @@ def load():
model = VisionModel.load_model(folder)
model.to(dtype=devices.dtype)
model.eval() # required: custom loader, not from_pretrained
- with open(os.path.join(folder, 'top_tags.txt'), 'r', encoding='utf8') as f:
+ with open(os.path.join(folder, 'top_tags.txt'), encoding='utf8') as f:
tags = [line.strip() for line in f.readlines() if line.strip()]
shared.log.info(f'Caption: type=vlm model="JoyTag" repo="{MODEL_REPO}" tags={len(tags)}')
sd_models.move_model(model, devices.device)
diff --git a/modules/caption/openclip.py b/modules/caption/openclip.py
index c7fca4d5f..b43229a93 100644
--- a/modules/caption/openclip.py
+++ b/modules/caption/openclip.py
@@ -330,11 +330,11 @@ def analyze_image(image, clip_model, blip_model):
top_movements = ci.movements.rank(image_features, 5)
top_trendings = ci.trendings.rank(image_features, 5)
top_flavors = ci.flavors.rank(image_features, 5)
- medium_ranks = dict(sorted(zip(top_mediums, ci.similarities(image_features, top_mediums)), key=lambda x: x[1], reverse=True))
- artist_ranks = dict(sorted(zip(top_artists, ci.similarities(image_features, top_artists)), key=lambda x: x[1], reverse=True))
- movement_ranks = dict(sorted(zip(top_movements, ci.similarities(image_features, top_movements)), key=lambda x: x[1], reverse=True))
- trending_ranks = dict(sorted(zip(top_trendings, ci.similarities(image_features, top_trendings)), key=lambda x: x[1], reverse=True))
- flavor_ranks = dict(sorted(zip(top_flavors, ci.similarities(image_features, top_flavors)), key=lambda x: x[1], reverse=True))
+ medium_ranks = dict(sorted(zip(top_mediums, ci.similarities(image_features, top_mediums), strict=False), key=lambda x: x[1], reverse=True))
+ artist_ranks = dict(sorted(zip(top_artists, ci.similarities(image_features, top_artists), strict=False), key=lambda x: x[1], reverse=True))
+ movement_ranks = dict(sorted(zip(top_movements, ci.similarities(image_features, top_movements), strict=False), key=lambda x: x[1], reverse=True))
+ trending_ranks = dict(sorted(zip(top_trendings, ci.similarities(image_features, top_trendings), strict=False), key=lambda x: x[1], reverse=True))
+ flavor_ranks = dict(sorted(zip(top_flavors, ci.similarities(image_features, top_flavors), strict=False), key=lambda x: x[1], reverse=True))
shared.log.debug(f'CLIP analyze: complete time={time.time()-t0:.2f}')
# Format labels as text
diff --git a/modules/caption/vqa.py b/modules/caption/vqa.py
index d840b6d25..9c358145f 100644
--- a/modules/caption/vqa.py
+++ b/modules/caption/vqa.py
@@ -708,7 +708,7 @@ class VQA:
debug(f'VQA caption: handler=qwen output_ids_shape={output_ids.shape}')
generated_ids = [
output_ids[len(input_ids):]
- for input_ids, output_ids in zip(inputs.input_ids, output_ids)
+ for input_ids, output_ids in zip(inputs.input_ids, output_ids, strict=False)
]
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if debug_enabled:
@@ -887,7 +887,7 @@ class VQA:
def _ovis(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
try:
- import flash_attn # pylint: disable=unused-import
+ pass # pylint: disable=unused-import
except Exception:
shared.log.error(f'Caption: vlm="{repo}" flash-attn is not available')
return ''
diff --git a/modules/caption/waifudiffusion.py b/modules/caption/waifudiffusion.py
index 4189fc989..416a9b1bd 100644
--- a/modules/caption/waifudiffusion.py
+++ b/modules/caption/waifudiffusion.py
@@ -126,7 +126,7 @@ class WaifuDiffusionTagger:
self.tags = []
self.tag_categories = []
- with open(csv_path, 'r', encoding='utf-8') as f:
+ with open(csv_path, encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
self.tags.append(row['name'])
@@ -269,7 +269,7 @@ class WaifuDiffusionTagger:
character_count = 0
rating_count = 0
- for i, (tag_name, prob) in enumerate(zip(self.tags, probs)):
+ for i, (tag_name, prob) in enumerate(zip(self.tags, probs, strict=False)):
category = self.tag_categories[i]
tag_lower = tag_name.lower()
diff --git a/modules/civitai/metadata_civitai.py b/modules/civitai/metadata_civitai.py
index 62de254b5..d3594c3f7 100644
--- a/modules/civitai/metadata_civitai.py
+++ b/modules/civitai/metadata_civitai.py
@@ -10,7 +10,9 @@ selected_model = None
class CivitModel:
- def __init__(self, name, fn, sha = None, meta = {}):
+ def __init__(self, name, fn, sha = None, meta = None):
+ if meta is None:
+ meta = {}
self.name = name
self.file = name
self.id = meta.get('id', 0)
diff --git a/modules/civitai/search_civitai.py b/modules/civitai/search_civitai.py
index 3b534662d..b30e32311 100644
--- a/modules/civitai/search_civitai.py
+++ b/modules/civitai/search_civitai.py
@@ -10,7 +10,7 @@ full_html = False
base_models = ['', 'AuraFlow', 'Chroma', 'CogVideoX', 'Flux.1 S', 'Flux.1 D', 'Flux.1 Krea', 'Flux.1 Kontext', 'Flux.2 D', 'HiDream', 'Hunyuan 1', 'Hunyuan Video', 'Illustrious', 'Kolors', 'LTXV', 'Lumina', 'Mochi', 'NoobAI', 'PixArt a', 'PixArt E', 'Pony', 'Pony V7', 'Qwen', 'SD 1.4', 'SD 1.5', 'SD 1.5 LCM', 'SD 1.5 Hyper', 'SD 2.0', 'SD 2.1', 'SDXL 1.0', 'SDXL Lightning', 'SDXL Hyper', 'Wan Video 1.3B t2v', 'Wan Video 14B t2v', 'Wan Video 14B i2v 480p', 'Wan Video 14B i2v 720p', 'Wan Video 2.2 TI2V-5B', 'Wan Video 2.2 I2V-A14B', 'Wan Video 2.2 T2V-A14B', 'Wan Video 2.5 T2V', 'Wan Video 2.5 I2V', 'ZImageTurbo', 'Other']
@dataclass
-class ModelImage():
+class ModelImage:
def __init__(self, dct: dict):
if isinstance(dct, str):
dct = json.loads(dct)
@@ -26,7 +26,7 @@ class ModelImage():
@dataclass
-class ModelFile():
+class ModelFile:
def __init__(self, dct: dict):
if isinstance(dct, str):
dct = json.loads(dct)
@@ -43,7 +43,7 @@ class ModelFile():
@dataclass
-class ModelVersion():
+class ModelVersion:
def __init__(self, dct: dict):
import bs4
if isinstance(dct, str):
@@ -65,7 +65,7 @@ class ModelVersion():
@dataclass
-class Model():
+class Model:
def __init__(self, dct: dict):
import bs4
if isinstance(dct, str):
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 0e7f421d6..78fda012a 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -129,7 +129,7 @@ def main_args():
def compatibility_args():
# removed args are added here as hidden in fixed format for compatbility reasons
- from modules.paths import data_path, models_path
+ from modules.paths import data_path
group_compat = parser.add_argument_group('Compatibility options')
group_compat.add_argument('--backend', type=str, choices=['diffusers', 'original'], help=argparse.SUPPRESS)
group_compat.add_argument("--allow-code", default=os.environ.get("SD_ALLOWCODE", False), action='store_true', help=argparse.SUPPRESS)
diff --git a/modules/control/processor.py b/modules/control/processor.py
index 5a58c09d4..db5947fe6 100644
--- a/modules/control/processor.py
+++ b/modules/control/processor.py
@@ -46,11 +46,17 @@ def preprocess_image(
input_mask:Image.Image = None,
input_type:str = 0,
unit_type:str = 'controlnet',
- active_process:list = [],
- active_model:list = [],
- selected_models:list = [],
+ active_process:list = None,
+ active_model:list = None,
+ selected_models:list = None,
has_models:bool = False,
):
+ if selected_models is None:
+ selected_models = []
+ if active_model is None:
+ active_model = []
+ if active_process is None:
+ active_process = []
t0 = time.time()
jobid = shared.state.begin('Preprocess')
diff --git a/modules/control/processors.py b/modules/control/processors.py
index a5985639f..2968443fe 100644
--- a/modules/control/processors.py
+++ b/modules/control/processors.py
@@ -161,7 +161,7 @@ def update_settings(*settings):
update(['Depth Pro', 'params', 'color_map'], settings[28])
-class Processor():
+class Processor:
def __init__(self, processor_id: str = None, resize = True):
self.model = None
self.processor_id = None
@@ -268,7 +268,9 @@ class Processor():
display(e, 'Control Processor load')
return f'Processor load filed: {processor_id}'
- def __call__(self, image_input: Image, mode: str = 'RGB', width: int = 0, height: int = 0, resize_mode: int = 0, resize_name: str = 'None', scale_tab: int = 1, scale_by: float = 1.0, local_config: dict = {}):
+ def __call__(self, image_input: Image, mode: str = 'RGB', width: int = 0, height: int = 0, resize_mode: int = 0, resize_name: str = 'None', scale_tab: int = 1, scale_by: float = 1.0, local_config: dict = None):
+ if local_config is None:
+ local_config = {}
if self.override is not None:
debug(f'Control Processor: id="{self.processor_id}" override={self.override}')
width = image_input.width if image_input is not None else width
diff --git a/modules/control/run.py b/modules/control/run.py
index 243840ff1..f678fcf1d 100644
--- a/modules/control/run.py
+++ b/modules/control/run.py
@@ -1,6 +1,5 @@
import os
import sys
-from typing import List, Union
import cv2
from PIL import Image
from modules.control import util # helper functions
@@ -141,12 +140,12 @@ def set_pipe(p, has_models, unit_type, selected_models, active_model, active_str
def check_active(p, unit_type, units):
- active_process: List[processors.Processor] = [] # all active preprocessors
- active_model: List[Union[controlnet.ControlNet, xs.ControlNetXS, t2iadapter.Adapter]] = [] # all active models
- active_strength: List[float] = [] # strength factors for all active models
- active_start: List[float] = [] # start step for all active models
- active_end: List[float] = [] # end step for all active models
- active_units: List[unit.Unit] = [] # all active units
+ active_process: list[processors.Processor] = [] # all active preprocessors
+ active_model: list[controlnet.ControlNet | xs.ControlNetXS | t2iadapter.Adapter] = [] # all active models
+ active_strength: list[float] = [] # strength factors for all active models
+ active_start: list[float] = [] # start step for all active models
+ active_end: list[float] = [] # end step for all active models
+ active_units: list[unit.Unit] = [] # all active units
num_units = 0
for u in units:
if u.type != unit_type:
@@ -218,7 +217,7 @@ def check_active(p, unit_type, units):
def check_enabled(p, unit_type, units, active_model, active_strength, active_start, active_end):
has_models = False
- selected_models: List[Union[controlnet.ControlNetModel, xs.ControlNetXSModel, t2iadapter.AdapterModel]] = None
+ selected_models: list[controlnet.ControlNetModel | xs.ControlNetXSModel | t2iadapter.AdapterModel] = None
control_conditioning = None
control_guidance_start = None
control_guidance_end = None
@@ -254,7 +253,7 @@ def control_set(kwargs):
p_extra_args[k] = v
-def init_units(units: List[unit.Unit]):
+def init_units(units: list[unit.Unit]):
for u in units:
if not u.enabled:
continue
@@ -271,9 +270,9 @@ def init_units(units: List[unit.Unit]):
def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg
- units: List[unit.Unit] = [], inputs: List[Image.Image] = [], inits: List[Image.Image] = [], mask: Image.Image = None, unit_type: str = None, is_generator: bool = True,
+ units: list[unit.Unit] = None, inputs: list[Image.Image] = None, inits: list[Image.Image] = None, mask: Image.Image = None, unit_type: str = None, is_generator: bool = True,
input_type: int = 0,
- prompt: str = '', negative_prompt: str = '', styles: List[str] = [],
+ prompt: str = '', negative_prompt: str = '', styles: list[str] = None,
steps: int = 20, sampler_index: int = None,
seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1,
guidance_name: str = 'Default', guidance_scale: float = 6.0, guidance_rescale: float = 0.0, guidance_start: float = 0.0, guidance_stop: float = 1.0,
@@ -289,11 +288,23 @@ def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg
enable_hr: bool = False, hr_sampler_index: int = None, hr_denoising_strength: float = 0.0, hr_resize_mode: int = 0, hr_resize_context: str = 'None', hr_upscaler: str = None, hr_force: bool = False, hr_second_pass_steps: int = 20,
hr_scale: float = 1.0, hr_resize_x: int = 0, hr_resize_y: int = 0, refiner_steps: int = 5, refiner_start: float = 0.0, refiner_prompt: str = '', refiner_negative: str = '',
video_skip_frames: int = 0, video_type: str = 'None', video_duration: float = 2.0, video_loop: bool = False, video_pad: int = 0, video_interpolate: int = 0,
- extra: dict = {},
+ extra: dict = None,
override_script_name: str = None,
- override_script_args = [],
+ override_script_args = None,
*input_script_args,
):
+ if override_script_args is None:
+ override_script_args = []
+ if extra is None:
+ extra = {}
+ if styles is None:
+ styles = []
+ if inits is None:
+ inits = []
+ if inputs is None:
+ inputs = []
+ if units is None:
+ units = []
global pipe, original_pipeline # pylint: disable=global-statement
if 'refine' in state:
enable_hr = True
@@ -303,7 +314,7 @@ def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg
init_units(units)
if inputs is None or (type(inputs) is list and len(inputs) == 0):
inputs = [None]
- output_images: List[Image.Image] = [] # output images
+ output_images: list[Image.Image] = [] # output images
processed_image: Image.Image = None # last processed image
if mask is not None and input_type == 0:
input_type = 1 # inpaint always requires control_image
diff --git a/modules/control/unit.py b/modules/control/unit.py
index 2eba804af..94f104558 100644
--- a/modules/control/unit.py
+++ b/modules/control/unit.py
@@ -1,4 +1,3 @@
-from typing import Union
from PIL import Image
import gradio as gr
from installer import log
@@ -7,7 +6,6 @@ from modules.control.units import controlnet
from modules.control.units import xs
from modules.control.units import lite
from modules.control.units import t2iadapter
-from modules.control.units import reference # pylint: disable=unused-import
default_device = None
@@ -16,7 +14,7 @@ unit_types = ['t2i adapter', 'controlnet', 'xs', 'lite', 'reference', 'ip']
current = []
-class Unit(): # mashup of gradio controls and mapping to actual implementation classes
+class Unit: # mashup of gradio controls and mapping to actual implementation classes
def update_choices(self, model_id=None):
name = model_id or self.model_name
if name == 'InstantX Union F1':
@@ -57,8 +55,10 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c
control_mode = None,
control_tile = None,
result_txt = None,
- extra_controls: list = [],
+ extra_controls: list = None,
):
+ if extra_controls is None:
+ extra_controls = []
self.model_id = model_id
self.process_id = process_id
self.controls = [gr.Label(value=unit_type, visible=False)] # separator
@@ -77,7 +77,7 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c
self.process_name = None
self.process: processors.Processor = processors.Processor()
self.adapter: t2iadapter.Adapter = None
- self.controlnet: Union[controlnet.ControlNet, xs.ControlNetXS] = None
+ self.controlnet: controlnet.ControlNet | xs.ControlNetXS = None
# map to input image
self.override: Image = None
# global settings but passed per-unit
diff --git a/modules/devices.py b/modules/devices.py
index 35391bf33..99276e30e 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -104,7 +104,7 @@ def get_gpu_info():
elif torch.cuda.is_available() and torch.version.cuda:
try:
import subprocess
- result = subprocess.run('nvidia-smi --query-gpu=driver_version --format=csv,noheader', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ result = subprocess.run('nvidia-smi --query-gpu=driver_version --format=csv,noheader', shell=True, check=False, env=os.environ, capture_output=True)
version = result.stdout.decode(encoding="utf8", errors="ignore").strip()
return version
except Exception:
@@ -307,7 +307,7 @@ def set_cuda_tunable():
lines={0}
try:
if os.path.exists(fn):
- with open(fn, 'r', encoding='utf8') as f:
+ with open(fn, encoding='utf8') as f:
lines = sum(1 for _line in f)
except Exception:
pass
diff --git a/modules/dml/Generator.py b/modules/dml/Generator.py
index ea273310c..c6c53f084 100644
--- a/modules/dml/Generator.py
+++ b/modules/dml/Generator.py
@@ -1,7 +1,6 @@
-from typing import Optional
import torch
class Generator(torch.Generator):
- def __init__(self, device: Optional[torch.device] = None):
+ def __init__(self, device: torch.device | None = None):
super().__init__("cpu")
diff --git a/modules/dml/__init__.py b/modules/dml/__init__.py
index c0ddbc1a1..82376c2d3 100644
--- a/modules/dml/__init__.py
+++ b/modules/dml/__init__.py
@@ -1,5 +1,6 @@
import platform
-from typing import NamedTuple, Callable, Optional
+from typing import NamedTuple, Optional
+from collections.abc import Callable
import torch
from modules.errors import log
from modules.sd_hijack_utils import CondFunc
@@ -86,8 +87,8 @@ def directml_do_hijack():
class OverrideItem(NamedTuple):
value: str
- condition: Optional[Callable]
- message: Optional[str]
+ condition: Callable | None
+ message: str | None
opts_override_table = {
diff --git a/modules/dml/amp/autocast_mode.py b/modules/dml/amp/autocast_mode.py
index 401d26d9e..2c344d53b 100644
--- a/modules/dml/amp/autocast_mode.py
+++ b/modules/dml/amp/autocast_mode.py
@@ -1,5 +1,5 @@
import importlib
-from typing import Any, Optional
+from typing import Any
import torch
@@ -52,7 +52,7 @@ class autocast:
fast_dtype: torch.dtype = torch.float16
prev_fast_dtype: torch.dtype
- def __init__(self, dtype: Optional[torch.dtype] = torch.float16):
+ def __init__(self, dtype: torch.dtype | None = torch.float16):
self.fast_dtype = dtype
def __enter__(self):
diff --git a/modules/dml/backend.py b/modules/dml/backend.py
index 7947dc81b..712e17591 100644
--- a/modules/dml/backend.py
+++ b/modules/dml/backend.py
@@ -1,5 +1,5 @@
# pylint: disable=no-member,no-self-argument,no-method-argument
-from typing import Optional, Callable
+from collections.abc import Callable
import torch
import torch_directml # pylint: disable=import-error
import modules.dml.amp as amp
@@ -9,17 +9,17 @@ from .Generator import Generator
from .device_properties import DeviceProperties
-def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
+def amd_mem_get_info(device: rDevice | None=None) -> tuple[int, int]:
from .memory_amd import AMDMemoryProvider
return AMDMemoryProvider.mem_get_info(get_device(device).index)
-def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
+def pdh_mem_get_info(device: rDevice | None=None) -> tuple[int, int]:
mem_info = DirectML.memory_provider.get_memory(get_device(device).index)
return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"])
-def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument
+def mem_get_info(device: rDevice | None=None) -> tuple[int, int]: # pylint: disable=unused-argument
return (8589934592, 8589934592)
@@ -28,7 +28,7 @@ class DirectML:
device = Device
Generator = Generator
- context_device: Optional[torch.device] = None
+ context_device: torch.device | None = None
is_autocast_enabled = False
autocast_gpu_dtype = torch.float16
@@ -41,7 +41,7 @@ class DirectML:
def is_directml_device(device: torch.device) -> bool:
return device.type == "privateuseone"
- def has_float64_support(device: Optional[rDevice]=None) -> bool:
+ def has_float64_support(device: rDevice | None=None) -> bool:
return torch_directml.has_float64_support(get_device(device).index)
def device_count() -> int:
@@ -53,16 +53,16 @@ class DirectML:
def default_device() -> torch.device:
return torch_directml.device(torch_directml.default_device())
- def get_device_string(device: Optional[rDevice]=None) -> str:
+ def get_device_string(device: rDevice | None=None) -> str:
return f"privateuseone:{get_device(device).index}"
- def get_device_name(device: Optional[rDevice]=None) -> str:
+ def get_device_name(device: rDevice | None=None) -> str:
return torch_directml.device_name(get_device(device).index)
- def get_device_properties(device: Optional[rDevice]=None) -> DeviceProperties:
+ def get_device_properties(device: rDevice | None=None) -> DeviceProperties:
return DeviceProperties(get_device(device))
- def memory_stats(device: Optional[rDevice]=None):
+ def memory_stats(device: rDevice | None=None):
return {
"num_ooms": 0,
"num_alloc_retries": 0,
@@ -70,11 +70,11 @@ class DirectML:
mem_get_info: Callable = mem_get_info
- def memory_allocated(device: Optional[rDevice]=None) -> int:
+ def memory_allocated(device: rDevice | None=None) -> int:
return sum(torch_directml.gpu_memory(get_device(device).index)) * (1 << 20)
- def max_memory_allocated(device: Optional[rDevice]=None):
+ def max_memory_allocated(device: rDevice | None=None):
return DirectML.memory_allocated(device) # DirectML does not empty GPU memory
- def reset_peak_memory_stats(device: Optional[rDevice]=None):
+ def reset_peak_memory_stats(device: rDevice | None=None):
return
diff --git a/modules/dml/device.py b/modules/dml/device.py
index ae4d32a99..cd7006333 100644
--- a/modules/dml/device.py
+++ b/modules/dml/device.py
@@ -1,4 +1,3 @@
-from typing import Optional
import torch
from .utils import rDevice, get_device
@@ -6,11 +5,11 @@ from .utils import rDevice, get_device
class Device:
idx: int
- def __enter__(self, device: Optional[rDevice]=None):
+ def __enter__(self, device: rDevice | None=None):
torch.dml.context_device = get_device(device)
self.idx = torch.dml.context_device.index
- def __init__(self, device: Optional[rDevice]=None) -> torch.device: # pylint: disable=return-in-init
+ def __init__(self, device: rDevice | None=None) -> torch.device: # pylint: disable=return-in-init
self.idx = get_device(device).index
def __exit__(self, t, v, tb):
diff --git a/modules/dml/hijack/tomesd.py b/modules/dml/hijack/tomesd.py
index 79de721df..0acc657ae 100644
--- a/modules/dml/hijack/tomesd.py
+++ b/modules/dml/hijack/tomesd.py
@@ -1,9 +1,8 @@
-from typing import Type
import torch
from modules.dml.hijack.utils import catch_nan
-def make_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
+def make_tome_block(block_class: type[torch.nn.Module]) -> type[torch.nn.Module]:
class ToMeBlock(block_class):
# Save for unpatching later
_parent = block_class
diff --git a/modules/dml/hijack/transformers.py b/modules/dml/hijack/transformers.py
index 78ddb20a2..6b4e090e1 100644
--- a/modules/dml/hijack/transformers.py
+++ b/modules/dml/hijack/transformers.py
@@ -1,4 +1,3 @@
-from typing import Optional
import torch
import transformers.models.clip.modeling_clip
@@ -22,9 +21,9 @@ def _make_causal_mask(
def CLIPTextEmbeddings_forward(
self: transformers.models.clip.modeling_clip.CLIPTextEmbeddings,
- input_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
) -> torch.Tensor:
from modules.devices import dtype
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
diff --git a/modules/dml/hijack/utils.py b/modules/dml/hijack/utils.py
index 659431c22..8817251b0 100644
--- a/modules/dml/hijack/utils.py
+++ b/modules/dml/hijack/utils.py
@@ -1,5 +1,5 @@
import torch
-from typing import Callable
+from collections.abc import Callable
from modules.shared import log, opts
diff --git a/modules/dml/pdh/apis.py b/modules/dml/pdh/apis.py
index f01222b45..9486e3321 100644
--- a/modules/dml/pdh/apis.py
+++ b/modules/dml/pdh/apis.py
@@ -1,6 +1,6 @@
from ctypes import CDLL, POINTER
from ctypes.wintypes import LPCWSTR, LPDWORD, DWORD
-from typing import Callable
+from collections.abc import Callable
from .structures import PDH_HQUERY, PDH_HCOUNTER, PPDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W
from .defines import PDH_FUNCTION, PZZWSTR, DWORD_PTR
diff --git a/modules/dml/utils.py b/modules/dml/utils.py
index cb19ed900..58dd3238b 100644
--- a/modules/dml/utils.py
+++ b/modules/dml/utils.py
@@ -1,9 +1,9 @@
-from typing import Optional, Union
+from typing import Union
import torch
rDevice = Union[torch.device, int]
-def get_device(device: Optional[rDevice]=None) -> torch.device:
+def get_device(device: rDevice | None=None) -> torch.device:
if device is None:
device = torch.dml.current_device()
return torch.device(device)
diff --git a/modules/errors.py b/modules/errors.py
index a3397143a..110b1f5bb 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -10,13 +10,17 @@ install_traceback()
already_displayed = {}
-def install(suppress=[]):
+def install(suppress=None):
+ if suppress is None:
+ suppress = []
warnings.filterwarnings("ignore", category=UserWarning)
install_traceback(suppress=suppress)
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s')
-def display(e: Exception, task: str, suppress=[]):
+def display(e: Exception, task: str, suppress=None):
+ if suppress is None:
+ suppress = []
if isinstance(e, ErrorLimiterAbort):
return
log.critical(f"{task or 'error'}: {type(e).__name__}")
@@ -45,7 +49,9 @@ def run(code, task: str):
display(e, task)
-def exception(suppress=[]):
+def exception(suppress=None):
+ if suppress is None:
+ suppress = []
console = get_console()
console.print_exception(show_locals=False, max_frames=16, extra_lines=2, suppress=suppress, theme="ansi_dark", word_wrap=False, width=min([console.width, 200]))
diff --git a/modules/extensions.py b/modules/extensions.py
index b29a6cbb0..c2643a4ad 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -186,7 +186,7 @@ class Extension:
continue
priority = '50'
if os.path.isfile(os.path.join(dirpath, "..", ".priority")):
- with open(os.path.join(dirpath, "..", ".priority"), "r", encoding="utf-8") as f:
+ with open(os.path.join(dirpath, "..", ".priority"), encoding="utf-8") as f:
priority = str(f.read().strip())
res.append(scripts_manager.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
if priority != '50':
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index 01913b187..66edf76a3 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -73,8 +73,12 @@ def is_stepwise(en_obj):
return any([len(str(x).split("@")) > 1 for x in all_args]) # noqa C419 # pylint: disable=use-a-generator
-def activate(p, extra_network_data=None, step=0, include=[], exclude=[]):
+def activate(p, extra_network_data=None, step=0, include=None, exclude=None):
"""call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list"""
+ if exclude is None:
+ exclude = []
+ if include is None:
+ include = []
if p.disable_extra_networks:
return
extra_network_data = extra_network_data or p.network_data
diff --git a/modules/extras.py b/modules/extras.py
index 221bbe6a0..7b2c9e6d6 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -33,7 +33,7 @@ def run_modelmerger(id_task, **kwargs): # pylint: disable=unused-argument
from installer import install
install('tensordict', quiet=True)
try:
- from tensordict import TensorDict # pylint: disable=unused-import
+ pass # pylint: disable=unused-import
except Exception as e:
shared.log.error(f"Merge: {e}")
return [*[gr.update() for _ in range(4)], "tensordict not available"]
diff --git a/modules/face/faceid.py b/modules/face/faceid.py
index 8986e92b1..933df80ff 100644
--- a/modules/face/faceid.py
+++ b/modules/face/faceid.py
@@ -1,4 +1,3 @@
-from typing import List
import os
import cv2
import torch
@@ -34,7 +33,7 @@ def hijack_load_ip_adapter(self):
def face_id(
p: processing.StableDiffusionProcessing,
app,
- source_images: List[Image.Image],
+ source_images: list[Image.Image],
model: str,
override: bool,
cache: bool,
diff --git a/modules/face/faceswap.py b/modules/face/faceswap.py
index df3765fb2..a3e4f27fc 100644
--- a/modules/face/faceswap.py
+++ b/modules/face/faceswap.py
@@ -1,4 +1,3 @@
-from typing import List
import os
import cv2
import numpy as np
@@ -12,7 +11,7 @@ insightface_app = None
swapper = None
-def face_swap(p: processing.StableDiffusionProcessing, app, input_images: List[Image.Image], source_image: Image.Image, cache: bool):
+def face_swap(p: processing.StableDiffusionProcessing, app, input_images: list[Image.Image], source_image: Image.Image, cache: bool):
global swapper # pylint: disable=global-statement
if swapper is None:
import insightface.model_zoo
diff --git a/modules/face/instantid_model.py b/modules/face/instantid_model.py
index 8af9a2907..2511ec27a 100644
--- a/modules/face/instantid_model.py
+++ b/modules/face/instantid_model.py
@@ -14,7 +14,8 @@
import math
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any
+from collections.abc import Callable
import cv2
import numpy as np
@@ -544,40 +545,40 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
- prompt: Union[str, List[str]] = None,
- prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt: str | list[str] = None,
+ prompt_2: str | list[str] | None = None,
image: PipelineImageInput = None,
- height: Optional[int] = None,
- width: Optional[int] = None,
+ height: int | None = None,
+ width: int | None = None,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
- negative_prompt: Optional[Union[str, List[str]]] = None,
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
- num_images_per_prompt: Optional[int] = 1,
+ negative_prompt: str | list[str] | None = None,
+ negative_prompt_2: str | list[str] | None = None,
+ num_images_per_prompt: int | None = 1,
eta: float = 0.0,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- latents: Optional[torch.FloatTensor] = None,
- prompt_embeds: Optional[torch.FloatTensor] = None,
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
- image_embeds: Optional[torch.FloatTensor] = None,
- output_type: Optional[str] = "pil",
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.FloatTensor | None = None,
+ prompt_embeds: torch.FloatTensor | None = None,
+ negative_prompt_embeds: torch.FloatTensor | None = None,
+ pooled_prompt_embeds: torch.FloatTensor | None = None,
+ negative_pooled_prompt_embeds: torch.FloatTensor | None = None,
+ image_embeds: torch.FloatTensor | None = None,
+ output_type: str | None = "pil",
return_dict: bool = True,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ cross_attention_kwargs: dict[str, Any] | None = None,
+ controlnet_conditioning_scale: float | list[float] = 1.0,
guess_mode: bool = False,
- control_guidance_start: Union[float, List[float]] = 0.0,
- control_guidance_end: Union[float, List[float]] = 1.0,
- original_size: Tuple[int, int] = None,
- crops_coords_top_left: Tuple[int, int] = (0, 0),
- target_size: Tuple[int, int] = None,
- negative_original_size: Optional[Tuple[int, int]] = None,
- negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
- negative_target_size: Optional[Tuple[int, int]] = None,
- clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = None,
+ control_guidance_start: float | list[float] = 0.0,
+ control_guidance_end: float | list[float] = 1.0,
+ original_size: tuple[int, int] = None,
+ crops_coords_top_left: tuple[int, int] = (0, 0),
+ target_size: tuple[int, int] = None,
+ negative_original_size: tuple[int, int] | None = None,
+ negative_crops_coords_top_left: tuple[int, int] = (0, 0),
+ negative_target_size: tuple[int, int] | None = None,
+ clip_skip: int | None = None,
+ callback_on_step_end: Callable[[int, int, dict], None] | None = None,
+ callback_on_step_end_tensor_inputs: list[str] = None,
**kwargs,
):
r"""
@@ -890,7 +891,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
- for s, e in zip(control_guidance_start, control_guidance_end)
+ for s, e in zip(control_guidance_start, control_guidance_end, strict=False)
]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
@@ -970,7 +971,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
controlnet_added_cond_kwargs = added_cond_kwargs
if isinstance(controlnet_keep[i], list):
- cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i], strict=False)]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
diff --git a/modules/face/photomaker_pipeline.py b/modules/face/photomaker_pipeline.py
index 45006a7e1..191c4f352 100644
--- a/modules/face/photomaker_pipeline.py
+++ b/modules/face/photomaker_pipeline.py
@@ -1,7 +1,8 @@
### original
import inspect
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Union
+from collections.abc import Callable
import PIL
import torch
from transformers import CLIPImageProcessor
@@ -26,8 +27,8 @@ from modules.face.photomaker_model_v2 import PhotoMakerIDEncoder_CLIPInsightface
PipelineImageInput = Union[
PIL.Image.Image,
torch.FloatTensor,
- List[PIL.Image.Image],
- List[torch.FloatTensor],
+ list[PIL.Image.Image],
+ list[torch.FloatTensor],
]
@@ -49,10 +50,10 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- sigmas: Optional[List[float]] = None,
+ num_inference_steps: int | None = None,
+ device: str | torch.device | None = None,
+ timesteps: list[int] | None = None,
+ sigmas: list[float] | None = None,
**kwargs,
):
"""
@@ -110,7 +111,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
@validate_hf_hub_args
def load_photomaker_adapter(
self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
weight_name: str,
subfolder: str = '',
trigger_word: str = 'img',
@@ -214,21 +215,21 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
def encode_prompt_with_trigger_word(
self,
prompt: str,
- prompt_2: Optional[str] = None,
- device: Optional[torch.device] = None,
+ prompt_2: str | None = None,
+ device: torch.device | None = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
- negative_prompt: Optional[str] = None,
- negative_prompt_2: Optional[str] = None,
- prompt_embeds: Optional[torch.Tensor] = None,
- negative_prompt_embeds: Optional[torch.Tensor] = None,
- pooled_prompt_embeds: Optional[torch.Tensor] = None,
- negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
- lora_scale: Optional[float] = None,
- clip_skip: Optional[int] = None,
+ negative_prompt: str | None = None,
+ negative_prompt_2: str | None = None,
+ prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ pooled_prompt_embeds: torch.Tensor | None = None,
+ negative_pooled_prompt_embeds: torch.Tensor | None = None,
+ lora_scale: float | None = None,
+ clip_skip: int | None = None,
### Added args
num_id_images: int = 1,
- class_tokens_mask: Optional[torch.LongTensor] = None,
+ class_tokens_mask: torch.LongTensor | None = None,
):
device = device or self._execution_device
@@ -273,7 +274,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
- for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): # pylint: disable=redefined-argument-from-local
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False): # pylint: disable=redefined-argument-from-local
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
@@ -362,7 +363,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
- uncond_tokens: List[str]
+ uncond_tokens: list[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
@@ -377,7 +378,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
- for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): # pylint: disable=redefined-argument-from-local
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders, strict=False): # pylint: disable=redefined-argument-from-local
if isinstance(self, TextualInversionLoaderMixin):
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
@@ -444,49 +445,47 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
@torch.no_grad()
def __call__(
self,
- prompt: Union[str, List[str]] = None,
- prompt_2: Optional[Union[str, List[str]]] = None,
- height: Optional[int] = None,
- width: Optional[int] = None,
+ prompt: str | list[str] = None,
+ prompt_2: str | list[str] | None = None,
+ height: int | None = None,
+ width: int | None = None,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
- sigmas: List[float] = None,
- denoising_end: Optional[float] = None,
+ timesteps: list[int] = None,
+ sigmas: list[float] = None,
+ denoising_end: float | None = None,
guidance_scale: float = 5.0,
- negative_prompt: Optional[Union[str, List[str]]] = None,
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
- num_images_per_prompt: Optional[int] = 1,
+ negative_prompt: str | list[str] | None = None,
+ negative_prompt_2: str | list[str] | None = None,
+ num_images_per_prompt: int | None = 1,
eta: float = 0.0,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- latents: Optional[torch.Tensor] = None,
- prompt_embeds: Optional[torch.Tensor] = None,
- negative_prompt_embeds: Optional[torch.Tensor] = None,
- pooled_prompt_embeds: Optional[torch.Tensor] = None,
- negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
- ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
- output_type: Optional[str] = "pil",
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ pooled_prompt_embeds: torch.Tensor | None = None,
+ negative_pooled_prompt_embeds: torch.Tensor | None = None,
+ ip_adapter_image: PipelineImageInput | None = None,
+ ip_adapter_image_embeds: list[torch.Tensor] | None = None,
+ output_type: str | None = "pil",
return_dict: bool = True,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ cross_attention_kwargs: dict[str, Any] | None = None,
guidance_rescale: float = 0.0,
- original_size: Optional[Tuple[int, int]] = None,
- crops_coords_top_left: Tuple[int, int] = (0, 0),
- target_size: Optional[Tuple[int, int]] = None,
- negative_original_size: Optional[Tuple[int, int]] = None,
- negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
- negative_target_size: Optional[Tuple[int, int]] = None,
- clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
- ] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ original_size: tuple[int, int] | None = None,
+ crops_coords_top_left: tuple[int, int] = (0, 0),
+ target_size: tuple[int, int] | None = None,
+ negative_original_size: tuple[int, int] | None = None,
+ negative_crops_coords_top_left: tuple[int, int] = (0, 0),
+ negative_target_size: tuple[int, int] | None = None,
+ clip_skip: int | None = None,
+ callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
+ callback_on_step_end_tensor_inputs: list[str] = None,
# Added parameters (for PhotoMaker)
input_id_images: PipelineImageInput = None,
start_merge_step: int = 10,
- class_tokens_mask: Optional[torch.LongTensor] = None,
- id_embeds: Optional[torch.FloatTensor] = None,
- prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
- pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
+ class_tokens_mask: torch.LongTensor | None = None,
+ id_embeds: torch.FloatTensor | None = None,
+ prompt_embeds_text_only: torch.FloatTensor | None = None,
+ pooled_prompt_embeds_text_only: torch.FloatTensor | None = None,
**kwargs,
):
r"""
@@ -512,6 +511,8 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
+ if callback_on_step_end_tensor_inputs is None:
+ callback_on_step_end_tensor_inputs = ["latents"]
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
diff --git a/modules/face/reswapper.py b/modules/face/reswapper.py
index 77328a1ff..d688a51d2 100644
--- a/modules/face/reswapper.py
+++ b/modules/face/reswapper.py
@@ -1,4 +1,3 @@
-from typing import List
import os
import cv2
import torch
@@ -43,8 +42,8 @@ def get_model(model_name: str):
def reswapper(
p: processing.StableDiffusionProcessing,
app,
- source_images: List[Image.Image],
- target_images: List[Image.Image],
+ source_images: list[Image.Image],
+ target_images: list[Image.Image],
model_name: str,
original: bool,
):
diff --git a/modules/face/reswapper_model.py b/modules/face/reswapper_model.py
index de68d8566..606189bf8 100644
--- a/modules/face/reswapper_model.py
+++ b/modules/face/reswapper_model.py
@@ -6,7 +6,7 @@ import torch.nn.functional as F
class ReSwapperModel(nn.Module):
def __init__(self):
- super(ReSwapperModel, self).__init__()
+ super().__init__()
# self.pad = nn.ReflectionPad2d(3)
# Encoder for target face
@@ -87,7 +87,7 @@ class ReSwapperModel(nn.Module):
class StyleBlock(nn.Module):
def __init__(self, in_channels, out_channels, blockIndex):
- super(StyleBlock, self).__init__()
+ super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0)
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0)
self.style1 = nn.Linear(512, 2048)
diff --git a/modules/files_cache.py b/modules/files_cache.py
index d65e0f4f4..5ffc4d383 100644
--- a/modules/files_cache.py
+++ b/modules/files_cache.py
@@ -2,7 +2,8 @@ import itertools
import os
from collections import UserDict
from dataclasses import dataclass, field
-from typing import Callable, Dict, Iterator, List, Optional, Union
+from typing import Union
+from collections.abc import Callable, Iterator
from installer import log
@@ -10,19 +11,19 @@ do_cache_folders = os.environ.get('SD_NO_CACHE', None) is None
class Directory: # forward declaration
...
-FilePathList = List[str]
+FilePathList = list[str]
FilePathIterator = Iterator[str]
-DirectoryPathList = List[str]
+DirectoryPathList = list[str]
DirectoryPathIterator = Iterator[str]
-DirectoryList = List[Directory]
+DirectoryList = list[Directory]
DirectoryIterator = Iterator[Directory]
-DirectoryCollection = Dict[str, Directory]
+DirectoryCollection = dict[str, Directory]
ExtensionFilter = Callable
ExtensionList = list[str]
RecursiveType = Union[bool,Callable]
-def real_path(directory_path:str) -> Union[str, None]:
+def real_path(directory_path:str) -> str | None:
try:
return os.path.abspath(os.path.expanduser(directory_path))
except Exception:
@@ -52,7 +53,7 @@ class Directory(Directory): # pylint: disable=E0102
def clear(self) -> None:
self._update(Directory.from_dict({
'path': None,
- 'mtime': float(),
+ 'mtime': 0.0,
'files': [],
'directories': []
}))
@@ -125,7 +126,7 @@ def clean_directory(directory: Directory, /, recursive: RecursiveType=False) ->
return is_clean
-def get_directory(directory_or_path: str, /, fetch: bool=True) -> Union[Directory, None]:
+def get_directory(directory_or_path: str, /, fetch: bool=True) -> Directory | None:
if isinstance(directory_or_path, Directory):
if directory_or_path.is_directory:
return directory_or_path
@@ -143,7 +144,7 @@ def get_directory(directory_or_path: str, /, fetch: bool=True) -> Union[Director
return cache_folders[directory_or_path] if directory_or_path in cache_folders else None
-def fetch_directory(directory_path: str) -> Union[Directory, None]:
+def fetch_directory(directory_path: str) -> Directory | None:
directory: Directory
for directory in _walk(directory_path, recurse=False):
return directory # The return is intentional, we get a generator, we only need the one
@@ -255,7 +256,7 @@ def get_directories(*directory_paths: DirectoryPathList, fetch:bool=True, recurs
return filter(bool, directories)
-def directory_files(*directories_or_paths: Union[DirectoryPathList, DirectoryList], recursive: RecursiveType=True) -> FilePathIterator:
+def directory_files(*directories_or_paths: DirectoryPathList | DirectoryList, recursive: RecursiveType=True) -> FilePathIterator:
return itertools.chain.from_iterable(
itertools.chain(
directory_object.files,
@@ -275,7 +276,7 @@ def directory_files(*directories_or_paths: Union[DirectoryPathList, DirectoryLis
)
-def extension_filter(ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None) -> ExtensionFilter:
+def extension_filter(ext_filter: ExtensionList | None=None, ext_blacklist: ExtensionList | None=None) -> ExtensionFilter:
if ext_filter:
ext_filter = [*map(str.upper, ext_filter)]
if ext_blacklist:
@@ -289,11 +290,11 @@ def not_hidden(filepath: str) -> bool:
return not os.path.basename(filepath).startswith('.')
-def filter_files(file_paths: FilePathList, ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None) -> FilePathIterator:
+def filter_files(file_paths: FilePathList, ext_filter: ExtensionList | None=None, ext_blacklist: ExtensionList | None=None) -> FilePathIterator:
return filter(extension_filter(ext_filter, ext_blacklist), file_paths)
-def list_files(*directory_paths:DirectoryPathList, ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None, recursive:RecursiveType=True) -> FilePathIterator:
+def list_files(*directory_paths:DirectoryPathList, ext_filter: ExtensionList | None=None, ext_blacklist: ExtensionList | None=None, recursive:RecursiveType=True) -> FilePathIterator:
return filter_files(itertools.chain.from_iterable(
directory_files(directory, recursive=recursive)
for directory in get_directories(*directory_paths, recursive=recursive)
diff --git a/modules/framepack/framepack_api.py b/modules/framepack/framepack_api.py
index 0c2ecb25b..ad3bd0e70 100644
--- a/modules/framepack/framepack_api.py
+++ b/modules/framepack/framepack_api.py
@@ -1,4 +1,3 @@
-from typing import Optional, List
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
from fastapi.exceptions import HTTPException
from modules import shared
@@ -8,39 +7,39 @@ class ReqFramepack(BaseModel):
variant: str = Field(default=None, title="Model variant", description="Model variant to use")
prompt: str = Field(default=None, title="Prompt", description="Prompt for the model")
init_image: str = Field(default=None, title="Initial image", description="Base64 encoded initial image")
- end_image: Optional[str] = Field(default=None, title="End image", description="Base64 encoded end image")
- start_weight: Optional[float] = Field(default=1.0, title="Start weight", description="Weight of the initial image")
- end_weight: Optional[float] = Field(default=1.0, title="End weight", description="Weight of the end image")
- vision_weight: Optional[float] = Field(default=1.0, title="Vision weight", description="Weight of the vision model")
- system_prompt: Optional[str] = Field(default=None, title="System prompt", description="System prompt for the model")
- optimized_prompt: Optional[bool] = Field(default=True, title="Optimized system prompt", description="Use optimized system prompt for the model")
- section_prompt: Optional[str] = Field(default=None, title="Section prompt", description="Prompt for each section")
- negative_prompt: Optional[str] = Field(default=None, title="Negative prompt", description="Negative prompt for the model")
- styles: Optional[List[str]] = Field(default=None, title="Styles", description="Styles for the model")
- seed: Optional[int] = Field(default=None, title="Seed", description="Seed for the model")
- resolution: Optional[int] = Field(default=640, title="Resolution", description="Resolution of the image")
- duration: Optional[float] = Field(default=4, title="Duration", description="Duration of the video in seconds")
- latent_ws: Optional[int] = Field(default=9, title="Latent window size", description="Size of the latent window")
- steps: Optional[int] = Field(default=25, title="Video steps", description="Number of steps for the video generation")
- cfg_scale: Optional[float] = Field(default=1.0, title="CFG scale", description="CFG scale for the model")
- cfg_distilled: Optional[float] = Field(default=10.0, title="Distilled CFG scale", description="Distilled CFG scale for the model")
- cfg_rescale: Optional[float] = Field(default=0.0, title="CFG re-scale", description="CFG re-scale for the model")
- shift: Optional[float] = Field(default=0, title="Sampler shift", description="Shift for the sampler")
- use_teacache: Optional[bool] = Field(default=True, title="Enable TeaCache", description="Use TeaCache for the model")
- use_cfgzero: Optional[bool] = Field(default=False, title="Enable CFGZero", description="Use CFGZero for the model")
- mp4_fps: Optional[int] = Field(default=30, title="FPS", description="Frames per second for the video")
- mp4_codec: Optional[str] = Field(default="libx264", title="Codec", description="Codec for the video")
- mp4_sf: Optional[bool] = Field(default=False, title="Save SafeTensors", description="Save SafeTensors for the video")
- mp4_video: Optional[bool] = Field(default=True, title="Save Video", description="Save video")
- mp4_frames: Optional[bool] = Field(default=False, title="Save Frames", description="Save frames for the video")
- mp4_opt: Optional[str] = Field(default="crf:16", title="Options", description="Options for the video codec")
- mp4_ext: Optional[str] = Field(default="mp4", title="Format", description="Format for the video")
- mp4_interpolate: Optional[int] = Field(default=0, title="Interpolation", description="Interpolation for the video")
- attention: Optional[str] = Field(default="Default", title="Attention", description="Attention type for the model")
- vae_type: Optional[str] = Field(default="Local", title="VAE", description="VAE type for the model")
- vlm_enhance: Optional[bool] = Field(default=False, title="VLM enhance", description="Enable VLM enhance")
- vlm_model: Optional[str] = Field(default=None, title="VLM model", description="VLM model to use")
- vlm_system_prompt: Optional[str] = Field(default=None, title="VLM system prompt", description="System prompt for the VLM model")
+ end_image: str | None = Field(default=None, title="End image", description="Base64 encoded end image")
+ start_weight: float | None = Field(default=1.0, title="Start weight", description="Weight of the initial image")
+ end_weight: float | None = Field(default=1.0, title="End weight", description="Weight of the end image")
+ vision_weight: float | None = Field(default=1.0, title="Vision weight", description="Weight of the vision model")
+ system_prompt: str | None = Field(default=None, title="System prompt", description="System prompt for the model")
+ optimized_prompt: bool | None = Field(default=True, title="Optimized system prompt", description="Use optimized system prompt for the model")
+ section_prompt: str | None = Field(default=None, title="Section prompt", description="Prompt for each section")
+ negative_prompt: str | None = Field(default=None, title="Negative prompt", description="Negative prompt for the model")
+ styles: list[str] | None = Field(default=None, title="Styles", description="Styles for the model")
+ seed: int | None = Field(default=None, title="Seed", description="Seed for the model")
+ resolution: int | None = Field(default=640, title="Resolution", description="Resolution of the image")
+ duration: float | None = Field(default=4, title="Duration", description="Duration of the video in seconds")
+ latent_ws: int | None = Field(default=9, title="Latent window size", description="Size of the latent window")
+ steps: int | None = Field(default=25, title="Video steps", description="Number of steps for the video generation")
+ cfg_scale: float | None = Field(default=1.0, title="CFG scale", description="CFG scale for the model")
+ cfg_distilled: float | None = Field(default=10.0, title="Distilled CFG scale", description="Distilled CFG scale for the model")
+ cfg_rescale: float | None = Field(default=0.0, title="CFG re-scale", description="CFG re-scale for the model")
+ shift: float | None = Field(default=0, title="Sampler shift", description="Shift for the sampler")
+ use_teacache: bool | None = Field(default=True, title="Enable TeaCache", description="Use TeaCache for the model")
+ use_cfgzero: bool | None = Field(default=False, title="Enable CFGZero", description="Use CFGZero for the model")
+ mp4_fps: int | None = Field(default=30, title="FPS", description="Frames per second for the video")
+ mp4_codec: str | None = Field(default="libx264", title="Codec", description="Codec for the video")
+ mp4_sf: bool | None = Field(default=False, title="Save SafeTensors", description="Save SafeTensors for the video")
+ mp4_video: bool | None = Field(default=True, title="Save Video", description="Save video")
+ mp4_frames: bool | None = Field(default=False, title="Save Frames", description="Save frames for the video")
+ mp4_opt: str | None = Field(default="crf:16", title="Options", description="Options for the video codec")
+ mp4_ext: str | None = Field(default="mp4", title="Format", description="Format for the video")
+ mp4_interpolate: int | None = Field(default=0, title="Interpolation", description="Interpolation for the video")
+ attention: str | None = Field(default="Default", title="Attention", description="Attention type for the model")
+ vae_type: str | None = Field(default="Local", title="VAE", description="VAE type for the model")
+ vlm_enhance: bool | None = Field(default=False, title="VLM enhance", description="Enable VLM enhance")
+ vlm_model: str | None = Field(default=None, title="VLM model", description="VLM model to use")
+ vlm_system_prompt: str | None = Field(default=None, title="VLM system prompt", description="System prompt for the VLM model")
class ResFramepack(BaseModel):
diff --git a/modules/framepack/framepack_worker.py b/modules/framepack/framepack_worker.py
index 345333ad9..9df245d59 100644
--- a/modules/framepack/framepack_worker.py
+++ b/modules/framepack/framepack_worker.py
@@ -43,8 +43,10 @@ def worker(
mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate,
vae_type,
variant,
- metadata:dict={},
+ metadata:dict=None,
):
+ if metadata is None:
+ metadata = {}
timer.process.reset()
memstats.reset_stats()
if stream is None or shared.state.interrupted or shared.state.skipped:
diff --git a/modules/framepack/pipeline/hunyuan_video_packed.py b/modules/framepack/pipeline/hunyuan_video_packed.py
index 0a3f8f62b..a852fbc10 100644
--- a/modules/framepack/pipeline/hunyuan_video_packed.py
+++ b/modules/framepack/pipeline/hunyuan_video_packed.py
@@ -1,4 +1,3 @@
-from typing import Optional, Tuple
import torch
import torch.nn as nn
@@ -251,7 +250,7 @@ class CombinedTimestepTextProjEmbeddings(nn.Module):
class HunyuanVideoAdaNorm(nn.Module):
- def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
+ def __init__(self, in_features: int, out_features: int | None = None) -> None:
super().__init__()
out_features = out_features or 2 * in_features
@@ -260,7 +259,7 @@ class HunyuanVideoAdaNorm(nn.Module):
def forward(
self, temb: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
temb = self.linear(self.nonlinearity(temb))
gate_msa, gate_mlp = temb.chunk(2, dim=-1)
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
@@ -298,7 +297,7 @@ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
norm_hidden_states = self.norm1(hidden_states)
@@ -346,7 +345,7 @@ class HunyuanVideoIndividualTokenRefiner(nn.Module):
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: torch.Tensor | None = None,
) -> None:
self_attn_mask = None
if attention_mask is not None:
@@ -396,7 +395,7 @@ class HunyuanVideoTokenRefiner(nn.Module):
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
- attention_mask: Optional[torch.LongTensor] = None,
+ attention_mask: torch.LongTensor | None = None,
) -> torch.Tensor:
if attention_mask is None:
pooled_projections = hidden_states.mean(dim=1)
@@ -464,8 +463,8 @@ class AdaLayerNormZero(nn.Module):
def forward(
self,
x: torch.Tensor,
- emb: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = emb.unsqueeze(-2)
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
@@ -487,8 +486,8 @@ class AdaLayerNormZeroSingle(nn.Module):
def forward(
self,
x: torch.Tensor,
- emb: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = emb.unsqueeze(-2)
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
@@ -558,8 +557,8 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_mask: torch.Tensor | None = None,
+ image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -636,9 +635,9 @@ class HunyuanVideoTransformerBlock(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ attention_mask: torch.Tensor | None = None,
+ freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
@@ -734,7 +733,7 @@ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterM
text_embed_dim: int = 4096,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
- rope_axes_dim: Tuple[int] = (16, 56, 56),
+ rope_axes_dim: tuple[int] = (16, 56, 56),
has_image_proj=False,
image_proj_dim=1152,
has_clean_x_embedder=False,
diff --git a/modules/framepack/pipeline/utils.py b/modules/framepack/pipeline/utils.py
index 9cd99571d..a14e1d7d9 100644
--- a/modules/framepack/pipeline/utils.py
+++ b/modules/framepack/pipeline/utils.py
@@ -102,14 +102,14 @@ def just_crop(image, w, h):
def write_to_json(data, file_path):
temp_file_path = file_path + ".tmp"
- with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
+ with open(temp_file_path, 'w', encoding='utf-8') as temp_file:
json.dump(data, temp_file, indent=4)
os.replace(temp_file_path, file_path)
return
def read_from_json(file_path):
- with open(file_path, 'rt', encoding='utf-8') as file:
+ with open(file_path, encoding='utf-8') as file:
data = json.load(file)
return data
@@ -283,7 +283,7 @@ def add_tensors_with_padding(tensor1, tensor2):
shape1 = tensor1.shape
shape2 = tensor2.shape
- new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
+ new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2, strict=False))
padded_tensor1 = torch.zeros(new_shape)
padded_tensor2 = torch.zeros(new_shape)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 63ffc545e..42f6aa864 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -5,7 +5,7 @@ import os
from PIL import Image
import gradio as gr
from modules import shared, gr_tempdir, script_callbacks, images
-from modules.infotext import parse, mapping, quote, unquote # pylint: disable=unused-import
+from modules.infotext import parse, mapping # pylint: disable=unused-import
type_of_gr_update = type(gr.update())
@@ -259,7 +259,7 @@ def connect_paste(button, local_paste_fields, input_comp, override_settings_comp
from modules.paths import params_path
if prompt is None or len(prompt.strip()) == 0:
if os.path.exists(params_path):
- with open(params_path, "r", encoding="utf8") as file:
+ with open(params_path, encoding="utf8") as file:
prompt = file.read()
shared.log.debug(f'Prompt parse: type="params" prompt="{prompt}"')
else:
diff --git a/modules/ggml/gguf_utils.py b/modules/ggml/gguf_utils.py
index c6c937380..f3fdc2146 100644
--- a/modules/ggml/gguf_utils.py
+++ b/modules/ggml/gguf_utils.py
@@ -1,7 +1,8 @@
# Original: invokeai.backend.quantization.gguf.utils
# Largely based on https://github.com/city96/ComfyUI-GGUF
-from typing import Callable, Optional, Union
+from typing import Union
+from collections.abc import Callable
import gguf
import torch
@@ -28,7 +29,7 @@ def get_scale_min(scales: torch.Tensor):
# Legacy Quants #
def dequantize_blocks_Q8_0(
- blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
+ blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
d, x = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
@@ -37,7 +38,7 @@ def dequantize_blocks_Q8_0(
def dequantize_blocks_Q5_1(
- blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
+ blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@@ -58,7 +59,7 @@ def dequantize_blocks_Q5_1(
def dequantize_blocks_Q5_0(
- blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
+ blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@@ -79,7 +80,7 @@ def dequantize_blocks_Q5_0(
def dequantize_blocks_Q4_1(
- blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
+ blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@@ -96,7 +97,7 @@ def dequantize_blocks_Q4_1(
def dequantize_blocks_Q4_0(
- blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
+ blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@@ -111,13 +112,13 @@ def dequantize_blocks_Q4_0(
def dequantize_blocks_BF16(
- blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
+ blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
def dequantize_blocks_Q6_K(
- blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
+ blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@@ -147,7 +148,7 @@ def dequantize_blocks_Q6_K(
def dequantize_blocks_Q5_K(
- blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
+ blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@@ -175,7 +176,7 @@ def dequantize_blocks_Q5_K(
def dequantize_blocks_Q4_K(
- blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
+ blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@@ -197,7 +198,7 @@ def dequantize_blocks_Q4_K(
def dequantize_blocks_Q3_K(
- blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
+ blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@@ -232,7 +233,7 @@ def dequantize_blocks_Q3_K(
def dequantize_blocks_Q2_K(
- blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
+ blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
@@ -254,7 +255,7 @@ def dequantize_blocks_Q2_K(
DEQUANTIZE_FUNCTIONS: dict[
- gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, Optional[torch.dtype]], torch.Tensor]
+ gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, torch.dtype | None], torch.Tensor]
] = {
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
@@ -270,7 +271,7 @@ DEQUANTIZE_FUNCTIONS: dict[
}
-def is_torch_compatible(tensor: Optional[torch.Tensor]):
+def is_torch_compatible(tensor: torch.Tensor | None):
return getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES
@@ -279,7 +280,7 @@ def is_quantized(tensor: torch.Tensor):
def dequantize(
- data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: Optional[torch.dtype] = None
+ data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: torch.dtype | None = None
):
"""
Dequantize tensor back to usable shape/dtype
diff --git a/modules/history.py b/modules/history.py
index 9c0c9c0c6..2540aa7a0 100644
--- a/modules/history.py
+++ b/modules/history.py
@@ -9,8 +9,10 @@ import torch
from modules import shared, devices
-class Item():
- def __init__(self, latent, preview=None, info=None, ops=[]):
+class Item:
+ def __init__(self, latent, preview=None, info=None, ops=None):
+ if ops is None:
+ ops = []
self.ts = datetime.datetime.now().replace(microsecond=0)
self.name = self.ts.strftime('%Y-%m-%d %H:%M:%S')
self.latent = latent.detach().clone().to(devices.cpu)
@@ -20,7 +22,7 @@ class Item():
self.size = sys.getsizeof(self.latent.storage())
-class History():
+class History:
def __init__(self):
self.index = -1
self.latents = deque(maxlen=1024)
@@ -58,7 +60,9 @@ class History():
return i
return -1
- def add(self, latent, preview=None, info=None, ops=[]):
+ def add(self, latent, preview=None, info=None, ops=None):
+ if ops is None:
+ ops = []
shared.state.latent_history += 1
if shared.opts.latent_history == 0:
return
diff --git a/modules/image/grid.py b/modules/image/grid.py
index 5827bdb95..2b6bc2ed1 100644
--- a/modules/image/grid.py
+++ b/modules/image/grid.py
@@ -171,7 +171,7 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
calc_img = Image.new("RGB", (1, 1), shared.opts.grid_background)
calc_d = ImageDraw.Draw(calc_img)
title_texts = [title] if title else [[GridAnnotation()]]
- for texts, allowed_width in zip(x_texts + y_texts + title_texts, [width] * len(x_texts) + [pad_left] * len(y_texts) + [(width+margin)*cols]):
+ for texts, allowed_width in zip(x_texts + y_texts + title_texts, [width] * len(x_texts) + [pad_left] * len(y_texts) + [(width+margin)*cols], strict=False):
items = [] + texts
texts.clear()
for line in items:
diff --git a/modules/image/resize.py b/modules/image/resize.py
index 029e2549e..a26c92714 100644
--- a/modules/image/resize.py
+++ b/modules/image/resize.py
@@ -1,4 +1,3 @@
-from typing import Union
import sys
import time
import numpy as np
@@ -8,7 +7,7 @@ from modules import shared, upscaler
from modules.image import sharpfin
-def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None):
+def resize_image(resize_mode: int, im: Image.Image | torch.Tensor, width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None):
upscaler_name = upscaler_name or shared.opts.upscaler_for_img2img
def verify_image(image):
@@ -34,7 +33,7 @@ def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width:
im = vae_decode(latents, shared.sd_model, output_type='pil', vae_type='Tiny')[0]
return im
- def resize(im: Union[Image.Image, torch.Tensor], w, h):
+ def resize(im: Image.Image | torch.Tensor, w, h):
w, h = int(w), int(h)
if upscaler_name is None or upscaler_name == "None" or (hasattr(im, 'mode') and im.mode == 'L'):
return sharpfin.resize(im, (w, h), linearize=False) # force for mask
diff --git a/modules/image/sharpfin.py b/modules/image/sharpfin.py
index 987e98bba..ecdebdf0a 100644
--- a/modules/image/sharpfin.py
+++ b/modules/image/sharpfin.py
@@ -22,7 +22,6 @@ _triton_ok = False
def check_sharpfin():
global _sharpfin_checked, _sharpfin_ok, _triton_ok # pylint: disable=global-statement
if not _sharpfin_checked:
- from modules.sharpfin.functional import scale # pylint: disable=unused-import
_sharpfin_ok = True
try:
from modules.sharpfin import TRITON_AVAILABLE
diff --git a/modules/images.py b/modules/images.py
index 070beb9a4..068141e77 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,10 +1,11 @@
-from modules.image.util import flatten, draw_text # pylint: disable=unused-import
-from modules.image.save import save_image # pylint: disable=unused-import
-from modules.image.convert import to_pil, to_tensor # pylint: disable=unused-import
-from modules.image.metadata import read_info_from_image, image_data # pylint: disable=unused-import
-from modules.image.resize import resize_image # pylint: disable=unused-import
-from modules.image.sharpfin import resize # pylint: disable=unused-import
-from modules.image.namegen import FilenameGenerator, get_next_sequence_number # pylint: disable=unused-import
-from modules.image.watermark import set_watermark, get_watermark # pylint: disable=unused-import
-from modules.image.grid import image_grid, get_grid_size, split_grid, combine_grid, check_grid_size, get_font, draw_grid_annotations, draw_prompt_matrix, GridAnnotation, Grid # pylint: disable=unused-import
-from modules.video import save_video # pylint: disable=unused-import
+from modules.image.metadata import image_data, read_info_from_image
+from modules.image.save import save_image, sanitize_filename_part
+from modules.image.resize import resize_image
+from modules.image.grid import image_grid, check_grid_size, get_grid_size, draw_grid_annotations, draw_prompt_matrix
+
+__all__ = [
+ 'image_data', 'read_info_from_image',
+ 'save_image', 'sanitize_filename_part',
+ 'resize_image',
+ 'image_grid', 'check_grid_size', 'get_grid_size', 'draw_grid_annotations', 'draw_prompt_matrix'
+]
diff --git a/modules/img2img.py b/modules/img2img.py
index bc080f732..9b832e7a0 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -77,7 +77,7 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args)
caption_file = os.path.splitext(image_file)[0] + '.txt'
prompt_type='default'
if os.path.exists(caption_file):
- with open(caption_file, 'r', encoding='utf8') as f:
+ with open(caption_file, encoding='utf8') as f:
p.prompt = f.read()
prompt_type='file'
else:
diff --git a/modules/infotext.py b/modules/infotext.py
index dbee9fbd9..75a955e2d 100644
--- a/modules/infotext.py
+++ b/modules/infotext.py
@@ -129,7 +129,7 @@ if __name__ == '__main__':
import sys
if len(sys.argv) > 1:
if os.path.exists(sys.argv[1]):
- with open(sys.argv[1], 'r', encoding='utf8') as f:
+ with open(sys.argv[1], encoding='utf8') as f:
parse(f.read())
else:
parse(sys.argv[1])
diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py
index 6d947fd3f..b32c1e558 100644
--- a/modules/infotext_utils.py
+++ b/modules/infotext_utils.py
@@ -1,3 +1,2 @@
# a1111 compatibility module: unused
-from modules.infotext import parse as parse_generation_parameters # pylint: disable=unused-import
diff --git a/modules/intel/openvino/__init__.py b/modules/intel/openvino/__init__.py
index e14b26ec0..0491525ad 100644
--- a/modules/intel/openvino/__init__.py
+++ b/modules/intel/openvino/__init__.py
@@ -81,7 +81,9 @@ def warn_once(msg):
warned = True
class OpenVINOGraphModule(torch.nn.Module):
- def __init__(self, gm, partition_id, use_python_fusion_cache, model_hash_str: str = None, file_name="", int_inputs=[]):
+ def __init__(self, gm, partition_id, use_python_fusion_cache, model_hash_str: str = None, file_name="", int_inputs=None):
+ if int_inputs is None:
+ int_inputs = []
super().__init__()
self.gm = gm
self.int_inputs = int_inputs
@@ -192,7 +194,7 @@ def execute(
elif executor == "strictly_openvino":
return openvino_execute(gm, *args, executor_parameters=executor_parameters, file_name=file_name)
- msg = "Received unexpected value for 'executor': {0}. Allowed values are: openvino, strictly_openvino.".format(executor)
+ msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: openvino, strictly_openvino."
raise ValueError(msg)
@@ -373,7 +375,7 @@ def openvino_execute(gm: GraphModule, *args, executor_parameters=None, partition
ov_inputs = []
for arg in flat_args:
if not isinstance(arg, int):
- ov_inputs.append((arg.detach().cpu().numpy()))
+ ov_inputs.append(arg.detach().cpu().numpy())
res = req.infer(ov_inputs, share_inputs=True, share_outputs=True)
@@ -423,7 +425,9 @@ def openvino_execute_partitioned(gm: GraphModule, *args, executor_parameters=Non
return shared.compiled_model_state.partitioned_modules[signature][0](*ov_inputs)
-def partition_graph(gm: GraphModule, use_python_fusion_cache: bool, model_hash_str: str = None, file_name="", int_inputs=[]):
+def partition_graph(gm: GraphModule, use_python_fusion_cache: bool, model_hash_str: str = None, file_name="", int_inputs=None):
+ if int_inputs is None:
+ int_inputs = []
for node in gm.graph.nodes:
if node.op == "call_module" and "fused_" in node.name:
openvino_submodule = getattr(gm, node.name)
@@ -509,7 +513,7 @@ def openvino_fx(subgraph, example_inputs, options=None):
if os.path.isfile(maybe_fs_cached_name + ".xml") and os.path.isfile(maybe_fs_cached_name + ".bin"):
example_inputs_reordered = []
if (os.path.isfile(maybe_fs_cached_name + ".txt")):
- f = open(maybe_fs_cached_name + ".txt", "r")
+ f = open(maybe_fs_cached_name + ".txt")
for input_data in example_inputs:
shape = f.readline()
if (str(input_data.size()) != shape):
@@ -532,7 +536,7 @@ def openvino_fx(subgraph, example_inputs, options=None):
if (shared.compiled_model_state.cn_model != [] and str(shared.compiled_model_state.cn_model) in maybe_fs_cached_name):
args_reordered = []
if (os.path.isfile(maybe_fs_cached_name + ".txt")):
- f = open(maybe_fs_cached_name + ".txt", "r")
+ f = open(maybe_fs_cached_name + ".txt")
for input_data in args:
shape = f.readline()
if (str(input_data.size()) != shape):
diff --git a/modules/ipadapter.py b/modules/ipadapter.py
index f29fc0d77..97ac171ae 100644
--- a/modules/ipadapter.py
+++ b/modules/ipadapter.py
@@ -288,7 +288,19 @@ def parse_params(p: processing.StableDiffusionProcessing, adapters: list, adapte
return adapter_images, adapter_masks, adapter_scales, adapter_crops, adapter_starts, adapter_ends
-def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapter_scales=[1.0], adapter_crops=[False], adapter_starts=[0.0], adapter_ends=[1.0], adapter_images=[]):
+def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=None, adapter_scales=None, adapter_crops=None, adapter_starts=None, adapter_ends=None, adapter_images=None):
+ if adapter_images is None:
+ adapter_images = []
+ if adapter_ends is None:
+ adapter_ends = [1.0]
+ if adapter_starts is None:
+ adapter_starts = [0.0]
+ if adapter_crops is None:
+ adapter_crops = [False]
+ if adapter_scales is None:
+ adapter_scales = [1.0]
+ if adapter_names is None:
+ adapter_names = []
global adapters_loaded # pylint: disable=global-statement
# overrides
if hasattr(p, 'ip_adapter_names'):
@@ -361,7 +373,7 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt
if adapter_starts[i] > 0:
adapter_scales[i] = 0.00
pipe.set_ip_adapter_scale(adapter_scales if len(adapter_scales) > 1 else adapter_scales[0])
- ip_str = [f'{os.path.splitext(adapter)[0]}:{scale}:{start}:{end}:{crop}' for adapter, scale, start, end, crop in zip(adapter_names, adapter_scales, adapter_starts, adapter_ends, adapter_crops)]
+ ip_str = [f'{os.path.splitext(adapter)[0]}:{scale}:{start}:{end}:{crop}' for adapter, scale, start, end, crop in zip(adapter_names, adapter_scales, adapter_starts, adapter_ends, adapter_crops, strict=False)]
if hasattr(pipe, 'transformer') and 'Nunchaku' in pipe.transformer.__class__.__name__:
if isinstance(repos, str):
sd_models.clear_caches(full=True)
diff --git a/modules/loader.py b/modules/loader.py
index b34e01dd6..f00f39882 100644
--- a/modules/loader.py
+++ b/modules/loader.py
@@ -63,7 +63,7 @@ try:
except Exception:
pass
try:
- import torch.distributed.distributed_c10d as _c10d # pylint: disable=unused-import,ungrouped-imports
+ pass # pylint: disable=unused-import,ungrouped-imports
except Exception:
errors.log.warning('Loader: torch is not built with distributed support')
@@ -73,7 +73,6 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvisi
torchvision = None
try:
import torchvision # pylint: disable=W0611,C0411
- import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them # pylint: disable=W0611,C0411
except Exception as e:
errors.log.error(f'Loader: torchvision=={torchvision.__version__ if "torchvision" in sys.modules else None} {e}')
if '_no_nep' in str(e):
@@ -100,7 +99,7 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
timer.startup.record("torch")
try:
- import bitsandbytes # pylint: disable=W0611,C0411
+ import bitsandbytes # pylint: disable=unused-import
_bnb = True
except Exception:
_bnb = False
@@ -132,7 +131,6 @@ except Exception as e:
errors.log.warning(f'Torch onnxruntime: {e}')
timer.startup.record("onnx")
-from fastapi import FastAPI # pylint: disable=W0611,C0411
timer.startup.record("fastapi")
import gradio # pylint: disable=W0611,C0411
@@ -161,17 +159,16 @@ except Exception as e:
sys.exit(1)
try:
- import pillow_jxl # pylint: disable=W0611,C0411
+ pass # pylint: disable=W0611,C0411
except Exception:
pass
-from PIL import Image # pylint: disable=W0611,C0411
timer.startup.record("pillow")
import cv2 # pylint: disable=W0611,C0411
timer.startup.record("cv2")
-class _tqdm_cls():
+class _tqdm_cls:
def __call__(self, *args, **kwargs):
bar_format = 'Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + '{desc}' + '\x1b[0m'
return tqdm_lib.tqdm(*args, bar_format=bar_format, ncols=80, colour='#327fba', **kwargs)
diff --git a/modules/localization.py b/modules/localization.py
index e3cc19959..e6fc9fdbe 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -28,7 +28,7 @@ def localization_js(current_localization_name):
data = {}
if fn is not None:
try:
- with open(fn, "r", encoding="utf8") as file:
+ with open(fn, encoding="utf8") as file:
data = json.load(file)
except Exception as e:
errors.log.error(f"Error loading localization from {fn}:")
diff --git a/modules/lora/extra_networks_lora.py b/modules/lora/extra_networks_lora.py
index cd23998a8..1dabaaa01 100644
--- a/modules/lora/extra_networks_lora.py
+++ b/modules/lora/extra_networks_lora.py
@@ -1,4 +1,3 @@
-from typing import List
import os
import re
import numpy as np
@@ -19,7 +18,7 @@ def get_stepwise(param, step, steps): # from https://github.com/cheald/sd-webui-
return steps[0][0]
steps = [[s[0], s[1] if len(s) == 2 else 1] for s in steps] # Add implicit 1s to any steps which don't have a weight
steps.sort(key=lambda k: k[1]) # Sort by index
- steps = [list(v) for v in zip(*steps)]
+ steps = [list(v) for v in zip(*steps, strict=False)]
return steps
def calculate_weight(m, step, max_steps, step_offset=2):
@@ -170,10 +169,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
self.model = None
self.errors = {}
- def signature(self, names: List[str], te_multipliers: List, unet_multipliers: List):
- return [f'{name}:{te}:{unet}' for name, te, unet in zip(names, te_multipliers, unet_multipliers)]
+ def signature(self, names: list[str], te_multipliers: list, unet_multipliers: list):
+ return [f'{name}:{te}:{unet}' for name, te, unet in zip(names, te_multipliers, unet_multipliers, strict=False)]
- def changed(self, requested: List[str], include: List[str] = None, exclude: List[str] = None) -> bool:
+ def changed(self, requested: list[str], include: list[str] = None, exclude: list[str] = None) -> bool:
if shared.opts.lora_force_reload:
debug_log(f'Network check: type=LoRA requested={requested} status=forced')
return True
@@ -190,7 +189,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
sd_model.loaded_loras[key] = requested
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=changed')
return True
- for req, load in zip(requested, loaded):
+ for req, load in zip(requested, loaded, strict=False):
if req != load:
sd_model.loaded_loras[key] = requested
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=changed')
@@ -198,7 +197,11 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=same')
return False
- def activate(self, p, params_list, step=0, include=[], exclude=[]):
+ def activate(self, p, params_list, step=0, include=None, exclude=None):
+ if exclude is None:
+ exclude = []
+ if include is None:
+ include = []
self.errors.clear()
if self.active:
if self.model != shared.opts.sd_model_checkpoint: # reset if model changed
diff --git a/modules/lora/lora_apply.py b/modules/lora/lora_apply.py
index e79306c9f..5cf0a4d38 100644
--- a/modules/lora/lora_apply.py
+++ b/modules/lora/lora_apply.py
@@ -1,4 +1,3 @@
-from typing import Union
import re
import time
import torch
@@ -12,7 +11,7 @@ bnb = None
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
-def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, wanted_names: tuple):
+def network_backup_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, network_layer_name: str, wanted_names: tuple):
global bnb # pylint: disable=W0603
backup_size = 0
if len(l.loaded_networks) > 0 and network_layer_name is not None and any([net.modules.get(network_layer_name, None) for net in l.loaded_networks]): # noqa: C419 # pylint: disable=R1729
@@ -76,7 +75,7 @@ def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n
return backup_size
-def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, use_previous: bool = False):
+def network_calc_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, network_layer_name: str, use_previous: bool = False):
if shared.opts.diffusers_offload_mode == "none":
try:
self.to(devices.device)
@@ -147,7 +146,7 @@ def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.
return batch_updown, batch_ex_bias
-def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], model_weights: Union[None, torch.Tensor] = None, lora_weights: torch.Tensor = None, deactivate: bool = False, device: torch.device = None, bias: bool = False):
+def network_add_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, model_weights: None | torch.Tensor = None, lora_weights: torch.Tensor = None, deactivate: bool = False, device: torch.device = None, bias: bool = False):
if lora_weights is None:
return
if deactivate:
@@ -239,7 +238,7 @@ def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.G
del model_weights, lora_weights, new_weight, weight # required to avoid memory leak
-def network_apply_direct(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, deactivate: bool = False, device: torch.device = devices.device):
+def network_apply_direct(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, updown: torch.Tensor, ex_bias: torch.Tensor, deactivate: bool = False, device: torch.device = devices.device):
weights_backup = getattr(self, "network_weights_backup", False)
bias_backup = getattr(self, "network_bias_backup", False)
if not isinstance(weights_backup, bool): # remove previous backup if we switched settings
@@ -266,7 +265,7 @@ def network_apply_direct(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.
l.timer.apply += time.time() - t0
-def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, device: torch.device, deactivate: bool = False):
+def network_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, updown: torch.Tensor, ex_bias: torch.Tensor, device: torch.device, deactivate: bool = False):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None and bias_backup is None:
diff --git a/modules/lora/lora_common.py b/modules/lora/lora_common.py
index a6b15ae13..7f171846b 100644
--- a/modules/lora/lora_common.py
+++ b/modules/lora/lora_common.py
@@ -1,4 +1,3 @@
-from typing import List
import os
from modules.lora import lora_timers
from modules.lora import network_lora, network_hada, network_ia3, network_oft, network_lokr, network_full, network_norm, network_glora
@@ -16,6 +15,6 @@ module_types = [
network_norm.ModuleTypeNorm(),
network_glora.ModuleTypeGLora(),
]
-loaded_networks: List = [] # no type due to circular import
-previously_loaded_networks: List = [] # no type due to circular import
+loaded_networks: list = [] # no type due to circular import
+previously_loaded_networks: list = [] # no type due to circular import
extra_network_lora = None # initialized in extra_networks.py
diff --git a/modules/lora/lora_convert.py b/modules/lora/lora_convert.py
index aaef92e43..f019c450e 100644
--- a/modules/lora/lora_convert.py
+++ b/modules/lora/lora_convert.py
@@ -1,7 +1,6 @@
import os
import re
import bisect
-from typing import Dict
import torch
from modules import shared
@@ -23,7 +22,7 @@ re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_compiled = {}
-def make_unet_conversion_map() -> Dict[str, str]:
+def make_unet_conversion_map() -> dict[str, str]:
unet_conversion_map_layer = []
for i in range(4): # num_blocks is 3 in sdxl
@@ -213,10 +212,10 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
ait_sd.update({k: down_weight for k in ait_down_keys})
# up_weight is split to each split
- ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension
+ ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0), strict=False)}) # noqa: C416 # pylint: disable=unnecessary-comprehension
else:
# down_weight is chunked to each split
- ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension
+ ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0), strict=False)}) # noqa: C416 # pylint: disable=unnecessary-comprehension
# up_weight is sparse: only non-zero values are copied to each split
i = 0
diff --git a/modules/lora/lora_diffusers.py b/modules/lora/lora_diffusers.py
index eb1515ca0..02179fb21 100644
--- a/modules/lora/lora_diffusers.py
+++ b/modules/lora/lora_diffusers.py
@@ -1,4 +1,3 @@
-from typing import Union
import os
import time
import diffusers
@@ -50,7 +49,7 @@ def load_per_module(sd_model: diffusers.DiffusionPipeline, filename: str, adapte
return adapter_name
-def load_diffusers(name: str, network_on_disk: network.NetworkOnDisk, lora_scale:float=shared.opts.extra_networks_default_multiplier, lora_module=None) -> Union[network.Network, None]:
+def load_diffusers(name: str, network_on_disk: network.NetworkOnDisk, lora_scale:float=shared.opts.extra_networks_default_multiplier, lora_module=None) -> network.Network | None:
t0 = time.time()
name = name.replace(".", "_")
sd_model: diffusers.DiffusionPipeline = getattr(shared.sd_model, "pipe", shared.sd_model)
diff --git a/modules/lora/lora_load.py b/modules/lora/lora_load.py
index a836b5323..ff5659ea5 100644
--- a/modules/lora/lora_load.py
+++ b/modules/lora/lora_load.py
@@ -1,4 +1,3 @@
-from typing import Union
import os
import time
import concurrent
@@ -39,7 +38,7 @@ def lora_dump(lora, dct):
f.write(line + "\n")
-def load_safetensors(name, network_on_disk: network.NetworkOnDisk) -> Union[network.Network, None]:
+def load_safetensors(name, network_on_disk: network.NetworkOnDisk) -> network.Network | None:
if not shared.sd_loaded:
return None
@@ -241,7 +240,7 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
lora_diffusers.diffuser_scales.clear()
t0 = time.time()
- for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
+ for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names, strict=False)):
net = None
if network_on_disk is not None:
shorthash = getattr(network_on_disk, 'shorthash', '').lower()
diff --git a/modules/lora/lora_nunchaku.py b/modules/lora/lora_nunchaku.py
index de4773158..318e25398 100644
--- a/modules/lora/lora_nunchaku.py
+++ b/modules/lora/lora_nunchaku.py
@@ -10,7 +10,7 @@ def load_nunchaku(names, strengths):
global previously_loaded # pylint: disable=global-statement
strengths = [s[0] if isinstance(s, list) else s for s in strengths]
networks = lora_load.gather_networks(names)
- networks = [(network, strength) for network, strength in zip(networks, strengths) if network is not None and strength > 0]
+ networks = [(network, strength) for network, strength in zip(networks, strengths, strict=False) if network is not None and strength > 0]
loras = [(network.filename, strength) for network, strength in networks]
is_changed = loras != previously_loaded
if not is_changed:
diff --git a/modules/lora/lora_timers.py b/modules/lora/lora_timers.py
index 30c35a728..6f3e48c33 100644
--- a/modules/lora/lora_timers.py
+++ b/modules/lora/lora_timers.py
@@ -1,4 +1,4 @@
-class Timer():
+class Timer:
list: float = 0
load: float = 0
backup: float = 0
diff --git a/modules/lora/network.py b/modules/lora/network.py
index b8a09913b..a959a3338 100644
--- a/modules/lora/network.py
+++ b/modules/lora/network.py
@@ -1,6 +1,5 @@
import os
import enum
-from typing import Union
from collections import namedtuple
from modules import sd_models, hashes, shared
@@ -120,7 +119,7 @@ class NetworkOnDisk:
if self.filename is not None:
fn = os.path.splitext(self.filename)[0] + '.txt'
if os.path.exists(fn):
- with open(fn, "r", encoding="utf-8") as file:
+ with open(fn, encoding="utf-8") as file:
return file.read()
return None
@@ -144,7 +143,7 @@ class Network: # LoraModule
class ModuleType:
- def create_module(self, net: Network, weights: NetworkWeights) -> Union[Network, None]: # pylint: disable=W0613
+ def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: # pylint: disable=W0613
return None
diff --git a/modules/lora/networks.py b/modules/lora/networks.py
index 69df992cc..512a63698 100644
--- a/modules/lora/networks.py
+++ b/modules/lora/networks.py
@@ -11,7 +11,11 @@ applied_layers: list[str] = []
default_components = ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'text_encoder_4', 'unet', 'transformer', 'transformer_2']
-def network_activate(include=[], exclude=[]):
+def network_activate(include=None, exclude=None):
+ if exclude is None:
+ exclude = []
+ if include is None:
+ include = []
t0 = time.time()
with limit_errors("network_activate"):
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
@@ -77,7 +81,11 @@ def network_activate(include=[], exclude=[]):
sd_models.set_diffuser_offload(sd_model, op="model")
-def network_deactivate(include=[], exclude=[]):
+def network_deactivate(include=None, exclude=None):
+ if exclude is None:
+ exclude = []
+ if include is None:
+ include = []
if not shared.opts.lora_fuse_native or shared.opts.lora_force_diffusers:
return
if len(l.previously_loaded_networks) == 0:
diff --git a/modules/masking.py b/modules/masking.py
index ea1844c19..92de3bb75 100644
--- a/modules/masking.py
+++ b/modules/masking.py
@@ -1,5 +1,4 @@
from types import SimpleNamespace
-from typing import List
import os
import sys
import time
@@ -235,7 +234,7 @@ def run_segment(input_image: gr.Image, input_mask: np.ndarray):
combined_mask = np.zeros(input_mask.shape, dtype='uint8')
input_mask_size = np.count_nonzero(input_mask)
debug(f'Segment SAM: {vars(opts)}')
- for mask, score in zip(outputs['masks'], outputs['scores']):
+ for mask, score in zip(outputs['masks'], outputs['scores'], strict=False):
mask = mask.astype('uint8')
mask_size = np.count_nonzero(mask)
if mask_size == 0:
@@ -561,7 +560,7 @@ def create_segment_ui():
return controls
-def bind_controls(image_controls: List[gr.Image], preview_image: gr.Image, output_image: gr.Image):
+def bind_controls(image_controls: list[gr.Image], preview_image: gr.Image, output_image: gr.Image):
for image_control in image_controls:
btn_mask.click(run_mask, inputs=[image_control], outputs=[preview_image])
btn_lama.click(run_lama, inputs=[image_control], outputs=[output_image])
diff --git a/modules/memmon.py b/modules/memmon.py
index 944d85d83..eb521f93f 100644
--- a/modules/memmon.py
+++ b/modules/memmon.py
@@ -2,7 +2,7 @@ from collections import defaultdict
import torch
-class MemUsageMonitor():
+class MemUsageMonitor:
device = None
disabled = False
opts = None
diff --git a/modules/memstats.py b/modules/memstats.py
index c9bb13238..0d67da44d 100644
--- a/modules/memstats.py
+++ b/modules/memstats.py
@@ -24,7 +24,7 @@ def get_docker_limit():
if docker_limit is not None:
return docker_limit
try:
- with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r', encoding='utf8') as f:
+ with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', encoding='utf8') as f:
docker_limit = float(f.read())
except Exception:
docker_limit = sys.float_info.max
@@ -145,7 +145,9 @@ class Object:
return f'{self.fn}.{self.name} type={self.type} size={self.size} ref={self.refcount}'
-def get_objects(gcl={}, threshold:int=0):
+def get_objects(gcl=None, threshold:int=0):
+ if gcl is None:
+ gcl = {}
objects = []
seen = []
diff --git a/modules/merging/convert_sdxl.py b/modules/merging/convert_sdxl.py
index 93fc71f5d..3238cfd35 100644
--- a/modules/merging/convert_sdxl.py
+++ b/modules/merging/convert_sdxl.py
@@ -260,7 +260,9 @@ def calculate_model_hash(state_dict):
return func.hexdigest()
-def convert(model_path:str, checkpoint_path:str, metadata:dict={}):
+def convert(model_path:str, checkpoint_path:str, metadata:dict=None):
+ if metadata is None:
+ metadata = {}
unet_path = os.path.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = os.path.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = os.path.join(model_path, "text_encoder", "model.safetensors")
diff --git a/modules/merging/merge.py b/modules/merging/merge.py
index d0e48bfba..ec493de2e 100644
--- a/modules/merging/merge.py
+++ b/modules/merging/merge.py
@@ -1,7 +1,6 @@
import os
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
-from typing import Dict, Optional, Tuple, Set
import safetensors.torch
import torch
import modules.memstats
@@ -37,7 +36,7 @@ KEY_POSITION_IDS = ".".join(
)
-def fix_clip(model: Dict) -> Dict:
+def fix_clip(model: dict) -> dict:
if KEY_POSITION_IDS in model.keys():
model[KEY_POSITION_IDS] = torch.tensor(
[list(range(MAX_TOKENS))],
@@ -48,7 +47,7 @@ def fix_clip(model: Dict) -> Dict:
return model
-def prune_sd_model(model: Dict, keyset: Set) -> Dict:
+def prune_sd_model(model: dict, keyset: set) -> dict:
keys = list(model.keys())
for k in keys:
if (
@@ -60,7 +59,7 @@ def prune_sd_model(model: Dict, keyset: Set) -> Dict:
return model
-def restore_sd_model(original_model: Dict, merged_model: Dict) -> Dict:
+def restore_sd_model(original_model: dict, merged_model: dict) -> dict:
for k in original_model:
if k not in merged_model:
merged_model[k] = original_model[k]
@@ -72,11 +71,11 @@ def log_vram(txt=""):
def load_thetas(
- models: Dict[str, os.PathLike],
+ models: dict[str, os.PathLike],
prune: bool,
device: torch.device,
precision: str,
-) -> Dict:
+) -> dict:
from tensordict import TensorDict
thetas = {k: TensorDict.from_dict(read_state_dict(m, "cpu")) for k, m in models.items()}
if prune:
@@ -95,7 +94,7 @@ def load_thetas(
def merge_models(
- models: Dict[str, os.PathLike],
+ models: dict[str, os.PathLike],
merge_mode: str,
precision: str = "fp16",
weights_clip: bool = False,
@@ -104,7 +103,7 @@ def merge_models(
prune: bool = False,
threads: int = 4,
**kwargs,
-) -> Dict:
+) -> dict:
thetas = load_thetas(models, prune, device, precision)
# log.info(f'Merge start: models={models.values()} precision={precision} clip={weights_clip} rebasin={re_basin} prune={prune} threads={threads}')
weight_matcher = WeightClass(thetas["model_a"], **kwargs)
@@ -136,13 +135,13 @@ def merge_models(
def un_prune_model(
- merged: Dict,
- thetas: Dict,
- models: Dict,
+ merged: dict,
+ thetas: dict,
+ models: dict,
device: torch.device,
prune: bool,
precision: str,
-) -> Dict:
+) -> dict:
if prune:
log.info("Merge restoring pruned keys")
del thetas
@@ -180,7 +179,7 @@ def un_prune_model(
def simple_merge(
- thetas: Dict[str, Dict],
+ thetas: dict[str, dict],
weight_matcher: WeightClass,
merge_mode: str,
precision: str = "fp16",
@@ -188,7 +187,7 @@ def simple_merge(
device: torch.device = None,
work_device: torch.device = None,
threads: int = 4,
-) -> Dict:
+) -> dict:
futures = []
import rich.progress as p
with p.Progress(p.TextColumn('[cyan]{task.description}'), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TextColumn('[cyan]keys={task.fields[keys]}'), console=console) as progress:
@@ -227,7 +226,7 @@ def simple_merge(
def rebasin_merge(
- thetas: Dict[str, os.PathLike],
+ thetas: dict[str, os.PathLike],
weight_matcher: WeightClass,
merge_mode: str,
precision: str = "fp16",
@@ -306,14 +305,14 @@ def simple_merge_key(progress, task, key, thetas, *args, **kwargs):
def merge_key( # pylint: disable=inconsistent-return-statements
key: str,
- thetas: Dict,
+ thetas: dict,
weight_matcher: WeightClass,
merge_mode: str,
precision: str = "fp16",
weights_clip: bool = False,
device: torch.device = None,
work_device: torch.device = None,
-) -> Optional[Tuple[str, Dict]]:
+) -> tuple[str, dict] | None:
if work_device is None:
work_device = device
@@ -376,11 +375,11 @@ def merge_key_context(*args, **kwargs):
def get_merge_method_args(
- current_bases: Dict,
- thetas: Dict,
+ current_bases: dict,
+ thetas: dict,
key: str,
work_device: torch.device,
-) -> Dict:
+) -> dict:
merge_method_args = {
"a": thetas["model_a"][key].to(work_device),
"b": thetas["model_b"][key].to(work_device),
diff --git a/modules/merging/merge_methods.py b/modules/merging/merge_methods.py
index ce196b60c..3e54d501f 100644
--- a/modules/merging/merge_methods.py
+++ b/modules/merging/merge_methods.py
@@ -1,5 +1,4 @@
import math
-from typing import Tuple
import torch
from torch import Tensor
@@ -151,7 +150,7 @@ def kth_abs_value(a: Tensor, k: int) -> Tensor:
return torch.kthvalue(torch.abs(a.float()), k)[0]
-def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]:
+def ratio_to_region(width: float, offset: float, n: int) -> tuple[int, int, bool]:
if width < 0:
offset += width
width = -width
@@ -233,7 +232,7 @@ def ties_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: flo
delta_filters = (signs == final_sign).float()
res = torch.zeros_like(c, device=c.device)
- for delta_filter, delta in zip(delta_filters, deltas):
+ for delta_filter, delta in zip(delta_filters, deltas, strict=False):
res += delta_filter * delta
param_count = torch.sum(delta_filters, dim=0)
diff --git a/modules/merging/modules_sdxl.py b/modules/merging/modules_sdxl.py
index 959ad36e5..529066fc7 100644
--- a/modules/merging/modules_sdxl.py
+++ b/modules/merging/modules_sdxl.py
@@ -206,7 +206,7 @@ def test_model(pipe: diffusers.StableDiffusionXLPipeline, fn: str, **kwargs):
if not test.generate:
return
try:
- generator = torch.Generator(devices.device).manual_seed(int(4242))
+ generator = torch.Generator(devices.device).manual_seed(4242)
args = {
'prompt': test.prompt,
'negative_prompt': test.negative,
@@ -278,7 +278,7 @@ def save_model(pipe: diffusers.StableDiffusionXLPipeline):
yield msg(f'pretrained={folder}')
shared.log.info(f'Modules merge save: type=sdxl diffusers="{folder}"')
pipe.save_pretrained(folder, safe_serialization=True, push_to_hub=False)
- with open(os.path.join(folder, 'vae', 'config.json'), 'r', encoding='utf8') as f:
+ with open(os.path.join(folder, 'vae', 'config.json'), encoding='utf8') as f:
vae_config = json.load(f)
vae_config['force_upcast'] = False
vae_config['scaling_factor'] = 0.13025
diff --git a/modules/mit_nunchaku.py b/modules/mit_nunchaku.py
index be6f84564..6ba77007e 100644
--- a/modules/mit_nunchaku.py
+++ b/modules/mit_nunchaku.py
@@ -53,8 +53,6 @@ def install_nunchaku():
import os
import sys
import platform
- import importlib
- import importlib.metadata
import torch
python_ver = f'{sys.version_info.major}{sys.version_info.minor}'
if python_ver not in ['311', '312', '313']:
diff --git a/modules/modeldata.py b/modules/modeldata.py
index 7e2164095..48172baa3 100644
--- a/modules/modeldata.py
+++ b/modules/modeldata.py
@@ -220,5 +220,13 @@ class Shared(sys.modules[__name__].__class__):
model_type = 'unknown'
return model_type
+ @property
+ def console(self):
+ try:
+ from installer import get_console
+ return get_console()
+ except ImportError:
+ return None
+
model_data = ModelData()
diff --git a/modules/modelloader.py b/modules/modelloader.py
index ea45bea82..c168d981b 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -4,13 +4,12 @@ import time
import shutil
import importlib
import contextlib
-from typing import Dict
from urllib.parse import urlparse
import huggingface_hub as hf
from installer import install, log
from modules import shared, errors, files_cache
from modules.upscaler import Upscaler
-from modules.paths import script_path, models_path
+from modules import paths
loggedin = None
@@ -55,7 +54,7 @@ def hf_login(token=None):
return True
-def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config: Dict[str, str] = None, token = None, variant = None, revision = None, mirror = None, custom_pipeline = None):
+def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config: dict[str, str] = None, token = None, variant = None, revision = None, mirror = None, custom_pipeline = None):
if hub_id is None or len(hub_id) == 0:
return None
from diffusers import DiffusionPipeline
@@ -117,7 +116,7 @@ def load_diffusers_models(clear=True):
# t0 = time.time()
place = shared.opts.diffusers_dir
if place is None or len(place) == 0 or not os.path.isdir(place):
- place = os.path.join(models_path, 'Diffusers')
+ place = os.path.join(paths.models_path, 'Diffusers')
if clear:
diffuser_repos.clear()
already_found = []
@@ -382,25 +381,25 @@ def cleanup_models():
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
# somehow auto-register and just do these things...
- root_path = script_path
- src_path = models_path
- dest_path = os.path.join(models_path, "Stable-diffusion")
+ root_path = paths.script_path
+ src_path = paths.models_path
+ dest_path = os.path.join(paths.models_path, "Stable-diffusion")
# move_files(src_path, dest_path, ".ckpt")
# move_files(src_path, dest_path, ".safetensors")
src_path = os.path.join(root_path, "ESRGAN")
- dest_path = os.path.join(models_path, "ESRGAN")
+ dest_path = os.path.join(paths.models_path, "ESRGAN")
move_files(src_path, dest_path)
- src_path = os.path.join(models_path, "BSRGAN")
- dest_path = os.path.join(models_path, "ESRGAN")
+ src_path = os.path.join(paths.models_path, "BSRGAN")
+ dest_path = os.path.join(paths.models_path, "ESRGAN")
move_files(src_path, dest_path, ".pth")
src_path = os.path.join(root_path, "SwinIR")
- dest_path = os.path.join(models_path, "SwinIR")
+ dest_path = os.path.join(paths.models_path, "SwinIR")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
- dest_path = os.path.join(models_path, "LDSR")
+ dest_path = os.path.join(paths.models_path, "LDSR")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "SCUNet")
- dest_path = os.path.join(models_path, "SCUNet")
+ dest_path = os.path.join(paths.models_path, "SCUNet")
move_files(src_path, dest_path)
@@ -430,7 +429,7 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced, so we'll try to import any _model.py files before looking in __subclasses__
t0 = time.time()
- modules_dir = os.path.join(shared.script_path, "modules", "postprocess")
+ modules_dir = os.path.join(paths.script_path, "modules", "postprocess")
for file in os.listdir(modules_dir):
if "_model.py" in file:
model_name = file.replace("_model.py", "")
diff --git a/modules/modelstats.py b/modules/modelstats.py
index 2ab1a6b62..b1fc54d8d 100644
--- a/modules/modelstats.py
+++ b/modules/modelstats.py
@@ -28,7 +28,7 @@ def stat(fn: str):
return size, mtime
-class Module():
+class Module:
name: str = ''
cls: str = None
device: str = None
@@ -61,7 +61,7 @@ class Module():
return s
-class Model():
+class Model:
name: str = ''
fn: str = ''
type: str = ''
diff --git a/modules/olive_script.py b/modules/olive_script.py
index c881679ed..c194e7f42 100644
--- a/modules/olive_script.py
+++ b/modules/olive_script.py
@@ -1,17 +1,18 @@
import os
-from typing import Type, Callable, TypeVar, Dict, Any
+from typing import TypeVar, Any
+from collections.abc import Callable
import torch
import diffusers
from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPTextModelWithProjection
class ENVStore:
- __DESERIALIZER: Dict[Type, Callable[[str,], Any]] = {
+ __DESERIALIZER: dict[type, Callable[[str,], Any]] = {
bool: lambda x: bool(int(x)),
int: int,
str: lambda x: x,
}
- __SERIALIZER: Dict[Type, Callable[[Any,], str]] = {
+ __SERIALIZER: dict[type, Callable[[Any,], str]] = {
bool: lambda x: str(int(x)),
int: str,
str: lambda x: x,
@@ -89,7 +90,7 @@ def get_loader_arguments(no_variant: bool = False):
T = TypeVar("T")
-def from_pretrained(cls: Type[T], pretrained_model_name_or_path: os.PathLike, *args, no_variant: bool = False, **kwargs) -> T:
+def from_pretrained(cls: type[T], pretrained_model_name_or_path: os.PathLike, *args, no_variant: bool = False, **kwargs) -> T:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if pretrained_model_name_or_path.endswith(".onnx"):
cls = diffusers.OnnxRuntimeModel
diff --git a/modules/onnx_impl/__init__.py b/modules/onnx_impl/__init__.py
index f7ef1b8be..1decd1d41 100644
--- a/modules/onnx_impl/__init__.py
+++ b/modules/onnx_impl/__init__.py
@@ -16,7 +16,7 @@ except Exception as e:
class DynamicSessionOptions(ort.SessionOptions):
- config: Optional[Dict] = None
+ config: dict | None = None
def __init__(self):
super().__init__()
@@ -28,7 +28,7 @@ class DynamicSessionOptions(ort.SessionOptions):
return sess_options.copy()
return DynamicSessionOptions()
- def enable_static_dims(self, config: Dict):
+ def enable_static_dims(self, config: dict):
self.config = config
self.add_free_dimension_override_by_name("unet_sample_batch", config["hidden_batch_size"])
self.add_free_dimension_override_by_name("unet_sample_channels", 4)
@@ -103,9 +103,9 @@ class OnnxRuntimeModel(TorchCompatibleModule, diffusers.OnnxRuntimeModel):
class VAEConfig:
DEFAULTS = { "scaling_factor": 0.18215 }
- config: Dict
+ config: dict
- def __init__(self, config: Dict):
+ def __init__(self, config: dict):
self.config = config
def __getattr__(self, key):
diff --git a/modules/onnx_impl/execution_providers.py b/modules/onnx_impl/execution_providers.py
index dd2622f1f..b220692de 100644
--- a/modules/onnx_impl/execution_providers.py
+++ b/modules/onnx_impl/execution_providers.py
@@ -1,6 +1,5 @@
import sys
from enum import Enum
-from typing import Tuple, List
from installer import log
from modules import devices
@@ -33,7 +32,7 @@ TORCH_DEVICE_TO_EP = {
try:
import onnxruntime as ort
- available_execution_providers: List[ExecutionProvider] = ort.get_available_providers()
+ available_execution_providers: list[ExecutionProvider] = ort.get_available_providers()
except Exception as e:
log.error(f'ONNX import error: {e}')
available_execution_providers = []
@@ -90,7 +89,7 @@ def get_execution_provider_options():
return execution_provider_options
-def get_provider() -> Tuple:
+def get_provider() -> tuple:
from modules.shared import opts
return (opts.onnx_execution_provider, get_execution_provider_options(),)
diff --git a/modules/onnx_impl/pipelines/__init__.py b/modules/onnx_impl/pipelines/__init__.py
index 99682b1aa..9b57d6577 100644
--- a/modules/onnx_impl/pipelines/__init__.py
+++ b/modules/onnx_impl/pipelines/__init__.py
@@ -103,12 +103,12 @@ class OnnxRawPipeline(PipelineBase):
path: os.PathLike
original_filename: str
- constructor: Type[PipelineBase]
- init_dict: Dict[str, Tuple[str]] = {}
+ constructor: type[PipelineBase]
+ init_dict: dict[str, tuple[str]] = {}
default_scheduler: Any = None # for Img2Img
- def __init__(self, constructor: Type[PipelineBase], path: os.PathLike): # pylint: disable=super-init-not-called
+ def __init__(self, constructor: type[PipelineBase], path: os.PathLike): # pylint: disable=super-init-not-called
self._is_sdxl = check_pipeline_sdxl(constructor)
self.from_diffusers_cache = check_diffusers_cache(path)
self.path = path
@@ -150,7 +150,7 @@ class OnnxRawPipeline(PipelineBase):
pipeline.scheduler = self.default_scheduler
return pipeline
- def convert(self, submodels: List[str], in_dir: os.PathLike, out_dir: os.PathLike):
+ def convert(self, submodels: list[str], in_dir: os.PathLike, out_dir: os.PathLike):
install('onnx') # may not be installed yet, this performs check and installs as needed
import onnx
shutil.rmtree("cache", ignore_errors=True)
@@ -218,7 +218,7 @@ class OnnxRawPipeline(PipelineBase):
with open(os.path.join(out_dir, "model_index.json"), 'w', encoding="utf-8") as file:
json.dump(model_index, file)
- def run_olive(self, submodels: List[str], in_dir: os.PathLike, out_dir: os.PathLike):
+ def run_olive(self, submodels: list[str], in_dir: os.PathLike, out_dir: os.PathLike):
from olive.model import ONNXModelHandler
from olive.workflows import run as run_workflows
@@ -235,8 +235,8 @@ class OnnxRawPipeline(PipelineBase):
for submodel in submodels:
log.info(f"\nProcessing {submodel}")
- with open(os.path.join(sd_configs_path, "olive", 'sdxl' if self._is_sdxl else 'sd', f"{submodel}.json"), "r", encoding="utf-8") as config_file:
- olive_config: Dict[str, Dict[str, Dict]] = json.load(config_file)
+ with open(os.path.join(sd_configs_path, "olive", 'sdxl' if self._is_sdxl else 'sd', f"{submodel}.json"), encoding="utf-8") as config_file:
+ olive_config: dict[str, dict[str, dict]] = json.load(config_file)
for flow in olive_config["pass_flows"]:
for i in range(len(flow)):
@@ -257,7 +257,7 @@ class OnnxRawPipeline(PipelineBase):
run_workflows(olive_config)
- with open(os.path.join("footprints", f"{submodel}_{EP_TO_NAME[shared.opts.onnx_execution_provider]}_footprints.json"), "r", encoding="utf-8") as footprint_file:
+ with open(os.path.join("footprints", f"{submodel}_{EP_TO_NAME[shared.opts.onnx_execution_provider]}_footprints.json"), encoding="utf-8") as footprint_file:
footprints = json.load(footprint_file)
processor_final_pass_footprint = None
for _, footprint in footprints.items():
diff --git a/modules/onnx_impl/pipelines/onnx_stable_diffusion_img2img_pipeline.py b/modules/onnx_impl/pipelines/onnx_stable_diffusion_img2img_pipeline.py
index 6d8ea5946..82c9740a8 100644
--- a/modules/onnx_impl/pipelines/onnx_stable_diffusion_img2img_pipeline.py
+++ b/modules/onnx_impl/pipelines/onnx_stable_diffusion_img2img_pipeline.py
@@ -1,5 +1,6 @@
import inspect
-from typing import Union, Optional, Callable, List, Any
+from typing import Any
+from collections.abc import Callable
import numpy as np
import torch
import diffusers
@@ -33,20 +34,20 @@ class OnnxStableDiffusionImg2ImgPipeline(diffusers.OnnxStableDiffusionImg2ImgPip
def __call__(
self,
- prompt: Union[str, List[str]],
+ prompt: str | list[str],
image: PipelineImageInput = None,
strength: float = 0.8,
- num_inference_steps: Optional[int] = 50,
- guidance_scale: Optional[float] = 7.5,
- negative_prompt: Optional[Union[str, List[str]]] = None,
- num_images_per_prompt: Optional[int] = 1,
- eta: Optional[float] = 0.0,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- prompt_embeds: Optional[np.ndarray] = None,
- negative_prompt_embeds: Optional[np.ndarray] = None,
- output_type: Optional[str] = "pil",
+ num_inference_steps: int | None = 50,
+ guidance_scale: float | None = 7.5,
+ negative_prompt: str | list[str] | None = None,
+ num_images_per_prompt: int | None = 1,
+ eta: float | None = 0.0,
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ prompt_embeds: np.ndarray | None = None,
+ negative_prompt_embeds: np.ndarray | None = None,
+ output_type: str | None = "pil",
return_dict: bool = True,
- callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback: Callable[[int, int, np.ndarray], None] | None = None,
callback_steps: int = 1,
):
# check inputs. Raise error if not correct
diff --git a/modules/onnx_impl/pipelines/onnx_stable_diffusion_inpaint_pipeline.py b/modules/onnx_impl/pipelines/onnx_stable_diffusion_inpaint_pipeline.py
index dccfb808d..e8ce33fc4 100644
--- a/modules/onnx_impl/pipelines/onnx_stable_diffusion_inpaint_pipeline.py
+++ b/modules/onnx_impl/pipelines/onnx_stable_diffusion_inpaint_pipeline.py
@@ -1,5 +1,6 @@
import inspect
-from typing import Union, Optional, Callable, List, Any
+from typing import Any
+from collections.abc import Callable
import numpy as np
import torch
import diffusers
@@ -31,25 +32,25 @@ class OnnxStableDiffusionInpaintPipeline(diffusers.OnnxStableDiffusionInpaintPip
@torch.no_grad()
def __call__(
self,
- prompt: Union[str, List[str]],
+ prompt: str | list[str],
image: PipelineImageInput,
mask_image: PipelineImageInput,
masked_image_latents: torch.FloatTensor = None,
- height: Optional[int] = 512,
- width: Optional[int] = 512,
+ height: int | None = 512,
+ width: int | None = 512,
strength: float = 1.0,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
- negative_prompt: Optional[Union[str, List[str]]] = None,
- num_images_per_prompt: Optional[int] = 1,
+ negative_prompt: str | list[str] | None = None,
+ num_images_per_prompt: int | None = 1,
eta: float = 0.0,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- latents: Optional[np.ndarray] = None,
- prompt_embeds: Optional[np.ndarray] = None,
- negative_prompt_embeds: Optional[np.ndarray] = None,
- output_type: Optional[str] = "pil",
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: np.ndarray | None = None,
+ prompt_embeds: np.ndarray | None = None,
+ negative_prompt_embeds: np.ndarray | None = None,
+ output_type: str | None = "pil",
return_dict: bool = True,
- callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback: Callable[[int, int, np.ndarray], None] | None = None,
callback_steps: int = 1,
):
# check inputs. Raise error if not correct
diff --git a/modules/onnx_impl/pipelines/onnx_stable_diffusion_pipeline.py b/modules/onnx_impl/pipelines/onnx_stable_diffusion_pipeline.py
index 112241996..2b583e8f5 100644
--- a/modules/onnx_impl/pipelines/onnx_stable_diffusion_pipeline.py
+++ b/modules/onnx_impl/pipelines/onnx_stable_diffusion_pipeline.py
@@ -1,5 +1,6 @@
import inspect
-from typing import Union, Optional, Callable, List, Any
+from typing import Any
+from collections.abc import Callable
import numpy as np
import torch
import diffusers
@@ -29,21 +30,21 @@ class OnnxStableDiffusionPipeline(diffusers.OnnxStableDiffusionPipeline, Callabl
def __call__(
self,
- prompt: Union[str, List[str]] = None,
- height: Optional[int] = 512,
- width: Optional[int] = 512,
- num_inference_steps: Optional[int] = 50,
- guidance_scale: Optional[float] = 7.5,
- negative_prompt: Optional[Union[str, List[str]]] = None,
- num_images_per_prompt: Optional[int] = 1,
- eta: Optional[float] = 0.0,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- latents: Optional[np.ndarray] = None,
- prompt_embeds: Optional[np.ndarray] = None,
- negative_prompt_embeds: Optional[np.ndarray] = None,
- output_type: Optional[str] = "pil",
+ prompt: str | list[str] = None,
+ height: int | None = 512,
+ width: int | None = 512,
+ num_inference_steps: int | None = 50,
+ guidance_scale: float | None = 7.5,
+ negative_prompt: str | list[str] | None = None,
+ num_images_per_prompt: int | None = 1,
+ eta: float | None = 0.0,
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: np.ndarray | None = None,
+ prompt_embeds: np.ndarray | None = None,
+ negative_prompt_embeds: np.ndarray | None = None,
+ output_type: str | None = "pil",
return_dict: bool = True,
- callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback: Callable[[int, int, np.ndarray], None] | None = None,
callback_steps: int = 1,
):
# check inputs. Raise error if not correct
diff --git a/modules/onnx_impl/pipelines/onnx_stable_diffusion_upscale_pipeline.py b/modules/onnx_impl/pipelines/onnx_stable_diffusion_upscale_pipeline.py
index 5bdc09794..f575959ab 100644
--- a/modules/onnx_impl/pipelines/onnx_stable_diffusion_upscale_pipeline.py
+++ b/modules/onnx_impl/pipelines/onnx_stable_diffusion_upscale_pipeline.py
@@ -1,5 +1,6 @@
import inspect
-from typing import Union, Optional, Callable, Any, List
+from typing import Any
+from collections.abc import Callable
import torch
import numpy as np
import diffusers
@@ -31,22 +32,22 @@ class OnnxStableDiffusionUpscalePipeline(diffusers.OnnxStableDiffusionUpscalePip
def __call__(
self,
- prompt: Union[str, List[str]],
+ prompt: str | list[str],
image: PipelineImageInput = None,
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
- negative_prompt: Optional[Union[str, List[str]]] = None,
- num_images_per_prompt: Optional[int] = 1,
+ negative_prompt: str | list[str] | None = None,
+ num_images_per_prompt: int | None = 1,
eta: float = 0.0,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- latents: Optional[np.ndarray] = None,
- prompt_embeds: Optional[np.ndarray] = None,
- negative_prompt_embeds: Optional[np.ndarray] = None,
- output_type: Optional[str] = "pil",
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: np.ndarray | None = None,
+ prompt_embeds: np.ndarray | None = None,
+ negative_prompt_embeds: np.ndarray | None = None,
+ output_type: str | None = "pil",
return_dict: bool = True,
- callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
- callback_steps: Optional[int] = 1,
+ callback: Callable[[int, int, np.ndarray], None] | None = None,
+ callback_steps: int | None = 1,
):
# 1. Check inputs
self.check_inputs(
diff --git a/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_img2img_pipeline.py b/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_img2img_pipeline.py
index 2627ba074..7a30a9a99 100644
--- a/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_img2img_pipeline.py
+++ b/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_img2img_pipeline.py
@@ -1,4 +1,4 @@
-from typing import Optional, Dict, Any
+from typing import Any
import numpy as np
import torch
import onnxruntime as ort
@@ -17,16 +17,16 @@ class OnnxStableDiffusionXLImg2ImgPipeline(CallablePipelineBase, optimum.onnxrun
vae_decoder: ort.InferenceSession,
text_encoder: ort.InferenceSession,
unet: ort.InferenceSession,
- config: Dict[str, Any],
+ config: dict[str, Any],
tokenizer: Any,
scheduler: Any,
feature_extractor = None,
- vae_encoder: Optional[ort.InferenceSession] = None,
- text_encoder_2: Optional[ort.InferenceSession] = None,
+ vae_encoder: ort.InferenceSession | None = None,
+ text_encoder_2: ort.InferenceSession | None = None,
tokenizer_2: Any = None,
- use_io_binding: Optional[bool] = None,
+ use_io_binding: bool | None = None,
model_save_dir = None,
- add_watermarker: Optional[bool] = None
+ add_watermarker: bool | None = None
):
optimum.onnxruntime.ORTStableDiffusionXLImg2ImgPipeline.__init__(self, vae_decoder, text_encoder, unet, config, tokenizer, scheduler, feature_extractor, vae_encoder, text_encoder_2, tokenizer_2, use_io_binding, model_save_dir, add_watermarker)
super().__init__()
diff --git a/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_pipeline.py b/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_pipeline.py
index 452e4f892..bbc541965 100644
--- a/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_pipeline.py
+++ b/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_pipeline.py
@@ -1,4 +1,4 @@
-from typing import Optional, Dict, Any
+from typing import Any
import onnxruntime as ort
import optimum.onnxruntime
from modules.onnx_impl.pipelines import CallablePipelineBase
@@ -14,16 +14,16 @@ class OnnxStableDiffusionXLPipeline(CallablePipelineBase, optimum.onnxruntime.OR
vae_decoder: ort.InferenceSession,
text_encoder: ort.InferenceSession,
unet: ort.InferenceSession,
- config: Dict[str, Any],
+ config: dict[str, Any],
tokenizer: Any,
scheduler: Any,
feature_extractor: Any = None,
- vae_encoder: Optional[ort.InferenceSession] = None,
- text_encoder_2: Optional[ort.InferenceSession] = None,
+ vae_encoder: ort.InferenceSession | None = None,
+ text_encoder_2: ort.InferenceSession | None = None,
tokenizer_2: Any = None,
- use_io_binding: Optional[bool] = None,
+ use_io_binding: bool | None = None,
model_save_dir = None,
- add_watermarker: Optional[bool] = None
+ add_watermarker: bool | None = None
):
optimum.onnxruntime.ORTStableDiffusionXLPipeline.__init__(self, vae_decoder, text_encoder, unet, config, tokenizer, scheduler, feature_extractor, vae_encoder, text_encoder_2, tokenizer_2, use_io_binding, model_save_dir, add_watermarker)
super().__init__()
diff --git a/modules/onnx_impl/pipelines/utils.py b/modules/onnx_impl/pipelines/utils.py
index c389cac01..f6b980302 100644
--- a/modules/onnx_impl/pipelines/utils.py
+++ b/modules/onnx_impl/pipelines/utils.py
@@ -1,9 +1,8 @@
-from typing import Union, List
import numpy as np
import torch
-def extract_generator_seed(generator: Union[torch.Generator, List[torch.Generator]]) -> List[int]:
+def extract_generator_seed(generator: torch.Generator | list[torch.Generator]) -> list[int]:
if isinstance(generator, list):
generator = [g.seed() for g in generator]
else:
@@ -11,7 +10,7 @@ def extract_generator_seed(generator: Union[torch.Generator, List[torch.Generato
return generator
-def randn_tensor(shape, dtype: np.dtype, generator: Union[torch.Generator, List[torch.Generator], int, List[int]]):
+def randn_tensor(shape, dtype: np.dtype, generator: torch.Generator | list[torch.Generator] | int | list[int]):
if hasattr(generator, "seed") or (isinstance(generator, list) and hasattr(generator[0], "seed")):
generator = extract_generator_seed(generator)
if len(generator) == 1:
@@ -25,8 +24,8 @@ def prepare_latents(
height: int,
width: int,
dtype: np.dtype,
- generator: Union[torch.Generator, List[torch.Generator]],
- latents: Union[np.ndarray, None] = None,
+ generator: torch.Generator | list[torch.Generator],
+ latents: np.ndarray | None = None,
num_channels_latents = 4,
vae_scale_factor = 8,
):
diff --git a/modules/onnx_impl/ui.py b/modules/onnx_impl/ui.py
index 703392d82..0a8ca2d22 100644
--- a/modules/onnx_impl/ui.py
+++ b/modules/onnx_impl/ui.py
@@ -1,11 +1,10 @@
import os
import json
import shutil
-from typing import Dict, List, Union
import gradio as gr
-def get_recursively(d: Union[Dict, List], *args):
+def get_recursively(d: dict | list, *args):
if len(args) == 0:
return d
return get_recursively(d.get(args[0]), *args[1:])
@@ -112,19 +111,19 @@ def create_ui():
with gr.TabItem("Stable Diffusion", id="sd"):
sd_config_path = os.path.join(sd_configs_path, "olive", "sd")
sd_submodels = os.listdir(sd_config_path)
- sd_configs: Dict[str, Dict[str, Dict[str, Dict]]] = {}
- sd_pass_config_components: Dict[str, Dict[str, Dict]] = {}
+ sd_configs: dict[str, dict[str, dict[str, dict]]] = {}
+ sd_pass_config_components: dict[str, dict[str, dict]] = {}
with gr.Tabs(elem_id="tabs_sd_submodel"):
def sd_create_change_listener(*args):
- def listener(v: Dict):
+ def listener(v: dict):
get_recursively(sd_configs, *args[:-1])[args[-1]] = v
return listener
for submodel in sd_submodels:
- config: Dict = None
+ config: dict = None
sd_pass_config_components[submodel] = {}
- with open(os.path.join(sd_config_path, submodel), "r", encoding="utf-8") as file:
+ with open(os.path.join(sd_config_path, submodel), encoding="utf-8") as file:
config = json.load(file)
sd_configs[submodel] = config
@@ -175,19 +174,19 @@ def create_ui():
with gr.TabItem("Stable Diffusion XL", id="sdxl"):
sdxl_config_path = os.path.join(sd_configs_path, "olive", "sdxl")
sdxl_submodels = os.listdir(sdxl_config_path)
- sdxl_configs: Dict[str, Dict[str, Dict[str, Dict]]] = {}
- sdxl_pass_config_components: Dict[str, Dict[str, Dict]] = {}
+ sdxl_configs: dict[str, dict[str, dict[str, dict]]] = {}
+ sdxl_pass_config_components: dict[str, dict[str, dict]] = {}
with gr.Tabs(elem_id="tabs_sdxl_submodel"):
def sdxl_create_change_listener(*args):
- def listener(v: Dict):
+ def listener(v: dict):
get_recursively(sdxl_configs, *args[:-1])[args[-1]] = v
return listener
for submodel in sdxl_submodels:
- config: Dict = None
+ config: dict = None
sdxl_pass_config_components[submodel] = {}
- with open(os.path.join(sdxl_config_path, submodel), "r", encoding="utf-8") as file:
+ with open(os.path.join(sdxl_config_path, submodel), encoding="utf-8") as file:
config = json.load(file)
sdxl_configs[submodel] = config
diff --git a/modules/onnx_impl/utils.py b/modules/onnx_impl/utils.py
index 80b75cb4a..5d3f7c06e 100644
--- a/modules/onnx_impl/utils.py
+++ b/modules/onnx_impl/utils.py
@@ -1,12 +1,12 @@
import os
import json
import importlib
-from typing import Type, Tuple, Union, List, Dict, Any
+from typing import Any
import torch
import diffusers
-def extract_device(args: List, kwargs: Dict):
+def extract_device(args: list, kwargs: dict):
device = kwargs.get("device", None)
if device is None:
@@ -42,7 +42,7 @@ def check_diffusers_cache(path: os.PathLike):
return opts.diffusers_dir in os.path.abspath(path)
-def check_pipeline_sdxl(cls: Type[diffusers.DiffusionPipeline]) -> bool:
+def check_pipeline_sdxl(cls: type[diffusers.DiffusionPipeline]) -> bool:
return 'XL' in cls.__name__
@@ -57,7 +57,7 @@ def check_cache_onnx(path: os.PathLike) -> bool:
init_dict = None
- with open(init_dict_path, "r", encoding="utf-8") as file:
+ with open(init_dict_path, encoding="utf-8") as file:
init_dict = file.read()
if "OnnxRuntimeModel" not in init_dict:
@@ -66,15 +66,15 @@ def check_cache_onnx(path: os.PathLike) -> bool:
return True
-def load_init_dict(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike):
- merged: Dict[str, Any] = {}
+def load_init_dict(cls: type[diffusers.DiffusionPipeline], path: os.PathLike):
+ merged: dict[str, Any] = {}
extracted = cls.extract_init_dict(diffusers.DiffusionPipeline.load_config(path))
for item in extracted:
merged.update(item)
merged = merged.items()
- R: Dict[str, Tuple[str]] = {}
+ R: dict[str, tuple[str]] = {}
for k, v in merged:
if isinstance(v, list):
@@ -85,7 +85,7 @@ def load_init_dict(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike):
return R
-def load_submodel(path: os.PathLike, is_sdxl: bool, submodel_name: str, item: List[Union[str, None]], **kwargs_ort):
+def load_submodel(path: os.PathLike, is_sdxl: bool, submodel_name: str, item: list[str | None], **kwargs_ort):
lib, atr = item
if lib is None or atr is None:
@@ -107,7 +107,7 @@ def load_submodel(path: os.PathLike, is_sdxl: bool, submodel_name: str, item: Li
return attribute.from_pretrained(path)
-def load_submodels(path: os.PathLike, is_sdxl: bool, init_dict: Dict[str, Type], **kwargs_ort):
+def load_submodels(path: os.PathLike, is_sdxl: bool, init_dict: dict[str, type], **kwargs_ort):
loaded = {}
for k, v in init_dict.items():
@@ -122,14 +122,14 @@ def load_submodels(path: os.PathLike, is_sdxl: bool, init_dict: Dict[str, Type],
return loaded
-def load_pipeline(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike, **kwargs_ort) -> diffusers.DiffusionPipeline:
+def load_pipeline(cls: type[diffusers.DiffusionPipeline], path: os.PathLike, **kwargs_ort) -> diffusers.DiffusionPipeline:
if os.path.isdir(path):
return cls(**patch_kwargs(cls, load_submodels(path, check_pipeline_sdxl(cls), load_init_dict(cls, path), **kwargs_ort)))
else:
return cls.from_single_file(path)
-def patch_kwargs(cls: Type[diffusers.DiffusionPipeline], kwargs: Dict) -> Dict:
+def patch_kwargs(cls: type[diffusers.DiffusionPipeline], kwargs: dict) -> dict:
if cls == diffusers.OnnxStableDiffusionPipeline or cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline:
kwargs["safety_checker"] = None
kwargs["requires_safety_checker"] = False
@@ -140,7 +140,7 @@ def patch_kwargs(cls: Type[diffusers.DiffusionPipeline], kwargs: Dict) -> Dict:
return kwargs
-def get_base_constructor(cls: Type[diffusers.DiffusionPipeline], is_refiner: bool):
+def get_base_constructor(cls: type[diffusers.DiffusionPipeline], is_refiner: bool):
if cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline:
return diffusers.OnnxStableDiffusionPipeline
@@ -153,8 +153,8 @@ def get_base_constructor(cls: Type[diffusers.DiffusionPipeline], is_refiner: boo
def get_io_config(submodel: str, is_sdxl: bool):
from modules.paths import sd_configs_path
- with open(os.path.join(sd_configs_path, "olive", 'sdxl' if is_sdxl else 'sd', f"{submodel}.json"), "r", encoding="utf-8") as config_file:
- io_config: Dict[str, Any] = json.load(config_file)["input_model"]["config"]["io_config"]
+ with open(os.path.join(sd_configs_path, "olive", 'sdxl' if is_sdxl else 'sd', f"{submodel}.json"), encoding="utf-8") as config_file:
+ io_config: dict[str, Any] = json.load(config_file)["input_model"]["config"]["io_config"]
for axe in io_config["dynamic_axes"]:
io_config["dynamic_axes"][axe] = { int(k): v for k, v in io_config["dynamic_axes"][axe].items() }
diff --git a/modules/options_handler.py b/modules/options_handler.py
index b087c0529..61ee1a29f 100644
--- a/modules/options_handler.py
+++ b/modules/options_handler.py
@@ -18,13 +18,15 @@ cmd_opts = cmd_args.parse_args()
compatibility_opts = ['clip_skip', 'uni_pc_lower_order_final', 'uni_pc_order']
-class Options():
+class Options:
data_labels: dict[str, OptionInfo | LegacyOption]
data: dict[str, Any]
typemap = {int: float}
debug = os.environ.get('SD_CONFIG_DEBUG', None) is not None
- def __init__(self, options_templates: dict[str, OptionInfo | LegacyOption] = {}, restricted_opts: set[str] | None = None, *, filename = ''):
+ def __init__(self, options_templates: dict[str, OptionInfo | LegacyOption] = None, restricted_opts: set[str] | None = None, *, filename = ''):
+ if options_templates is None:
+ options_templates = {}
if restricted_opts is None:
restricted_opts = set()
super().__setattr__('data_labels', options_templates)
@@ -48,21 +50,21 @@ class Options():
log.warning(f'Settings set: {key}={value} legacy')
self.data[key] = value
return
- return super(Options, self).__setattr__(key, value) # pylint: disable=super-with-arguments
+ return super().__setattr__(key, value) # pylint: disable=super-with-arguments
def get(self, item):
if item in self.data:
return self.data[item]
if item in self.data_labels:
return self.data_labels[item].default
- return super(Options, self).__getattribute__(item) # pylint: disable=super-with-arguments
+ return super().__getattribute__(item) # pylint: disable=super-with-arguments
def __getattr__(self, item):
if item in self.data:
return self.data[item]
if item in self.data_labels:
return self.data_labels[item].default
- return super(Options, self).__getattribute__(item) # pylint: disable=super-with-arguments
+ return super().__getattribute__(item) # pylint: disable=super-with-arguments
def set(self, key, value):
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
diff --git a/modules/patches.py b/modules/patches.py
index f24a38293..655cf4534 100644
--- a/modules/patches.py
+++ b/modules/patches.py
@@ -1,5 +1,4 @@
from collections import defaultdict
-from typing import Optional
from modules.errors import log
@@ -55,13 +54,13 @@ def original(key, obj, field):
return originals[key].get(patch_key, None)
-def patch_method(cls, key:Optional[str]=None):
+def patch_method(cls, key:str | None=None):
def decorator(func):
patch(func.__module__ if key is None else key, cls, func.__name__, func)
return decorator
-def add_method(cls, key:Optional[str]=None):
+def add_method(cls, key:str | None=None):
def decorator(func):
patch(func.__module__ if key is None else key, cls, func.__name__, func, True)
return decorator
diff --git a/modules/paths.py b/modules/paths.py
index 63ed005c7..cd7b6e5d4 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -2,7 +2,6 @@
import os
import sys
import json
-import shlex
import argparse
import tempfile
from installer import log
@@ -19,7 +18,7 @@ cli = parser.parse_known_args(argv)[0]
config_path = cli.config if os.path.isabs(cli.config) else os.path.join(cli.data_dir, cli.config)
try:
- with open(config_path, 'r', encoding='utf8') as f:
+ with open(config_path, encoding='utf8') as f:
config = json.load(f)
except Exception:
config = {}
diff --git a/modules/paths_internal.py b/modules/paths_internal.py
index a9dabdd0f..f304361aa 100644
--- a/modules/paths_internal.py
+++ b/modules/paths_internal.py
@@ -1,3 +1,2 @@
# no longer used, all paths are defined in paths.py
-from modules.paths import modules_path, script_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, data_path, models_path, extensions_dir, extensions_builtin_dir # pylint: disable=unused-import
diff --git a/modules/postprocess/aurasr_model.py b/modules/postprocess/aurasr_model.py
index 9d77af93f..c2030d184 100644
--- a/modules/postprocess/aurasr_model.py
+++ b/modules/postprocess/aurasr_model.py
@@ -1,5 +1,4 @@
import torch
-import diffusers
from PIL import Image
from modules import shared, devices
from modules.upscaler import Upscaler, UpscalerData
diff --git a/modules/postprocess/esrgan_model_arch.py b/modules/postprocess/esrgan_model_arch.py
index bf9f0ac6e..e50b441b7 100644
--- a/modules/postprocess/esrgan_model_arch.py
+++ b/modules/postprocess/esrgan_model_arch.py
@@ -14,7 +14,7 @@ class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
finalact=None, gaussian_noise=False, plus=False):
- super(RRDBNet, self).__init__()
+ super().__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
@@ -69,7 +69,7 @@ class RRDB(nn.Module):
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
spectral_norm=False, gaussian_noise=False, plus=False):
- super(RRDB, self).__init__()
+ super().__init__()
# This is for backwards compatibility with existing models
if nr == 3:
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
@@ -111,7 +111,7 @@ class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
spectral_norm=False, gaussian_noise=False, plus=False):
- super(ResidualDenseBlock_5C, self).__init__()
+ super().__init__()
self.noise = GaussianNoise() if gaussian_noise else None
self.conv1x1 = conv1x1(nf, gc) if plus else None
@@ -185,7 +185,7 @@ class SRVGGNetCompact(nn.Module):
"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
- super(SRVGGNetCompact, self).__init__()
+ super().__init__()
self.num_in_ch = num_in_ch
self.num_out_ch = num_out_ch
self.num_feat = num_feat
@@ -245,7 +245,7 @@ class Upsample(nn.Module):
"""
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
- super(Upsample, self).__init__()
+ super().__init__()
if isinstance(scale_factor, tuple):
self.scale_factor = tuple(float(factor) for factor in scale_factor)
else:
@@ -354,7 +354,7 @@ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
class Identity(nn.Module):
def __init__(self, *kwargs):
- super(Identity, self).__init__()
+ super().__init__()
def forward(self, x, *kwargs):
return x
@@ -399,7 +399,7 @@ def get_valid_padding(kernel_size, dilation):
class ShortcutBlock(nn.Module):
""" Elementwise sum the output of a submodule to its input """
def __init__(self, submodule):
- super(ShortcutBlock, self).__init__()
+ super().__init__()
self.sub = submodule
def forward(self, x):
diff --git a/modules/postprocess/pixelart.py b/modules/postprocess/pixelart.py
index 3295a0a21..00b9b9636 100644
--- a/modules/postprocess/pixelart.py
+++ b/modules/postprocess/pixelart.py
@@ -1,4 +1,3 @@
-from typing import List
import math
import torch
@@ -225,8 +224,8 @@ class JPEGEncoder(ImageProcessingMixin, ConfigMixin):
block_size: int = 16,
cbcr_downscale: int = 2,
norm: str = "ortho",
- latents_std: List[float] = None,
- latents_mean: List[float] = None,
+ latents_std: list[float] = None,
+ latents_mean: list[float] = None,
):
self.block_size = block_size
self.cbcr_downscale = cbcr_downscale
diff --git a/modules/postprocess/realesrgan_model_arch.py b/modules/postprocess/realesrgan_model_arch.py
index bfdfffad6..dd350e4d9 100644
--- a/modules/postprocess/realesrgan_model_arch.py
+++ b/modules/postprocess/realesrgan_model_arch.py
@@ -14,7 +14,7 @@ from modules.upscaler import compile_upscaler
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-class RealESRGANer():
+class RealESRGANer:
"""A helper class for upsampling images with RealESRGAN.
Args:
@@ -340,7 +340,7 @@ class SRVGGNetCompact(nn.Module):
"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
- super(SRVGGNetCompact, self).__init__()
+ super().__init__()
self.num_in_ch = num_in_ch
self.num_out_ch = num_out_ch
self.num_feat = num_feat
diff --git a/modules/postprocess/scunet_model_arch.py b/modules/postprocess/scunet_model_arch.py
index b51a88062..2441e06e2 100644
--- a/modules/postprocess/scunet_model_arch.py
+++ b/modules/postprocess/scunet_model_arch.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
@@ -12,7 +11,7 @@ class WMSA(nn.Module):
"""
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
- super(WMSA, self).__init__()
+ super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.head_dim = head_dim
@@ -103,7 +102,7 @@ class Block(nn.Module):
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
""" SwinTransformer Block
"""
- super(Block, self).__init__()
+ super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
assert type in ['W', 'SW']
@@ -131,7 +130,7 @@ class ConvTransBlock(nn.Module):
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
""" SwinTransformer and Conv Block
"""
- super(ConvTransBlock, self).__init__()
+ super().__init__()
self.conv_dim = conv_dim
self.trans_dim = trans_dim
self.head_dim = head_dim
@@ -170,7 +169,7 @@ class ConvTransBlock(nn.Module):
class SCUNet(nn.Module):
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
- super(SCUNet, self).__init__()
+ super().__init__()
if config is None:
config = [2, 2, 2, 2, 2, 2, 2]
self.config = config
diff --git a/modules/postprocess/swinir_model.py b/modules/postprocess/swinir_model.py
index 86cc2e77f..60a0d267f 100644
--- a/modules/postprocess/swinir_model.py
+++ b/modules/postprocess/swinir_model.py
@@ -4,7 +4,7 @@ from PIL import Image
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn
from modules.postprocess.swinir_model_arch import SwinIR as net
from modules.postprocess.swinir_model_arch_v2 import Swin2SR as net2
-from modules import devices, script_callbacks, shared
+from modules import devices, shared
from modules.upscaler import Upscaler, compile_upscaler
diff --git a/modules/postprocess/swinir_model_arch.py b/modules/postprocess/swinir_model_arch.py
index d5ae4dd32..4b306433d 100644
--- a/modules/postprocess/swinir_model_arch.py
+++ b/modules/postprocess/swinir_model_arch.py
@@ -232,7 +232,7 @@ class SwinTransformerBlock(nn.Module):
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, (-100.0)).masked_fill(attn_mask == 0, 0.0)
return attn_mask
@@ -442,7 +442,7 @@ class RSTB(nn.Module):
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
img_size=224, patch_size=4, resi_connection='1conv'):
- super(RSTB, self).__init__()
+ super().__init__()
self.dim = dim
self.input_resolution = input_resolution
@@ -587,7 +587,7 @@ class Upsample(nn.Sequential):
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
+ super().__init__(*m)
class UpsampleOneStep(nn.Sequential):
@@ -606,7 +606,7 @@ class UpsampleOneStep(nn.Sequential):
m = []
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
+ super().__init__(*m)
def flops(self):
H, W = self.input_resolution
@@ -649,7 +649,7 @@ class SwinIR(nn.Module):
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
- super(SwinIR, self).__init__()
+ super().__init__()
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
diff --git a/modules/postprocess/swinir_model_arch_v2.py b/modules/postprocess/swinir_model_arch_v2.py
index ca69e2969..d61e92668 100644
--- a/modules/postprocess/swinir_model_arch_v2.py
+++ b/modules/postprocess/swinir_model_arch_v2.py
@@ -260,7 +260,7 @@ class SwinTransformerBlock(nn.Module):
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, (-100.0)).masked_fill(attn_mask == 0, 0.0)
return attn_mask
@@ -518,7 +518,7 @@ class RSTB(nn.Module):
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
img_size=224, patch_size=4, resi_connection='1conv'):
- super(RSTB, self).__init__()
+ super().__init__()
self.dim = dim
self.input_resolution = input_resolution
@@ -619,7 +619,7 @@ class Upsample(nn.Sequential):
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
+ super().__init__(*m)
class Upsample_hf(nn.Sequential):
"""Upsample module.
@@ -640,7 +640,7 @@ class Upsample_hf(nn.Sequential):
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
- super(Upsample_hf, self).__init__(*m)
+ super().__init__(*m)
class UpsampleOneStep(nn.Sequential):
@@ -659,7 +659,7 @@ class UpsampleOneStep(nn.Sequential):
m = []
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
+ super().__init__(*m)
def flops(self):
H, W = self.input_resolution
@@ -702,7 +702,7 @@ class Swin2SR(nn.Module):
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
- super(Swin2SR, self).__init__()
+ super().__init__()
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
diff --git a/modules/postprocess/yolo.py b/modules/postprocess/yolo.py
index 61f5f5661..473f8ec3f 100644
--- a/modules/postprocess/yolo.py
+++ b/modules/postprocess/yolo.py
@@ -26,7 +26,9 @@ load_lock = threading.Lock()
class YoloResult:
- def __init__(self, cls: int, label: str, score: float, box: list[int], mask: Image.Image = None, item: Image.Image = None, width = 0, height = 0, args = {}):
+ def __init__(self, cls: int, label: str, score: float, box: list[int], mask: Image.Image = None, item: Image.Image = None, width = 0, height = 0, args = None):
+ if args is None:
+ args = {}
self.cls = cls
self.label = label
self.score = score
@@ -138,7 +140,7 @@ class YoloRestorer(Detailer):
masks = prediction.masks.data.cpu().float().numpy() if prediction.masks is not None else []
if len(masks) < len(classes):
masks = len(classes) * [None]
- for score, box, cls, seg in zip(scores, boxes, classes, masks):
+ for score, box, cls, seg in zip(scores, boxes, classes, masks, strict=False):
if seg is not None:
try:
seg = (255 * seg).astype(np.uint8)
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index b624e3a7e..318750493 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -1,6 +1,5 @@
import os
import tempfile
-from typing import List
from PIL import Image
@@ -9,7 +8,7 @@ from modules.shared import opts
from modules.paths import resolve_output_path
-def run_postprocessing(extras_mode, image, image_folder: List[tempfile.NamedTemporaryFile], input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
+def run_postprocessing(extras_mode, image, image_folder: list[tempfile.NamedTemporaryFile], input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc()
shared.state.begin('Extras')
image_data = []
@@ -61,7 +60,7 @@ def run_postprocessing(extras_mode, image, image_folder: List[tempfile.NamedTemp
else:
outpath = resolve_output_path(opts.outdir_samples, opts.outdir_extras_samples)
processed_images = []
- for image, name, ext in zip(image_data, image_names, image_ext): # pylint: disable=redefined-argument-from-local
+ for image, name, ext in zip(image_data, image_names, image_ext, strict=False): # pylint: disable=redefined-argument-from-local
shared.log.debug(f'Process: image={image} {args}')
info = ''
if shared.state.interrupted:
diff --git a/modules/processing.py b/modules/processing.py
index 13f2c275c..8c47ce83a 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -5,7 +5,13 @@ import numpy as np
from PIL import Image, ImageOps
from modules import shared, devices, errors, images, scripts_manager, memstats, script_callbacks, extra_networks, detailer, sd_models, sd_checkpoint, sd_vae, processing_helpers, timer
from modules.sd_hijack_hypertile import context_hypertile_vae, context_hypertile_unet
-from modules.processing_class import StableDiffusionProcessing, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, StableDiffusionProcessingControl, StableDiffusionProcessingVideo # pylint: disable=unused-import
+from modules.processing_class import ( # pylint: disable=unused-import
+ StableDiffusionProcessing,
+ StableDiffusionProcessingTxt2Img,
+ StableDiffusionProcessingImg2Img,
+ StableDiffusionProcessingVideo,
+ StableDiffusionProcessingControl,
+)
from modules.processing_info import create_infotext
from modules.modeldata import model_data
@@ -433,7 +439,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
results = p.scripts.process_images(p)
if results is not None:
samples = results.images
- for script_image, script_infotext in zip(results.images, results.infotexts):
+ for script_image, script_infotext in zip(results.images, results.infotexts, strict=False):
output_images.append(script_image)
infotexts.append(script_infotext)
@@ -467,7 +473,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
output_binary = samples.bytes
else:
batch_images, batch_infotexts = process_samples(p, samples)
- for batch_image, batch_infotext in zip(batch_images, batch_infotexts):
+ for batch_image, batch_infotext in zip(batch_images, batch_infotexts, strict=False):
if batch_image is not None and batch_image not in output_images:
output_images.append(batch_image)
infotexts.append(batch_infotext)
diff --git a/modules/processing_args.py b/modules/processing_args.py
index 7a827ebea..97f4d69c9 100644
--- a/modules/processing_args.py
+++ b/modules/processing_args.py
@@ -1,4 +1,3 @@
-import typing
import os
import re
import math
@@ -9,7 +8,7 @@ import numpy as np
from PIL import Image
from modules import shared, sd_models, processing, processing_vae, processing_helpers, sd_hijack_hypertile, extra_networks, sd_vae
from modules.processing_callbacks import diffusers_callback_legacy, diffusers_callback, set_callbacks_p
-from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, get_generator, set_latents, apply_circular # pylint: disable=unused-import
+from modules.processing_helpers import get_generator, apply_circular # pylint: disable=unused-import
from modules.processing_prompt import set_prompt
from modules.api import helpers
@@ -185,7 +184,7 @@ def get_params(model):
return possible
-def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:typing.Optional[list]=None, negative_prompts_2:typing.Optional[list]=None, prompt_attention:typing.Optional[str]=None, desc:typing.Optional[str]='', **kwargs):
+def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:list | None=None, negative_prompts_2:list | None=None, prompt_attention:str | None=None, desc:str | None='', **kwargs):
t0 = time.time()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
argsid = shared.state.begin('Params')
diff --git a/modules/processing_callbacks.py b/modules/processing_callbacks.py
index eed7f985f..ea5720eef 100644
--- a/modules/processing_callbacks.py
+++ b/modules/processing_callbacks.py
@@ -1,4 +1,3 @@
-import typing
import os
import time
import torch
@@ -33,7 +32,7 @@ def prompt_callback(step, kwargs):
return kwargs
-def diffusers_callback_legacy(step: int, timestep: int, latents: typing.Union[torch.FloatTensor, np.ndarray]):
+def diffusers_callback_legacy(step: int, timestep: int, latents: torch.FloatTensor | np.ndarray):
if p is None:
return
if isinstance(latents, np.ndarray): # latents from Onnx pipelines is ndarray.
@@ -51,7 +50,9 @@ def diffusers_callback_legacy(step: int, timestep: int, latents: typing.Union[to
time.sleep(0.1)
-def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {}):
+def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = None):
+ if kwargs is None:
+ kwargs = {}
t0 = time.time()
if devices.backend == "ipex":
torch.xpu.synchronize(devices.device)
diff --git a/modules/processing_class.py b/modules/processing_class.py
index 26d942efa..93f2bf135 100644
--- a/modules/processing_class.py
+++ b/modules/processing_class.py
@@ -2,7 +2,7 @@ import os
import sys
import inspect
import hashlib
-from typing import Any, Dict, List
+from typing import Any
from dataclasses import dataclass, field
import numpy as np
from PIL import Image, ImageOps
@@ -51,7 +51,7 @@ class StableDiffusionProcessing:
pag_scale: float = 0.0,
pag_adaptive: float = 0.5,
# styles
- styles: List[str] = [],
+ styles: list[str] = None,
# vae
tiling: bool = False,
vae_type: str = 'Full',
@@ -79,8 +79,8 @@ class StableDiffusionProcessing:
hdr_color_picker: str = None,
hdr_tint_ratio: float = 0,
# img2img
- init_images: list = [],
- init_control: list = [],
+ init_images: list = None,
+ init_control: list = None,
denoising_strength: float = 0.3,
image_cfg_scale: float = None,
initial_noise_multiplier: float = None, # pylint: disable=unused-argument # a1111 compatibility
@@ -150,9 +150,9 @@ class StableDiffusionProcessing:
# xyz flag
xyz: bool = False,
# scripts
- script_args: list = [],
+ script_args: list = None,
# overrides
- override_settings: Dict[str, Any] = {},
+ override_settings: dict[str, Any] = None,
override_settings_restore_afterwards: bool = True,
# metadata
# extra_generation_params: Dict[Any, Any] = {},
@@ -161,6 +161,16 @@ class StableDiffusionProcessing:
**kwargs,
):
+ if override_settings is None:
+ override_settings = {}
+ if script_args is None:
+ script_args = []
+ if init_control is None:
+ init_control = []
+ if init_images is None:
+ init_images = []
+ if styles is None:
+ styles = []
for k, v in kwargs.items():
setattr(self, k, v)
diff --git a/modules/processing_vae.py b/modules/processing_vae.py
index 53f8c3751..c2cbffb57 100644
--- a/modules/processing_vae.py
+++ b/modules/processing_vae.py
@@ -365,7 +365,7 @@ def reprocess(gallery):
shared.log.info(f'Reprocessing: latent={latent.shape}')
reprocessed = vae_decode(latent, shared.sd_model, output_type='pil')
outputs = []
- for i0, i1 in zip(gallery, reprocessed):
+ for i0, i1 in zip(gallery, reprocessed, strict=False):
if isinstance(i1, np.ndarray):
i1 = Image.fromarray(i1)
fn = i0['name']
diff --git a/modules/progress.py b/modules/progress.py
index f0fe52877..bbdbee553 100644
--- a/modules/progress.py
+++ b/modules/progress.py
@@ -2,7 +2,6 @@ import base64
import os
import io
import time
-from typing import Union
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
import modules.shared as shared
@@ -48,7 +47,7 @@ class ProgressRequest(BaseModel):
class InternalProgressResponse(BaseModel):
job: str = Field(default=None, title="Job name", description="Internal job name")
- textinfo: Union[str|None] = Field(default=None, title="Info text", description="Info text used by WebUI.")
+ textinfo: str|None = Field(default=None, title="Info text", description="Info text used by WebUI.")
# status fields
active: bool = Field(title="Whether the task is being worked on right now")
queued: bool = Field(title="Whether the task is in queue")
@@ -62,10 +61,10 @@ class InternalProgressResponse(BaseModel):
batch_count: int = Field(default=None, title="Total batches", description="Total number of batches")
# calculated fields
progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
- eta: Union[float|None] = Field(default=None, title="ETA in secs")
+ eta: float|None = Field(default=None, title="ETA in secs")
# image fields
- live_preview: Union[str|None] = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
- id_live_preview: Union[int|None] = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
+ live_preview: str|None = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
+ id_live_preview: int|None = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
def api_progress(req: ProgressRequest):
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 57587e3b5..06df1a9e7 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -10,7 +10,6 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import os
import re
from collections import namedtuple
-from typing import List
import lark
import torch
from compel import Compel
@@ -181,7 +180,7 @@ def get_learned_conditioning(model, prompts, steps):
res = []
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
cache = {}
- for prompt, prompt_schedule in zip(prompts, prompt_schedules):
+ for prompt, prompt_schedule in zip(prompts, prompt_schedules, strict=False):
debug(f'Prompt schedule: {prompt_schedule}')
cached = cache.get(prompt, None)
if cached is not None:
@@ -220,14 +219,14 @@ def get_multicond_prompt_list(prompts):
class ComposableScheduledPromptConditioning:
def __init__(self, schedules, weight=1.0):
- self.schedules: List[ScheduledPromptConditioning] = schedules
+ self.schedules: list[ScheduledPromptConditioning] = schedules
self.weight: float = weight
class MulticondLearnedConditioning:
def __init__(self, shape, batch):
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
- self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
+ self.batch: list[list[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
@@ -243,7 +242,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
-def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
+def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
for i, cond_schedule in enumerate(c):
diff --git a/modules/prompt_parser_diffusers.py b/modules/prompt_parser_diffusers.py
index 25349e583..17f4821f0 100644
--- a/modules/prompt_parser_diffusers.py
+++ b/modules/prompt_parser_diffusers.py
@@ -1,7 +1,6 @@
import os
import math
import time
-import typing
from collections import OrderedDict
import torch
from compel.embeddings_provider import BaseTextualInversionManager, EmbeddingsProvider
@@ -85,7 +84,7 @@ class PromptEmbedder:
return
seen_prompts = {}
# per prompt in batch
- for batchidx, (prompt, negative_prompt) in enumerate(zip(self.prompts, self.negative_prompts)):
+ for batchidx, (prompt, negative_prompt) in enumerate(zip(self.prompts, self.negative_prompts, strict=False)):
self.prepare_schedule(prompt, negative_prompt)
schedule_key = (
tuple(self.positive_schedule) if self.positive_schedule is not None else None,
@@ -300,7 +299,7 @@ class PromptEmbedder:
return None
-def compel_hijack(self, token_ids: torch.Tensor, attention_mask: typing.Optional[torch.Tensor] = None) -> torch.Tensor:
+def compel_hijack(self, token_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
needs_hidden_states = self.returned_embeddings_type != 1
text_encoder_output = self.text_encoder(token_ids, attention_mask, output_hidden_states=needs_hidden_states, return_dict=True)
@@ -323,7 +322,7 @@ def compel_hijack(self, token_ids: torch.Tensor, attention_mask: typing.Optional
return hidden_state
-def sd3_compel_hijack(self, token_ids: torch.Tensor, attention_mask: typing.Optional[torch.Tensor] = None) -> torch.Tensor:
+def sd3_compel_hijack(self, token_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
needs_hidden_states = True
text_encoder_output = self.text_encoder(token_ids, attention_mask, output_hidden_states=needs_hidden_states, return_dict=True)
clip_skip = int(self.returned_embeddings_type)
@@ -353,10 +352,10 @@ class DiffusersTextualInversionManager(BaseTextualInversionManager):
# code from
# https://github.com/huggingface/diffusers/blob/705c592ea98ba4e288d837b9cba2767623c78603/src/diffusers/loaders.py
- def maybe_convert_prompt(self, prompt: typing.Union[str, typing.List[str]], tokenizer: PreTrainedTokenizer):
- prompts = [prompt] if not isinstance(prompt, typing.List) else prompt
+ def maybe_convert_prompt(self, prompt: str | list[str], tokenizer: PreTrainedTokenizer):
+ prompts = [prompt] if not isinstance(prompt, list) else prompt
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
- if not isinstance(prompt, typing.List):
+ if not isinstance(prompt, list):
return prompts[0]
return prompts
@@ -378,7 +377,7 @@ class DiffusersTextualInversionManager(BaseTextualInversionManager):
debug(f'Prompt: convert="{prompt}"')
return prompt
- def expand_textual_inversion_token_ids_if_necessary(self, token_ids: typing.List[int]) -> typing.List[int]:
+ def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
if len(token_ids) == 0:
return token_ids
prompt = self.pipe.tokenizer.decode(token_ids)
@@ -470,7 +469,7 @@ def get_prompts_with_weights(pipe, prompt: str):
texts_and_weights = prompt_parser.parse_prompt_attention(prompt)
if shared.opts.prompt_mean_norm:
texts_and_weights = normalize_prompt(texts_and_weights)
- texts, text_weights = zip(*texts_and_weights)
+ texts, text_weights = zip(*texts_and_weights, strict=False)
avg_weight = 0
min_weight = 1
max_weight = 0
@@ -478,7 +477,7 @@ def get_prompts_with_weights(pipe, prompt: str):
try:
all_tokens = 0
- for text, weight in zip(texts, text_weights):
+ for text, weight in zip(texts, text_weights, strict=False):
tokens = get_tokens(pipe, 'section', text)
all_tokens += tokens
avg_weight += tokens*weight
@@ -627,8 +626,8 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c
ps = 2 * [get_prompts_with_weights(pipe, prompt)]
ns = 2 * [get_prompts_with_weights(pipe, neg_prompt)]
- positives, positive_weights = zip(*ps)
- negatives, negative_weights = zip(*ns)
+ positives, positive_weights = zip(*ps, strict=False)
+ negatives, negative_weights = zip(*ns, strict=False)
if hasattr(pipe, "tokenizer_2") and not hasattr(pipe, "tokenizer"):
positives.pop(0)
positive_weights.pop(0)
diff --git a/modules/ras/ras_attention.py b/modules/ras/ras_attention.py
index 4989cc931..5be0db7de 100644
--- a/modules/ras/ras_attention.py
+++ b/modules/ras/ras_attention.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
import math
import torch
import torch.nn.functional as F
@@ -38,10 +37,10 @@ class RASLuminaAttnProcessor2_0:
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- query_rotary_emb: Optional[torch.Tensor] = None,
- key_rotary_emb: Optional[torch.Tensor] = None,
- base_sequence_length: Optional[int] = None,
+ attention_mask: torch.Tensor | None = None,
+ query_rotary_emb: torch.Tensor | None = None,
+ key_rotary_emb: torch.Tensor | None = None,
+ base_sequence_length: int | None = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
@@ -165,7 +164,7 @@ class RASJointAttnProcessor2_0:
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
+ attention_mask: torch.FloatTensor | None = None,
*args,
**kwargs,
) -> torch.FloatTensor:
diff --git a/modules/ras/ras_forward.py b/modules/ras/ras_forward.py
index 63c71428e..ef0f245ea 100644
--- a/modules/ras/ras_forward.py
+++ b/modules/ras/ras_forward.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Union
+from typing import Any
import torch
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers
@@ -25,11 +25,11 @@ def ras_forward(
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
- block_controlnet_hidden_states: List = None,
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ block_controlnet_hidden_states: list = None,
+ joint_attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
- skip_layers: Optional[List[int]] = None,
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ skip_layers: list[int] | None = None,
+ ) -> torch.FloatTensor | Transformer2DModelOutput:
"""
The [`SD3Transformer2DModel`] forward method.
diff --git a/modules/ras/ras_scheduler.py b/modules/ras/ras_scheduler.py
index a5143a067..f2131e4c6 100644
--- a/modules/ras/ras_scheduler.py
+++ b/modules/ras/ras_scheduler.py
@@ -15,7 +15,6 @@
# limitations under the License.
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
import torch
from diffusers.configuration_utils import register_to_config
from diffusers.utils import BaseOutput, logging
@@ -66,10 +65,10 @@ class RASFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
num_train_timesteps: int = 1000,
shift: float = 1.0,
use_dynamic_shifting=False,
- base_shift: Optional[float] = 0.5,
- max_shift: Optional[float] = 1.15,
- base_image_seq_len: Optional[int] = 256,
- max_image_seq_len: Optional[int] = 4096,
+ base_shift: float | None = 0.5,
+ max_shift: float | None = 1.15,
+ base_image_seq_len: int | None = 256,
+ max_image_seq_len: int | None = 4096,
invert_sigmas: bool = False,
):
super().__init__(num_train_timesteps=num_train_timesteps,
@@ -120,15 +119,15 @@ class RASFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def step(
self,
model_output: torch.FloatTensor,
- timestep: Union[float, torch.FloatTensor],
+ timestep: float | torch.FloatTensor,
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
- generator: Optional[torch.Generator] = None,
+ generator: torch.Generator | None = None,
return_dict: bool = True,
- ) -> Union[RASFlowMatchEulerDiscreteSchedulerOutput, Tuple]:
+ ) -> RASFlowMatchEulerDiscreteSchedulerOutput | tuple:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
diff --git a/modules/res4lyf/abnorsett_scheduler.py b/modules/res4lyf/abnorsett_scheduler.py
index e2ba0a686..810b0deab 100644
--- a/modules/res4lyf/abnorsett_scheduler.py
+++ b/modules/res4lyf/abnorsett_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import ClassVar, List, Literal, Optional, Tuple, Union
+from typing import ClassVar, Literal
import numpy as np
import torch
@@ -31,7 +31,7 @@ class ABNorsettScheduler(SchedulerMixin, ConfigMixin):
Adams-Bashforth Norsett (ABNorsett) scheduler.
"""
- _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
+ _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
@@ -41,7 +41,7 @@ class ABNorsettScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ trained_betas: np.ndarray | list[float] | None = None,
prediction_type: str = "epsilon",
variant: Literal["abnorsett_2m", "abnorsett_3m", "abnorsett_4m"] = "abnorsett_2m",
use_analytic_solution: bool = True,
@@ -87,23 +87,22 @@ class ABNorsettScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
- get_sigmas_flow,
get_sigmas_karras,
)
@@ -183,7 +182,7 @@ class ABNorsettScheduler(SchedulerMixin, ConfigMixin):
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -195,10 +194,10 @@ class ABNorsettScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/bong_tangent_scheduler.py b/modules/res4lyf/bong_tangent_scheduler.py
index a0c827218..d3b7eaa84 100644
--- a/modules/res4lyf/bong_tangent_scheduler.py
+++ b/modules/res4lyf/bong_tangent_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import ClassVar, List, Optional, Tuple, Union
+from typing import ClassVar
import numpy as np
import torch
@@ -29,7 +29,7 @@ class BongTangentScheduler(SchedulerMixin, ConfigMixin):
BongTangent scheduler using Exponential Integrator step.
"""
- _compatibles: ClassVar[List[str]] = []
+ _compatibles: ClassVar[list[str]] = []
order = 1
@register_to_config
@@ -86,17 +86,17 @@ class BongTangentScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -105,7 +105,7 @@ class BongTangentScheduler(SchedulerMixin, ConfigMixin):
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
@@ -210,7 +210,7 @@ class BongTangentScheduler(SchedulerMixin, ConfigMixin):
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
- def _get_bong_tangent_sigmas(self, steps: int, slope: float, pivot: int, start: float, end: float, dtype: torch.dtype = torch.float32) -> List[float]:
+ def _get_bong_tangent_sigmas(self, steps: int, slope: float, pivot: int, start: float, end: float, dtype: torch.dtype = torch.float32) -> list[float]:
x = torch.arange(steps, dtype=dtype)
def bong_fn(val):
@@ -228,10 +228,10 @@ class BongTangentScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/common_sigma_scheduler.py b/modules/res4lyf/common_sigma_scheduler.py
index 202d289af..bfe32a875 100644
--- a/modules/res4lyf/common_sigma_scheduler.py
+++ b/modules/res4lyf/common_sigma_scheduler.py
@@ -13,7 +13,7 @@
# limitations under the License.
import math
-from typing import ClassVar, List, Literal, Optional, Tuple, Union
+from typing import ClassVar, Literal
import numpy as np
import torch
@@ -30,7 +30,7 @@ class CommonSigmaScheduler(SchedulerMixin, ConfigMixin):
Common Sigma scheduler using Exponential Integrator step.
"""
- _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
+ _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers]
order: ClassVar[int] = 1
@register_to_config
@@ -88,17 +88,17 @@ class CommonSigmaScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = None
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
@@ -200,7 +200,7 @@ class CommonSigmaScheduler(SchedulerMixin, ConfigMixin):
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -212,10 +212,10 @@ class CommonSigmaScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/deis_scheduler_alt.py b/modules/res4lyf/deis_scheduler_alt.py
index 70c63cecf..bcb3a266e 100644
--- a/modules/res4lyf/deis_scheduler_alt.py
+++ b/modules/res4lyf/deis_scheduler_alt.py
@@ -1,4 +1,3 @@
-from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -33,16 +32,16 @@ class RESDEISMultistepScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ trained_betas: np.ndarray | list[float] | None = None,
prediction_type: str = "epsilon",
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
- sigma_min: Optional[float] = None,
- sigma_max: Optional[float] = None,
+ sigma_min: float | None = None,
+ sigma_max: float | None = None,
rho: float = 7.0,
- shift: Optional[float] = None,
+ shift: float | None = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
@@ -87,8 +86,8 @@ class RESDEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int,
- device: Union[str, torch.device] = None,
- mu: Optional[float] = None,
+ device: str | torch.device = None,
+ mu: float | None = None,
dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
@@ -225,7 +224,7 @@ class RESDEISMultistepScheduler(SchedulerMixin, ConfigMixin):
if self._step_index is None:
self._step_index = self.index_for_timestep(timestep)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -236,10 +235,10 @@ class RESDEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/etdrk_scheduler.py b/modules/res4lyf/etdrk_scheduler.py
index 07b624ff6..cc7f693fe 100644
--- a/modules/res4lyf/etdrk_scheduler.py
+++ b/modules/res4lyf/etdrk_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import ClassVar, List, Literal, Optional, Tuple, Union
+from typing import ClassVar, Literal
import numpy as np
import torch
@@ -31,7 +31,7 @@ class ETDRKScheduler(SchedulerMixin, ConfigMixin):
Exponential Time Differencing Runge-Kutta (ETDRK) scheduler.
"""
- _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
+ _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
@@ -41,7 +41,7 @@ class ETDRKScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ trained_betas: np.ndarray | list[float] | None = None,
prediction_type: str = "epsilon",
variant: Literal["etdrk2_2s", "etdrk3_a_3s", "etdrk3_b_3s", "etdrk4_4s", "etdrk4_4s_alt"] = "etdrk4_4s",
use_analytic_solution: bool = True,
@@ -87,17 +87,17 @@ class ETDRKScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
@@ -171,7 +171,7 @@ class ETDRKScheduler(SchedulerMixin, ConfigMixin):
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -183,10 +183,10 @@ class ETDRKScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/gauss_legendre_scheduler.py b/modules/res4lyf/gauss_legendre_scheduler.py
index 38db308b8..0cbb5ea03 100644
--- a/modules/res4lyf/gauss_legendre_scheduler.py
+++ b/modules/res4lyf/gauss_legendre_scheduler.py
@@ -1,4 +1,3 @@
-from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -22,17 +21,17 @@ class GaussLegendreScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ trained_betas: np.ndarray | list[float] | None = None,
prediction_type: str = "epsilon",
variant: str = "gauss-legendre_2s", # 2s to 8s variants
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
- sigma_min: Optional[float] = None,
- sigma_max: Optional[float] = None,
+ sigma_min: float | None = None,
+ sigma_max: float | None = None,
rho: float = 7.0,
- shift: Optional[float] = None,
+ shift: float | None = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
@@ -147,8 +146,8 @@ class GaussLegendreScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int,
- device: Union[str, torch.device] = None,
- mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ device: str | torch.device = None,
+ mu: float | None = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
@@ -248,7 +247,7 @@ class GaussLegendreScheduler(SchedulerMixin, ConfigMixin):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -259,10 +258,10 @@ class GaussLegendreScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/langevin_dynamics_scheduler.py b/modules/res4lyf/langevin_dynamics_scheduler.py
index 8e3c2eb48..af7213b52 100644
--- a/modules/res4lyf/langevin_dynamics_scheduler.py
+++ b/modules/res4lyf/langevin_dynamics_scheduler.py
@@ -13,7 +13,7 @@
# limitations under the License.
import math
-from typing import ClassVar, List, Optional, Tuple, Union
+from typing import ClassVar
import numpy as np
import torch
@@ -30,7 +30,7 @@ class LangevinDynamicsScheduler(SchedulerMixin, ConfigMixin):
Langevin Dynamics sigma scheduler using Exponential Integrator step.
"""
- _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
+ _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers]
order: ClassVar[int] = 1
@register_to_config
@@ -85,11 +85,11 @@ class LangevinDynamicsScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = None
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
@@ -98,9 +98,9 @@ class LangevinDynamicsScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int,
- device: Union[str, torch.device] = None,
- generator: Optional[torch.Generator] = None,
- mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ device: str | torch.device = None,
+ generator: torch.Generator | None = None,
+ mu: float | None = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
@@ -187,7 +187,7 @@ class LangevinDynamicsScheduler(SchedulerMixin, ConfigMixin):
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -199,10 +199,10 @@ class LangevinDynamicsScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/lawson_scheduler.py b/modules/res4lyf/lawson_scheduler.py
index 0af304eb2..3631024bf 100644
--- a/modules/res4lyf/lawson_scheduler.py
+++ b/modules/res4lyf/lawson_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import ClassVar, List, Literal, Optional, Tuple, Union
+from typing import ClassVar, Literal
import numpy as np
import torch
@@ -29,7 +29,7 @@ class LawsonScheduler(SchedulerMixin, ConfigMixin):
Lawson's integration method scheduler.
"""
- _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
+ _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
@@ -39,7 +39,7 @@ class LawsonScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ trained_betas: np.ndarray | list[float] | None = None,
prediction_type: str = "epsilon",
variant: Literal["lawson2a_2s", "lawson2b_2s", "lawson4_4s"] = "lawson4_4s",
use_analytic_solution: bool = True,
@@ -85,17 +85,17 @@ class LawsonScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
@@ -169,7 +169,7 @@ class LawsonScheduler(SchedulerMixin, ConfigMixin):
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -181,10 +181,10 @@ class LawsonScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/linear_rk_scheduler.py b/modules/res4lyf/linear_rk_scheduler.py
index 8e2a9aac1..955e4af8e 100644
--- a/modules/res4lyf/linear_rk_scheduler.py
+++ b/modules/res4lyf/linear_rk_scheduler.py
@@ -1,4 +1,3 @@
-from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -22,17 +21,17 @@ class LinearRKScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ trained_betas: np.ndarray | list[float] | None = None,
prediction_type: str = "epsilon",
variant: str = "rk4", # euler, heun, rk2, rk3, rk4, ralston, midpoint
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
- sigma_min: Optional[float] = None,
- sigma_max: Optional[float] = None,
+ sigma_min: float | None = None,
+ sigma_max: float | None = None,
rho: float = 7.0,
- shift: Optional[float] = None,
+ shift: float | None = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
@@ -103,8 +102,8 @@ class LinearRKScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int,
- device: Union[str, torch.device] = None,
- mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ device: str | torch.device = None,
+ mu: float | None = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
@@ -204,7 +203,7 @@ class LinearRKScheduler(SchedulerMixin, ConfigMixin):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -215,10 +214,10 @@ class LinearRKScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
a_mat, b_vec, c_vec = self._get_tableau()
diff --git a/modules/res4lyf/lobatto_scheduler.py b/modules/res4lyf/lobatto_scheduler.py
index 97d073e88..e1698b935 100644
--- a/modules/res4lyf/lobatto_scheduler.py
+++ b/modules/res4lyf/lobatto_scheduler.py
@@ -1,4 +1,3 @@
-from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -23,17 +22,17 @@ class LobattoScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ trained_betas: np.ndarray | list[float] | None = None,
prediction_type: str = "epsilon",
variant: str = "lobatto_iiia_3s", # Available: iiia, iiib, iiic
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
- sigma_min: Optional[float] = None,
- sigma_max: Optional[float] = None,
+ sigma_min: float | None = None,
+ sigma_max: float | None = None,
rho: float = 7.0,
- shift: Optional[float] = None,
+ shift: float | None = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
@@ -103,8 +102,8 @@ class LobattoScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int,
- device: Union[str, torch.device] = None,
- mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ device: str | torch.device = None,
+ mu: float | None = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
@@ -204,7 +203,7 @@ class LobattoScheduler(SchedulerMixin, ConfigMixin):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -215,10 +214,10 @@ class LobattoScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
a_mat, b_vec, c_vec = self._get_tableau()
diff --git a/modules/res4lyf/pec_scheduler.py b/modules/res4lyf/pec_scheduler.py
index f6df4f449..d5951b937 100644
--- a/modules/res4lyf/pec_scheduler.py
+++ b/modules/res4lyf/pec_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import ClassVar, List, Literal, Optional, Tuple, Union
+from typing import ClassVar, Literal
import numpy as np
import torch
@@ -31,7 +31,7 @@ class PECScheduler(SchedulerMixin, ConfigMixin):
Predictor-Corrector (PEC) scheduler.
"""
- _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
+ _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
@@ -41,7 +41,7 @@ class PECScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ trained_betas: np.ndarray | list[float] | None = None,
prediction_type: str = "epsilon",
variant: Literal["pec423_2h2s", "pec433_2h3s"] = "pec423_2h2s",
use_analytic_solution: bool = True,
@@ -87,11 +87,11 @@ class PECScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
@@ -100,8 +100,8 @@ class PECScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int,
- device: Union[str, torch.device] = None,
- mu: Optional[float] = None,
+ device: str | torch.device = None,
+ mu: float | None = None,
dtype: torch.dtype = torch.float32,
):
from .scheduler_utils import (
@@ -177,7 +177,7 @@ class PECScheduler(SchedulerMixin, ConfigMixin):
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -188,10 +188,10 @@ class PECScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/phi_functions.py b/modules/res4lyf/phi_functions.py
index 7941f7c2a..ddd859585 100644
--- a/modules/res4lyf/phi_functions.py
+++ b/modules/res4lyf/phi_functions.py
@@ -13,7 +13,6 @@
# limitations under the License.
import math
-from typing import Dict, List, Tuple, Union
import torch
from mpmath import exp as mp_exp
@@ -89,10 +88,10 @@ class Phi:
Supports both standard torch-based and high-precision mpmath-based solutions.
"""
- def __init__(self, h: torch.Tensor, c: List[Union[float, mpf]], analytic_solution: bool = True):
+ def __init__(self, h: torch.Tensor, c: list[float | mpf], analytic_solution: bool = True):
self.h = h
self.c = c
- self.cache: Dict[Tuple[int, int], Union[float, torch.Tensor]] = {}
+ self.cache: dict[tuple[int, int], float | torch.Tensor] = {}
self.analytic_solution = analytic_solution
if analytic_solution:
@@ -102,7 +101,7 @@ class Phi:
else:
self.phi_f = phi_standard_torch
- def __call__(self, j: int, i: int = -1) -> Union[float, torch.Tensor]:
+ def __call__(self, j: int, i: int = -1) -> float | torch.Tensor:
if (j, i) in self.cache:
return self.cache[(j, i)]
diff --git a/modules/res4lyf/radau_iia_scheduler.py b/modules/res4lyf/radau_iia_scheduler.py
index 2cd5d85e3..4d072205b 100644
--- a/modules/res4lyf/radau_iia_scheduler.py
+++ b/modules/res4lyf/radau_iia_scheduler.py
@@ -1,4 +1,3 @@
-from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -23,17 +22,17 @@ class RadauIIAScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ trained_betas: np.ndarray | list[float] | None = None,
prediction_type: str = "epsilon",
variant: str = "radau_iia_3s", # 2s to 11s variants
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
- sigma_min: Optional[float] = None,
- sigma_max: Optional[float] = None,
+ sigma_min: float | None = None,
+ sigma_max: float | None = None,
rho: float = 7.0,
- shift: Optional[float] = None,
+ shift: float | None = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
@@ -137,8 +136,8 @@ class RadauIIAScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int,
- device: Union[str, torch.device] = None,
- mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ device: str | torch.device = None,
+ mu: float | None = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
@@ -238,7 +237,7 @@ class RadauIIAScheduler(SchedulerMixin, ConfigMixin):
return np.abs(schedule_timesteps - timestep).argmin().item()
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -257,10 +256,10 @@ class RadauIIAScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
a_mat, b_vec, c_vec = self._get_tableau()
diff --git a/modules/res4lyf/res_multistep_scheduler.py b/modules/res4lyf/res_multistep_scheduler.py
index e324408ee..081e83307 100644
--- a/modules/res4lyf/res_multistep_scheduler.py
+++ b/modules/res4lyf/res_multistep_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import ClassVar, List, Literal, Optional, Tuple, Union
+from typing import ClassVar, Literal
import numpy as np
import torch
@@ -49,7 +49,7 @@ class RESMultistepScheduler(SchedulerMixin, ConfigMixin):
Whether to use high-precision analytic solutions for phi functions.
"""
- _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
+ _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
@@ -102,17 +102,17 @@ class RESMultistepScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -120,13 +120,12 @@ class RESMultistepScheduler(SchedulerMixin, ConfigMixin):
sigma = self.sigmas[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
- get_sigmas_flow,
get_sigmas_karras,
)
@@ -208,10 +207,10 @@ class RESMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/res_multistep_sde_scheduler.py b/modules/res4lyf/res_multistep_sde_scheduler.py
index 8ed98688b..adc40f832 100644
--- a/modules/res4lyf/res_multistep_sde_scheduler.py
+++ b/modules/res4lyf/res_multistep_sde_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import ClassVar, List, Literal, Optional, Tuple, Union
+from typing import ClassVar, Literal
import numpy as np
import torch
@@ -38,7 +38,7 @@ class RESMultistepSDEScheduler(SchedulerMixin, ConfigMixin):
The amount of noise to add during sampling (0.0 for ODE, 1.0 for full SDE).
"""
- _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
+ _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
@@ -92,17 +92,17 @@ class RESMultistepSDEScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -111,7 +111,7 @@ class RESMultistepSDEScheduler(SchedulerMixin, ConfigMixin):
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
@@ -188,11 +188,11 @@ class RESMultistepSDEScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
- generator: Optional[torch.Generator] = None,
+ generator: torch.Generator | None = None,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/res_singlestep_scheduler.py b/modules/res4lyf/res_singlestep_scheduler.py
index 29146029f..86d10fd24 100644
--- a/modules/res4lyf/res_singlestep_scheduler.py
+++ b/modules/res4lyf/res_singlestep_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import ClassVar, List, Literal, Optional, Tuple, Union
+from typing import ClassVar, Literal
import numpy as np
import torch
@@ -29,7 +29,7 @@ class RESSinglestepScheduler(SchedulerMixin, ConfigMixin):
RESSinglestepScheduler (Multistage Exponential Integrator) ported from RES4LYF.
"""
- _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
+ _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
@@ -78,17 +78,17 @@ class RESSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -96,7 +96,7 @@ class RESSinglestepScheduler(SchedulerMixin, ConfigMixin):
sigma = self.sigmas[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
@@ -183,10 +183,10 @@ class RESSinglestepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/res_singlestep_sde_scheduler.py b/modules/res4lyf/res_singlestep_sde_scheduler.py
index ef7fea5b9..a83b5b403 100644
--- a/modules/res4lyf/res_singlestep_sde_scheduler.py
+++ b/modules/res4lyf/res_singlestep_sde_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import ClassVar, List, Literal, Optional, Tuple, Union
+from typing import ClassVar, Literal
import numpy as np
import torch
@@ -30,7 +30,7 @@ class RESSinglestepSDEScheduler(SchedulerMixin, ConfigMixin):
RESSinglestepSDEScheduler (Stochastic Multistage Exponential Integrator) ported from RES4LYF.
"""
- _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
+ _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
@@ -80,17 +80,17 @@ class RESSinglestepSDEScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -99,7 +99,7 @@ class RESSinglestepSDEScheduler(SchedulerMixin, ConfigMixin):
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
@@ -173,11 +173,11 @@ class RESSinglestepSDEScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
- generator: Optional[torch.Generator] = None,
+ generator: torch.Generator | None = None,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/res_unified_scheduler.py b/modules/res4lyf/res_unified_scheduler.py
index 5aa619db6..061517f10 100644
--- a/modules/res4lyf/res_unified_scheduler.py
+++ b/modules/res4lyf/res_unified_scheduler.py
@@ -1,4 +1,4 @@
-from typing import ClassVar, List, Optional, Tuple, Union
+from typing import ClassVar
import numpy as np
import torch
@@ -15,7 +15,7 @@ class RESUnifiedScheduler(SchedulerMixin, ConfigMixin):
Supports DEIS 1S, 2M, 3M
"""
- _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
+ _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers]
order: ClassVar[int] = 1
@register_to_config
@@ -74,17 +74,17 @@ class RESUnifiedScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -92,13 +92,12 @@ class RESUnifiedScheduler(SchedulerMixin, ConfigMixin):
sigma = self.sigmas[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
- get_sigmas_flow,
get_sigmas_karras,
)
@@ -236,10 +235,10 @@ class RESUnifiedScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/riemannian_flow_scheduler.py b/modules/res4lyf/riemannian_flow_scheduler.py
index 926c31c46..2b2ada55d 100644
--- a/modules/res4lyf/riemannian_flow_scheduler.py
+++ b/modules/res4lyf/riemannian_flow_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import ClassVar, List, Literal, Optional, Tuple, Union
+from typing import ClassVar, Literal
import numpy as np
import torch
@@ -29,7 +29,7 @@ class RiemannianFlowScheduler(SchedulerMixin, ConfigMixin):
Riemannian Flow scheduler using Exponential Integrator step.
"""
- _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
+ _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers]
order: ClassVar[int] = 1
@register_to_config
@@ -84,17 +84,17 @@ class RiemannianFlowScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = None
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
@@ -202,7 +202,7 @@ class RiemannianFlowScheduler(SchedulerMixin, ConfigMixin):
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -214,10 +214,10 @@ class RiemannianFlowScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/rungekutta_44s_scheduler.py b/modules/res4lyf/rungekutta_44s_scheduler.py
index be6efe9da..d18941e01 100644
--- a/modules/res4lyf/rungekutta_44s_scheduler.py
+++ b/modules/res4lyf/rungekutta_44s_scheduler.py
@@ -1,4 +1,3 @@
-from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -23,16 +22,16 @@ class RungeKutta44Scheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ trained_betas: np.ndarray | list[float] | None = None,
prediction_type: str = "epsilon",
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
- sigma_min: Optional[float] = None,
- sigma_max: Optional[float] = None,
+ sigma_min: float | None = None,
+ sigma_max: float | None = None,
rho: float = 7.0,
- shift: Optional[float] = None,
+ shift: float | None = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
@@ -69,7 +68,7 @@ class RungeKutta44Scheduler(SchedulerMixin, ConfigMixin):
self._sigmas_cpu = None
self._step_index = None
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Base sigmas
@@ -141,7 +140,7 @@ class RungeKutta44Scheduler(SchedulerMixin, ConfigMixin):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -152,10 +151,10 @@ class RungeKutta44Scheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/rungekutta_57s_scheduler.py b/modules/res4lyf/rungekutta_57s_scheduler.py
index d3f6b2297..5d118bff7 100644
--- a/modules/res4lyf/rungekutta_57s_scheduler.py
+++ b/modules/res4lyf/rungekutta_57s_scheduler.py
@@ -1,4 +1,3 @@
-from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -21,16 +20,16 @@ class RungeKutta57Scheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ trained_betas: np.ndarray | list[float] | None = None,
prediction_type: str = "epsilon",
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
- sigma_min: Optional[float] = None,
- sigma_max: Optional[float] = None,
+ sigma_min: float | None = None,
+ sigma_max: float | None = None,
rho: float = 7.0,
- shift: Optional[float] = None,
+ shift: float | None = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
@@ -72,8 +71,8 @@ class RungeKutta57Scheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int,
- device: Union[str, torch.device] = None,
- mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ device: str | torch.device = None,
+ mu: float | None = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
@@ -178,7 +177,7 @@ class RungeKutta57Scheduler(SchedulerMixin, ConfigMixin):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -189,10 +188,10 @@ class RungeKutta57Scheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/rungekutta_67s_scheduler.py b/modules/res4lyf/rungekutta_67s_scheduler.py
index b2c13ad47..55af1b16c 100644
--- a/modules/res4lyf/rungekutta_67s_scheduler.py
+++ b/modules/res4lyf/rungekutta_67s_scheduler.py
@@ -1,4 +1,3 @@
-from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -22,16 +21,16 @@ class RungeKutta67Scheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ trained_betas: np.ndarray | list[float] | None = None,
prediction_type: str = "epsilon",
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
- sigma_min: Optional[float] = None,
- sigma_max: Optional[float] = None,
+ sigma_min: float | None = None,
+ sigma_max: float | None = None,
rho: float = 7.0,
- shift: Optional[float] = None,
+ shift: float | None = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
@@ -72,8 +71,8 @@ class RungeKutta67Scheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int,
- device: Union[str, torch.device] = None,
- mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ device: str | torch.device = None,
+ mu: float | None = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
@@ -177,7 +176,7 @@ class RungeKutta67Scheduler(SchedulerMixin, ConfigMixin):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -188,10 +187,10 @@ class RungeKutta67Scheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/simple_exponential_scheduler.py b/modules/res4lyf/simple_exponential_scheduler.py
index 52e678ca9..01a901e12 100644
--- a/modules/res4lyf/simple_exponential_scheduler.py
+++ b/modules/res4lyf/simple_exponential_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import ClassVar, List, Optional, Tuple, Union
+from typing import ClassVar
import numpy as np
import torch
@@ -29,7 +29,7 @@ class SimpleExponentialScheduler(SchedulerMixin, ConfigMixin):
Simple Exponential sigma scheduler using Exponential Integrator step.
"""
- _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
+ _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers]
order: ClassVar[int] = 1
@register_to_config
@@ -85,17 +85,17 @@ class SimpleExponentialScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = None
@property
- def step_index(self) -> Optional[int]:
+ def step_index(self) -> int | None:
return self._step_index
@property
- def begin_index(self) -> Optional[int]:
+ def begin_index(self) -> int | None:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
@@ -152,7 +152,7 @@ class SimpleExponentialScheduler(SchedulerMixin, ConfigMixin):
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -164,10 +164,10 @@ class SimpleExponentialScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
if self._step_index is None:
self._init_step_index(timestep)
diff --git a/modules/res4lyf/specialized_rk_scheduler.py b/modules/res4lyf/specialized_rk_scheduler.py
index fa9b23a2e..33b6df815 100644
--- a/modules/res4lyf/specialized_rk_scheduler.py
+++ b/modules/res4lyf/specialized_rk_scheduler.py
@@ -1,4 +1,3 @@
-from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -23,17 +22,17 @@ class SpecializedRKScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ trained_betas: np.ndarray | list[float] | None = None,
prediction_type: str = "epsilon",
variant: str = "ssprk3_3s", # ssprk3_3s, ssprk4_4s, tsi_7s, ralston_4s, bogacki-shampine_4s
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
- sigma_min: Optional[float] = None,
- sigma_max: Optional[float] = None,
+ sigma_min: float | None = None,
+ sigma_max: float | None = None,
rho: float = 7.0,
- shift: Optional[float] = None,
+ shift: float | None = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
@@ -107,8 +106,8 @@ class SpecializedRKScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int,
- device: Union[str, torch.device] = None,
- mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
+ device: str | torch.device = None,
+ mu: float | None = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
@@ -211,7 +210,7 @@ class SpecializedRKScheduler(SchedulerMixin, ConfigMixin):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
- def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
@@ -222,10 +221,10 @@ class SpecializedRKScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
- timestep: Union[float, torch.Tensor],
+ timestep: float | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
- ) -> Union[SchedulerOutput, Tuple]:
+ ) -> SchedulerOutput | tuple:
self._init_step_index(timestep)
a_mat, b_vec, c_vec = self._get_tableau()
num_stages = len(c_vec)
diff --git a/modules/rife/__init__.py b/modules/rife/__init__.py
index c4db35645..ba2a66d8b 100644
--- a/modules/rife/__init__.py
+++ b/modules/rife/__init__.py
@@ -12,7 +12,7 @@ from torch.nn import functional as F
from tqdm.rich import tqdm
from modules.rife.ssim import ssim_matlab
from modules.rife.model_rife import RifeModel
-from modules import devices, shared
+from modules import devices, shared, paths
model_url = 'https://github.com/vladmandic/rife/raw/main/model/flownet-v46.pkl'
@@ -23,7 +23,7 @@ def load(model_path: str = 'rife/flownet-v46.pkl'):
global model # pylint: disable=global-statement
if model is None:
from modules import modelloader
- model_dir = os.path.join(shared.models_path, 'RIFE')
+ model_dir = os.path.join(paths.models_path, 'RIFE')
model_path = modelloader.load_file_from_url(url=model_url, model_dir=model_dir, file_name='flownet-v46.pkl')
shared.log.debug(f'Video interpolate: model="{model_path}"')
model = RifeModel()
@@ -104,7 +104,7 @@ def interpolate(images: list, count: int = 2, scale: float = 1.0, pad: int = 1,
else:
output = execute(I0, I1, count-1)
for mid in output:
- mid = (((mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0)))
+ mid = ((mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0))
buffer.put(mid[:h, :w])
buffer.put(frame)
pbar.update(1)
diff --git a/modules/rife/loss.py b/modules/rife/loss.py
index 8b6309006..f525ff443 100644
--- a/modules/rife/loss.py
+++ b/modules/rife/loss.py
@@ -8,7 +8,7 @@ from modules import devices
class EPE(nn.Module):
def __init__(self):
- super(EPE, self).__init__()
+ super().__init__()
def forward(self, flow, gt, loss_mask):
loss_map = (flow - gt.detach()) ** 2
@@ -18,7 +18,7 @@ class EPE(nn.Module):
class Ternary(nn.Module):
def __init__(self):
- super(Ternary, self).__init__()
+ super().__init__()
patch_size = 7
out_channels = patch_size * patch_size
self.w = np.eye(out_channels).reshape(
@@ -56,7 +56,7 @@ class Ternary(nn.Module):
class SOBEL(nn.Module):
def __init__(self):
- super(SOBEL, self).__init__()
+ super().__init__()
self.kernelX = torch.tensor([
[1, 0, -1],
[2, 0, -2],
@@ -82,7 +82,7 @@ class SOBEL(nn.Module):
class MeanShift(nn.Conv2d):
def __init__(self, data_mean, data_std, data_range=1, norm=True):
c = len(data_mean)
- super(MeanShift, self).__init__(c, c, kernel_size=1)
+ super().__init__(c, c, kernel_size=1)
std = torch.Tensor(data_std)
self.weight.data = torch.eye(c).view(c, c, 1, 1)
if norm:
@@ -97,7 +97,7 @@ class MeanShift(nn.Conv2d):
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, rank=0): # pylint: disable=unused-argument
- super(VGGPerceptualLoss, self).__init__()
+ super().__init__()
pretrained = True
self.vgg_pretrained_features = models.vgg19(
pretrained=pretrained).features
diff --git a/modules/rife/model_ifnet.py b/modules/rife/model_ifnet.py
index 843430bee..df32b59a1 100644
--- a/modules/rife/model_ifnet.py
+++ b/modules/rife/model_ifnet.py
@@ -28,7 +28,7 @@ def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=
class ResConv(nn.Module):
def __init__(self, c, dilation=1):
- super(ResConv, self).__init__()
+ super().__init__()
self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1\
)
self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
@@ -39,7 +39,7 @@ class ResConv(nn.Module):
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
- super(IFBlock, self).__init__()
+ super().__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c//2, 3, 2, 1),
conv(c//2, c, 3, 2, 1),
@@ -74,7 +74,7 @@ class IFBlock(nn.Module):
class IFNet(nn.Module):
def __init__(self):
- super(IFNet, self).__init__()
+ super().__init__()
self.block0 = IFBlock(7, c=192)
self.block1 = IFBlock(8+4, c=128)
self.block2 = IFBlock(8+4, c=96)
@@ -82,7 +82,9 @@ class IFNet(nn.Module):
# self.contextnet = Contextnet()
# self.unet = Unet()
- def forward( self, x, timestep=0.5, scale_list=[8, 4, 2, 1], training=False, fastmode=True, ensemble=False): # pylint: disable=dangerous-default-value, unused-argument
+ def forward( self, x, timestep=0.5, scale_list=None, training=False, fastmode=True, ensemble=False): # pylint: disable=dangerous-default-value, unused-argument
+ if scale_list is None:
+ scale_list = [8, 4, 2, 1]
if training is False:
channel = x.shape[1] // 2
img0 = x[:, :channel]
diff --git a/modules/rife/refine.py b/modules/rife/refine.py
index 5d77582cc..9ea076d1f 100644
--- a/modules/rife/refine.py
+++ b/modules/rife/refine.py
@@ -29,7 +29,7 @@ def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): # pylint:
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
- super(Conv2, self).__init__()
+ super().__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
@@ -41,7 +41,7 @@ class Conv2(nn.Module):
class Contextnet(nn.Module):
def __init__(self):
- super(Contextnet, self).__init__()
+ super().__init__()
self.conv1 = Conv2(3, c)
self.conv2 = Conv2(c, 2*c)
self.conv3 = Conv2(2*c, 4*c)
@@ -65,7 +65,7 @@ class Contextnet(nn.Module):
class Unet(nn.Module):
def __init__(self):
- super(Unet, self).__init__()
+ super().__init__()
self.down0 = Conv2(17, 2*c)
self.down1 = Conv2(4*c, 4*c)
self.down2 = Conv2(8*c, 8*c)
diff --git a/modules/rife/ssim.py b/modules/rife/ssim.py
index e2261ca7a..8233ec8f3 100644
--- a/modules/rife/ssim.py
+++ b/modules/rife/ssim.py
@@ -142,7 +142,7 @@ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normal
# Classes to re-use window
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, val_range=None):
- super(SSIM, self).__init__()
+ super().__init__()
self.window_size = window_size
self.size_average = size_average
self.val_range = val_range
@@ -165,7 +165,7 @@ class SSIM(torch.nn.Module):
class MSSSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, channel=3):
- super(MSSSIM, self).__init__()
+ super().__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = channel
diff --git a/modules/rocm.py b/modules/rocm.py
index c0b8b8df1..42af4db75 100644
--- a/modules/rocm.py
+++ b/modules/rocm.py
@@ -5,14 +5,14 @@ import ctypes
import shutil
import subprocess
from types import ModuleType
-from typing import Union, overload, TYPE_CHECKING
+from typing import overload, TYPE_CHECKING
from enum import Enum
from functools import wraps
if TYPE_CHECKING:
import torch
-rocm_sdk: Union[ModuleType, None] = None
+rocm_sdk: ModuleType | None = None
def resolve_link(path_: str) -> str:
@@ -27,7 +27,7 @@ def dirname(path_: str, r: int = 1) -> str:
return path_
-def spawn(command: Union[str, list[str]], cwd: os.PathLike = '.') -> str:
+def spawn(command: str | list[str], cwd: os.PathLike = '.') -> str:
process = subprocess.run(command, cwd=cwd, shell=True, check=False, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
return process.stdout.decode(encoding="utf8", errors="ignore")
@@ -116,7 +116,7 @@ class Agent:
return self.name
@property
- def therock(self) -> Union[str, None]:
+ def therock(self) -> str | None:
if (self.gfx_version & 0xFFF0) == 0x1200:
return "v2/gfx120X-all"
if (self.gfx_version & 0xFFF0) == 0x1100:
@@ -141,7 +141,7 @@ class Agent:
# return "gfx950-dcgpu"
return None
- def get_gfx_version(self) -> Union[str, None]:
+ def get_gfx_version(self) -> str | None:
if self.gfx_version is None:
return None
if self.gfx_version >= 0x1100 and self.gfx_version < 0x1200:
@@ -153,7 +153,7 @@ class Agent:
return None
-def find() -> Union[ROCmEnvironment, None]:
+def find() -> ROCmEnvironment | None:
hip_path = shutil.which("hipconfig")
if hip_path is not None:
return ROCmEnvironment(dirname(resolve_link(hip_path), 2))
@@ -364,7 +364,6 @@ else: # sys.platform != "win32"
def rocm_init():
try:
- import torch
from installer import log
from modules.devices import get_hip_agent
@@ -377,10 +376,10 @@ else: # sys.platform != "win32"
is_wsl: bool = os.environ.get('WSL_DISTRO_NAME', 'unknown' if spawn('wslpath -w /') else None) is not None
-environment: Union[Environment, None] = None
-blaslt_tensile_libpath: Union[str, None] = None
+environment: Environment | None = None
+blaslt_tensile_libpath: str | None = None
is_installed: bool = False
-version: Union[str, None] = None
+version: str | None = None
refresh()
# amdgpu-arch.exe written in Python
diff --git a/modules/rocm_triton_windows.py b/modules/rocm_triton_windows.py
index 4bcbaff18..022e27b10 100644
--- a/modules/rocm_triton_windows.py
+++ b/modules/rocm_triton_windows.py
@@ -1,5 +1,4 @@
import sys
-from typing import Union
import torch
from modules import shared, devices
from modules.rocm import Agent
@@ -58,7 +57,7 @@ if sys.platform == "win32":
from modules import zluda
return zluda.core.to_hip_stream(_cuda_getCurrentRawStream(device))
- def get_default_agent() -> Union[Agent, None]:
+ def get_default_agent() -> Agent | None:
if shared.devices.has_rocm():
return devices.get_hip_agent()
else:
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index a0c85a283..fad126ce1 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -2,7 +2,7 @@ import os
import sys
import time
from collections import namedtuple
-from typing import Optional, Dict, Any
+from typing import Any
from fastapi import FastAPI
from gradio import Blocks
import modules.errors as errors
@@ -149,7 +149,7 @@ def clear_callbacks():
callback_list.clear()
-def app_started_callback(demo: Optional[Blocks], app: FastAPI):
+def app_started_callback(demo: Blocks | None, app: FastAPI):
for c in callback_map['callbacks_app_started']:
try:
t0 = time.time()
@@ -319,7 +319,7 @@ def image_grid_callback(params: ImageGridLoopParams):
report_exception(e, c, 'image_grid')
-def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
+def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
for c in callback_map['callbacks_infotext_pasted']:
try:
t0 = time.time()
diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py
index a1ebc104e..2b1cbb847 100644
--- a/modules/scripts_auto_postprocessing.py
+++ b/modules/scripts_auto_postprocessing.py
@@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts_manager.Script):
return self.postprocessing_controls.values()
def postprocess_image(self, p, script_pp, *args): # pylint: disable=arguments-differ
- args_dict = dict(zip(self.postprocessing_controls, args))
+ args_dict = dict(zip(self.postprocessing_controls, args, strict=False))
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
pp.info = {}
self.script.process(pp, **args_dict)
diff --git a/modules/scripts_manager.py b/modules/scripts_manager.py
index 7ca38f5cd..1ac69bdf5 100644
--- a/modules/scripts_manager.py
+++ b/modules/scripts_manager.py
@@ -234,7 +234,7 @@ def list_scripts(scriptdirname, extension):
else:
priority = '9'
if os.path.isfile(os.path.join(base, "..", ".priority")):
- with open(os.path.join(base, "..", ".priority"), "r", encoding="utf-8") as f:
+ with open(os.path.join(base, "..", ".priority"), encoding="utf-8") as f:
priority = priority + str(f.read().strip())
errors.log.debug(f'Script priority override: ${script.name}:{priority}')
else:
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py
index e72e16d55..1bea4e219 100644
--- a/modules/scripts_postprocessing.py
+++ b/modules/scripts_postprocessing.py
@@ -4,7 +4,9 @@ from modules import errors, shared
class PostprocessedImage:
- def __init__(self, image, info = {}):
+ def __init__(self, image, info = None):
+ if info is None:
+ info = {}
self.image = image
self.info = info
@@ -99,7 +101,7 @@ class ScriptPostprocessingRunner:
jobid = shared.state.begin(script.name)
script_args = args[script.args_from:script.args_to]
process_args = {}
- for (name, _component), value in zip(script.controls.items(), script_args):
+ for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
process_args[name] = value
shared.log.debug(f'Process: script="{script.name}" args={process_args}')
script.process(pp, **process_args)
@@ -129,7 +131,7 @@ class ScriptPostprocessingRunner:
jobid = shared.state.begin(script.name)
script_args = args[script.args_from:script.args_to]
process_args = {}
- for (name, _component), value in zip(script.controls.items(), script_args):
+ for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
process_args[name] = value
shared.log.debug(f'Postprocess: script={script.name} args={process_args}')
script.postprocess(filenames, **process_args)
diff --git a/modules/sd_checkpoint.py b/modules/sd_checkpoint.py
index 874cb6729..c050811fe 100644
--- a/modules/sd_checkpoint.py
+++ b/modules/sd_checkpoint.py
@@ -149,11 +149,11 @@ def list_models():
checkpoint_info.register()
if shared.cmd_opts.ckpt is not None:
checkpoint_info = CheckpointInfo(shared.cmd_opts.ckpt)
- if checkpoint_info.name is not None:
+ if checkpoint_info.name is not None and os.path.exists(checkpoint_info.filename):
checkpoint_info.register()
shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
- elif shared.cmd_opts.ckpt != shared.default_sd_model_file and shared.cmd_opts.ckpt is not None:
- shared.log.warning(f'Load model: path="{shared.cmd_opts.ckpt}" not found')
+ elif shared.cmd_opts.ckpt != shared.default_sd_model_file:
+ shared.log.warning(f'Load model: path="{shared.cmd_opts.ckpt}" not found')
shared.log.info(f'Available Models: safetensors="{shared.opts.ckpt_dir}":{len(safetensors_list)} diffusers="{shared.opts.diffusers_dir}":{len(diffusers_list)} reference={len(list(shared.reference_models))} items={len(checkpoints_list)} time={time.time()-t0:.2f}')
checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename))
diff --git a/modules/sd_hijack_accelerate.py b/modules/sd_hijack_accelerate.py
index 834038a41..fd8f4a3cb 100644
--- a/modules/sd_hijack_accelerate.py
+++ b/modules/sd_hijack_accelerate.py
@@ -1,4 +1,3 @@
-from typing import Optional, Union
import time
import torch
import torch.nn as nn
@@ -17,10 +16,10 @@ orig_torch_conv = torch.nn.modules.conv.Conv2d._conv_forward # pylint: disable=p
def hijack_set_module_tensor(
module: nn.Module,
tensor_name: str,
- device: Union[int, str, torch.device],
- value: Optional[torch.Tensor] = None,
- dtype: Optional[Union[str, torch.dtype]] = None, # pylint: disable=unused-argument
- fp16_statistics: Optional[torch.HalfTensor] = None, # pylint: disable=unused-argument
+ device: int | str | torch.device,
+ value: torch.Tensor | None = None,
+ dtype: str | torch.dtype | None = None, # pylint: disable=unused-argument
+ fp16_statistics: torch.HalfTensor | None = None, # pylint: disable=unused-argument
):
global tensor_to_timer # pylint: disable=global-statement
if device == 'cpu': # override to load directly to gpu
@@ -46,10 +45,10 @@ def hijack_set_module_tensor(
def hijack_set_module_tensor_simple(
module: nn.Module,
tensor_name: str,
- device: Union[int, str, torch.device],
- value: Optional[torch.Tensor] = None,
- dtype: Optional[Union[str, torch.dtype]] = None, # pylint: disable=unused-argument
- fp16_statistics: Optional[torch.HalfTensor] = None, # pylint: disable=unused-argument
+ device: int | str | torch.device,
+ value: torch.Tensor | None = None,
+ dtype: str | torch.dtype | None = None, # pylint: disable=unused-argument
+ fp16_statistics: torch.HalfTensor | None = None, # pylint: disable=unused-argument
):
global tensor_to_timer # pylint: disable=global-statement
if device == 'cpu': # override to load directly to gpu
diff --git a/modules/sd_hijack_dynamic_atten.py b/modules/sd_hijack_dynamic_atten.py
index c3202ad16..6410b5a71 100644
--- a/modules/sd_hijack_dynamic_atten.py
+++ b/modules/sd_hijack_dynamic_atten.py
@@ -1,8 +1,6 @@
-from typing import Tuple, Optional
from functools import cache, wraps
import torch
-from diffusers.utils import USE_PEFT_BACKEND # pylint: disable=unused-import
from modules import shared, devices
@@ -21,7 +19,7 @@ def find_split_size(original_size: int, slice_block_size: int, slice_rate: int =
# Find slice sizes for SDPA
@cache
-def find_sdpa_slice_sizes(query_shape: Tuple[int], key_shape: Tuple[int], query_element_size: int, slice_rate: int = 2, trigger_rate: int = 3) -> Tuple[bool, int]:
+def find_sdpa_slice_sizes(query_shape: tuple[int], key_shape: tuple[int], query_element_size: int, slice_rate: int = 2, trigger_rate: int = 3) -> tuple[bool, int]:
batch_size, attn_heads, query_len, _ = query_shape
_, _, key_len, _ = key_shape
@@ -55,7 +53,7 @@ def find_sdpa_slice_sizes(query_shape: Tuple[int], key_shape: Tuple[int], query_
if devices.sdpa_pre_dyanmic_atten is None:
devices.sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(devices.sdpa_pre_dyanmic_atten)
-def dynamic_scaled_dot_product_attention(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.FloatTensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
+def dynamic_scaled_dot_product_attention(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.FloatTensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
is_unsqueezed = False
if query.dim() == 3:
query = query.unsqueeze(0)
diff --git a/modules/sd_hijack_hypertile.py b/modules/sd_hijack_hypertile.py
index ec8f0c64f..f25ac8078 100644
--- a/modules/sd_hijack_hypertile.py
+++ b/modules/sd_hijack_hypertile.py
@@ -2,7 +2,7 @@
# based on: https://github.com/tfernd/HyperTile/tree/main/hyper_tile/utils.py + https://github.com/tfernd/HyperTile/tree/main/hyper_tile/hyper_tile.py
from __future__ import annotations
-from typing import Callable
+from collections.abc import Callable
from functools import wraps, cache
from contextlib import contextmanager, nullcontext
import random
diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py
index 179ebc78e..e20a247b3 100644
--- a/modules/sd_hijack_utils.py
+++ b/modules/sd_hijack_utils.py
@@ -2,7 +2,7 @@ import importlib
class CondFunc:
def __new__(cls, orig_func, sub_func, cond_func):
- self = super(CondFunc, cls).__new__(cls)
+ self = super().__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
for i in range(len(func_path)-1, -1, -1):
diff --git a/modules/sd_models.py b/modules/sd_models.py
index a88acded9..6590ebf3c 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,9 +14,9 @@ from installer import log
from modules import timer, paths, shared, shared_items, modelloader, devices, script_callbacks, sd_vae, sd_unet, errors, sd_models_compile, sd_detect, model_quant, sd_hijack_te, sd_hijack_accelerate, sd_hijack_safetensors, attention
from modules.memstats import memory_stats
from modules.modeldata import model_data
-from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, sd_metadata_file, checkpoints_list, checkpoint_titles, get_closest_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import
+from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoint_titles, get_closest_checkpoint_match, update_model_hashes, write_metadata, checkpoints_list # pylint: disable=unused-import
from modules.sd_offload import get_module_names, disable_offload, set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import
-from modules.sd_models_utils import NoWatermark, get_signature, get_call, path_to_repo, patch_diffuser_config, convert_to_faketensors, read_state_dict, get_state_dict_from_checkpoint, apply_function_to_model # pylint: disable=unused-import
+from modules.sd_models_utils import NoWatermark, get_signature, path_to_repo, apply_function_to_model, read_state_dict, get_state_dict_from_checkpoint # pylint: disable=unused-import
model_dir = "Stable-diffusion"
diff --git a/modules/sd_models_compile.py b/modules/sd_models_compile.py
index 7c91b9bfc..564a64556 100644
--- a/modules/sd_models_compile.py
+++ b/modules/sd_models_compile.py
@@ -65,7 +65,6 @@ def ipex_optimize(sd_model, apply_to_components=True, op="Model"):
def optimize_openvino(sd_model, clear_cache=True):
try:
- from modules.intel.openvino import openvino_fx # pylint: disable=unused-import
if clear_cache and shared.compiled_model_state is not None:
shared.compiled_model_state.compiled_cache.clear()
shared.compiled_model_state.req_cache.clear()
@@ -124,12 +123,10 @@ def compile_stablefast(sd_model):
return sd_model
config = sf.CompilationConfig.Default()
try:
- import xformers # pylint: disable=unused-import
config.enable_xformers = True
except Exception:
pass
try:
- import triton # pylint: disable=unused-import
config.enable_triton = True
except Exception:
pass
@@ -196,7 +193,7 @@ def compile_torch(sd_model, apply_to_components=True, op="Model"):
shared.compiled_model_state = CompiledModelState()
return sd_model
elif shared.opts.cuda_compile_backend == "migraphx":
- import torch_migraphx # pylint: disable=unused-import
+ pass # pylint: disable=unused-import
log_level = logging.WARNING if 'verbose' in shared.opts.cuda_compile_options else logging.CRITICAL # pylint: disable=protected-access
if hasattr(torch, '_logging'):
torch._logging.set_logs(dynamo=log_level, aot=log_level, inductor=log_level) # pylint: disable=protected-access
diff --git a/modules/sd_models_utils.py b/modules/sd_models_utils.py
index 05f44ef6c..f766270b7 100644
--- a/modules/sd_models_utils.py
+++ b/modules/sd_models_utils.py
@@ -8,8 +8,7 @@ import torch
import safetensors.torch
from modules import paths, shared, errors
-from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closest_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import
-from modules.sd_offload import disable_offload, set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import
+from modules.sd_checkpoint import CheckpointInfo # pylint: disable=unused-import
class NoWatermark:
@@ -124,11 +123,11 @@ def patch_diffuser_config(sd_model, model_file):
cfg_file = f'{model_file}_{k}.json'
try:
if os.path.exists(cfg_file):
- with open(cfg_file, 'r', encoding='utf-8') as f:
+ with open(cfg_file, encoding='utf-8') as f:
return json.load(f)
cfg_file = f'{os.path.join(paths.sd_configs_path, os.path.basename(model_file))}_{k}.json'
if os.path.exists(cfg_file):
- with open(cfg_file, 'r', encoding='utf-8') as f:
+ with open(cfg_file, encoding='utf-8') as f:
return json.load(f)
except Exception:
pass
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 8c1eb1ecd..c0e45f7a5 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,7 +1,6 @@
import os
import copy
from modules import shared
-from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # pylint: disable=unused-import
debug = shared.log.trace if os.environ.get('SD_SAMPLER_DEBUG', None) is not None else lambda *args, **kwargs: None
diff --git a/modules/sd_te_remote.py b/modules/sd_te_remote.py
index 264eeb145..cdc743e0d 100644
--- a/modules/sd_te_remote.py
+++ b/modules/sd_te_remote.py
@@ -1,4 +1,3 @@
-from typing import List, Optional, Union
import os
import time
import json
@@ -8,11 +7,11 @@ from modules import devices, errors
def get_t5_prompt_embeds(
- prompt: Union[str, List[str]] = None,
+ prompt: str | list[str] = None,
num_images_per_prompt: int = 1, # pylint: disable=unused-argument
max_sequence_length: int = 512, # pylint: disable=unused-argument
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
+ device: torch.device | None = None,
+ dtype: torch.dtype | None = None,
):
device = device or devices.device
dtype = dtype or devices.dtype
diff --git a/modules/sdnq/dequantizer.py b/modules/sdnq/dequantizer.py
index ff1036260..882af5802 100644
--- a/modules/sdnq/dequantizer.py
+++ b/modules/sdnq/dequantizer.py
@@ -1,6 +1,5 @@
# pylint: disable=redefined-builtin,no-member,protected-access
-from typing import List, Tuple, Optional
from dataclasses import dataclass
import torch
@@ -13,7 +12,7 @@ from .layers import SDNQLayer
@devices.inference_context()
-def dequantize_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False) -> torch.FloatTensor:
+def dequantize_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None, skip_quantized_matmul: bool = False) -> torch.FloatTensor:
result = torch.addcmul(zero_point, weight.to(dtype=scale.dtype), scale)
if result_shape is not None:
result = result.view(result_shape)
@@ -34,7 +33,7 @@ def dequantize_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, ze
@devices.inference_context()
-def dequantize_symmetric(weight: torch.CharTensor, scale: torch.FloatTensor, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor:
+def dequantize_symmetric(weight: torch.CharTensor, scale: torch.FloatTensor, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor:
result = weight.to(dtype=scale.dtype).mul_(scale)
if skip_quantized_matmul and not re_quantize_for_matmul:
result.t_()
@@ -57,7 +56,7 @@ def dequantize_symmetric(weight: torch.CharTensor, scale: torch.FloatTensor, svd
@devices.inference_context()
-def dequantize_symmetric_with_bias(weight: torch.CharTensor, scale: torch.FloatTensor, bias: torch.FloatTensor, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None) -> torch.FloatTensor:
+def dequantize_symmetric_with_bias(weight: torch.CharTensor, scale: torch.FloatTensor, bias: torch.FloatTensor, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None) -> torch.FloatTensor:
result = torch.addcmul(bias, weight.to(dtype=scale.dtype), scale)
if result_shape is not None:
result = result.view(result_shape)
@@ -67,48 +66,48 @@ def dequantize_symmetric_with_bias(weight: torch.CharTensor, scale: torch.FloatT
@devices.inference_context()
-def dequantize_packed_int_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False) -> torch.FloatTensor:
+def dequantize_packed_int_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None, skip_quantized_matmul: bool = False) -> torch.FloatTensor:
return dequantize_asymmetric(unpack_int_asymetric(weight, shape, weights_dtype), scale, zero_point, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=result_shape, skip_quantized_matmul=skip_quantized_matmul)
@devices.inference_context()
-def dequantize_packed_int_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor:
+def dequantize_packed_int_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor:
return dequantize_symmetric(unpack_int_symetric(weight, shape, weights_dtype, dtype=scale.dtype), scale, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=result_shape, skip_quantized_matmul=skip_quantized_matmul, re_quantize_for_matmul=re_quantize_for_matmul)
@devices.inference_context()
-def dequantize_packed_float_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False) -> torch.FloatTensor:
+def dequantize_packed_float_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None, skip_quantized_matmul: bool = False) -> torch.FloatTensor:
return dequantize_asymmetric(unpack_float(weight, shape, weights_dtype), scale, zero_point, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=result_shape, skip_quantized_matmul=skip_quantized_matmul)
@devices.inference_context()
-def dequantize_packed_float_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor:
+def dequantize_packed_float_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor:
return dequantize_symmetric(unpack_float(weight, shape, weights_dtype), scale, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=result_shape, skip_quantized_matmul=skip_quantized_matmul, re_quantize_for_matmul=re_quantize_for_matmul)
@devices.inference_context()
-def quantize_int_mm(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "int8") -> Tuple[torch.Tensor, torch.FloatTensor]:
+def quantize_int_mm(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "int8") -> tuple[torch.Tensor, torch.FloatTensor]:
scale = torch.amax(input.abs(), dim=dim, keepdims=True).div_(dtype_dict[matmul_dtype]["max"])
input = torch.div(input, scale).round_().clamp_(dtype_dict[matmul_dtype]["min"], dtype_dict[matmul_dtype]["max"]).to(dtype=dtype_dict[matmul_dtype]["torch_dtype"])
return input, scale
@devices.inference_context()
-def quantize_int_mm_sr(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "int8") -> Tuple[torch.Tensor, torch.FloatTensor]:
+def quantize_int_mm_sr(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "int8") -> tuple[torch.Tensor, torch.FloatTensor]:
scale = torch.amax(input.abs(), dim=dim, keepdims=True).div_(dtype_dict[matmul_dtype]["max"])
input = torch.div(input, scale).add_(torch.randn_like(input), alpha=0.1).round_().clamp_(dtype_dict[matmul_dtype]["min"], dtype_dict[matmul_dtype]["max"]).to(dtype=dtype_dict[matmul_dtype]["torch_dtype"])
return input, scale
@devices.inference_context()
-def quantize_fp_mm(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "float8_e4m3fn") -> Tuple[torch.Tensor, torch.FloatTensor]:
+def quantize_fp_mm(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]:
scale = torch.amax(input.abs(), dim=dim, keepdims=True).div_(dtype_dict[matmul_dtype]["max"])
input = torch.div(input, scale).nan_to_num_().clamp_(dtype_dict[matmul_dtype]["min"], dtype_dict[matmul_dtype]["max"]).to(dtype=dtype_dict[matmul_dtype]["torch_dtype"])
return input, scale
@devices.inference_context()
-def quantize_fp_mm_sr(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "float8_e4m3fn") -> Tuple[torch.Tensor, torch.FloatTensor]:
+def quantize_fp_mm_sr(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]:
mantissa_difference = 1 << (23 - dtype_dict[matmul_dtype]["mantissa"])
scale = torch.amax(input.abs(), dim=dim, keepdims=True).div_(dtype_dict[matmul_dtype]["max"])
input = torch.div(input, scale).to(dtype=torch.float32).view(dtype=torch.int32)
@@ -118,7 +117,7 @@ def quantize_fp_mm_sr(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str
@devices.inference_context()
-def re_quantize_int_mm(weight: torch.FloatTensor) -> Tuple[torch.Tensor, torch.FloatTensor]:
+def re_quantize_int_mm(weight: torch.FloatTensor) -> tuple[torch.Tensor, torch.FloatTensor]:
if weight.ndim > 2: # convs
weight = weight.flatten(1,-1)
if use_contiguous_mm:
@@ -130,7 +129,7 @@ def re_quantize_int_mm(weight: torch.FloatTensor) -> Tuple[torch.Tensor, torch.F
@devices.inference_context()
-def re_quantize_fp_mm(weight: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> Tuple[torch.Tensor, torch.FloatTensor]:
+def re_quantize_fp_mm(weight: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]:
if weight.ndim > 2: # convs
weight = weight.flatten(1,-1)
weight, scale = quantize_fp_mm(weight.contiguous(), dim=-1, matmul_dtype=matmul_dtype)
@@ -141,7 +140,7 @@ def re_quantize_fp_mm(weight: torch.FloatTensor, matmul_dtype: str = "float8_e4m
@devices.inference_context()
-def re_quantize_matmul_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, matmul_dtype: str, result_shape: Optional[torch.Size] = None, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]:
+def re_quantize_matmul_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, matmul_dtype: str, result_shape: torch.Size | None = None, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None) -> tuple[torch.Tensor, torch.FloatTensor]:
weight = dequantize_asymmetric(weight, scale, zero_point, svd_up=svd_up, svd_down=svd_down, dtype=scale.dtype, result_shape=result_shape)
if dtype_dict[matmul_dtype]["is_integer"]:
return re_quantize_int_mm(weight)
@@ -150,7 +149,7 @@ def re_quantize_matmul_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTe
@devices.inference_context()
-def re_quantize_matmul_symmetric(weight: torch.CharTensor, scale: torch.FloatTensor, matmul_dtype: str, result_shape: Optional[torch.Size] = None, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]:
+def re_quantize_matmul_symmetric(weight: torch.CharTensor, scale: torch.FloatTensor, matmul_dtype: str, result_shape: torch.Size | None = None, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None) -> tuple[torch.Tensor, torch.FloatTensor]:
weight = dequantize_symmetric(weight, scale, svd_up=svd_up, svd_down=svd_down, dtype=scale.dtype, result_shape=result_shape)
if dtype_dict[matmul_dtype]["is_integer"]:
return re_quantize_int_mm(weight)
@@ -159,22 +158,22 @@ def re_quantize_matmul_symmetric(weight: torch.CharTensor, scale: torch.FloatTen
@devices.inference_context()
-def re_quantize_matmul_packed_int_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: torch.Size, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]:
+def re_quantize_matmul_packed_int_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: torch.Size, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None) -> tuple[torch.Tensor, torch.FloatTensor]:
return re_quantize_matmul_asymmetric(unpack_int_asymetric(weight, shape, weights_dtype), scale, zero_point, matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=result_shape)
@devices.inference_context()
-def re_quantize_matmul_packed_int_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: Optional[torch.Size] = None, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]:
+def re_quantize_matmul_packed_int_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: torch.Size | None = None, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None) -> tuple[torch.Tensor, torch.FloatTensor]:
return re_quantize_matmul_symmetric(unpack_int_symetric(weight, shape, weights_dtype, dtype=scale.dtype), scale, matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=result_shape)
@devices.inference_context()
-def re_quantize_matmul_packed_float_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: torch.Size, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]:
+def re_quantize_matmul_packed_float_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: torch.Size, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None) -> tuple[torch.Tensor, torch.FloatTensor]:
return re_quantize_matmul_asymmetric(unpack_float(weight, shape, weights_dtype), scale, zero_point, matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=result_shape)
@devices.inference_context()
-def re_quantize_matmul_packed_float_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: Optional[torch.Size] = None, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]:
+def re_quantize_matmul_packed_float_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: torch.Size | None = None, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None) -> tuple[torch.Tensor, torch.FloatTensor]:
return re_quantize_matmul_symmetric(unpack_float(weight, shape, weights_dtype), scale, matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=result_shape)
@@ -220,7 +219,7 @@ class SDNQDequantizer:
result_dtype: torch.dtype
result_shape: torch.Size
original_shape: torch.Size
- original_stride: List[int]
+ original_stride: list[int]
quantized_weight_shape: torch.Size
weights_dtype: str
quantized_matmul_dtype: str
@@ -241,7 +240,7 @@ class SDNQDequantizer:
result_dtype: torch.dtype,
result_shape: torch.Size,
original_shape: torch.Size,
- original_stride: List[int],
+ original_stride: list[int],
quantized_weight_shape: torch.Size,
weights_dtype: str,
quantized_matmul_dtype: str,
diff --git a/modules/sdnq/forward.py b/modules/sdnq/forward.py
index 9fc99d9f9..2deccf3ac 100644
--- a/modules/sdnq/forward.py
+++ b/modules/sdnq/forward.py
@@ -1,6 +1,6 @@
# pylint: disable=protected-access
-from typing import Callable
+from collections.abc import Callable
from .common import dtype_dict, conv_types, conv_transpose_types, use_tensorwise_fp8_matmul
diff --git a/modules/sdnq/layers/conv/conv_fp16.py b/modules/sdnq/layers/conv/conv_fp16.py
index 8b60767cc..a8f1c4460 100644
--- a/modules/sdnq/layers/conv/conv_fp16.py
+++ b/modules/sdnq/layers/conv/conv_fp16.py
@@ -1,6 +1,5 @@
# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
-from typing import List
import torch
@@ -18,10 +17,10 @@ def conv_fp16_matmul(
weight: torch.Tensor,
scale: torch.FloatTensor,
result_shape: torch.Size,
- reversed_padding_repeated_twice: List[int],
+ reversed_padding_repeated_twice: list[int],
padding_mode: str, conv_type: int,
- groups: int, stride: List[int],
- padding: List[int], dilation: List[int],
+ groups: int, stride: list[int],
+ padding: list[int], dilation: list[int],
bias: torch.FloatTensor = None,
svd_up: torch.FloatTensor = None,
svd_down: torch.FloatTensor = None,
diff --git a/modules/sdnq/layers/conv/conv_fp8.py b/modules/sdnq/layers/conv/conv_fp8.py
index 994850fb1..a2b864381 100644
--- a/modules/sdnq/layers/conv/conv_fp8.py
+++ b/modules/sdnq/layers/conv/conv_fp8.py
@@ -1,6 +1,5 @@
# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
-from typing import List
import torch
@@ -17,10 +16,10 @@ def conv_fp8_matmul(
weight: torch.Tensor,
scale: torch.FloatTensor,
result_shape: torch.Size,
- reversed_padding_repeated_twice: List[int],
+ reversed_padding_repeated_twice: list[int],
padding_mode: str, conv_type: int,
- groups: int, stride: List[int],
- padding: List[int], dilation: List[int],
+ groups: int, stride: list[int],
+ padding: list[int], dilation: list[int],
bias: torch.FloatTensor = None,
svd_up: torch.FloatTensor = None,
svd_down: torch.FloatTensor = None,
diff --git a/modules/sdnq/layers/conv/conv_fp8_tensorwise.py b/modules/sdnq/layers/conv/conv_fp8_tensorwise.py
index 9be958923..9fc388873 100644
--- a/modules/sdnq/layers/conv/conv_fp8_tensorwise.py
+++ b/modules/sdnq/layers/conv/conv_fp8_tensorwise.py
@@ -1,6 +1,5 @@
# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
-from typing import List
import torch
@@ -18,10 +17,10 @@ def conv_fp8_matmul_tensorwise(
weight: torch.Tensor,
scale: torch.FloatTensor,
result_shape: torch.Size,
- reversed_padding_repeated_twice: List[int],
+ reversed_padding_repeated_twice: list[int],
padding_mode: str, conv_type: int,
- groups: int, stride: List[int],
- padding: List[int], dilation: List[int],
+ groups: int, stride: list[int],
+ padding: list[int], dilation: list[int],
bias: torch.FloatTensor = None,
svd_up: torch.FloatTensor = None,
svd_down: torch.FloatTensor = None,
diff --git a/modules/sdnq/layers/conv/conv_int8.py b/modules/sdnq/layers/conv/conv_int8.py
index 9777b3d9b..3e28c11ea 100644
--- a/modules/sdnq/layers/conv/conv_int8.py
+++ b/modules/sdnq/layers/conv/conv_int8.py
@@ -1,6 +1,5 @@
# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
-from typing import List
import torch
@@ -18,10 +17,10 @@ def conv_int8_matmul(
weight: torch.Tensor,
scale: torch.FloatTensor,
result_shape: torch.Size,
- reversed_padding_repeated_twice: List[int],
+ reversed_padding_repeated_twice: list[int],
padding_mode: str, conv_type: int,
- groups: int, stride: List[int],
- padding: List[int], dilation: List[int],
+ groups: int, stride: list[int],
+ padding: list[int], dilation: list[int],
bias: torch.FloatTensor = None,
svd_up: torch.FloatTensor = None,
svd_down: torch.FloatTensor = None,
diff --git a/modules/sdnq/layers/conv/forward.py b/modules/sdnq/layers/conv/forward.py
index 2ed3d816f..74454d2d9 100644
--- a/modules/sdnq/layers/conv/forward.py
+++ b/modules/sdnq/layers/conv/forward.py
@@ -1,6 +1,5 @@
# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
-from typing import Optional
import torch
@@ -78,16 +77,16 @@ def quantized_conv_forward(self, input) -> torch.FloatTensor:
return self._conv_forward(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias)
-def quantized_conv_transpose_1d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor:
+def quantized_conv_transpose_1d_forward(self, input: torch.FloatTensor, output_size: list[int] | None = None) -> torch.FloatTensor:
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 1, self.dilation)
return torch.nn.functional.conv_transpose1d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
-def quantized_conv_transpose_2d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor:
+def quantized_conv_transpose_2d_forward(self, input: torch.FloatTensor, output_size: list[int] | None = None) -> torch.FloatTensor:
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 2, self.dilation)
return torch.nn.functional.conv_transpose2d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
-def quantized_conv_transpose_3d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor:
+def quantized_conv_transpose_3d_forward(self, input: torch.FloatTensor, output_size: list[int] | None = None) -> torch.FloatTensor:
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 3, self.dilation)
return torch.nn.functional.conv_transpose3d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
diff --git a/modules/sdnq/layers/linear/forward.py b/modules/sdnq/layers/linear/forward.py
index 7b3a169d9..be51a66ad 100644
--- a/modules/sdnq/layers/linear/forward.py
+++ b/modules/sdnq/layers/linear/forward.py
@@ -1,13 +1,12 @@
# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
-from typing import Tuple
import torch
from ...common import use_contiguous_mm # noqa: TID252
-def check_mats(input: torch.Tensor, weight: torch.Tensor, allow_contiguous_mm: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
+def check_mats(input: torch.Tensor, weight: torch.Tensor, allow_contiguous_mm: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
input = input.contiguous()
if allow_contiguous_mm and use_contiguous_mm:
weight = weight.contiguous()
diff --git a/modules/sdnq/layers/linear/linear_fp8.py b/modules/sdnq/layers/linear/linear_fp8.py
index 169d318f9..80bf64b0e 100644
--- a/modules/sdnq/layers/linear/linear_fp8.py
+++ b/modules/sdnq/layers/linear/linear_fp8.py
@@ -1,6 +1,5 @@
# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
-from typing import Tuple
import torch
@@ -11,7 +10,7 @@ from ...dequantizer import quantize_fp_mm # noqa: TID252
from .forward import check_mats
-def quantize_fp_mm_input(input: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> Tuple[torch.Tensor, torch.FloatTensor]:
+def quantize_fp_mm_input(input: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]:
input = input.flatten(0,-2).to(dtype=torch.float32)
input, input_scale = quantize_fp_mm(input, dim=-1, matmul_dtype=matmul_dtype)
return input, input_scale
diff --git a/modules/sdnq/layers/linear/linear_fp8_tensorwise.py b/modules/sdnq/layers/linear/linear_fp8_tensorwise.py
index a5ea71c55..8b4954c35 100644
--- a/modules/sdnq/layers/linear/linear_fp8_tensorwise.py
+++ b/modules/sdnq/layers/linear/linear_fp8_tensorwise.py
@@ -1,6 +1,5 @@
# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
-from typing import Tuple
import torch
@@ -11,7 +10,7 @@ from ...dequantizer import quantize_fp_mm, dequantize_symmetric, dequantize_symm
from .forward import check_mats
-def quantize_fp_mm_input_tensorwise(input: torch.FloatTensor, scale: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> Tuple[torch.Tensor, torch.FloatTensor]:
+def quantize_fp_mm_input_tensorwise(input: torch.FloatTensor, scale: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]:
input = input.flatten(0,-2).to(dtype=scale.dtype)
input, input_scale = quantize_fp_mm(input, dim=-1, matmul_dtype=matmul_dtype)
scale = torch.mul(input_scale, scale)
diff --git a/modules/sdnq/layers/linear/linear_int8.py b/modules/sdnq/layers/linear/linear_int8.py
index 2d26a6086..2a1213cb8 100644
--- a/modules/sdnq/layers/linear/linear_int8.py
+++ b/modules/sdnq/layers/linear/linear_int8.py
@@ -1,6 +1,5 @@
# pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access
-from typing import Tuple
import torch
@@ -11,7 +10,7 @@ from ...dequantizer import quantize_int_mm, dequantize_symmetric, dequantize_sym
from .forward import check_mats
-def quantize_int_mm_input(input: torch.FloatTensor, scale: torch.FloatTensor) -> Tuple[torch.CharTensor, torch.FloatTensor]:
+def quantize_int_mm_input(input: torch.FloatTensor, scale: torch.FloatTensor) -> tuple[torch.CharTensor, torch.FloatTensor]:
input = input.flatten(0,-2).to(dtype=scale.dtype)
input, input_scale = quantize_int_mm(input, dim=-1)
scale = torch.mul(input_scale, scale)
diff --git a/modules/sdnq/loader.py b/modules/sdnq/loader.py
index 91be08394..789fd6e96 100644
--- a/modules/sdnq/loader.py
+++ b/modules/sdnq/loader.py
@@ -72,14 +72,14 @@ def load_sdnq_model(model_path: str, model_cls: ModelMixin = None, file_name: st
if model_config is None:
if os.path.exists(model_config_path):
- with open(model_config_path, "r", encoding="utf-8") as f:
+ with open(model_config_path, encoding="utf-8") as f:
model_config = json.load(f)
else:
model_config = {}
if quantization_config is None:
if os.path.exists(quantization_config_path):
- with open(quantization_config_path, "r", encoding="utf-8") as f:
+ with open(quantization_config_path, encoding="utf-8") as f:
quantization_config = json.load(f)
else:
quantization_config = model_config.get("quantization_config", None)
diff --git a/modules/sdnq/packed_int.py b/modules/sdnq/packed_int.py
index 09a38efbc..0cee35309 100644
--- a/modules/sdnq/packed_int.py
+++ b/modules/sdnq/packed_int.py
@@ -1,6 +1,5 @@
# pylint: disable=redefined-builtin,no-member,protected-access
-from typing import Optional
import torch
@@ -15,7 +14,7 @@ def pack_int_asymetric(tensor: torch.CharTensor, weights_dtype: str) -> torch.By
return packed_int_function_dict[weights_dtype]["pack"](tensor.to(dtype=dtype_dict[weights_dtype]["storage_dtype"]))
-def unpack_int_symetric(packed_tensor: torch.ByteTensor, shape: torch.Size, weights_dtype: str, dtype: Optional[torch.dtype] = None) -> torch.CharTensor:
+def unpack_int_symetric(packed_tensor: torch.ByteTensor, shape: torch.Size, weights_dtype: str, dtype: torch.dtype | None = None) -> torch.CharTensor:
if dtype is None:
dtype = dtype_dict[weights_dtype]["torch_dtype"]
return packed_int_function_dict[weights_dtype]["unpack"](packed_tensor, shape).to(dtype=dtype).add_(dtype_dict[weights_dtype]["min"])
diff --git a/modules/sdnq/quantizer.py b/modules/sdnq/quantizer.py
index a88da98cc..7035adc88 100644
--- a/modules/sdnq/quantizer.py
+++ b/modules/sdnq/quantizer.py
@@ -1,6 +1,6 @@
# pylint: disable=redefined-builtin,no-member,protected-access
-from typing import Dict, List, Tuple, Optional, Union
+from typing import Union
from dataclasses import dataclass
from enum import Enum
@@ -29,7 +29,7 @@ class QuantizationMethod(str, Enum):
@devices.inference_context()
-def get_scale_asymmetric(weight: torch.FloatTensor, reduction_axes: Union[int, List[int]], weights_dtype: str) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+def get_scale_asymmetric(weight: torch.FloatTensor, reduction_axes: int | list[int], weights_dtype: str) -> tuple[torch.FloatTensor, torch.FloatTensor]:
zero_point = torch.amin(weight, dim=reduction_axes, keepdims=True)
scale = torch.amax(weight, dim=reduction_axes, keepdims=True).sub_(zero_point).div_(dtype_dict[weights_dtype]["max"] - dtype_dict[weights_dtype]["min"])
if dtype_dict[weights_dtype]["min"] != 0:
@@ -38,12 +38,12 @@ def get_scale_asymmetric(weight: torch.FloatTensor, reduction_axes: Union[int, L
@devices.inference_context()
-def get_scale_symmetric(weight: torch.FloatTensor, reduction_axes: Union[int, List[int]], weights_dtype: str) -> torch.FloatTensor:
+def get_scale_symmetric(weight: torch.FloatTensor, reduction_axes: int | list[int], weights_dtype: str) -> torch.FloatTensor:
return torch.amax(weight.abs(), dim=reduction_axes, keepdims=True).div_(dtype_dict[weights_dtype]["max"])
@devices.inference_context()
-def quantize_weight(weight: torch.FloatTensor, reduction_axes: Union[int, List[int]], weights_dtype: str, dtype: torch.dtype = None, use_stochastic_rounding: bool = False) -> Tuple[torch.Tensor, torch.FloatTensor, torch.FloatTensor]:
+def quantize_weight(weight: torch.FloatTensor, reduction_axes: int | list[int], weights_dtype: str, dtype: torch.dtype = None, use_stochastic_rounding: bool = False) -> tuple[torch.Tensor, torch.FloatTensor, torch.FloatTensor]:
weight = weight.to(dtype=torch.float32)
if dtype_dict[weights_dtype]["is_unsigned"]:
@@ -73,7 +73,7 @@ def quantize_weight(weight: torch.FloatTensor, reduction_axes: Union[int, List[i
@devices.inference_context()
-def apply_svdquant(weight: torch.FloatTensor, rank: int = 32, niter: int = 8, dtype: torch.dtype = None) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
+def apply_svdquant(weight: torch.FloatTensor, rank: int = 32, niter: int = 8, dtype: torch.dtype = None) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
reshape_weight = False
if weight.ndim > 2: # convs
reshape_weight = True
@@ -102,7 +102,7 @@ def prepare_weight_for_matmul(weight: torch.Tensor) -> torch.Tensor:
@devices.inference_context()
-def prepare_svd_for_matmul(svd_up: torch.FloatTensor, svd_down: torch.FloatTensor, use_quantized_matmul: bool) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+def prepare_svd_for_matmul(svd_up: torch.FloatTensor, svd_down: torch.FloatTensor, use_quantized_matmul: bool) -> tuple[torch.FloatTensor, torch.FloatTensor]:
if svd_up is not None:
if use_quantized_matmul:
svd_up = prepare_weight_for_matmul(svd_up)
@@ -113,7 +113,7 @@ def prepare_svd_for_matmul(svd_up: torch.FloatTensor, svd_down: torch.FloatTenso
return svd_up, svd_down
-def check_param_name_in(param_name: str, param_list: List[str]) -> str:
+def check_param_name_in(param_name: str, param_list: list[str]) -> str:
split_param_name = param_name.split(".")
for param in param_list:
if param.startswith("."):
@@ -153,7 +153,7 @@ def get_quant_args_from_config(quantization_config: Union["SDNQConfig", dict]) -
return quantization_config_dict
-def get_minimum_dtype(weights_dtype: str, param_name: str, modules_dtype_dict: Dict[str, List[str]]):
+def get_minimum_dtype(weights_dtype: str, param_name: str, modules_dtype_dict: dict[str, list[str]]):
if len(modules_dtype_dict.keys()) > 0:
for key, value in modules_dtype_dict.items():
if check_param_name_in(param_name, value) is not None:
@@ -180,7 +180,7 @@ def get_minimum_dtype(weights_dtype: str, param_name: str, modules_dtype_dict: D
return weights_dtype
-def get_quant_kwargs(quant_kwargs: dict, modules_quant_config: Dict[str, dict]) -> dict:
+def get_quant_kwargs(quant_kwargs: dict, modules_quant_config: dict[str, dict]) -> dict:
param_key = check_param_name_in(quant_kwargs["param_name"], modules_quant_config.keys())
if param_key is not None:
for key, value in modules_quant_config[param_key].items():
@@ -189,7 +189,7 @@ def get_quant_kwargs(quant_kwargs: dict, modules_quant_config: Dict[str, dict])
return quant_kwargs
-def add_module_skip_keys(model, modules_to_not_convert: List[str] = None, modules_dtype_dict: Dict[str, List[str]] = None):
+def add_module_skip_keys(model, modules_to_not_convert: list[str] = None, modules_dtype_dict: dict[str, list[str]] = None):
if modules_to_not_convert is None:
modules_to_not_convert = []
if modules_dtype_dict is None:
@@ -552,7 +552,7 @@ def sdnq_quantize_layer(layer, weights_dtype="int8", quantized_matmul_dtype=None
@devices.inference_context()
-def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=False, non_blocking=False, modules_to_not_convert: List[str] = None, modules_dtype_dict: Dict[str, List[str]] = None, modules_quant_config: Dict[str, dict] = None, quantization_device=None, return_device=None, full_param_name=""): # pylint: disable=unused-argument
+def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=False, non_blocking=False, modules_to_not_convert: list[str] = None, modules_dtype_dict: dict[str, list[str]] = None, modules_quant_config: dict[str, dict] = None, quantization_device=None, return_device=None, full_param_name=""): # pylint: disable=unused-argument
has_children = list(model.children())
if not has_children:
return model, modules_to_not_convert, modules_dtype_dict
@@ -648,11 +648,11 @@ def sdnq_post_load_quant(
dequantize_fp32: bool = False,
non_blocking: bool = False,
add_skip_keys:bool = True,
- quantization_device: Optional[torch.device] = None,
- return_device: Optional[torch.device] = None,
- modules_to_not_convert: Optional[List[str]] = None,
- modules_dtype_dict: Optional[Dict[str, List[str]]] = None,
- modules_quant_config: Optional[Dict[str, dict]] = None,
+ quantization_device: torch.device | None = None,
+ return_device: torch.device | None = None,
+ modules_to_not_convert: list[str] | None = None,
+ modules_dtype_dict: dict[str, list[str]] | None = None,
+ modules_quant_config: dict[str, dict] | None = None,
):
if modules_to_not_convert is None:
modules_to_not_convert = []
@@ -733,7 +733,7 @@ def sdnq_post_load_quant(
return model
-class SDNQQuantize():
+class SDNQQuantize:
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer
@@ -887,7 +887,7 @@ class SDNQQuantizer(DiffusersQuantizer, HfQuantizer):
def get_quantize_ops(self):
return SDNQQuantize(self)
- def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
max_memory = {key: val * 0.80 for key, val in max_memory.items()}
return max_memory
@@ -908,7 +908,7 @@ class SDNQQuantizer(DiffusersQuantizer, HfQuantizer):
self,
model,
device_map, # pylint: disable=unused-argument
- keep_in_fp32_modules: List[str] = None,
+ keep_in_fp32_modules: list[str] = None,
**kwargs, # pylint: disable=unused-argument
):
if self.pre_quantized:
@@ -1067,11 +1067,11 @@ class SDNQConfig(QuantizationConfigMixin):
dequantize_fp32: bool = False,
non_blocking: bool = False,
add_skip_keys: bool = True,
- quantization_device: Optional[torch.device] = None,
- return_device: Optional[torch.device] = None,
- modules_to_not_convert: Optional[List[str]] = None,
- modules_dtype_dict: Optional[Dict[str, List[str]]] = None,
- modules_quant_config: Optional[Dict[str, dict]] = None,
+ quantization_device: torch.device | None = None,
+ return_device: torch.device | None = None,
+ modules_to_not_convert: list[str] | None = None,
+ modules_dtype_dict: dict[str, list[str]] | None = None,
+ modules_quant_config: dict[str, dict] | None = None,
is_training: bool = False,
**kwargs, # pylint: disable=unused-argument
):
diff --git a/modules/server.py b/modules/server.py
index 4b757a1d3..8f1229a73 100644
--- a/modules/server.py
+++ b/modules/server.py
@@ -41,9 +41,8 @@ class UvicornServer(uvicorn.Server):
self.start()
-class HypercornServer():
+class HypercornServer:
def __init__(self, app: fastapi.FastAPI, listen = None, port = None, keyfile = None, certfile = None, loop = "auto", http = None):
- import asyncio
import hypercorn
self.app: fastapi.FastAPI = app
self.server: HypercornServer = None
diff --git a/modules/shared.py b/modules/shared.py
index c48b46a03..fc8b71a83 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -8,17 +8,16 @@ import contextlib
from enum import Enum
from typing import TYPE_CHECKING
import gradio as gr
-from installer import log, print_dict, console, get_version # pylint: disable=unused-import
+from installer import log, print_dict # pylint: disable=unused-import
log.debug('Initializing: shared module')
import modules.memmon
import modules.paths as paths
-from modules.json_helpers import readfile, writefile # pylint: disable=W0611
-from modules.shared_helpers import listdir, walk_files, html_path, html, req, total_tqdm # pylint: disable=W0611
+from modules.json_helpers import readfile # pylint: disable=W0611
+from modules.shared_helpers import listdir, req # pylint: disable=W0611
from modules import errors, devices, shared_state, cmd_args, theme, history, files_cache
from modules.shared_defaults import get_default_modes
-from modules.paths import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611
-from modules.memstats import memory_stats, ram_stats # pylint: disable=unused-import
+from modules.memstats import memory_stats # pylint: disable=unused-import
log.debug('Initializing: pipelines')
from modules import shared_items
@@ -74,6 +73,9 @@ sdnq_quant_modes = ["int8", "int7", "int6", "uint5", "uint4", "uint3", "uint2",
sdnq_matmul_modes = ["auto", "int8", "float8_e4m3fn", "float16"]
default_hfcache_dir = os.environ.get("SD_HFCACHEDIR", None) or os.path.join(paths.models_path, 'huggingface')
state = shared_state.State()
+models_path = paths.models_path
+script_path = paths.script_path
+data_path = paths.data_path
# early select backend
@@ -120,6 +122,7 @@ def list_checkpoint_titles():
list_checkpoint_tiles = list_checkpoint_titles # alias for legacy typo
+default_sd_model_file = paths.default_sd_model_file
default_checkpoint = list_checkpoint_titles()[0] if len(list_checkpoint_titles()) > 0 else "model.safetensors"
@@ -862,7 +865,6 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", devices.device)
history = history.History()
if devices.backend == "directml":
directml_do_hijack()
-from modules import sdnq # pylint: disable=unused-import # register to diffusers and transformers
log.debug('Quantization: registered=SDNQ')
try:
diff --git a/modules/shared_state.py b/modules/shared_state.py
index ebc412265..f7a086849 100644
--- a/modules/shared_state.py
+++ b/modules/shared_state.py
@@ -148,7 +148,9 @@ class State:
return job
return None
- def history(self, op:str, task_id:str=None, results:list=[]):
+ def history(self, op:str, task_id:str=None, results:list=None):
+ if results is None:
+ results = []
job = {
'id': task_id or self.id,
'job': self.job.lower(),
diff --git a/modules/styles.py b/modules/styles.py
index 072b03390..bb555e616 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -5,14 +5,13 @@ import csv
import json
import time
import random
-from typing import Dict
from modules import files_cache, shared, infotext, sd_models, sd_vae
debug_enabled = os.environ.get('SD_STYLES_DEBUG', None) is not None
-class Style():
+class Style:
def __init__(self, name: str, desc: str = "", prompt: str = "", negative_prompt: str = "", extra: str = "", wildcards: str = "", filename: str = "", preview: str = "", mtime: float = 0):
self.name = name
self.description = desc
@@ -50,7 +49,7 @@ def select_from_weighted_list(inner: str) -> str:
return ''
parts = [p.strip() for p in inner.split('|') if p.strip()]
- weighted: Dict[str, float] = {}
+ weighted: dict[str, float] = {}
unweighted = []
for p in parts:
@@ -102,7 +101,7 @@ def select_from_weighted_list(inner: str) -> str:
if total <= 0.0:
return items[0][0]
- names, weights = zip(*items)
+ names, weights = zip(*items, strict=False)
return random.choices(names, weights=weights, k=1)[0]
@@ -130,7 +129,11 @@ def apply_curly_braces_to_prompt(prompt, seed=-1):
return prompt
-def apply_file_wildcards(prompt, replaced = [], not_found = [], recursion=0, seed=-1):
+def apply_file_wildcards(prompt, replaced = None, not_found = None, recursion=0, seed=-1):
+ if not_found is None:
+ not_found = []
+ if replaced is None:
+ replaced = []
def check_wildcard_files(prompt, wildcard, files, file_only=True):
trimmed = wildcard.replace('\\', os.path.sep).replace('/', os.path.sep).strip().lower()
for file in files:
@@ -141,7 +144,7 @@ def apply_file_wildcards(prompt, replaced = [], not_found = [], recursion=0, see
paths.insert(0, os.path.splitext(file)[0].lower())
if (trimmed in paths) or (os.path.sep in trimmed and trimmed in paths[0]):
try:
- with open(file, 'r', encoding='utf-8') as f:
+ with open(file, encoding='utf-8') as f:
lines = f.readlines()
lines = [line.split('#')[0].strip('\n').strip() for line in lines]
lines = [line for line in lines if len(line) > 0]
@@ -317,7 +320,7 @@ class StyleDatabase:
pass
def load_style(self, fn, prefix=None):
- with open(fn, 'r', encoding='utf-8') as f:
+ with open(fn, encoding='utf-8') as f:
new_style = None
try:
all_styles = json.load(f)
@@ -508,7 +511,7 @@ class StyleDatabase:
def load_csv(self, legacy_file):
if not os.path.isfile(legacy_file):
return
- with open(legacy_file, "r", encoding="utf-8-sig", newline='') as file:
+ with open(legacy_file, encoding="utf-8-sig", newline='') as file:
reader = csv.DictReader(file, skipinitialspace=True)
num = 0
for row in reader:
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index f7302ca8e..3b233cfcf 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -12,7 +12,7 @@
from functools import partial
import math
-from typing import Optional, NamedTuple, List
+from typing import NamedTuple
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
@@ -97,10 +97,10 @@ def _query_chunk_attention(
)
return summarize_chunk(query, key_chunk, value_chunk)
- chunks: List[AttnChunk] = [
+ chunks: list[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
]
- acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
+ acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks, strict=False)))
chunk_values, chunk_weights, chunk_max = acc_chunk
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
@@ -142,8 +142,8 @@ def efficient_dot_product_attention(
key: Tensor,
value: Tensor,
query_chunk_size=1024,
- kv_chunk_size: Optional[int] = None,
- kv_chunk_size_min: Optional[int] = None,
+ kv_chunk_size: int | None = None,
+ kv_chunk_size_min: int | None = None,
use_checkpoint=True,
):
"""Computes efficient dot-product attention given query, key, and value.
diff --git a/modules/taesd/hybrid_small.py b/modules/taesd/hybrid_small.py
index a59b0b4d7..8ca1135ab 100644
--- a/modules/taesd/hybrid_small.py
+++ b/modules/taesd/hybrid_small.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -74,19 +73,19 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self,
in_channels: int = 3,
out_channels: int = 3,
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
- block_out_channels: Tuple[int] = (64,),
- encoder_block_out_channels: Tuple[int] = None,
- decoder_block_out_channels: Tuple[int] = None,
+ down_block_types: tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: tuple[int] = (64,),
+ encoder_block_out_channels: tuple[int] = None,
+ decoder_block_out_channels: tuple[int] = None,
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
- latents_mean: Optional[Tuple[float]] = None,
- latents_std: Optional[Tuple[float]] = None,
+ latents_mean: tuple[float] | None = None,
+ latents_std: tuple[float] | None = None,
force_upcast: float = True,
):
super().__init__()
@@ -177,7 +176,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin):
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ def attn_processors(self) -> dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
@@ -186,7 +185,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# set recursively
processors = {}
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
@@ -201,7 +200,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
r"""
Sets the attention processor to use to compute attention.
@@ -254,7 +253,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin):
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
"""
Encode a batch of images into latents.
@@ -284,7 +283,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return AutoencoderKLOutput(latent_dist=posterior)
- def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> DecoderOutput | torch.FloatTensor:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
@@ -299,7 +298,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin):
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
- ) -> Union[DecoderOutput, torch.FloatTensor]:
+ ) -> DecoderOutput | torch.FloatTensor:
"""
Decode a batch of images.
@@ -391,7 +390,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return AutoencoderKLOutput(latent_dist=posterior)
- def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> DecoderOutput | torch.FloatTensor:
r"""
Decode a batch of images using a tiled decoder.
@@ -444,8 +443,8 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin):
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
- generator: Optional[torch.Generator] = None,
- ) -> Union[DecoderOutput, torch.FloatTensor]:
+ generator: torch.Generator | None = None,
+ ) -> DecoderOutput | torch.FloatTensor:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
diff --git a/modules/textual_inversion.py b/modules/textual_inversion.py
index 064d7d214..d01118e64 100644
--- a/modules/textual_inversion.py
+++ b/modules/textual_inversion.py
@@ -1,4 +1,3 @@
-from typing import List, Union
import os
import time
import torch
@@ -83,7 +82,7 @@ def get_text_encoders():
text_encoders = []
tokenizers = []
hidden_sizes = []
- for te, tok in zip(te_names, tokenizers_names):
+ for te, tok in zip(te_names, tokenizers_names, strict=False):
text_encoder = getattr(pipe, te, None)
if text_encoder is None:
continue
@@ -135,14 +134,14 @@ def insert_vectors(embedding, tokenizers, text_encoders, hiddensizes):
this may cause collisions.
"""
with devices.inference_context():
- for vector, size in zip(embedding.vec, embedding.vector_sizes):
+ for vector, size in zip(embedding.vec, embedding.vector_sizes, strict=False):
if size not in hiddensizes:
continue
idx = hiddensizes.index(size)
unk_token_id = tokenizers[idx].convert_tokens_to_ids(tokenizers[idx].unk_token)
if text_encoders[idx].get_input_embeddings().weight.data.shape[0] != len(tokenizers[idx]):
text_encoders[idx].resize_token_embeddings(len(tokenizers[idx]))
- for token, v in zip(embedding.tokens, vector.unbind()):
+ for token, v in zip(embedding.tokens, vector.unbind(), strict=False):
token_id = tokenizers[idx].convert_tokens_to_ids(token)
if token_id > unk_token_id:
text_encoders[idx].get_input_embeddings().weight.data[token_id] = v
@@ -254,7 +253,7 @@ class EmbeddingDatabase:
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
return embedding
- def load_diffusers_embedding(self, filename: Union[str, List[str]] = None, data: dict = None):
+ def load_diffusers_embedding(self, filename: str | list[str] = None, data: dict = None):
"""
File names take precidence over bundled embeddings passed as a dict.
Bundled embeddings are automatically set to overwrite previous embeddings.
diff --git a/modules/theme.py b/modules/theme.py
index da6fa562e..0b384ac20 100644
--- a/modules/theme.py
+++ b/modules/theme.py
@@ -18,7 +18,7 @@ def refresh_themes(no_update=False):
res = []
if os.path.exists(themes_file):
try:
- with open(themes_file, 'r', encoding='utf8') as f:
+ with open(themes_file, encoding='utf8') as f:
res = json.load(f)
except Exception:
modules.shared.log.error('Exception loading UI themes')
diff --git a/modules/todo/todo_merge.py b/modules/todo/todo_merge.py
index 77840d6ed..cde8381fe 100644
--- a/modules/todo/todo_merge.py
+++ b/modules/todo/todo_merge.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple, Callable
+from collections.abc import Callable
import math
import torch
import torch.nn.functional as F
@@ -136,7 +136,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
sy: int,
r: int,
no_rand: bool = False,
- generator: torch.Generator = None) -> Tuple[Callable, Callable]:
+ generator: torch.Generator = None) -> tuple[Callable, Callable]:
"""
Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
@@ -305,9 +305,9 @@ class TokenMergeAttentionProcessor:
self,
attn: Attention,
hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: torch.FloatTensor | None = None,
+ attention_mask: torch.FloatTensor | None = None,
+ temb: torch.FloatTensor | None = None,
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
diff --git a/modules/todo/todo_utils.py b/modules/todo/todo_utils.py
index 34a24bb82..0077b5558 100644
--- a/modules/todo/todo_utils.py
+++ b/modules/todo/todo_utils.py
@@ -29,7 +29,9 @@ def remove_tome_patch(pipe: torch.nn.Module):
if hasattr(m, "processor"):
m.processor = AttnProcessor2_0()
-def patch_attention_proc(unet, token_merge_args={}):
+def patch_attention_proc(unet, token_merge_args=None):
+ if token_merge_args is None:
+ token_merge_args = {}
unet._tome_info = { # pylint: disable=protected-access
"size": None,
"timestep": None,
diff --git a/modules/ui.py b/modules/ui.py
index 3fe74b32d..b1c0033b6 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -2,7 +2,6 @@ import gradio as gr
import gradio.routes
import gradio.utils
from modules import errors, timer, gr_hijack, shared, script_callbacks, ui_common, ui_symbols, ui_javascript, ui_sections, generation_parameters_copypaste, call_queue, scripts_manager
-from modules.paths import script_path, data_path # pylint: disable=unused-import
from modules.api import mime
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 89dd2038a..740d44385 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -38,7 +38,9 @@ def update_generation_info(generation_info, html_info, img_index):
return html_info, html_info
-def plaintext_to_html(text, elem_classes=[]):
+def plaintext_to_html(text, elem_classes=None):
+ if elem_classes is None:
+ elem_classes = []
res = f'' + '
\n'.join([f"{html.escape(x)}" for x in text.split('\n')]) + '
'
return res
diff --git a/modules/ui_control.py b/modules/ui_control.py
index 1049b5f54..2c1c74748 100644
--- a/modules/ui_control.py
+++ b/modules/ui_control.py
@@ -73,7 +73,7 @@ def return_controls(res, t: float = None):
def get_units(*values):
update = []
what = None
- for c, v in zip(controls, values):
+ for c, v in zip(controls, values, strict=False):
if isinstance(c, gr.Label): # unit type indicator
what = c.value['label']
c.value = v
diff --git a/modules/ui_docs.py b/modules/ui_docs.py
index 08306e5ac..a8e824b30 100644
--- a/modules/ui_docs.py
+++ b/modules/ui_docs.py
@@ -5,7 +5,7 @@ from modules import ui_symbols, ui_components
from installer import install, log
-class Page():
+class Page:
def __init__(self, fn, full: bool = True):
self.fn = fn
self.title = ''
@@ -21,7 +21,7 @@ class Page():
try:
self.title = ' ' + os.path.basename(self.fn).replace('.md', '').replace('-', ' ') + ' '
self.mtime = time.localtime(os.path.getmtime(self.fn))
- with open(self.fn, 'r', encoding='utf-8') as f:
+ with open(self.fn, encoding='utf-8') as f:
content = f.read()
self.size = len(content)
self.lines = [line.strip().lower() + ' ' for line in content.splitlines() if len(line)>1]
@@ -80,7 +80,7 @@ class Page():
log.error(f'Search docs: page="{self.fn}" does not exist')
return f'page="{self.fn}" does not exist'
try:
- with open(self.fn, 'r', encoding='utf-8') as f:
+ with open(self.fn, encoding='utf-8') as f:
content = f.read()
return content
except Exception as e:
@@ -91,7 +91,7 @@ class Page():
return f'Page(title="{self.title.strip()}" fn="{self.fn}" mtime={self.mtime} h1={[h.strip() for h in self.h1]} h2={len(self.h2)} h3={len(self.h3)} lines={len(self.lines)} size={self.size})'
-class Pages():
+class Pages:
def __init__(self):
self.time = time.time()
self.size = 0
@@ -117,7 +117,7 @@ class Pages():
text = text.lower()
scores = [page.search(text) for page in self.pages]
mtimes = [page.mtime for page in self.pages]
- found = sorted(zip(scores, mtimes, self.pages), key=lambda x: (x[0], x[1]), reverse=True)
+ found = sorted(zip(scores, mtimes, self.pages, strict=False), key=lambda x: (x[0], x[1]), reverse=True)
found = [item for item in found if item[0] > 0]
return [(item[0], item[2]) for item in found][:topk]
except Exception as e:
@@ -177,7 +177,7 @@ def search_docs(search_term):
def get_github_page(page):
try:
- with open(os.path.join('wiki', f'{page}.md'), 'r', encoding='utf-8') as f:
+ with open(os.path.join('wiki', f'{page}.md'), encoding='utf-8') as f:
content = f.read()
log.debug(f'Search wiki: page="{page}" size={len(content)}')
except Exception as e:
@@ -230,7 +230,7 @@ def search_github(search_term):
def create_ui_logs():
def get_changelog():
- with open('CHANGELOG.md', 'r', encoding='utf-8') as f:
+ with open('CHANGELOG.md', encoding='utf-8') as f:
content = f.read()
content = content.replace('# Change Log for SD.Next', ' ')
return content
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 68204650e..063580f51 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -391,7 +391,7 @@ class ExtraNetworksPage:
r = random.randint(100, 255)
g = random.randint(100, 255)
b = random.randint(100, 255)
- return '#{:02x}{:02x}{:02x}'.format(r, g, b) # pylint: disable=consider-using-f-string
+ return f'#{r:02x}{g:02x}{b:02x}' # pylint: disable=consider-using-f-string
try:
onclick = f'cardClicked({item.get("prompt", None)})'
@@ -515,7 +515,7 @@ class ExtraNetworksPage:
fn = os.path.splitext(path)[0] + '.txt'
if os.path.exists(fn):
try:
- with open(fn, "r", encoding="utf-8", errors="replace") as f:
+ with open(fn, encoding="utf-8", errors="replace") as f:
txt = f.read()
txt = re.sub('[<>]', '', txt)
return txt
@@ -588,7 +588,6 @@ def register_pages():
if shared.opts.diffusers_enable_embed:
from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
register_page(ExtraNetworksPageTextualInversion())
- from modules.video_models.models_def import models # pylint: disable=unused-import
def get_pages(title=None):
@@ -1044,7 +1043,7 @@ def create_ui(container, button_parent, tabname, skip_indexing = False):
params, text = get_last_args()
if (not params) or (not text) or (len(text) == 0):
if os.path.exists(paths.params_path):
- with open(paths.params_path, "r", encoding="utf8") as file:
+ with open(paths.params_path, encoding="utf8") as file:
text = file.read()
else:
text = ''
@@ -1062,7 +1061,7 @@ def create_ui(container, button_parent, tabname, skip_indexing = False):
params, text = get_last_args()
if (not params) or (not text) or (len(text) == 0):
if os.path.exists(paths.params_path):
- with open(paths.params_path, "r", encoding="utf8") as file:
+ with open(paths.params_path, encoding="utf8") as file:
text = file.read()
else:
text = ''
diff --git a/modules/ui_img2img.py b/modules/ui_img2img.py
index 5e651de91..7828bc8ab 100644
--- a/modules/ui_img2img.py
+++ b/modules/ui_img2img.py
@@ -57,7 +57,7 @@ def create_ui():
def add_copy_image_controls(tab_name, elem):
with gr.Row(variant="compact", elem_id=f"img2img_copy_{tab_name}_row"):
- for title, name in zip(['➠ Image', '➠ Inpaint', '➠ Sketch', '➠ Composite'], ['img2img', 'inpaint', 'sketch', 'composite']):
+ for title, name in zip(['➠ Image', '➠ Inpaint', '➠ Sketch', '➠ Composite'], ['img2img', 'inpaint', 'sketch', 'composite'], strict=False):
if name == tab_name:
gr.Button(title, elem_id=f'{tab_name}_copy_to_{name}', interactive=False)
copy_image_destinations[name] = elem
diff --git a/modules/ui_javascript.py b/modules/ui_javascript.py
index dcbf14731..5c5f966a8 100644
--- a/modules/ui_javascript.py
+++ b/modules/ui_javascript.py
@@ -55,7 +55,7 @@ def html_body():
def html_login():
fn = os.path.join(script_path, "javascript", "login.js")
- with open(fn, 'r', encoding='utf8') as f:
+ with open(fn, encoding='utf8') as f:
inline = f.read()
js = f'\n'
return js
@@ -110,11 +110,11 @@ def reload_javascript():
def template_response(*args, **kwargs):
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
- res.body = res.body.replace(b'', f'{title}'.encode("utf8"))
- res.body = res.body.replace(b'', f'{manifest}'.encode("utf8"))
- res.body = res.body.replace(b'', f'{login}'.encode("utf8"))
- res.body = res.body.replace(b'', f'{js}'.encode("utf8"))
- res.body = res.body.replace(b'
{title}'.encode())
+ res.body = res.body.replace(b'', f'{manifest}'.encode())
+ res.body = res.body.replace(b'', f'{login}'.encode())
+ res.body = res.body.replace(b'', f'{js}'.encode())
+ res.body = res.body.replace(b'', f'{css}{body}'.encode())
lines = res.body.decode("utf8").split('\n')
for line in lines:
if 'meta name="twitter:' in line:
diff --git a/modules/ui_models.py b/modules/ui_models.py
index def5ec3c8..34b9b9798 100644
--- a/modules/ui_models.py
+++ b/modules/ui_models.py
@@ -346,7 +346,7 @@ def create_ui():
preset = interpolate(presets, ratio)
else:
preset = presets[0]
- preset = ['%.3f' % x if int(x) != x else str(x) for x in preset] # pylint: disable=consider-using-f-string
+ preset = [f'{x:.3f}' if int(x) != x else str(x) for x in preset] # pylint: disable=consider-using-f-string
preset = [preset[0], ",".join(preset[1:13]), preset[13], ",".join(preset[14:])]
return [gr.update(value=x) for x in preset] + [gr.update(selected=2)]
@@ -498,7 +498,7 @@ def create_ui():
def civitai_download(model_urls, model_names, model_types, model_path, civit_token, model_output):
from modules.civitai.download_civitai import download_civit_model
- for model_url, model_name, model_type in zip(model_urls, model_names, model_types):
+ for model_url, model_name, model_type in zip(model_urls, model_names, model_types, strict=False):
msg = f"
"
yield msg + model_output
download_civit_model(model_url, model_name, model_path, model_type, civit_token)
diff --git a/modules/ui_models_load.py b/modules/ui_models_load.py
index 58e0a0f58..d63a63d20 100644
--- a/modules/ui_models_load.py
+++ b/modules/ui_models_load.py
@@ -1,6 +1,5 @@
import os
import re
-import json # pylint: disable=unused-import
import inspect
import gradio as gr
import torch
@@ -101,7 +100,7 @@ def process_huggingface_url(url):
return repo, subfolder, fn, download
-class Component():
+class Component:
def __init__(self, signature, name=None, cls=None, val=None, local=None, remote=None, typ=None, dtype=None, quant=False, loadable=None):
self.id = len(components) + 1
self.name = signature.name if signature else name
diff --git a/modules/ui_settings.py b/modules/ui_settings.py
index 345077ec9..247325b3f 100644
--- a/modules/ui_settings.py
+++ b/modules/ui_settings.py
@@ -119,7 +119,7 @@ def create_dirty_indicator(key, keys_to_reset, **kwargs):
def run_settings(*args):
changed = []
- for key, value, comp in zip(shared.opts.data_labels.keys(), args, components):
+ for key, value, comp in zip(shared.opts.data_labels.keys(), args, components, strict=False):
if comp == dummy_component or value=='dummy': # or getattr(comp, 'visible', True) is False or key in hidden_list:
# actual = shared.opts.data.get(key, None) # ensure the key is in data
# default = shared.opts.data_labels[key].default
@@ -173,7 +173,9 @@ def run_settings_single(value, key, progress=False):
return get_value_for_setting(key), shared.opts.dumpjson()
-def create_ui(disabled_tabs=[]):
+def create_ui(disabled_tabs=None):
+ if disabled_tabs is None:
+ disabled_tabs = []
shared.log.debug('UI initialize: tab=settings')
global text_settings # pylint: disable=global-statement
text_settings = gr.Textbox(elem_id="settings_json", elem_classes=["settings_json"], value=lambda: shared.opts.dumpjson(), visible=False)
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 0cdd7892c..10e816569 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -1,7 +1,7 @@
import os
from abc import abstractmethod
from PIL import Image
-from modules import modelloader, shared
+from modules import modelloader, shared, paths
models = None
@@ -39,14 +39,13 @@ class Upscaler:
if self.user_path is not None and len(self.user_path) > 0 and not os.path.exists(self.user_path):
shared.log.info(f'Upscaler create: folder="{self.user_path}"')
if self.model_path is None and self.name:
- self.model_path = os.path.join(shared.models_path, self.name)
+ self.model_path = os.path.join(paths.models_path, self.name)
try:
if self.model_path and create_dirs:
os.makedirs(self.model_path, exist_ok=True)
except Exception:
pass
try:
- import cv2 # pylint: disable=unused-import
self.can_tile = True
except Exception:
pass
diff --git a/modules/vae/sd_vae_fal.py b/modules/vae/sd_vae_fal.py
index bd482a779..0f0e5b3bb 100644
--- a/modules/vae/sd_vae_fal.py
+++ b/modules/vae/sd_vae_fal.py
@@ -49,17 +49,25 @@ class Flux2TinyAutoEncoder(ModelMixin, ConfigMixin):
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 128,
- encoder_block_out_channels: list[int] = [64, 64, 64, 64],
- decoder_block_out_channels: list[int] = [64, 64, 64, 64],
+ encoder_block_out_channels: list[int] = None,
+ decoder_block_out_channels: list[int] = None,
act_fn: str = "silu",
upsampling_scaling_factor: int = 2,
- num_encoder_blocks: list[int] = [1, 3, 3, 3],
- num_decoder_blocks: list[int] = [3, 3, 3, 1],
+ num_encoder_blocks: list[int] = None,
+ num_decoder_blocks: list[int] = None,
latent_magnitude: float = 3.0,
latent_shift: float = 0.5,
force_upcast: bool = False,
scaling_factor: float = 0.13025,
) -> None:
+ if num_decoder_blocks is None:
+ num_decoder_blocks = [3, 3, 3, 1]
+ if num_encoder_blocks is None:
+ num_encoder_blocks = [1, 3, 3, 3]
+ if decoder_block_out_channels is None:
+ decoder_block_out_channels = [64, 64, 64, 64]
+ if encoder_block_out_channels is None:
+ encoder_block_out_channels = [64, 64, 64, 64]
super().__init__()
self.tiny_vae = AutoencoderTiny(
in_channels=in_channels,
diff --git a/modules/vae/sd_vae_natten.py b/modules/vae/sd_vae_natten.py
index 478e9b654..246816c8d 100644
--- a/modules/vae/sd_vae_natten.py
+++ b/modules/vae/sd_vae_natten.py
@@ -1,7 +1,6 @@
# copied from https://github.com/Birch-san/sdxl-play/blob/main/src/attn/natten_attn_processor.py
import os
-from typing import Optional
from diffusers.models.attention import Attention
import torch
from torch.nn import Linear
@@ -45,9 +44,9 @@ class NattenAttnProcessor:
self,
attn: Attention,
hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.BoolTensor] = None,
- temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: torch.FloatTensor | None = None,
+ attention_mask: torch.BoolTensor | None = None,
+ temb: torch.FloatTensor | None = None,
):
import natten
assert hasattr(attn, 'qkv'), "Did not find property qkv on attn. Expected you to fuse its q_proj, k_proj, v_proj weights and biases beforehand, and multiply attn.scale into the q weights and bias."
diff --git a/modules/video_models/google_veo.py b/modules/video_models/google_veo.py
index aebc3f22f..49893b260 100644
--- a/modules/video_models/google_veo.py
+++ b/modules/video_models/google_veo.py
@@ -43,7 +43,7 @@ def get_size_buckets(width: int, height: int) -> str:
return closest_size, closest_aspect_ratio
-class GoogleVeoVideoPipeline():
+class GoogleVeoVideoPipeline:
def __init__(self, model_name: str):
self.model = model_name
self.client = None
diff --git a/modules/video_models/models_def.py b/modules/video_models/models_def.py
index 06b64fdd1..351d1c33a 100644
--- a/modules/video_models/models_def.py
+++ b/modules/video_models/models_def.py
@@ -6,7 +6,7 @@ from installer import log
@dataclass
-class Model():
+class Model:
name: str
url: str = ''
repo: str = None
diff --git a/modules/video_models/video_load.py b/modules/video_models/video_load.py
index aae853a8a..e60463368 100644
--- a/modules/video_models/video_load.py
+++ b/modules/video_models/video_load.py
@@ -2,7 +2,6 @@ import os
import sys
import copy
import time
-import transformers # pylint: disable=unused-import
import diffusers
from modules import shared, errors, sd_models, sd_checkpoint, model_quant, devices, sd_hijack_te, sd_hijack_vae
from modules.video_models import models_def, video_utils, video_overrides, video_cache
diff --git a/modules/video_models/video_save.py b/modules/video_models/video_save.py
index e137ca966..9f75df9c4 100644
--- a/modules/video_models/video_save.py
+++ b/modules/video_models/video_save.py
@@ -136,9 +136,11 @@ def atomic_save_video(filename: str,
pix_fmt:str='yuv420p',
options:str='',
aac:int=24000,
- metadata:dict={},
+ metadata:dict=None,
pbar=None,
):
+ if metadata is None:
+ metadata = {}
av = check_av()
if av is None or av is False:
shared.log.error('Video: ffmpeg/av not available')
@@ -205,9 +207,11 @@ def save_video(
mp4_interpolate:int=0, # rife interpolation
aac_sample_rate:int=24000, # audio sample rate
stream=None, # async progress reporting stream
- metadata:dict={}, # metadata for video
+ metadata:dict=None, # metadata for video
pbar=None, # progress bar for video
):
+ if metadata is None:
+ metadata = {}
output_video = None
if binary is not None:
diff --git a/modules/zluda.py b/modules/zluda.py
index 7b85eec62..258958cc7 100644
--- a/modules/zluda.py
+++ b/modules/zluda.py
@@ -1,13 +1,11 @@
import sys
-from typing import Union
-from modules.zluda_installer import core, default_agent # pylint: disable=unused-import
PLATFORM = sys.platform
do_nothing = lambda _: None # pylint: disable=unnecessary-lambda-assignment
-def test(device) -> Union[Exception, None]:
+def test(device) -> Exception | None:
import torch
device = torch.device(device)
try:
diff --git a/modules/zluda_installer.py b/modules/zluda_installer.py
index b5055b049..5326588eb 100644
--- a/modules/zluda_installer.py
+++ b/modules/zluda_installer.py
@@ -6,7 +6,6 @@ import ctypes
import shutil
import zipfile
import urllib.request
-from typing import Union
from installer import args, log
from modules import rocm
@@ -23,7 +22,7 @@ HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll', 'rocsparse.dll', 'hipfft.dll',
MIOpen_enabled = False
path = os.path.abspath(os.environ.get('ZLUDA', '.zluda'))
-default_agent: Union[rocm.Agent, None] = None
+default_agent: rocm.Agent | None = None
hipBLASLt_enabled = False
diff --git a/webui.py b/webui.py
index 399e0ff4e..e6a995533 100644
--- a/webui.py
+++ b/webui.py
@@ -90,7 +90,7 @@ def initialize():
timer.startup.record("te")
modules.modelloader.cleanup_models()
- modules.sd_models.setup_model()
+ modules.sd_checkpoint.setup_model()
timer.startup.record("models")
from modules.lora import lora_load