diff --git a/CHANGELOG.md b/CHANGELOG.md index 16541b7c2..b813f4591 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,7 +45,7 @@ TBD see [docs](https://vladmandic.github.io/sdnext-docs/Python/) for details - remove hard-dependnecies: `clip, numba, skimage, torchsde, omegaconf, antlr, patch-ng, patch-ng, astunparse, addict, inflection, jsonmerge, kornia`, - `resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets` + `resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets, imp` these are now installed on-demand when needed - refactor: to/from image/tensor logic, thanks @CalamitousFelicitousness - refactor: switch to `pyproject.toml` for tool configs @@ -57,6 +57,8 @@ TBD - refactor: unified command line parsing - refactor: launch use threads to async execute non-critical tasks - refactor: switch from deprecated `pkg_resources` to `importlib` + - refactor: modernize typing and type annotations + - refactor: improve `pydantic==2.x` compatibility - update `lint` rules, thanks @awsr - remove requirements: `clip`, `open-clip` - update `requirements` diff --git a/installer.py b/installer.py index fb44e4053..e8e2fbe72 100644 --- a/installer.py +++ b/installer.py @@ -1,4 +1,4 @@ -from typing import overload, List, Optional +from typing import overload import os import sys import json @@ -106,10 +106,12 @@ def str_to_bool(val: str | bool | None) -> bool | None: return val -def install_traceback(suppress: list = []): +def install_traceback(suppress: list = None): from rich.traceback import install as traceback_install from rich.pretty import install as pretty_install + if suppress is None: + suppress = [] width = os.environ.get("SD_TRACEWIDTH", console.width if console else None) if width is not None: width = int(width) @@ -133,7 +135,7 @@ def setup_logging(): from functools import partial, partialmethod from logging.handlers import RotatingFileHandler try: - import rich # pylint: disable=unused-import + pass # pylint: disable=unused-import except Exception: log.error('Please restart SD.Next so changes take effect') sys.exit(1) @@ -187,7 +189,7 @@ def setup_logging(): _Segment = Segment left = _Segment(" " * self.left, style) if self.left else None right = [_Segment.line()] - blank_line: Optional[List[Segment]] = None + blank_line: list[Segment] | None = None if self.top: blank_line = [_Segment(f'{" " * width}\n', style)] yield from blank_line * self.top @@ -215,8 +217,10 @@ def setup_logging(): logging.Logger.trace = partialmethod(logging.Logger.log, logging.TRACE) logging.trace = partial(logging.log, logging.TRACE) - def exception_hook(e: Exception, suppress=[]): + def exception_hook(e: Exception, suppress=None): from rich.traceback import Traceback + if suppress is None: + suppress = [] tb = Traceback.from_exception(type(e), e, e.__traceback__, show_locals=False, max_frames=16, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=console.width) # print-to-console, does not get printed-to-file exc_type, exc_value, exc_traceback = sys.exc_info() @@ -416,7 +420,7 @@ def uninstall(package, quiet = False): def run(cmd: str, arg: str): - result = subprocess.run(f'"{cmd}" {arg}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + result = subprocess.run(f'"{cmd}" {arg}', shell=True, check=False, env=os.environ, capture_output=True) txt = result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0: txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") @@ -461,7 +465,7 @@ def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True): all_args = f'{pip_log}{arg} {env_args}'.strip() if not quiet: log.debug(f'Running: {pipCmd}="{all_args}"') - result = subprocess.run(f'"{sys.executable}" -m {pipCmd} {all_args}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + result = subprocess.run(f'"{sys.executable}" -m {pipCmd} {all_args}', shell=True, check=False, env=os.environ, capture_output=True) txt = result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0: if uv and result.returncode != 0: @@ -509,7 +513,7 @@ def git(arg: str, folder: str = None, ignore: bool = False, optional: bool = Fal git_cmd = os.environ.get('GIT', "git") if git_cmd != "git": git_cmd = os.path.abspath(git_cmd) - result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.') + result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, capture_output=True, cwd=folder or '.') stdout = result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0: stdout += ('\n' if len(stdout) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") @@ -639,7 +643,11 @@ def get_platform(): # check python version -def check_python(supported_minors=[], experimental_minors=[], reason=None): +def check_python(supported_minors=None, experimental_minors=None, reason=None): + if experimental_minors is None: + experimental_minors = [] + if supported_minors is None: + supported_minors = [] if supported_minors is None or len(supported_minors) == 0: supported_minors = [10, 11, 12, 13] experimental_minors = [14] @@ -911,8 +919,6 @@ def install_torch_addons(): if 'xformers' in xformers_package: try: install(xformers_package, ignore=True, no_deps=True) - import torch # pylint: disable=unused-import - import xformers # pylint: disable=unused-import except Exception as e: log.debug(f'xFormers cannot install: {e}') elif not args.experimental and not args.use_xformers and opts.get('cross_attention_optimization', '') != 'xFormers': @@ -1126,7 +1132,7 @@ def run_extension_installer(folder): if os.environ.get('PYTHONPATH', None) is not None: seperator = ';' if sys.platform == 'win32' else ':' env['PYTHONPATH'] += seperator + os.environ.get('PYTHONPATH', None) - result = subprocess.run(f'"{sys.executable}" "{path_installer}"', shell=True, env=env, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder) + result = subprocess.run(f'"{sys.executable}" "{path_installer}"', shell=True, env=env, check=False, capture_output=True, cwd=folder) txt = result.stdout.decode(encoding="utf8", errors="ignore") debug(f'Extension installer: file="{path_installer}" {txt}') if result.returncode != 0: @@ -1265,7 +1271,7 @@ def ensure_base_requirements(): local_log = logging.getLogger('sdnext.installer') global setuptools, distutils # pylint: disable=global-statement # python may ship with incompatible setuptools - subprocess.run(f'"{sys.executable}" -m pip install setuptools=={setuptools_version}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + subprocess.run(f'"{sys.executable}" -m pip install setuptools=={setuptools_version}', shell=True, check=False, env=os.environ, capture_output=True) # need to delete all references to modules to be able to reload them otherwise python will use cached version modules = [m for m in sys.modules if m.startswith('setuptools') or m.startswith('distutils')] for m in modules: @@ -1399,7 +1405,7 @@ def install_requirements(): log.info('Install: verifying requirements') if args.new: log.debug('Install: flag=new') - with open('requirements.txt', 'r', encoding='utf8') as f: + with open('requirements.txt', encoding='utf8') as f: lines = [line.strip() for line in f.readlines() if line.strip() != '' and not line.startswith('#') and line is not None] for line in lines: if not installed(line, quiet=True): @@ -1495,20 +1501,20 @@ def get_version(force=False): t_start = time.time() if (version is None) or (version.get('branch', 'unknown') == 'unknown') or force: try: - subprocess.run('git config log.showsignature false', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True) + subprocess.run('git config log.showsignature false', capture_output=True, shell=True, check=True) except Exception: pass try: - res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True) + res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', capture_output=True, shell=True, check=True) ver = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ' ' commit, updated = ver.split(' ') version['commit'], version['updated'] = commit, updated except Exception as e: log.warning(f'Version: where=commit {e}') try: - res = subprocess.run('git remote get-url origin', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True) + res = subprocess.run('git remote get-url origin', capture_output=True, shell=True, check=True) origin = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else '' - res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True) + res = subprocess.run('git rev-parse --abbrev-ref HEAD', capture_output=True, shell=True, check=True) branch_name = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else '' version['url'] = origin.replace('\n', '').removesuffix('.git') + '/tree/' + branch_name.replace('\n', '') version['branch'] = branch_name.replace('\n', '') @@ -1520,7 +1526,7 @@ def get_version(force=False): try: if os.path.exists('extensions-builtin/sdnext-modernui'): os.chdir('extensions-builtin/sdnext-modernui') - res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True) + res = subprocess.run('git rev-parse --abbrev-ref HEAD', capture_output=True, shell=True, check=True) branch_ui = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else '' branch_ui = 'dev' if 'dev' in branch_ui else 'main' version['ui'] = branch_ui @@ -1536,7 +1542,7 @@ def get_version(force=False): version['kanvas'] = 'disabled' elif os.path.exists('extensions-builtin/sdnext-kanvas'): os.chdir('extensions-builtin/sdnext-kanvas') - res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True) + res = subprocess.run('git rev-parse --abbrev-ref HEAD', capture_output=True, shell=True, check=True) branch_kanvas = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else '' branch_kanvas = 'dev' if 'dev' in branch_kanvas else 'main' version['kanvas'] = branch_kanvas @@ -1723,7 +1729,7 @@ def check_timestamp(): ok = True setup_time = -1 version_time = -1 - with open(log_file, 'r', encoding='utf8') as f: + with open(log_file, encoding='utf8') as f: lines = f.readlines() for line in lines: if 'Setup complete without errors' in line: @@ -1752,7 +1758,6 @@ def check_timestamp(): def add_args(parser): - import argparse group_install = parser.add_argument_group('Install') group_install.add_argument('--quick', default=os.environ.get("SD_QUICK",False), action='store_true', help="Bypass version checks, default: %(default)s") group_install.add_argument('--reset', default=os.environ.get("SD_RESET",False), action='store_true', help="Reset main repository to latest version, default: %(default)s") @@ -1832,7 +1837,7 @@ def read_options(): t_start = time.time() global opts # pylint: disable=global-statement if os.path.isfile(args.config): - with open(args.config, "r", encoding="utf8") as file: + with open(args.config, encoding="utf8") as file: try: opts = json.load(file) if type(opts) is str: diff --git a/launch.py b/launch.py index e51de7510..3ba43ef5b 100755 --- a/launch.py +++ b/launch.py @@ -72,7 +72,7 @@ def get_custom_args(): rec('args') -@lru_cache() +@lru_cache def commit_hash(): # compatbility function global stored_commit_hash # pylint: disable=global-statement if stored_commit_hash is not None: @@ -85,7 +85,7 @@ def commit_hash(): # compatbility function return stored_commit_hash -@lru_cache() +@lru_cache def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compatbility function if desc is not None: installer.log.info(desc) @@ -94,7 +94,7 @@ def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compat if result.returncode != 0: raise RuntimeError(f"""{errdesc or 'Error running command'} Command: {command} Error code: {result.returncode}""") return '' - result = subprocess.run(command, stdout=subprocess.PIPE, check=False, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env) + result = subprocess.run(command, capture_output=True, check=False, shell=True, env=os.environ if custom_env is None else custom_env) if result.returncode != 0: raise RuntimeError(f"""{errdesc or 'Error running command'}: {command} code: {result.returncode} {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''} @@ -104,26 +104,26 @@ def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compat def check_run(command): # compatbility function - result = subprocess.run(command, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) + result = subprocess.run(command, check=False, capture_output=True, shell=True) return result.returncode == 0 -@lru_cache() +@lru_cache def is_installed(pkg): # compatbility function return installer.installed(pkg) -@lru_cache() +@lru_cache def repo_dir(name): # compatbility function return os.path.join(script_path, dir_repos, name) -@lru_cache() +@lru_cache def run_python(code, desc=None, errdesc=None): # compatbility function return run(f'"{sys.executable}" -c "{code}"', desc, errdesc) -@lru_cache() +@lru_cache def run_pip(pkg, desc=None): # compatbility function forbidden = ['onnxruntime', 'opencv-python'] if desc is None: @@ -136,7 +136,7 @@ def run_pip(pkg, desc=None): # compatbility function return run(f'"{sys.executable}" -m pip {pkg} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") -@lru_cache() +@lru_cache def check_run_python(code): # compatbility function return check_run(f'"{sys.executable}" -c "{code}"') diff --git a/modules/apg/pipeline_stable_cascade_prior_apg.py b/modules/apg/pipeline_stable_cascade_prior_apg.py index 0e311ad4e..6ffaf8089 100644 --- a/modules/apg/pipeline_stable_cascade_prior_apg.py +++ b/modules/apg/pipeline_stable_cascade_prior_apg.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from math import ceil -from typing import Callable, Dict, List, Optional, Union +from collections.abc import Callable import numpy as np import PIL @@ -63,11 +63,11 @@ class StableCascadePriorPipelineOutput(BaseOutput): Text embeddings for the negative prompt. """ - image_embeddings: Union[torch.Tensor, np.ndarray] - prompt_embeds: Union[torch.Tensor, np.ndarray] - prompt_embeds_pooled: Union[torch.Tensor, np.ndarray] - negative_prompt_embeds: Union[torch.Tensor, np.ndarray] - negative_prompt_embeds_pooled: Union[torch.Tensor, np.ndarray] + image_embeddings: torch.Tensor | np.ndarray + prompt_embeds: torch.Tensor | np.ndarray + prompt_embeds_pooled: torch.Tensor | np.ndarray + negative_prompt_embeds: torch.Tensor | np.ndarray + negative_prompt_embeds_pooled: torch.Tensor | np.ndarray class StableCascadePriorPipelineAPG(DiffusionPipeline): @@ -109,8 +109,8 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline): prior: StableCascadeUNet, scheduler: DDPMWuerstchenScheduler, resolution_multiple: float = 42.67, - feature_extractor: Optional[CLIPImageProcessor] = None, - image_encoder: Optional[CLIPVisionModelWithProjection] = None, + feature_extractor: CLIPImageProcessor | None = None, + image_encoder: CLIPVisionModelWithProjection | None = None, ) -> None: super().__init__() self.register_modules( @@ -151,10 +151,10 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline): do_classifier_free_guidance, prompt=None, negative_prompt=None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_pooled: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_pooled: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_pooled: torch.Tensor | None = None, ): if prompt_embeds is None: # get prompt text embeddings @@ -196,7 +196,7 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline): prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0) if negative_prompt_embeds is None and do_classifier_free_guidance: - uncond_tokens: List[str] + uncond_tokens: list[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): @@ -367,26 +367,26 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Optional[Union[str, List[str]]] = None, - images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None, + prompt: str | list[str] | None = None, + images: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image] = None, height: int = 1024, width: int = 1024, num_inference_steps: int = 20, - timesteps: List[float] = None, + timesteps: list[float] = None, guidance_scale: float = 4.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_pooled: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, - image_embeds: Optional[torch.Tensor] = None, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - output_type: Optional[str] = "pt", + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_pooled: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_pooled: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pt", return_dict: bool = True, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = None, ): """ Function invoked when calling the pipeline for generation. @@ -460,6 +460,8 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline): """ # 0. Define commonly used variables + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["latents"] device = self._execution_device dtype = next(self.prior.parameters()).dtype self._guidance_scale = guidance_scale diff --git a/modules/apg/pipeline_stable_diffision_xl_apg.py b/modules/apg/pipeline_stable_diffision_xl_apg.py index 3371877fd..b09ae242f 100644 --- a/modules/apg/pipeline_stable_diffision_xl_apg.py +++ b/modules/apg/pipeline_stable_diffision_xl_apg.py @@ -13,7 +13,8 @@ # limitations under the License. import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any +from collections.abc import Callable import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection @@ -28,7 +29,6 @@ from diffusers.utils import USE_PEFT_BACKEND, deprecate, is_invisible_watermark_ from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput -from diffusers.models.attention_processor import Attention from modules import apg if is_invisible_watermark_available(): @@ -76,10 +76,10 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, **kwargs, ): """ @@ -217,7 +217,7 @@ class StableDiffusionXLPipelineAPG( image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = True, - add_watermarker: Optional[bool] = None, + add_watermarker: bool | None = None, ): super().__init__() @@ -248,18 +248,18 @@ class StableDiffusionXLPipelineAPG( def encode_prompt( self, prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, + prompt_2: str | None = None, + device: torch.device | None = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -343,7 +343,7 @@ class StableDiffusionXLPipelineAPG( # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) @@ -396,7 +396,7 @@ class StableDiffusionXLPipelineAPG( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) - uncond_tokens: List[str] + uncond_tokens: list[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" @@ -412,7 +412,7 @@ class StableDiffusionXLPipelineAPG( uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders, strict=False): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) @@ -521,7 +521,7 @@ class StableDiffusionXLPipelineAPG( ) for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers, strict=False ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( @@ -793,42 +793,40 @@ class StableDiffusionXLPipelineAPG( @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, num_inference_steps: int = 50, - timesteps: List[int] = None, - sigmas: List[float] = None, - denoising_end: Optional[float] = None, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, guidance_scale: float = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - ip_adapter_image: Optional[PipelineImageInput] = None, - ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, - output_type: Optional[str] = "pil", + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", return_dict: bool = True, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, + cross_attention_kwargs: dict[str, Any] | None = None, guidance_rescale: float = 0.0, - original_size: Optional[Tuple[int, int]] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Optional[Tuple[int, int]] = None, - negative_original_size: Optional[Tuple[int, int]] = None, - negative_crops_coords_top_left: Tuple[int, int] = (0, 0), - negative_target_size: Optional[Tuple[int, int]] = None, - clip_skip: Optional[int] = None, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] | None = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = None, **kwargs, ): r""" @@ -976,6 +974,8 @@ class StableDiffusionXLPipelineAPG( `tuple`. When returning a tuple, the first element is a list with the generated images. """ + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["latents"] callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) diff --git a/modules/apg/pipeline_stable_diffusion_apg.py b/modules/apg/pipeline_stable_diffusion_apg.py index 6eb6bae90..ae1eb26e0 100644 --- a/modules/apg/pipeline_stable_diffusion_apg.py +++ b/modules/apg/pipeline_stable_diffusion_apg.py @@ -13,7 +13,8 @@ # limitations under the License. import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any +from collections.abc import Callable import torch from packaging import version @@ -71,10 +72,10 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def retrieve_timesteps( scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, **kwargs, ): """ @@ -273,9 +274,9 @@ class StableDiffusionPipelineAPG( num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." @@ -305,10 +306,10 @@ class StableDiffusionPipelineAPG( num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -421,7 +422,7 @@ class StableDiffusionPipelineAPG( # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] + uncond_tokens: list[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): @@ -520,7 +521,7 @@ class StableDiffusionPipelineAPG( ) for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers, strict=False ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( @@ -748,31 +749,29 @@ class StableDiffusionPipelineAPG( @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]] = None, - height: Optional[int] = None, - width: Optional[int] = None, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, num_inference_steps: int = 50, - timesteps: List[int] = None, - sigmas: List[float] = None, + timesteps: list[int] = None, + sigmas: list[float] = None, guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - ip_adapter_image: Optional[PipelineImageInput] = None, - ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, - output_type: Optional[str] = "pil", + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", return_dict: bool = True, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, + cross_attention_kwargs: dict[str, Any] | None = None, guidance_rescale: float = 0.0, - clip_skip: Optional[int] = None, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = None, **kwargs, ): r""" @@ -861,6 +860,8 @@ class StableDiffusionPipelineAPG( "not-safe-for-work" (nsfw) content. """ + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["latents"] callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) diff --git a/modules/api/api.py b/modules/api/api.py index 9ba641b23..669362d7d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,4 +1,3 @@ -from typing import List, Optional from threading import Lock from secrets import compare_digest from fastapi import FastAPI, APIRouter, Depends, Request @@ -19,7 +18,7 @@ class Api: user, password = auth.split(":") self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip() if shared.cmd_opts.auth_file: - with open(shared.cmd_opts.auth_file, 'r', encoding="utf8") as file: + with open(shared.cmd_opts.auth_file, encoding="utf8") as file: for line in file.readlines(): user, password = line.split(":") self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip() @@ -41,7 +40,7 @@ class Api: self.add_api_route("/js", server.get_js, methods=["GET"], auth=False) # server api self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str) - self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=List[str]) + self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=list[str]) self.add_api_route("/sdapi/v1/log", server.post_log, methods=["POST"]) self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"]) self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"]) @@ -56,7 +55,7 @@ class Api: self.add_api_route("/sdapi/v1/options", server.get_config, methods=["GET"], response_model=models.OptionsModel) self.add_api_route("/sdapi/v1/options", server.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/cmd-flags", server.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) - self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu_status, methods=["GET"], response_model=List[models.ResGPU]) + self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu_status, methods=["GET"], response_model=list[models.ResGPU]) # core api using locking self.add_api_route("/sdapi/v1/txt2img", self.generate.post_text2img, methods=["POST"], response_model=models.ResTxt2Img) @@ -71,21 +70,21 @@ class Api: # api dealing with optional scripts self.add_api_route("/sdapi/v1/scripts", script.get_scripts_list, methods=["GET"], response_model=models.ResScripts) - self.add_api_route("/sdapi/v1/script-info", script.get_script_info, methods=["GET"], response_model=List[models.ItemScript]) + self.add_api_route("/sdapi/v1/script-info", script.get_script_info, methods=["GET"], response_model=list[models.ItemScript]) # enumerator api - self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=List[process.ItemPreprocess]) + self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=list[process.ItemPreprocess]) self.add_api_route("/sdapi/v1/masking", self.process.get_mask, methods=["GET"], response_model=process.ItemMask) - self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=List[models.ItemSampler]) - self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=List[models.ItemUpscaler]) - self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=List[models.ItemModel]) - self.add_api_route("/sdapi/v1/controlnets", endpoints.get_controlnets, methods=["GET"], response_model=List[str]) - self.add_api_route("/sdapi/v1/detailers", endpoints.get_detailers, methods=["GET"], response_model=List[models.ItemDetailer]) - self.add_api_route("/sdapi/v1/prompt-styles", endpoints.get_prompt_styles, methods=["GET"], response_model=List[models.ItemStyle]) + self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=list[models.ItemSampler]) + self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=list[models.ItemUpscaler]) + self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=list[models.ItemModel]) + self.add_api_route("/sdapi/v1/controlnets", endpoints.get_controlnets, methods=["GET"], response_model=list[str]) + self.add_api_route("/sdapi/v1/detailers", endpoints.get_detailers, methods=["GET"], response_model=list[models.ItemDetailer]) + self.add_api_route("/sdapi/v1/prompt-styles", endpoints.get_prompt_styles, methods=["GET"], response_model=list[models.ItemStyle]) self.add_api_route("/sdapi/v1/embeddings", endpoints.get_embeddings, methods=["GET"], response_model=models.ResEmbeddings) - self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=List[models.ItemVae]) - self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=List[models.ItemExtension]) - self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=List[models.ItemExtraNetwork]) + self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=list[models.ItemVae]) + self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=list[models.ItemExtension]) + self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=list[models.ItemExtraNetwork]) # functional api self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo) @@ -96,7 +95,7 @@ class Api: self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"]) self.add_api_route("/sdapi/v1/lock-checkpoint", endpoints.post_lock_checkpoint, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"]) - self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=List[str]) + self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=list[str]) self.add_api_route("/sdapi/v1/latents", endpoints.post_latent_history, methods=["POST"], response_model=int) self.add_api_route("/sdapi/v1/modules", endpoints.get_modules, methods=["GET"]) self.add_api_route("/sdapi/v1/sampler", endpoints.get_sampler, methods=["GET"], response_model=dict) @@ -146,7 +145,7 @@ class Api: shared.log.error(f'API authentication: user="{credentials.username}"') raise HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"}) - def get_session_start(self, req: Request, agent: Optional[str] = None): + def get_session_start(self, req: Request, agent: str | None = None): token = req.cookies.get("access-token") or req.cookies.get("access-token-unsecure") user = self.app.tokens.get(token) if hasattr(self.app, 'tokens') else None shared.log.info(f'Browser session: user={user} client={req.client.host} agent={agent}') diff --git a/modules/api/caption.py b/modules/api/caption.py index ad164889b..82110da74 100644 --- a/modules/api/caption.py +++ b/modules/api/caption.py @@ -25,7 +25,7 @@ Core processing logic is shared between direct and dispatch handlers via ``do_openclip``, ``do_tagger``, and ``do_vqa`` functions to avoid duplication. """ -from typing import Optional, List, Union, Literal, Annotated +from typing import Literal, Annotated from pydantic import BaseModel, Field # pylint: disable=no-name-in-module from fastapi.exceptions import HTTPException from modules import shared @@ -49,21 +49,21 @@ class ReqCaption(BaseModel): mode: str = Field(default="best", title="Mode", description="Caption mode. 'best': Most thorough analysis, slowest but highest quality. 'fast': Quick caption with minimal flavor terms. 'classic': Standard captioning with balanced quality and speed. 'caption': BLIP caption only, no CLIP flavor matching. 'negative': Generate terms suitable for use as a negative prompt.") analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed image analysis breakdown (medium, artist, movement, trending, flavor) in addition to caption.") # Advanced settings (optional per-request overrides) - max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum number of tokens in the generated caption.") - chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up captioning but increase VRAM usage.") - min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum number of descriptive tags (flavors) to keep in the final prompt.") - max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum number of descriptive tags (flavors) to keep in the final prompt.") - flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of the intermediate candidate pool when matching image features to descriptive tags. Higher values may improve quality but are slower.") - num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Number of beams for beam search during caption generation. Higher values search more possibilities but are slower.") + max_length: int | None = Field(default=None, title="Max Length", description="Maximum number of tokens in the generated caption.") + chunk_size: int | None = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up captioning but increase VRAM usage.") + min_flavors: int | None = Field(default=None, title="Min Flavors", description="Minimum number of descriptive tags (flavors) to keep in the final prompt.") + max_flavors: int | None = Field(default=None, title="Max Flavors", description="Maximum number of descriptive tags (flavors) to keep in the final prompt.") + flavor_count: int | None = Field(default=None, title="Intermediates", description="Size of the intermediate candidate pool when matching image features to descriptive tags. Higher values may improve quality but are slower.") + num_beams: int | None = Field(default=None, title="Num Beams", description="Number of beams for beam search during caption generation. Higher values search more possibilities but are slower.") class ResCaption(BaseModel): """Response model for image captioning results.""" - caption: Optional[str] = Field(default=None, title="Caption", description="Generated caption/prompt describing the image content and style.") - medium: Optional[str] = Field(default=None, title="Medium", description="Detected artistic medium (e.g., 'oil painting', 'digital art', 'photograph'). Only returned when analyze=True.") - artist: Optional[str] = Field(default=None, title="Artist", description="Detected similar artist style (e.g., 'by greg rutkowski'). Only returned when analyze=True.") - movement: Optional[str] = Field(default=None, title="Movement", description="Detected art movement (e.g., 'art nouveau', 'impressionism'). Only returned when analyze=True.") - trending: Optional[str] = Field(default=None, title="Trending", description="Trending/platform tags (e.g., 'trending on artstation'). Only returned when analyze=True.") - flavor: Optional[str] = Field(default=None, title="Flavor", description="Additional descriptive elements (e.g., 'cinematic lighting', 'highly detailed'). Only returned when analyze=True.") + caption: str | None = Field(default=None, title="Caption", description="Generated caption/prompt describing the image content and style.") + medium: str | None = Field(default=None, title="Medium", description="Detected artistic medium (e.g., 'oil painting', 'digital art', 'photograph'). Only returned when analyze=True.") + artist: str | None = Field(default=None, title="Artist", description="Detected similar artist style (e.g., 'by greg rutkowski'). Only returned when analyze=True.") + movement: str | None = Field(default=None, title="Movement", description="Detected art movement (e.g., 'art nouveau', 'impressionism'). Only returned when analyze=True.") + trending: str | None = Field(default=None, title="Trending", description="Trending/platform tags (e.g., 'trending on artstation'). Only returned when analyze=True.") + flavor: str | None = Field(default=None, title="Flavor", description="Additional descriptive elements (e.g., 'cinematic lighting', 'highly detailed'). Only returned when analyze=True.") class ReqVQA(BaseModel): """Request model for Vision-Language Model (VLM) captioning. @@ -74,32 +74,32 @@ class ReqVQA(BaseModel): image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string containing the image data.") model: str = Field(default="Alibaba Qwen 2.5 VL 3B", title="Model", description="Select which model to use for Visual Language tasks. Use GET /sdapi/v1/vqa/models for full list. Models which support thinking mode are indicated in capabilities.") question: str = Field(default="describe the image", title="Question/Task", description="Task for the model to perform. Common tasks: 'Short Caption', 'Normal Caption', 'Long Caption'. Set to 'Use Prompt' to pass custom text via the prompt field. Florence-2 tasks: 'Object Detection', 'OCR (Read Text)', 'Phrase Grounding', 'Dense Region Caption', 'Region Proposal', 'OCR with Regions'. PromptGen tasks: 'Analyze', 'Generate Tags', 'Mixed Caption'. Moondream tasks: 'Point at...', 'Detect all...', 'Detect Gaze' (Moondream 2 only). Use GET /sdapi/v1/vqa/prompts?model= to list tasks available for a specific model.") - prompt: Optional[str] = Field(default=None, title="Prompt", description="Custom prompt text. Required when question is 'Use Prompt'. For 'Point at...' tasks, specify what to find (e.g., 'the red car'). For 'Detect all...' tasks, specify what to detect (e.g., 'faces').") + prompt: str | None = Field(default=None, title="Prompt", description="Custom prompt text. Required when question is 'Use Prompt'. For 'Point at...' tasks, specify what to find (e.g., 'the red car'). For 'Detect all...' tasks, specify what to detect (e.g., 'faces').") system: str = Field(default="You are image captioning expert, creative, unbiased and uncensored.", title="System Prompt", description="System prompt controls behavior of the LLM. Processed first and persists throughout conversation. Has highest priority weighting and is always appended at the beginning of the sequence. Use for: Response formatting rules, role definition, style.") include_annotated: bool = Field(default=False, title="Include Annotated Image", description="If True and the task produces detection results (object detection, point detection, gaze), returns annotated image with bounding boxes/points drawn. Only applicable for detection tasks on models like Florence-2 and Moondream.") # LLM generation parameters (optional overrides) - max_tokens: Optional[int] = Field(default=None, title="Max Tokens", description="Maximum number of tokens the model can generate in its response. The model is not aware of this limit during generation; it simply sets the hard limit for the length and will forcefully cut off the response when reached.") - temperature: Optional[float] = Field(default=None, title="Temperature", description="Controls randomness in token selection. Lower values (e.g., 0.1) make outputs more focused and deterministic, always choosing high-probability tokens. Higher values (e.g., 0.9) increase creativity and diversity by allowing less probable tokens. Set to 0 for fully deterministic output.") - top_k: Optional[int] = Field(default=None, title="Top-K", description="Limits token selection to the K most likely candidates at each step. Lower values (e.g., 40) make outputs more focused and predictable, while higher values allow more diverse choices. Set to 0 to disable.") - top_p: Optional[float] = Field(default=None, title="Top-P", description="Selects tokens from the smallest set whose cumulative probability exceeds P (e.g., 0.9). Dynamically adapts the number of candidates based on model confidence; fewer options when certain, more when uncertain. Set to 1 to disable.") - num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Maintains multiple candidate paths simultaneously and selects the overall best sequence. More thorough but much slower and less creative than random sampling. Generally not recommended; most modern VLMs perform better with sampling methods. Set to 1 to disable.") - do_sample: Optional[bool] = Field(default=None, title="Use Samplers", description="Enable to use sampling (randomly selecting tokens based on sampling methods like Top-K or Top-P) or disable to use greedy decoding (selecting the most probable token at each step). Enabling makes outputs more diverse and creative but less deterministic.") - thinking_mode: Optional[bool] = Field(default=None, title="Thinking Mode", description="Enables thinking/reasoning, allowing the model to take more time to generate responses. Can lead to more thoughtful and detailed answers but increases response time. Only works with models that support this feature.") - prefill: Optional[str] = Field(default=None, title="Prefill Text", description="Pre-fills the start of the model's response to guide its output format or content by forcing it to continue the prefill text. Prefill is filtered out and does not appear in the final response unless keep_prefill is True. Leave empty to let the model generate from scratch.") - keep_thinking: Optional[bool] = Field(default=None, title="Keep Thinking Trace", description="Include the model's reasoning process in the final output. Useful for understanding how the model arrived at its answer. Only works with models that support thinking mode.") - keep_prefill: Optional[bool] = Field(default=None, title="Keep Prefill", description="Include the prefill text at the beginning of the final output. If disabled, the prefill text used to guide the model is removed from the result.") + max_tokens: int | None = Field(default=None, title="Max Tokens", description="Maximum number of tokens the model can generate in its response. The model is not aware of this limit during generation; it simply sets the hard limit for the length and will forcefully cut off the response when reached.") + temperature: float | None = Field(default=None, title="Temperature", description="Controls randomness in token selection. Lower values (e.g., 0.1) make outputs more focused and deterministic, always choosing high-probability tokens. Higher values (e.g., 0.9) increase creativity and diversity by allowing less probable tokens. Set to 0 for fully deterministic output.") + top_k: int | None = Field(default=None, title="Top-K", description="Limits token selection to the K most likely candidates at each step. Lower values (e.g., 40) make outputs more focused and predictable, while higher values allow more diverse choices. Set to 0 to disable.") + top_p: float | None = Field(default=None, title="Top-P", description="Selects tokens from the smallest set whose cumulative probability exceeds P (e.g., 0.9). Dynamically adapts the number of candidates based on model confidence; fewer options when certain, more when uncertain. Set to 1 to disable.") + num_beams: int | None = Field(default=None, title="Num Beams", description="Maintains multiple candidate paths simultaneously and selects the overall best sequence. More thorough but much slower and less creative than random sampling. Generally not recommended; most modern VLMs perform better with sampling methods. Set to 1 to disable.") + do_sample: bool | None = Field(default=None, title="Use Samplers", description="Enable to use sampling (randomly selecting tokens based on sampling methods like Top-K or Top-P) or disable to use greedy decoding (selecting the most probable token at each step). Enabling makes outputs more diverse and creative but less deterministic.") + thinking_mode: bool | None = Field(default=None, title="Thinking Mode", description="Enables thinking/reasoning, allowing the model to take more time to generate responses. Can lead to more thoughtful and detailed answers but increases response time. Only works with models that support this feature.") + prefill: str | None = Field(default=None, title="Prefill Text", description="Pre-fills the start of the model's response to guide its output format or content by forcing it to continue the prefill text. Prefill is filtered out and does not appear in the final response unless keep_prefill is True. Leave empty to let the model generate from scratch.") + keep_thinking: bool | None = Field(default=None, title="Keep Thinking Trace", description="Include the model's reasoning process in the final output. Useful for understanding how the model arrived at its answer. Only works with models that support thinking mode.") + keep_prefill: bool | None = Field(default=None, title="Keep Prefill", description="Include the prefill text at the beginning of the final output. If disabled, the prefill text used to guide the model is removed from the result.") class ResVQA(BaseModel): """Response model for VLM captioning results.""" - answer: Optional[str] = Field(default=None, title="Answer", description="Generated caption, answer, or analysis from the VLM. Format depends on the question/task type.") - annotated_image: Optional[str] = Field(default=None, title="Annotated Image", description="Base64 encoded PNG image with detection results drawn (bounding boxes, points). Only returned when include_annotated=True and the task produces detection results.") + answer: str | None = Field(default=None, title="Answer", description="Generated caption, answer, or analysis from the VLM. Format depends on the question/task type.") + annotated_image: str | None = Field(default=None, title="Annotated Image", description="Base64 encoded PNG image with detection results drawn (bounding boxes, points). Only returned when include_annotated=True and the task produces detection results.") class ItemVLMModel(BaseModel): """VLM model information.""" name: str = Field(title="Name", description="Display name of the model") repo: str = Field(title="Repository", description="HuggingFace repository ID") - prompts: List[str] = Field(title="Prompts", description="Available prompts/tasks for this model") - capabilities: List[str] = Field(title="Capabilities", description="Model capabilities. Possible values: 'caption' (image captioning), 'vqa' (visual question answering), 'detection' (object/point detection), 'ocr' (text recognition), 'thinking' (reasoning mode support).") + prompts: list[str] = Field(title="Prompts", description="Available prompts/tasks for this model") + capabilities: list[str] = Field(title="Capabilities", description="Model capabilities. Possible values: 'caption' (image captioning), 'vqa' (visual question answering), 'detection' (object/point detection), 'ocr' (text recognition), 'thinking' (reasoning mode support).") class ResVLMPrompts(BaseModel): """Available VLM prompts grouped by category. @@ -107,12 +107,12 @@ class ResVLMPrompts(BaseModel): When called without ``model`` parameter, returns all prompt categories. When called with ``model``, returns only the ``available`` field with prompts for that model. """ - common: Optional[List[str]] = Field(default=None, title="Common", description="Prompts available for all models: Use Prompt, Short/Normal/Long Caption.") - florence: Optional[List[str]] = Field(default=None, title="Florence", description="Florence-2 base model tasks: Phrase Grounding, Object Detection, Dense Region Caption, Region Proposal, OCR (Read Text), OCR with Regions.") - promptgen: Optional[List[str]] = Field(default=None, title="PromptGen", description="MiaoshouAI PromptGen fine-tune tasks: Analyze, Generate Tags, Mixed Caption, Mixed Caption+. Only available on PromptGen models.") - moondream: Optional[List[str]] = Field(default=None, title="Moondream", description="Moondream 2 and 3 tasks: Point at..., Detect all...") - moondream2_only: Optional[List[str]] = Field(default=None, title="Moondream 2 Only", description="Moondream 2 exclusive tasks: Detect Gaze. Not available in Moondream 3.") - available: Optional[List[str]] = Field(default=None, title="Available", description="Populated only when filtering by model. Contains the combined list of prompts available for the specified model.") + common: list[str] | None = Field(default=None, title="Common", description="Prompts available for all models: Use Prompt, Short/Normal/Long Caption.") + florence: list[str] | None = Field(default=None, title="Florence", description="Florence-2 base model tasks: Phrase Grounding, Object Detection, Dense Region Caption, Region Proposal, OCR (Read Text), OCR with Regions.") + promptgen: list[str] | None = Field(default=None, title="PromptGen", description="MiaoshouAI PromptGen fine-tune tasks: Analyze, Generate Tags, Mixed Caption, Mixed Caption+. Only available on PromptGen models.") + moondream: list[str] | None = Field(default=None, title="Moondream", description="Moondream 2 and 3 tasks: Point at..., Detect all...") + moondream2_only: list[str] | None = Field(default=None, title="Moondream 2 Only", description="Moondream 2 exclusive tasks: Detect Gaze. Not available in Moondream 3.") + available: list[str] | None = Field(default=None, title="Available", description="Populated only when filtering by model. Contains the combined list of prompts available for the specified model.") class ItemTaggerModel(BaseModel): """Tagger model information.""" @@ -136,7 +136,7 @@ class ReqTagger(BaseModel): class ResTagger(BaseModel): """Response model for image tagging results.""" tags: str = Field(title="Tags", description="Comma-separated list of detected tags") - scores: Optional[dict] = Field(default=None, title="Scores", description="Tag confidence scores (when show_scores=True)") + scores: dict | None = Field(default=None, title="Scores", description="Tag confidence scores (when show_scores=True)") # ============================================================================= @@ -158,12 +158,12 @@ class ReqCaptionOpenCLIP(BaseModel): blip_model: str = Field(default="blip-large", title="Caption Model", description="BLIP model used to generate the initial image caption.") mode: str = Field(default="best", title="Mode", description="Caption mode: 'best' (highest quality, slowest), 'fast' (quick, fewer flavors), 'classic' (balanced), 'caption' (BLIP only, no CLIP matching), 'negative' (for negative prompts).") analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed breakdown (medium, artist, movement, trending, flavor).") - max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum tokens in generated caption.") - chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing flavors.") - min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum descriptive tags to keep.") - max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum descriptive tags to keep.") - flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of intermediate candidate pool.") - num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Beams for beam search during caption generation.") + max_length: int | None = Field(default=None, title="Max Length", description="Maximum tokens in generated caption.") + chunk_size: int | None = Field(default=None, title="Chunk Size", description="Batch size for processing flavors.") + min_flavors: int | None = Field(default=None, title="Min Flavors", description="Minimum descriptive tags to keep.") + max_flavors: int | None = Field(default=None, title="Max Flavors", description="Maximum descriptive tags to keep.") + flavor_count: int | None = Field(default=None, title="Intermediates", description="Size of intermediate candidate pool.") + num_beams: int | None = Field(default=None, title="Num Beams", description="Beams for beam search during caption generation.") class ReqCaptionTagger(BaseModel): @@ -196,24 +196,24 @@ class ReqCaptionVLM(BaseModel): image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string.") model: str = Field(default="Alibaba Qwen 2.5 VL 3B", title="Model", description="VLM model to use. See GET /sdapi/v1/vqa/models for full list.") question: str = Field(default="describe the image", title="Question/Task", description="Task to perform: 'Short Caption', 'Normal Caption', 'Long Caption', 'Use Prompt' (custom text via prompt field). Model-specific tasks available via GET /sdapi/v1/vqa/prompts.") - prompt: Optional[str] = Field(default=None, title="Prompt", description="Custom prompt text when question is 'Use Prompt'.") + prompt: str | None = Field(default=None, title="Prompt", description="Custom prompt text when question is 'Use Prompt'.") system: str = Field(default="You are image captioning expert, creative, unbiased and uncensored.", title="System Prompt", description="System prompt for LLM behavior.") include_annotated: bool = Field(default=False, title="Include Annotated Image", description="Return annotated image for detection tasks.") - max_tokens: Optional[int] = Field(default=None, title="Max Tokens", description="Maximum tokens in response.") - temperature: Optional[float] = Field(default=None, title="Temperature", description="Randomness in token selection (0=deterministic, 0.9=creative).") - top_k: Optional[int] = Field(default=None, title="Top-K", description="Limit to K most likely tokens per step.") - top_p: Optional[float] = Field(default=None, title="Top-P", description="Nucleus sampling threshold.") - num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Beam search width (1=disabled).") - do_sample: Optional[bool] = Field(default=None, title="Use Samplers", description="Enable sampling vs greedy decoding.") - thinking_mode: Optional[bool] = Field(default=None, title="Thinking Mode", description="Enable reasoning mode (supported models only).") - prefill: Optional[str] = Field(default=None, title="Prefill Text", description="Pre-fill response start to guide output.") - keep_thinking: Optional[bool] = Field(default=None, title="Keep Thinking Trace", description="Include reasoning in output.") - keep_prefill: Optional[bool] = Field(default=None, title="Keep Prefill", description="Keep prefill text in final output.") + max_tokens: int | None = Field(default=None, title="Max Tokens", description="Maximum tokens in response.") + temperature: float | None = Field(default=None, title="Temperature", description="Randomness in token selection (0=deterministic, 0.9=creative).") + top_k: int | None = Field(default=None, title="Top-K", description="Limit to K most likely tokens per step.") + top_p: float | None = Field(default=None, title="Top-P", description="Nucleus sampling threshold.") + num_beams: int | None = Field(default=None, title="Num Beams", description="Beam search width (1=disabled).") + do_sample: bool | None = Field(default=None, title="Use Samplers", description="Enable sampling vs greedy decoding.") + thinking_mode: bool | None = Field(default=None, title="Thinking Mode", description="Enable reasoning mode (supported models only).") + prefill: str | None = Field(default=None, title="Prefill Text", description="Pre-fill response start to guide output.") + keep_thinking: bool | None = Field(default=None, title="Keep Thinking Trace", description="Include reasoning in output.") + keep_prefill: bool | None = Field(default=None, title="Keep Prefill", description="Keep prefill text in final output.") # Discriminated union for the dispatch endpoint ReqCaptionDispatch = Annotated[ - Union[ReqCaptionOpenCLIP, ReqCaptionTagger, ReqCaptionVLM], + ReqCaptionOpenCLIP | ReqCaptionTagger | ReqCaptionVLM, Field(discriminator="backend") ] @@ -226,18 +226,18 @@ class ResCaptionDispatch(BaseModel): # Common backend: str = Field(title="Backend", description="The backend that processed the request: 'openclip', 'tagger', or 'vlm'.") # OpenCLIP fields - caption: Optional[str] = Field(default=None, title="Caption", description="Generated caption (OpenCLIP backend).") - medium: Optional[str] = Field(default=None, title="Medium", description="Detected artistic medium (OpenCLIP with analyze=True).") - artist: Optional[str] = Field(default=None, title="Artist", description="Detected artist style (OpenCLIP with analyze=True).") - movement: Optional[str] = Field(default=None, title="Movement", description="Detected art movement (OpenCLIP with analyze=True).") - trending: Optional[str] = Field(default=None, title="Trending", description="Trending tags (OpenCLIP with analyze=True).") - flavor: Optional[str] = Field(default=None, title="Flavor", description="Flavor descriptors (OpenCLIP with analyze=True).") + caption: str | None = Field(default=None, title="Caption", description="Generated caption (OpenCLIP backend).") + medium: str | None = Field(default=None, title="Medium", description="Detected artistic medium (OpenCLIP with analyze=True).") + artist: str | None = Field(default=None, title="Artist", description="Detected artist style (OpenCLIP with analyze=True).") + movement: str | None = Field(default=None, title="Movement", description="Detected art movement (OpenCLIP with analyze=True).") + trending: str | None = Field(default=None, title="Trending", description="Trending tags (OpenCLIP with analyze=True).") + flavor: str | None = Field(default=None, title="Flavor", description="Flavor descriptors (OpenCLIP with analyze=True).") # Tagger fields - tags: Optional[str] = Field(default=None, title="Tags", description="Comma-separated tags (Tagger backend).") - scores: Optional[dict] = Field(default=None, title="Scores", description="Tag confidence scores (Tagger with show_scores=True).") + tags: str | None = Field(default=None, title="Tags", description="Comma-separated tags (Tagger backend).") + scores: dict | None = Field(default=None, title="Scores", description="Tag confidence scores (Tagger with show_scores=True).") # VLM fields - answer: Optional[str] = Field(default=None, title="Answer", description="VLM response (VLM backend).") - annotated_image: Optional[str] = Field(default=None, title="Annotated Image", description="Base64 annotated image (VLM with include_annotated=True).") + answer: str | None = Field(default=None, title="Answer", description="VLM response (VLM backend).") + annotated_image: str | None = Field(default=None, title="Annotated Image", description="Base64 annotated image (VLM with include_annotated=True).") # ============================================================================= @@ -596,7 +596,7 @@ def get_vqa_models(): return models_list -def get_vqa_prompts(model: Optional[str] = None): +def get_vqa_prompts(model: str | None = None): """ List available prompts/tasks for VLM models. @@ -653,11 +653,11 @@ def get_tagger_models(): def register_api(): from modules.shared import api - api.add_api_route("/sdapi/v1/openclip", get_caption, methods=["GET"], response_model=List[str], tags=["Caption"]) + api.add_api_route("/sdapi/v1/openclip", get_caption, methods=["GET"], response_model=list[str], tags=["Caption"]) api.add_api_route("/sdapi/v1/caption", post_caption_dispatch, methods=["POST"], response_model=ResCaptionDispatch, tags=["Caption"]) api.add_api_route("/sdapi/v1/openclip", post_caption, methods=["POST"], response_model=ResCaption, tags=["Caption"]) api.add_api_route("/sdapi/v1/vqa", post_vqa, methods=["POST"], response_model=ResVQA, tags=["Caption"]) - api.add_api_route("/sdapi/v1/vqa/models", get_vqa_models, methods=["GET"], response_model=List[ItemVLMModel], tags=["Caption"]) + api.add_api_route("/sdapi/v1/vqa/models", get_vqa_models, methods=["GET"], response_model=list[ItemVLMModel], tags=["Caption"]) api.add_api_route("/sdapi/v1/vqa/prompts", get_vqa_prompts, methods=["GET"], response_model=ResVLMPrompts, tags=["Caption"]) api.add_api_route("/sdapi/v1/tagger", post_tagger, methods=["POST"], response_model=ResTagger, tags=["Caption"]) - api.add_api_route("/sdapi/v1/tagger/models", get_tagger_models, methods=["GET"], response_model=List[ItemTaggerModel], tags=["Caption"]) + api.add_api_route("/sdapi/v1/tagger/models", get_tagger_models, methods=["GET"], response_model=list[ItemTaggerModel], tags=["Caption"]) diff --git a/modules/api/control.py b/modules/api/control.py index 63b7636dd..5f5882e04 100644 --- a/modules/api/control.py +++ b/modules/api/control.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional from threading import Lock from pydantic import BaseModel, Field # pylint: disable=no-name-in-module from modules import errors, shared, processing_helpers @@ -43,9 +43,9 @@ ReqControl = models.create_model_from_signature( {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, - {"key": "ip_adapter", "type": Optional[List[models.ItemIPAdapter]], "default": None, "exclude": True}, + {"key": "ip_adapter", "type": Optional[list[models.ItemIPAdapter]], "default": None, "exclude": True}, {"key": "face", "type": Optional[models.ItemFace], "default": None, "exclude": True}, - {"key": "control", "type": Optional[List[ItemControl]], "default": [], "exclude": True}, + {"key": "control", "type": Optional[list[ItemControl]], "default": [], "exclude": True}, {"key": "xyz", "type": Optional[ItemXYZ], "default": None, "exclude": True}, # {"key": "extra", "type": Optional[dict], "default": {}, "exclude": True}, ] @@ -55,13 +55,13 @@ if not hasattr(ReqControl, "__config__"): class ResControl(BaseModel): - images: List[str] = Field(default=None, title="Images", description="") - processed: List[str] = Field(default=None, title="Processed", description="") + images: list[str] = Field(default=None, title="Images", description="") + processed: list[str] = Field(default=None, title="Processed", description="") params: dict = Field(default={}, title="Settings", description="") info: str = Field(default="", title="Info", description="") -class APIControl(): +class APIControl: def __init__(self, queue_lock: Lock): self.queue_lock = queue_lock self.default_script_arg = [] diff --git a/modules/api/endpoints.py b/modules/api/endpoints.py index 0ad141b48..543756906 100644 --- a/modules/api/endpoints.py +++ b/modules/api/endpoints.py @@ -1,4 +1,3 @@ -from typing import Optional from modules import shared from modules.api import models, helpers @@ -43,7 +42,7 @@ def get_sd_models(): checkpoints.append(model) return checkpoints -def get_controlnets(model_type: Optional[str] = None): +def get_controlnets(model_type: str | None = None): from modules.control.units.controlnet import api_list_models return api_list_models(model_type) @@ -60,7 +59,7 @@ def get_embeddings(): return models.ResEmbeddings(loaded=[], skipped=[]) return models.ResEmbeddings(loaded=list(db.word_embeddings.keys()), skipped=list(db.skipped_embeddings.keys())) -def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, filename: Optional[str] = None, title: Optional[str] = None, fullname: Optional[str] = None, hash: Optional[str] = None): # pylint: disable=redefined-builtin +def get_extra_networks(page: str | None = None, name: str | None = None, filename: str | None = None, title: str | None = None, fullname: str | None = None, hash: str | None = None): # pylint: disable=redefined-builtin res = [] for pg in shared.extra_networks: if page is not None and pg.name != page.lower(): diff --git a/modules/api/gallery.py b/modules/api/gallery.py index e4dc8ba0b..ac6cf93e1 100644 --- a/modules/api/gallery.py +++ b/modules/api/gallery.py @@ -2,7 +2,6 @@ import io import os import time import base64 -from typing import List, Union from urllib.parse import quote, unquote from fastapi import FastAPI from fastapi.responses import JSONResponse @@ -52,7 +51,7 @@ class ConnectionManager: debug(f'Browser WS disconnect: client={ws.client.host}') self.active.remove(ws) - async def send(self, ws: WebSocket, data: Union[str, dict, bytes]): + async def send(self, ws: WebSocket, data: str | dict | bytes): # debug(f'Browser WS send: client={ws.client.host} data={type(data)}') if ws.client_state != WebSocketState.CONNECTED: return @@ -65,7 +64,7 @@ class ConnectionManager: else: debug(f'Browser WS send: client={ws.client.host} data={type(data)} unknown') - async def broadcast(self, data: Union[str, dict, bytes]): + async def broadcast(self, data: str | dict | bytes): for ws in self.active: await self.send(ws, data) @@ -206,7 +205,7 @@ def register_api(app: FastAPI): # register api shared.log.error(f'Gallery: {folder} {e}') return [] - shared.api.add_api_route("/sdapi/v1/browser/folders", get_folders, methods=["GET"], response_model=List[str]) + shared.api.add_api_route("/sdapi/v1/browser/folders", get_folders, methods=["GET"], response_model=list[str]) shared.api.add_api_route("/sdapi/v1/browser/thumb", get_thumb, methods=["GET"], response_model=dict) shared.api.add_api_route("/sdapi/v1/browser/files", ht_files, methods=["GET"], response_model=list) diff --git a/modules/api/generate.py b/modules/api/generate.py index 102b15f2c..9fd1d2bf3 100644 --- a/modules/api/generate.py +++ b/modules/api/generate.py @@ -9,7 +9,7 @@ from modules.paths import resolve_output_path errors.install() -class APIGenerate(): +class APIGenerate: def __init__(self, queue_lock: Lock): self.queue_lock = queue_lock self.default_script_arg_txt2img = [] diff --git a/modules/api/loras.py b/modules/api/loras.py index c192ec62d..8a9d7a594 100644 --- a/modules/api/loras.py +++ b/modules/api/loras.py @@ -1,4 +1,3 @@ -from typing import List from fastapi.exceptions import HTTPException @@ -25,5 +24,5 @@ def post_refresh_loras(): def register_api(): from modules.shared import api api.add_api_route("/sdapi/v1/lora", get_lora, methods=["GET"], response_model=dict) - api.add_api_route("/sdapi/v1/loras", get_loras, methods=["GET"], response_model=List[dict]) + api.add_api_route("/sdapi/v1/loras", get_loras, methods=["GET"], response_model=list[dict]) api.add_api_route("/sdapi/v1/refresh-loras", post_refresh_loras, methods=["POST"]) diff --git a/modules/api/models.py b/modules/api/models.py index b1425d267..7305fc7bf 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,7 +1,15 @@ import re import inspect -from typing import Any, Optional, Dict, List, Type, Callable, Union -from pydantic import BaseModel, Field, create_model # pylint: disable=no-name-in-module +from typing import Any, Optional, Union +from collections.abc import Callable +import pydantic +from pydantic import BaseModel, Field, create_model +try: + from pydantic import ConfigDict + PYDANTIC_V2 = True +except ImportError: + ConfigDict = None + PYDANTIC_V2 = False from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img import modules.shared as shared @@ -41,8 +49,10 @@ class PydanticModelGenerator: model_name: str = None, class_instance = None, additional_fields = None, - exclude_fields: List = [], + exclude_fields: list = None, ): + if exclude_fields is None: + exclude_fields = [] def field_type_generator(_k, v): field_type = v.annotation return Optional[field_type] @@ -80,12 +90,15 @@ class PydanticModelGenerator: def generate_model(self): model_fields = { d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def } - DynamicModel = create_model(self._model_name, **model_fields) - try: - DynamicModel.__config__.allow_population_by_field_name = True - DynamicModel.__config__.allow_mutation = True - except Exception: - pass + if PYDANTIC_V2: + config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True) + else: + class Config: + arbitrary_types_allowed = True + orm_mode = True + allow_population_by_field_name = True + config = Config + DynamicModel = create_model(self._model_name, __config__=config, **model_fields) return DynamicModel ### item classes @@ -100,49 +113,49 @@ class ItemVae(BaseModel): class ItemUpscaler(BaseModel): name: str = Field(title="Name") - model_name: Optional[str] = Field(title="Model Name") - model_path: Optional[str] = Field(title="Path") - model_url: Optional[str] = Field(title="URL") - scale: Optional[float] = Field(title="Scale") + model_name: str | None = Field(title="Model Name") + model_path: str | None = Field(title="Path") + model_url: str | None = Field(title="URL") + scale: float | None = Field(title="Scale") class ItemModel(BaseModel): title: str = Field(title="Title") model_name: str = Field(title="Model Name") filename: str = Field(title="Filename") type: str = Field(title="Model type") - sha256: Optional[str] = Field(title="SHA256 hash") - hash: Optional[str] = Field(title="Short hash") - config: Optional[str] = Field(title="Config file") + sha256: str | None = Field(title="SHA256 hash") + hash: str | None = Field(title="Short hash") + config: str | None = Field(title="Config file") class ItemHypernetwork(BaseModel): name: str = Field(title="Name") - path: Optional[str] = Field(title="Path") + path: str | None = Field(title="Path") class ItemDetailer(BaseModel): name: str = Field(title="Name") - path: Optional[str] = Field(title="Path") + path: str | None = Field(title="Path") class ItemGAN(BaseModel): name: str = Field(title="Name") - path: Optional[str] = Field(title="Path") - scale: Optional[int] = Field(title="Scale") + path: str | None = Field(title="Path") + scale: int | None = Field(title="Scale") class ItemStyle(BaseModel): name: str = Field(title="Name") - prompt: Optional[str] = Field(title="Prompt") - negative_prompt: Optional[str] = Field(title="Negative Prompt") - extra: Optional[str] = Field(title="Extra") - filename: Optional[str] = Field(title="Filename") - preview: Optional[str] = Field(title="Preview") + prompt: str | None = Field(title="Prompt") + negative_prompt: str | None = Field(title="Negative Prompt") + extra: str | None = Field(title="Extra") + filename: str | None = Field(title="Filename") + preview: str | None = Field(title="Preview") class ItemExtraNetwork(BaseModel): name: str = Field(title="Name") type: str = Field(title="Type") - title: Optional[str] = Field(title="Title") - fullname: Optional[str] = Field(title="Fullname") - filename: Optional[str] = Field(title="Filename") - hash: Optional[str] = Field(title="Hash") - preview: Optional[str] = Field(title="Preview image URL") + title: str | None = Field(title="Title") + fullname: str | None = Field(title="Fullname") + filename: str | None = Field(title="Filename") + hash: str | None = Field(title="Hash") + preview: str | None = Field(title="Preview image URL") class ItemArtist(BaseModel): name: str = Field(title="Name") @@ -150,16 +163,16 @@ class ItemArtist(BaseModel): category: str = Field(title="Category") class ItemEmbedding(BaseModel): - step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available") - sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available") - sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead") + step: int | None = Field(title="Step", description="The number of steps that were used to train this embedding, if available") + sd_checkpoint: str | None = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available") + sd_checkpoint_name: str | None = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead") shape: int = Field(title="Shape", description="The length of each individual vector in the embedding") vectors: int = Field(title="Vectors", description="The number of vectors in the embedding") class ItemIPAdapter(BaseModel): adapter: str = Field(title="Adapter", default="Base", description="IP adapter name") - images: List[str] = Field(title="Image", default=[], description="IP adapter input images") - masks: Optional[List[str]] = Field(title="Mask", default=[], description="IP adapter mask images") + images: list[str] = Field(title="Image", default=[], description="IP adapter input images") + masks: list[str] | None = Field(title="Mask", default=[], description="IP adapter mask images") scale: float = Field(title="Scale", default=0.5, ge=0, le=1, description="IP adapter scale") start: float = Field(title="Start", default=0.0, ge=0, le=1, description="IP adapter start step") end: float = Field(title="End", default=1.0, gt=0, le=1, description="IP adapter end step") @@ -183,17 +196,17 @@ class ItemFace(BaseModel): class ScriptArg(BaseModel): label: str = Field(default=None, title="Label", description="Name of the argument in UI") - value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument") - minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI") - maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI") - step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI") - choices: Optional[Any] = Field(default=None, title="Choices", description="Possible values for the argument") + value: Any | None = Field(default=None, title="Value", description="Default value of the argument") + minimum: Any | None = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI") + maximum: Any | None = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI") + step: Any | None = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI") + choices: Any | None = Field(default=None, title="Choices", description="Possible values for the argument") class ItemScript(BaseModel): name: str = Field(default=None, title="Name", description="Script name") is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script") is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script") - args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments") + args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments") class ItemExtension(BaseModel): name: str = Field(title="Name", description="Extension name") @@ -201,13 +214,13 @@ class ItemExtension(BaseModel): branch: str = Field(default="uknnown", title="Branch", description="Extension Repository Branch") commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash") version: str = Field(title="Version", description="Extension Version") - commit_date: Union[str, int] = Field(title="Commit Date", description="Extension Repository Commit Date") + commit_date: str | int = Field(title="Commit Date", description="Extension Repository Commit Date") enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled") class ItemScheduler(BaseModel): name: str = Field(title="Name", description="Scheduler name") cls: str = Field(title="Class", description="Scheduler class name") - options: Dict[str, Any] = Field(title="Options", description="Dictionary of scheduler options") + options: dict[str, Any] = Field(title="Options", description="Dictionary of scheduler options") ### request/response classes @@ -223,7 +236,7 @@ ReqTxt2Img = PydanticModelGenerator( {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, - {"key": "ip_adapter", "type": Optional[List[ItemIPAdapter]], "default": None, "exclude": True}, + {"key": "ip_adapter", "type": Optional[list[ItemIPAdapter]], "default": None, "exclude": True}, {"key": "face", "type": Optional[ItemFace], "default": None, "exclude": True}, {"key": "extra", "type": Optional[dict], "default": {}, "exclude": True}, ] @@ -233,7 +246,7 @@ if not hasattr(ReqTxt2Img, "__config__"): StableDiffusionTxt2ImgProcessingAPI = ReqTxt2Img class ResTxt2Img(BaseModel): - images: List[str] = Field(default=None, title="Image", description="The generated images in base64 format.") + images: list[str] = Field(default=None, title="Image", description="The generated images in base64 format.") parameters: dict info: str @@ -253,7 +266,7 @@ ReqImg2Img = PydanticModelGenerator( {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, - {"key": "ip_adapter", "type": Optional[List[ItemIPAdapter]], "default": None, "exclude": True}, + {"key": "ip_adapter", "type": Optional[list[ItemIPAdapter]], "default": None, "exclude": True}, {"key": "face_id", "type": Optional[ItemFace], "default": None, "exclude": True}, {"key": "extra", "type": Optional[dict], "default": {}, "exclude": True}, ] @@ -263,7 +276,7 @@ if not hasattr(ReqImg2Img, "__config__"): StableDiffusionImg2ImgProcessingAPI = ReqImg2Img class ResImg2Img(BaseModel): - images: List[str] = Field(default=None, title="Image", description="The generated images in base64 format.") + images: list[str] = Field(default=None, title="Image", description="The generated images in base64 format.") parameters: dict info: str @@ -289,9 +302,9 @@ class ResProcess(BaseModel): class ReqPromptEnhance(BaseModel): prompt: str = Field(title="Prompt", description="Prompt to enhance") type: str = Field(title="Type", default='text', description="Type of enhancement to perform") - model: Optional[str] = Field(title="Model", default=None, description="Model to use for enhancement") - system_prompt: Optional[str] = Field(title="System prompt", default=None, description="Model system prompt") - image: Optional[str] = Field(title="Image", default=None, description="Image to work on, must be a Base64 string containing the image's data.") + model: str | None = Field(title="Model", default=None, description="Model to use for enhancement") + system_prompt: str | None = Field(title="System prompt", default=None, description="Model system prompt") + image: str | None = Field(title="Image", default=None, description="Image to work on, must be a Base64 string containing the image's data.") seed: int = Field(title="Seed", default=-1, description="Seed used to generate the prompt") nsfw: bool = Field(title="NSFW", default=True, description="Should NSFW content be allowed?") @@ -306,10 +319,10 @@ class ResProcessImage(ResProcess): image: str = Field(default=None, title="Image", description="The generated image in base64 format.") class ReqProcessBatch(ReqProcess): - imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") + imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ResProcessBatch(ResProcess): - images: List[str] = Field(title="Images", description="The generated images in base64 format.") + images: list[str] = Field(title="Images", description="The generated images in base64 format.") class ReqImageInfo(BaseModel): image: str = Field(title="Image", description="The base64 encoded image") @@ -325,38 +338,38 @@ class ReqGetLog(BaseModel): class ReqPostLog(BaseModel): - message: Optional[str] = Field(default=None, title="Message", description="The info message to log") - debug: Optional[str] = Field(default=None, title="Debug message", description="The debug message to log") - error: Optional[str] = Field(default=None, title="Error message", description="The error message to log") + message: str | None = Field(default=None, title="Message", description="The info message to log") + debug: str | None = Field(default=None, title="Debug message", description="The debug message to log") + error: str | None = Field(default=None, title="Error message", description="The error message to log") class ReqHistory(BaseModel): - id: Union[int, str, None] = Field(default=None, title="Task ID", description="Task ID") + id: int | str | None = Field(default=None, title="Task ID", description="Task ID") class ReqProgress(BaseModel): skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization") class ResProgress(BaseModel): - id: Union[int, str, None] = Field(title="TaskID", description="Task ID") + id: int | str | None = Field(title="TaskID", description="Task ID") progress: float = Field(title="Progress", description="The progress with a range of 0 to 1") eta_relative: float = Field(title="ETA in secs") state: dict = Field(title="State", description="The current state snapshot") - current_image: Optional[str] = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") - textinfo: Optional[str] = Field(default=None, title="Info text", description="Info text used by WebUI.") + current_image: str | None = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") + textinfo: str | None = Field(default=None, title="Info text", description="Info text used by WebUI.") class ResHistory(BaseModel): - id: Union[int, str, None] = Field(title="ID", description="Task ID") + id: int | str | None = Field(title="ID", description="Task ID") job: str = Field(title="Job", description="Job name") op: str = Field(title="Operation", description="Job state") - timestamp: Union[float, None] = Field(title="Timestamp", description="Job timestamp") - duration: Union[float, None] = Field(title="Duration", description="Job duration") - outputs: List[str] = Field(title="Outputs", description="List of filenames") + timestamp: float | None = Field(title="Timestamp", description="Job timestamp") + duration: float | None = Field(title="Duration", description="Job duration") + outputs: list[str] = Field(title="Outputs", description="List of filenames") class ResStatus(BaseModel): status: str = Field(title="Status", description="Current status") task: str = Field(title="Task", description="Current job") - timestamp: Optional[str] = Field(title="Timestamp", description="Timestamp of the current job") + timestamp: str | None = Field(title="Timestamp", description="Timestamp of the current job") current: str = Field(title="Task", description="Current job") - id: Union[int, str, None] = Field(title="ID", description="ID of the current task") + id: int | str | None = Field(title="ID", description="ID of the current task") job: int = Field(title="Job", description="Current job") jobs: int = Field(title="Jobs", description="Total jobs") total: int = Field(title="Total Jobs", description="Total jobs") @@ -364,9 +377,9 @@ class ResStatus(BaseModel): steps: int = Field(title="Steps", description="Total steps") queued: int = Field(title="Queued", description="Number of queued tasks") uptime: int = Field(title="Uptime", description="Uptime of the server") - elapsed: Optional[float] = Field(default=None, title="Elapsed time") - eta: Optional[float] = Field(default=None, title="ETA in secs") - progress: Optional[float] = Field(default=None, title="Progress", description="The progress with a range of 0 to 1") + elapsed: float | None = Field(default=None, title="Elapsed time") + eta: float | None = Field(default=None, title="ETA in secs") + progress: float | None = Field(default=None, title="Progress", description="The progress with a range of 0 to 1") class ReqLatentHistory(BaseModel): name: str = Field(title="Name", description="Name of the history item to select") @@ -392,7 +405,15 @@ for key, metadata in shared.opts.data_labels.items(): else: fields.update({key: (Optional[optType], Field())}) -OptionsModel = create_model("Options", **fields) +if PYDANTIC_V2: + config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True) +else: + class Config: + arbitrary_types_allowed = True + orm_mode = True + allow_population_by_field_name = True + config = Config +OptionsModel = create_model("Options", __config__=config, **fields) flags = {} _options = vars(shared.parser)['_option_string_actions'] @@ -404,7 +425,15 @@ for key in _options: _type = type(_options[key].default) flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))}) -FlagsModel = create_model("Flags", **flags) +if PYDANTIC_V2: + config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True) +else: + class Config: + arbitrary_types_allowed = True + orm_mode = True + allow_population_by_field_name = True + config = Config +FlagsModel = create_model("Flags", __config__=config, **flags) class ResEmbeddings(BaseModel): loaded: list = Field(default=None, title="loaded", description="List of loaded embeddings") @@ -426,9 +455,13 @@ class ResGPU(BaseModel): # definition of http response # helper function -def create_model_from_signature(func: Callable, model_name: str, base_model: Type[BaseModel] = BaseModel, additional_fields: List = [], exclude_fields: List[str] = []) -> type[BaseModel]: +def create_model_from_signature(func: Callable, model_name: str, base_model: type[BaseModel] = BaseModel, additional_fields: list = None, exclude_fields: list[str] = None) -> type[BaseModel]: from PIL import Image + if exclude_fields is None: + exclude_fields = [] + if additional_fields is None: + additional_fields = [] class Config: extra = 'allow' @@ -443,13 +476,13 @@ def create_model_from_signature(func: Callable, model_name: str, base_model: Typ defaults = (...,) * non_default_args + defaults keyword_only_params = {param: kwonlydefaults.get(param, Any) for param in kwonlyargs} for k, v in annotations.items(): - if v == List[Image.Image]: - annotations[k] = List[str] + if v == list[Image.Image]: + annotations[k] = list[str] elif v == Image.Image: annotations[k] = str elif str(v) == 'typing.List[modules.control.unit.Unit]': - annotations[k] = List[str] - model_fields = {param: (annotations.get(param, Any), default) for param, default in zip(args, defaults)} + annotations[k] = list[str] + model_fields = {param: (annotations.get(param, Any), default) for param, default in zip(args, defaults, strict=False)} for fld in additional_fields: model_def = ModelDef( @@ -464,16 +497,21 @@ def create_model_from_signature(func: Callable, model_name: str, base_model: Typ if fld in model_fields: del model_fields[fld] + if PYDANTIC_V2: + config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True, extra='allow' if varkw else 'ignore') + else: + class Config: + arbitrary_types_allowed = True + orm_mode = True + allow_population_by_field_name = True + extra = 'allow' if varkw else 'ignore' + config = Config + model = create_model( model_name, - **model_fields, - **keyword_only_params, __base__=base_model, __config__=config, + **model_fields, + **keyword_only_params, ) - try: - model.__config__.allow_population_by_field_name = True - model.__config__.allow_mutation = True - except Exception: - pass return model diff --git a/modules/api/process.py b/modules/api/process.py index c106a18e2..c3d102252 100644 --- a/modules/api/process.py +++ b/modules/api/process.py @@ -1,4 +1,3 @@ -from typing import Optional, List from threading import Lock from pydantic import BaseModel, Field # pylint: disable=no-name-in-module from fastapi.responses import JSONResponse @@ -15,7 +14,7 @@ errors.install() class ReqPreprocess(BaseModel): image: str = Field(title="Image", description="The base64 encoded image") model: str = Field(title="Model", description="The model to use for preprocessing") - params: Optional[dict] = Field(default={}, title="Settings", description="Preprocessor settings") + params: dict | None = Field(default={}, title="Settings", description="Preprocessor settings") class ResPreprocess(BaseModel): model: str = Field(default='', title="Model", description="The processor model used") @@ -24,20 +23,20 @@ class ResPreprocess(BaseModel): class ReqMask(BaseModel): image: str = Field(title="Image", description="The base64 encoded image") type: str = Field(title="Mask type", description="Type of masking image to return") - mask: Optional[str] = Field(title="Mask", description="If optional maks image is not provided auto-masking will be performed") - model: Optional[str] = Field(title="Model", description="The model to use for preprocessing") - params: Optional[dict] = Field(default={}, title="Settings", description="Preprocessor settings") + mask: str | None = Field(title="Mask", description="If optional maks image is not provided auto-masking will be performed") + model: str | None = Field(title="Model", description="The model to use for preprocessing") + params: dict | None = Field(default={}, title="Settings", description="Preprocessor settings") class ReqFace(BaseModel): image: str = Field(title="Image", description="The base64 encoded image") - model: Optional[str] = Field(title="Model", description="The model to use for detection") + model: str | None = Field(title="Model", description="The model to use for detection") class ResFace(BaseModel): - classes: List[int] = Field(title="Class", description="The class of detected item") - labels: List[str] = Field(title="Label", description="The label of detected item") - boxes: List[List[int]] = Field(title="Box", description="The bounding box of detected item") - images: List[str] = Field(title="Image", description="The base64 encoded images of detected faces") - scores: List[float] = Field(title="Scores", description="The scores of the detected faces") + classes: list[int] = Field(title="Class", description="The class of detected item") + labels: list[str] = Field(title="Label", description="The label of detected item") + boxes: list[list[int]] = Field(title="Box", description="The bounding box of detected item") + images: list[str] = Field(title="Image", description="The base64 encoded images of detected faces") + scores: list[float] = Field(title="Scores", description="The scores of the detected faces") class ResMask(BaseModel): mask: str = Field(default='', title="Image", description="The processed image in base64 format") @@ -47,13 +46,13 @@ class ItemPreprocess(BaseModel): params: dict = Field(title="Params") class ItemMask(BaseModel): - models: List[str] = Field(title="Models") - colormaps: List[str] = Field(title="Color maps") + models: list[str] = Field(title="Models") + colormaps: list[str] = Field(title="Color maps") params: dict = Field(title="Params") - types: List[str] = Field(title="Types") + types: list[str] = Field(title="Types") -class APIProcess(): +class APIProcess: def __init__(self, queue_lock: Lock): self.queue_lock = queue_lock diff --git a/modules/api/script.py b/modules/api/script.py index 064d26143..ce8df9339 100644 --- a/modules/api/script.py +++ b/modules/api/script.py @@ -1,4 +1,3 @@ -from typing import Optional from fastapi.exceptions import HTTPException import gradio as gr from modules.api import models @@ -36,7 +35,7 @@ def get_scripts_list(): return models.ResScripts(txt2img = t2ilist, img2img = i2ilist, control = control) -def get_script_info(script_name: Optional[str] = None): +def get_script_info(script_name: str | None = None): res = [] for script_list in [scripts_manager.scripts_txt2img.scripts, scripts_manager.scripts_img2img.scripts, scripts_manager.scripts_control.scripts]: for script in script_list: diff --git a/modules/api/xyz_grid.py b/modules/api/xyz_grid.py index 569ae98b0..ec230c324 100644 --- a/modules/api/xyz_grid.py +++ b/modules/api/xyz_grid.py @@ -1,7 +1,6 @@ -from typing import List -def xyz_grid_enum(option: str = "") -> List[dict]: +def xyz_grid_enum(option: str = "") -> list[dict]: from scripts.xyz import xyz_grid_classes # pylint: disable=no-name-in-module options = [] for x in xyz_grid_classes.axis_options: @@ -23,4 +22,4 @@ def xyz_grid_enum(option: str = "") -> List[dict]: def register_api(): from modules.shared import api as api_instance - api_instance.add_api_route("/sdapi/v1/xyz-grid", xyz_grid_enum, methods=["GET"], response_model=List[dict]) + api_instance.add_api_route("/sdapi/v1/xyz-grid", xyz_grid_enum, methods=["GET"], response_model=list[dict]) diff --git a/modules/attention.py b/modules/attention.py index a0a29bfb1..6490f6abb 100644 --- a/modules/attention.py +++ b/modules/attention.py @@ -1,4 +1,3 @@ -from typing import Optional from functools import wraps import torch from modules import rocm @@ -23,7 +22,7 @@ def set_triton_flash_attention(backend: str): from modules.flash_attn_triton_amd import interface_fa sdpa_pre_triton_flash_atten = torch.nn.functional.scaled_dot_product_attention @wraps(sdpa_pre_triton_flash_atten) - def sdpa_triton_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: + def sdpa_triton_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32: if scale is None: scale = query.shape[-1] ** (-0.5) @@ -56,7 +55,7 @@ def set_flex_attention(): sdpa_pre_flex_atten = torch.nn.functional.scaled_dot_product_attention @wraps(sdpa_pre_flex_atten) - def sdpa_flex_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: # pylint: disable=unused-argument + def sdpa_flex_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: # pylint: disable=unused-argument score_mod = None block_mask = None if attn_mask is not None: @@ -96,7 +95,7 @@ def set_ck_flash_attention(backend: str, device: torch.device): from flash_attn import flash_attn_func sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention @wraps(sdpa_pre_flash_atten) - def sdpa_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: + def sdpa_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32: is_unsqueezed = False if query.dim() == 3: @@ -162,7 +161,7 @@ def set_sage_attention(backend: str, device: torch.device): sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention @wraps(sdpa_pre_sage_atten) - def sdpa_sage_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: + def sdpa_sage_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: if (query.shape[-1] in {128, 96, 64}) and (attn_mask is None) and (query.dtype != torch.float32): if enable_gqa: key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) diff --git a/modules/ben2/ben2_model.py b/modules/ben2/ben2_model.py index fe38c571e..2eccfc8ce 100644 --- a/modules/ben2/ben2_model.py +++ b/modules/ben2/ben2_model.py @@ -373,7 +373,7 @@ class BasicLayer(nn.Module): mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill(attn_mask != 0, (-100.0)).masked_fill(attn_mask == 0, 0.0) for blk in self.blocks: blk.H, blk.W = H, W @@ -464,8 +464,8 @@ class SwinTransformer(nn.Module): patch_size=4, in_chans=3, embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[3, 6, 12, 24], + depths=None, + num_heads=None, window_size=7, mlp_ratio=4., qkv_bias=True, @@ -479,6 +479,10 @@ class SwinTransformer(nn.Module): out_indices=(0, 1, 2, 3), frozen_stages=-1, use_checkpoint=False): + if num_heads is None: + num_heads = [3, 6, 12, 24] + if depths is None: + depths = [2, 2, 6, 2] super().__init__() self.pretrain_img_size = pretrain_img_size @@ -668,8 +672,10 @@ class PositionEmbeddingSine: class MCLM(nn.Module): - def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]): - super(MCLM, self).__init__() + def __init__(self, d_model, num_heads, pool_ratios=None): + if pool_ratios is None: + pool_ratios = [1, 4, 8] + super().__init__() self.attention = nn.ModuleList([ nn.MultiheadAttention(d_model, num_heads, dropout=0.1), nn.MultiheadAttention(d_model, num_heads, dropout=0.1), @@ -739,7 +745,7 @@ class MCLM(nn.Module): _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w) _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2) outputs_re = [] - for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))): + for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1), strict=False)): outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c @@ -760,8 +766,10 @@ class MCLM(nn.Module): class MCRM(nn.Module): - def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None): # pylint: disable=unused-argument - super(MCRM, self).__init__() + def __init__(self, d_model, num_heads, pool_ratios=None, h=None): # pylint: disable=unused-argument + if pool_ratios is None: + pool_ratios = [4, 8, 16] + super().__init__() self.attention = nn.ModuleList([ nn.MultiheadAttention(d_model, num_heads, dropout=0.1), nn.MultiheadAttention(d_model, num_heads, dropout=0.1), @@ -1049,7 +1057,7 @@ class BEN_Base(nn.Module): """ cap = cv2.VideoCapture(video_path) if not cap.isOpened(): - raise IOError(f"Cannot open video: {video_path}") + raise OSError(f"Cannot open video: {video_path}") original_fps = cap.get(cv2.CAP_PROPFPS) original_fps = 30 if original_fps == 0 else original_fps @@ -1225,7 +1233,7 @@ def add_audio_to_video(video_without_audio_path, original_video_path, output_pat '-of', 'csv=p=0', original_video_path ] - result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False) + result = subprocess.run(probe_command, capture_output=True, text=True, check=False) # result.stdout is empty if no audio stream found if not result.stdout.strip(): diff --git a/modules/caption/deepbooru.py b/modules/caption/deepbooru.py index d8324e428..5162a0dc8 100644 --- a/modules/caption/deepbooru.py +++ b/modules/caption/deepbooru.py @@ -4,7 +4,7 @@ import threading import torch import numpy as np from PIL import Image -from modules import modelloader, devices, shared +from modules import modelloader, devices, shared, paths re_special = re.compile(r'([\\()])') load_lock = threading.Lock() @@ -18,7 +18,7 @@ class DeepDanbooru: with load_lock: if self.model is not None: return - model_path = os.path.join(shared.models_path, "DeepDanbooru") + model_path = os.path.join(paths.models_path, "DeepDanbooru") shared.log.debug(f'Caption load: module=DeepDanbooru folder="{model_path}"') files = modelloader.load_models( model_path=model_path, @@ -96,7 +96,7 @@ class DeepDanbooru: x = torch.from_numpy(a).to(device=devices.device, dtype=devices.dtype) y = self.model(x)[0].detach().float().cpu().numpy() probability_dict = {} - for current, probability in zip(self.model.tags, y): + for current, probability in zip(self.model.tags, y, strict=False): if probability < general_threshold: continue if current.startswith("rating:") and not include_rating: diff --git a/modules/caption/deepbooru_model.py b/modules/caption/deepbooru_model.py index 2963385c3..9489182ab 100644 --- a/modules/caption/deepbooru_model.py +++ b/modules/caption/deepbooru_model.py @@ -671,4 +671,4 @@ class DeepDanbooruModel(nn.Module): def load_state_dict(self, state_dict, **kwargs): # pylint: disable=arguments-differ,unused-argument self.tags = state_dict.get('tags', []) - super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'}) # pylint: disable=R1725 + super().load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'}) # pylint: disable=R1725 diff --git a/modules/caption/deepseek.py b/modules/caption/deepseek.py index e7d3eac0c..44bff3fb6 100644 --- a/modules/caption/deepseek.py +++ b/modules/caption/deepseek.py @@ -21,7 +21,7 @@ vl_chat_processor = None loaded_repo = None -class fake_attrdict(): +class fake_attrdict: class AttrDict(dict): # dot notation access to dictionary attributes __getattr__ = dict.get __setattr__ = dict.__setitem__ diff --git a/modules/caption/joycaption.py b/modules/caption/joycaption.py index 3fc990ba8..99f4ef677 100644 --- a/modules/caption/joycaption.py +++ b/modules/caption/joycaption.py @@ -39,7 +39,7 @@ Extra Options: """ @dataclass -class JoyOptions(): +class JoyOptions: repo: str = "fancyfeast/llama-joycaption-alpha-two-hf-llava" temp: float = 0.5 top_k: float = 10 diff --git a/modules/caption/joytag.py b/modules/caption/joytag.py index efae5bbca..042f87529 100644 --- a/modules/caption/joytag.py +++ b/modules/caption/joytag.py @@ -6,7 +6,6 @@ import os import math import json from pathlib import Path -from typing import Optional from PIL import Image import torch import torch.backends.cuda @@ -126,7 +125,7 @@ class VisionModel(nn.Module): @staticmethod def load_model(path: str) -> 'VisionModel': - with open(Path(path) / 'config.json', 'r', encoding='utf8') as f: + with open(Path(path) / 'config.json', encoding='utf8') as f: config = json.load(f) from safetensors.torch import load_file resume = load_file(Path(path) / 'model.safetensors', device='cpu') @@ -244,7 +243,7 @@ class CLIPMlp(nn.Module): class FastCLIPAttention2(nn.Module): """Fast Attention module for CLIP-like. This is NOT a drop-in replacement for CLIPAttention, since it adds additional flexibility. Mainly uses xformers.""" - def __init__(self, hidden_size: int, out_dim: int, num_attention_heads: int, out_seq_len: Optional[int] = None, norm_qk: bool = False): + def __init__(self, hidden_size: int, out_dim: int, num_attention_heads: int, out_seq_len: int | None = None, norm_qk: bool = False): super().__init__() self.out_seq_len = out_seq_len self.embed_dim = hidden_size @@ -308,12 +307,12 @@ class FastCLIPEncoderLayer(nn.Module): self, hidden_size: int, num_attention_heads: int, - out_seq_len: Optional[int], + out_seq_len: int | None, activation_cls = QuickGELUActivation, use_palm_alt: bool = False, norm_qk: bool = False, - skip_init: Optional[float] = None, - stochastic_depth: Optional[float] = None, + skip_init: float | None = None, + stochastic_depth: float | None = None, ): super().__init__() self.use_palm_alt = use_palm_alt @@ -523,8 +522,8 @@ class CLIPLikeModel(VisionModel): norm_qk: bool = False, no_wd_bias: bool = False, use_gap_head: bool = False, - skip_init: Optional[float] = None, - stochastic_depth: Optional[float] = None, + skip_init: float | None = None, + stochastic_depth: float | None = None, ): super().__init__(image_size, n_tags) out_dim = n_tags @@ -939,7 +938,7 @@ class ViT(VisionModel): stochdepth_rate: float, use_sine: bool, loss_type: str, - layerscale_init: Optional[float] = None, + layerscale_init: float | None = None, head_mean_after: bool = False, cnn_stem: str = None, patch_dropout: float = 0.0, @@ -1048,7 +1047,7 @@ def load(): model = VisionModel.load_model(folder) model.to(dtype=devices.dtype) model.eval() # required: custom loader, not from_pretrained - with open(os.path.join(folder, 'top_tags.txt'), 'r', encoding='utf8') as f: + with open(os.path.join(folder, 'top_tags.txt'), encoding='utf8') as f: tags = [line.strip() for line in f.readlines() if line.strip()] shared.log.info(f'Caption: type=vlm model="JoyTag" repo="{MODEL_REPO}" tags={len(tags)}') sd_models.move_model(model, devices.device) diff --git a/modules/caption/openclip.py b/modules/caption/openclip.py index c7fca4d5f..b43229a93 100644 --- a/modules/caption/openclip.py +++ b/modules/caption/openclip.py @@ -330,11 +330,11 @@ def analyze_image(image, clip_model, blip_model): top_movements = ci.movements.rank(image_features, 5) top_trendings = ci.trendings.rank(image_features, 5) top_flavors = ci.flavors.rank(image_features, 5) - medium_ranks = dict(sorted(zip(top_mediums, ci.similarities(image_features, top_mediums)), key=lambda x: x[1], reverse=True)) - artist_ranks = dict(sorted(zip(top_artists, ci.similarities(image_features, top_artists)), key=lambda x: x[1], reverse=True)) - movement_ranks = dict(sorted(zip(top_movements, ci.similarities(image_features, top_movements)), key=lambda x: x[1], reverse=True)) - trending_ranks = dict(sorted(zip(top_trendings, ci.similarities(image_features, top_trendings)), key=lambda x: x[1], reverse=True)) - flavor_ranks = dict(sorted(zip(top_flavors, ci.similarities(image_features, top_flavors)), key=lambda x: x[1], reverse=True)) + medium_ranks = dict(sorted(zip(top_mediums, ci.similarities(image_features, top_mediums), strict=False), key=lambda x: x[1], reverse=True)) + artist_ranks = dict(sorted(zip(top_artists, ci.similarities(image_features, top_artists), strict=False), key=lambda x: x[1], reverse=True)) + movement_ranks = dict(sorted(zip(top_movements, ci.similarities(image_features, top_movements), strict=False), key=lambda x: x[1], reverse=True)) + trending_ranks = dict(sorted(zip(top_trendings, ci.similarities(image_features, top_trendings), strict=False), key=lambda x: x[1], reverse=True)) + flavor_ranks = dict(sorted(zip(top_flavors, ci.similarities(image_features, top_flavors), strict=False), key=lambda x: x[1], reverse=True)) shared.log.debug(f'CLIP analyze: complete time={time.time()-t0:.2f}') # Format labels as text diff --git a/modules/caption/vqa.py b/modules/caption/vqa.py index d840b6d25..9c358145f 100644 --- a/modules/caption/vqa.py +++ b/modules/caption/vqa.py @@ -708,7 +708,7 @@ class VQA: debug(f'VQA caption: handler=qwen output_ids_shape={output_ids.shape}') generated_ids = [ output_ids[len(input_ids):] - for input_ids, output_ids in zip(inputs.input_ids, output_ids) + for input_ids, output_ids in zip(inputs.input_ids, output_ids, strict=False) ] response = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) if debug_enabled: @@ -887,7 +887,7 @@ class VQA: def _ovis(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument try: - import flash_attn # pylint: disable=unused-import + pass # pylint: disable=unused-import except Exception: shared.log.error(f'Caption: vlm="{repo}" flash-attn is not available') return '' diff --git a/modules/caption/waifudiffusion.py b/modules/caption/waifudiffusion.py index 4189fc989..416a9b1bd 100644 --- a/modules/caption/waifudiffusion.py +++ b/modules/caption/waifudiffusion.py @@ -126,7 +126,7 @@ class WaifuDiffusionTagger: self.tags = [] self.tag_categories = [] - with open(csv_path, 'r', encoding='utf-8') as f: + with open(csv_path, encoding='utf-8') as f: reader = csv.DictReader(f) for row in reader: self.tags.append(row['name']) @@ -269,7 +269,7 @@ class WaifuDiffusionTagger: character_count = 0 rating_count = 0 - for i, (tag_name, prob) in enumerate(zip(self.tags, probs)): + for i, (tag_name, prob) in enumerate(zip(self.tags, probs, strict=False)): category = self.tag_categories[i] tag_lower = tag_name.lower() diff --git a/modules/civitai/metadata_civitai.py b/modules/civitai/metadata_civitai.py index 62de254b5..d3594c3f7 100644 --- a/modules/civitai/metadata_civitai.py +++ b/modules/civitai/metadata_civitai.py @@ -10,7 +10,9 @@ selected_model = None class CivitModel: - def __init__(self, name, fn, sha = None, meta = {}): + def __init__(self, name, fn, sha = None, meta = None): + if meta is None: + meta = {} self.name = name self.file = name self.id = meta.get('id', 0) diff --git a/modules/civitai/search_civitai.py b/modules/civitai/search_civitai.py index 3b534662d..b30e32311 100644 --- a/modules/civitai/search_civitai.py +++ b/modules/civitai/search_civitai.py @@ -10,7 +10,7 @@ full_html = False base_models = ['', 'AuraFlow', 'Chroma', 'CogVideoX', 'Flux.1 S', 'Flux.1 D', 'Flux.1 Krea', 'Flux.1 Kontext', 'Flux.2 D', 'HiDream', 'Hunyuan 1', 'Hunyuan Video', 'Illustrious', 'Kolors', 'LTXV', 'Lumina', 'Mochi', 'NoobAI', 'PixArt a', 'PixArt E', 'Pony', 'Pony V7', 'Qwen', 'SD 1.4', 'SD 1.5', 'SD 1.5 LCM', 'SD 1.5 Hyper', 'SD 2.0', 'SD 2.1', 'SDXL 1.0', 'SDXL Lightning', 'SDXL Hyper', 'Wan Video 1.3B t2v', 'Wan Video 14B t2v', 'Wan Video 14B i2v 480p', 'Wan Video 14B i2v 720p', 'Wan Video 2.2 TI2V-5B', 'Wan Video 2.2 I2V-A14B', 'Wan Video 2.2 T2V-A14B', 'Wan Video 2.5 T2V', 'Wan Video 2.5 I2V', 'ZImageTurbo', 'Other'] @dataclass -class ModelImage(): +class ModelImage: def __init__(self, dct: dict): if isinstance(dct, str): dct = json.loads(dct) @@ -26,7 +26,7 @@ class ModelImage(): @dataclass -class ModelFile(): +class ModelFile: def __init__(self, dct: dict): if isinstance(dct, str): dct = json.loads(dct) @@ -43,7 +43,7 @@ class ModelFile(): @dataclass -class ModelVersion(): +class ModelVersion: def __init__(self, dct: dict): import bs4 if isinstance(dct, str): @@ -65,7 +65,7 @@ class ModelVersion(): @dataclass -class Model(): +class Model: def __init__(self, dct: dict): import bs4 if isinstance(dct, str): diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 0e7f421d6..78fda012a 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -129,7 +129,7 @@ def main_args(): def compatibility_args(): # removed args are added here as hidden in fixed format for compatbility reasons - from modules.paths import data_path, models_path + from modules.paths import data_path group_compat = parser.add_argument_group('Compatibility options') group_compat.add_argument('--backend', type=str, choices=['diffusers', 'original'], help=argparse.SUPPRESS) group_compat.add_argument("--allow-code", default=os.environ.get("SD_ALLOWCODE", False), action='store_true', help=argparse.SUPPRESS) diff --git a/modules/control/processor.py b/modules/control/processor.py index 5a58c09d4..db5947fe6 100644 --- a/modules/control/processor.py +++ b/modules/control/processor.py @@ -46,11 +46,17 @@ def preprocess_image( input_mask:Image.Image = None, input_type:str = 0, unit_type:str = 'controlnet', - active_process:list = [], - active_model:list = [], - selected_models:list = [], + active_process:list = None, + active_model:list = None, + selected_models:list = None, has_models:bool = False, ): + if selected_models is None: + selected_models = [] + if active_model is None: + active_model = [] + if active_process is None: + active_process = [] t0 = time.time() jobid = shared.state.begin('Preprocess') diff --git a/modules/control/processors.py b/modules/control/processors.py index a5985639f..2968443fe 100644 --- a/modules/control/processors.py +++ b/modules/control/processors.py @@ -161,7 +161,7 @@ def update_settings(*settings): update(['Depth Pro', 'params', 'color_map'], settings[28]) -class Processor(): +class Processor: def __init__(self, processor_id: str = None, resize = True): self.model = None self.processor_id = None @@ -268,7 +268,9 @@ class Processor(): display(e, 'Control Processor load') return f'Processor load filed: {processor_id}' - def __call__(self, image_input: Image, mode: str = 'RGB', width: int = 0, height: int = 0, resize_mode: int = 0, resize_name: str = 'None', scale_tab: int = 1, scale_by: float = 1.0, local_config: dict = {}): + def __call__(self, image_input: Image, mode: str = 'RGB', width: int = 0, height: int = 0, resize_mode: int = 0, resize_name: str = 'None', scale_tab: int = 1, scale_by: float = 1.0, local_config: dict = None): + if local_config is None: + local_config = {} if self.override is not None: debug(f'Control Processor: id="{self.processor_id}" override={self.override}') width = image_input.width if image_input is not None else width diff --git a/modules/control/run.py b/modules/control/run.py index 243840ff1..f678fcf1d 100644 --- a/modules/control/run.py +++ b/modules/control/run.py @@ -1,6 +1,5 @@ import os import sys -from typing import List, Union import cv2 from PIL import Image from modules.control import util # helper functions @@ -141,12 +140,12 @@ def set_pipe(p, has_models, unit_type, selected_models, active_model, active_str def check_active(p, unit_type, units): - active_process: List[processors.Processor] = [] # all active preprocessors - active_model: List[Union[controlnet.ControlNet, xs.ControlNetXS, t2iadapter.Adapter]] = [] # all active models - active_strength: List[float] = [] # strength factors for all active models - active_start: List[float] = [] # start step for all active models - active_end: List[float] = [] # end step for all active models - active_units: List[unit.Unit] = [] # all active units + active_process: list[processors.Processor] = [] # all active preprocessors + active_model: list[controlnet.ControlNet | xs.ControlNetXS | t2iadapter.Adapter] = [] # all active models + active_strength: list[float] = [] # strength factors for all active models + active_start: list[float] = [] # start step for all active models + active_end: list[float] = [] # end step for all active models + active_units: list[unit.Unit] = [] # all active units num_units = 0 for u in units: if u.type != unit_type: @@ -218,7 +217,7 @@ def check_active(p, unit_type, units): def check_enabled(p, unit_type, units, active_model, active_strength, active_start, active_end): has_models = False - selected_models: List[Union[controlnet.ControlNetModel, xs.ControlNetXSModel, t2iadapter.AdapterModel]] = None + selected_models: list[controlnet.ControlNetModel | xs.ControlNetXSModel | t2iadapter.AdapterModel] = None control_conditioning = None control_guidance_start = None control_guidance_end = None @@ -254,7 +253,7 @@ def control_set(kwargs): p_extra_args[k] = v -def init_units(units: List[unit.Unit]): +def init_units(units: list[unit.Unit]): for u in units: if not u.enabled: continue @@ -271,9 +270,9 @@ def init_units(units: List[unit.Unit]): def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg - units: List[unit.Unit] = [], inputs: List[Image.Image] = [], inits: List[Image.Image] = [], mask: Image.Image = None, unit_type: str = None, is_generator: bool = True, + units: list[unit.Unit] = None, inputs: list[Image.Image] = None, inits: list[Image.Image] = None, mask: Image.Image = None, unit_type: str = None, is_generator: bool = True, input_type: int = 0, - prompt: str = '', negative_prompt: str = '', styles: List[str] = [], + prompt: str = '', negative_prompt: str = '', styles: list[str] = None, steps: int = 20, sampler_index: int = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, guidance_name: str = 'Default', guidance_scale: float = 6.0, guidance_rescale: float = 0.0, guidance_start: float = 0.0, guidance_stop: float = 1.0, @@ -289,11 +288,23 @@ def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg enable_hr: bool = False, hr_sampler_index: int = None, hr_denoising_strength: float = 0.0, hr_resize_mode: int = 0, hr_resize_context: str = 'None', hr_upscaler: str = None, hr_force: bool = False, hr_second_pass_steps: int = 20, hr_scale: float = 1.0, hr_resize_x: int = 0, hr_resize_y: int = 0, refiner_steps: int = 5, refiner_start: float = 0.0, refiner_prompt: str = '', refiner_negative: str = '', video_skip_frames: int = 0, video_type: str = 'None', video_duration: float = 2.0, video_loop: bool = False, video_pad: int = 0, video_interpolate: int = 0, - extra: dict = {}, + extra: dict = None, override_script_name: str = None, - override_script_args = [], + override_script_args = None, *input_script_args, ): + if override_script_args is None: + override_script_args = [] + if extra is None: + extra = {} + if styles is None: + styles = [] + if inits is None: + inits = [] + if inputs is None: + inputs = [] + if units is None: + units = [] global pipe, original_pipeline # pylint: disable=global-statement if 'refine' in state: enable_hr = True @@ -303,7 +314,7 @@ def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg init_units(units) if inputs is None or (type(inputs) is list and len(inputs) == 0): inputs = [None] - output_images: List[Image.Image] = [] # output images + output_images: list[Image.Image] = [] # output images processed_image: Image.Image = None # last processed image if mask is not None and input_type == 0: input_type = 1 # inpaint always requires control_image diff --git a/modules/control/unit.py b/modules/control/unit.py index 2eba804af..94f104558 100644 --- a/modules/control/unit.py +++ b/modules/control/unit.py @@ -1,4 +1,3 @@ -from typing import Union from PIL import Image import gradio as gr from installer import log @@ -7,7 +6,6 @@ from modules.control.units import controlnet from modules.control.units import xs from modules.control.units import lite from modules.control.units import t2iadapter -from modules.control.units import reference # pylint: disable=unused-import default_device = None @@ -16,7 +14,7 @@ unit_types = ['t2i adapter', 'controlnet', 'xs', 'lite', 'reference', 'ip'] current = [] -class Unit(): # mashup of gradio controls and mapping to actual implementation classes +class Unit: # mashup of gradio controls and mapping to actual implementation classes def update_choices(self, model_id=None): name = model_id or self.model_name if name == 'InstantX Union F1': @@ -57,8 +55,10 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c control_mode = None, control_tile = None, result_txt = None, - extra_controls: list = [], + extra_controls: list = None, ): + if extra_controls is None: + extra_controls = [] self.model_id = model_id self.process_id = process_id self.controls = [gr.Label(value=unit_type, visible=False)] # separator @@ -77,7 +77,7 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c self.process_name = None self.process: processors.Processor = processors.Processor() self.adapter: t2iadapter.Adapter = None - self.controlnet: Union[controlnet.ControlNet, xs.ControlNetXS] = None + self.controlnet: controlnet.ControlNet | xs.ControlNetXS = None # map to input image self.override: Image = None # global settings but passed per-unit diff --git a/modules/devices.py b/modules/devices.py index 35391bf33..99276e30e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -104,7 +104,7 @@ def get_gpu_info(): elif torch.cuda.is_available() and torch.version.cuda: try: import subprocess - result = subprocess.run('nvidia-smi --query-gpu=driver_version --format=csv,noheader', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + result = subprocess.run('nvidia-smi --query-gpu=driver_version --format=csv,noheader', shell=True, check=False, env=os.environ, capture_output=True) version = result.stdout.decode(encoding="utf8", errors="ignore").strip() return version except Exception: @@ -307,7 +307,7 @@ def set_cuda_tunable(): lines={0} try: if os.path.exists(fn): - with open(fn, 'r', encoding='utf8') as f: + with open(fn, encoding='utf8') as f: lines = sum(1 for _line in f) except Exception: pass diff --git a/modules/dml/Generator.py b/modules/dml/Generator.py index ea273310c..c6c53f084 100644 --- a/modules/dml/Generator.py +++ b/modules/dml/Generator.py @@ -1,7 +1,6 @@ -from typing import Optional import torch class Generator(torch.Generator): - def __init__(self, device: Optional[torch.device] = None): + def __init__(self, device: torch.device | None = None): super().__init__("cpu") diff --git a/modules/dml/__init__.py b/modules/dml/__init__.py index c0ddbc1a1..82376c2d3 100644 --- a/modules/dml/__init__.py +++ b/modules/dml/__init__.py @@ -1,5 +1,6 @@ import platform -from typing import NamedTuple, Callable, Optional +from typing import NamedTuple, Optional +from collections.abc import Callable import torch from modules.errors import log from modules.sd_hijack_utils import CondFunc @@ -86,8 +87,8 @@ def directml_do_hijack(): class OverrideItem(NamedTuple): value: str - condition: Optional[Callable] - message: Optional[str] + condition: Callable | None + message: str | None opts_override_table = { diff --git a/modules/dml/amp/autocast_mode.py b/modules/dml/amp/autocast_mode.py index 401d26d9e..2c344d53b 100644 --- a/modules/dml/amp/autocast_mode.py +++ b/modules/dml/amp/autocast_mode.py @@ -1,5 +1,5 @@ import importlib -from typing import Any, Optional +from typing import Any import torch @@ -52,7 +52,7 @@ class autocast: fast_dtype: torch.dtype = torch.float16 prev_fast_dtype: torch.dtype - def __init__(self, dtype: Optional[torch.dtype] = torch.float16): + def __init__(self, dtype: torch.dtype | None = torch.float16): self.fast_dtype = dtype def __enter__(self): diff --git a/modules/dml/backend.py b/modules/dml/backend.py index 7947dc81b..712e17591 100644 --- a/modules/dml/backend.py +++ b/modules/dml/backend.py @@ -1,5 +1,5 @@ # pylint: disable=no-member,no-self-argument,no-method-argument -from typing import Optional, Callable +from collections.abc import Callable import torch import torch_directml # pylint: disable=import-error import modules.dml.amp as amp @@ -9,17 +9,17 @@ from .Generator import Generator from .device_properties import DeviceProperties -def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: +def amd_mem_get_info(device: rDevice | None=None) -> tuple[int, int]: from .memory_amd import AMDMemoryProvider return AMDMemoryProvider.mem_get_info(get_device(device).index) -def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: +def pdh_mem_get_info(device: rDevice | None=None) -> tuple[int, int]: mem_info = DirectML.memory_provider.get_memory(get_device(device).index) return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"]) -def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument +def mem_get_info(device: rDevice | None=None) -> tuple[int, int]: # pylint: disable=unused-argument return (8589934592, 8589934592) @@ -28,7 +28,7 @@ class DirectML: device = Device Generator = Generator - context_device: Optional[torch.device] = None + context_device: torch.device | None = None is_autocast_enabled = False autocast_gpu_dtype = torch.float16 @@ -41,7 +41,7 @@ class DirectML: def is_directml_device(device: torch.device) -> bool: return device.type == "privateuseone" - def has_float64_support(device: Optional[rDevice]=None) -> bool: + def has_float64_support(device: rDevice | None=None) -> bool: return torch_directml.has_float64_support(get_device(device).index) def device_count() -> int: @@ -53,16 +53,16 @@ class DirectML: def default_device() -> torch.device: return torch_directml.device(torch_directml.default_device()) - def get_device_string(device: Optional[rDevice]=None) -> str: + def get_device_string(device: rDevice | None=None) -> str: return f"privateuseone:{get_device(device).index}" - def get_device_name(device: Optional[rDevice]=None) -> str: + def get_device_name(device: rDevice | None=None) -> str: return torch_directml.device_name(get_device(device).index) - def get_device_properties(device: Optional[rDevice]=None) -> DeviceProperties: + def get_device_properties(device: rDevice | None=None) -> DeviceProperties: return DeviceProperties(get_device(device)) - def memory_stats(device: Optional[rDevice]=None): + def memory_stats(device: rDevice | None=None): return { "num_ooms": 0, "num_alloc_retries": 0, @@ -70,11 +70,11 @@ class DirectML: mem_get_info: Callable = mem_get_info - def memory_allocated(device: Optional[rDevice]=None) -> int: + def memory_allocated(device: rDevice | None=None) -> int: return sum(torch_directml.gpu_memory(get_device(device).index)) * (1 << 20) - def max_memory_allocated(device: Optional[rDevice]=None): + def max_memory_allocated(device: rDevice | None=None): return DirectML.memory_allocated(device) # DirectML does not empty GPU memory - def reset_peak_memory_stats(device: Optional[rDevice]=None): + def reset_peak_memory_stats(device: rDevice | None=None): return diff --git a/modules/dml/device.py b/modules/dml/device.py index ae4d32a99..cd7006333 100644 --- a/modules/dml/device.py +++ b/modules/dml/device.py @@ -1,4 +1,3 @@ -from typing import Optional import torch from .utils import rDevice, get_device @@ -6,11 +5,11 @@ from .utils import rDevice, get_device class Device: idx: int - def __enter__(self, device: Optional[rDevice]=None): + def __enter__(self, device: rDevice | None=None): torch.dml.context_device = get_device(device) self.idx = torch.dml.context_device.index - def __init__(self, device: Optional[rDevice]=None) -> torch.device: # pylint: disable=return-in-init + def __init__(self, device: rDevice | None=None) -> torch.device: # pylint: disable=return-in-init self.idx = get_device(device).index def __exit__(self, t, v, tb): diff --git a/modules/dml/hijack/tomesd.py b/modules/dml/hijack/tomesd.py index 79de721df..0acc657ae 100644 --- a/modules/dml/hijack/tomesd.py +++ b/modules/dml/hijack/tomesd.py @@ -1,9 +1,8 @@ -from typing import Type import torch from modules.dml.hijack.utils import catch_nan -def make_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: +def make_tome_block(block_class: type[torch.nn.Module]) -> type[torch.nn.Module]: class ToMeBlock(block_class): # Save for unpatching later _parent = block_class diff --git a/modules/dml/hijack/transformers.py b/modules/dml/hijack/transformers.py index 78ddb20a2..6b4e090e1 100644 --- a/modules/dml/hijack/transformers.py +++ b/modules/dml/hijack/transformers.py @@ -1,4 +1,3 @@ -from typing import Optional import torch import transformers.models.clip.modeling_clip @@ -22,9 +21,9 @@ def _make_causal_mask( def CLIPTextEmbeddings_forward( self: transformers.models.clip.modeling_clip.CLIPTextEmbeddings, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, ) -> torch.Tensor: from modules.devices import dtype seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] diff --git a/modules/dml/hijack/utils.py b/modules/dml/hijack/utils.py index 659431c22..8817251b0 100644 --- a/modules/dml/hijack/utils.py +++ b/modules/dml/hijack/utils.py @@ -1,5 +1,5 @@ import torch -from typing import Callable +from collections.abc import Callable from modules.shared import log, opts diff --git a/modules/dml/pdh/apis.py b/modules/dml/pdh/apis.py index f01222b45..9486e3321 100644 --- a/modules/dml/pdh/apis.py +++ b/modules/dml/pdh/apis.py @@ -1,6 +1,6 @@ from ctypes import CDLL, POINTER from ctypes.wintypes import LPCWSTR, LPDWORD, DWORD -from typing import Callable +from collections.abc import Callable from .structures import PDH_HQUERY, PDH_HCOUNTER, PPDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W from .defines import PDH_FUNCTION, PZZWSTR, DWORD_PTR diff --git a/modules/dml/utils.py b/modules/dml/utils.py index cb19ed900..58dd3238b 100644 --- a/modules/dml/utils.py +++ b/modules/dml/utils.py @@ -1,9 +1,9 @@ -from typing import Optional, Union +from typing import Union import torch rDevice = Union[torch.device, int] -def get_device(device: Optional[rDevice]=None) -> torch.device: +def get_device(device: rDevice | None=None) -> torch.device: if device is None: device = torch.dml.current_device() return torch.device(device) diff --git a/modules/errors.py b/modules/errors.py index a3397143a..110b1f5bb 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -10,13 +10,17 @@ install_traceback() already_displayed = {} -def install(suppress=[]): +def install(suppress=None): + if suppress is None: + suppress = [] warnings.filterwarnings("ignore", category=UserWarning) install_traceback(suppress=suppress) logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s') -def display(e: Exception, task: str, suppress=[]): +def display(e: Exception, task: str, suppress=None): + if suppress is None: + suppress = [] if isinstance(e, ErrorLimiterAbort): return log.critical(f"{task or 'error'}: {type(e).__name__}") @@ -45,7 +49,9 @@ def run(code, task: str): display(e, task) -def exception(suppress=[]): +def exception(suppress=None): + if suppress is None: + suppress = [] console = get_console() console.print_exception(show_locals=False, max_frames=16, extra_lines=2, suppress=suppress, theme="ansi_dark", word_wrap=False, width=min([console.width, 200])) diff --git a/modules/extensions.py b/modules/extensions.py index b29a6cbb0..c2643a4ad 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -186,7 +186,7 @@ class Extension: continue priority = '50' if os.path.isfile(os.path.join(dirpath, "..", ".priority")): - with open(os.path.join(dirpath, "..", ".priority"), "r", encoding="utf-8") as f: + with open(os.path.join(dirpath, "..", ".priority"), encoding="utf-8") as f: priority = str(f.read().strip()) res.append(scripts_manager.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority)) if priority != '50': diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 01913b187..66edf76a3 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -73,8 +73,12 @@ def is_stepwise(en_obj): return any([len(str(x).split("@")) > 1 for x in all_args]) # noqa C419 # pylint: disable=use-a-generator -def activate(p, extra_network_data=None, step=0, include=[], exclude=[]): +def activate(p, extra_network_data=None, step=0, include=None, exclude=None): """call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list""" + if exclude is None: + exclude = [] + if include is None: + include = [] if p.disable_extra_networks: return extra_network_data = extra_network_data or p.network_data diff --git a/modules/extras.py b/modules/extras.py index 221bbe6a0..7b2c9e6d6 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -33,7 +33,7 @@ def run_modelmerger(id_task, **kwargs): # pylint: disable=unused-argument from installer import install install('tensordict', quiet=True) try: - from tensordict import TensorDict # pylint: disable=unused-import + pass # pylint: disable=unused-import except Exception as e: shared.log.error(f"Merge: {e}") return [*[gr.update() for _ in range(4)], "tensordict not available"] diff --git a/modules/face/faceid.py b/modules/face/faceid.py index 8986e92b1..933df80ff 100644 --- a/modules/face/faceid.py +++ b/modules/face/faceid.py @@ -1,4 +1,3 @@ -from typing import List import os import cv2 import torch @@ -34,7 +33,7 @@ def hijack_load_ip_adapter(self): def face_id( p: processing.StableDiffusionProcessing, app, - source_images: List[Image.Image], + source_images: list[Image.Image], model: str, override: bool, cache: bool, diff --git a/modules/face/faceswap.py b/modules/face/faceswap.py index df3765fb2..a3e4f27fc 100644 --- a/modules/face/faceswap.py +++ b/modules/face/faceswap.py @@ -1,4 +1,3 @@ -from typing import List import os import cv2 import numpy as np @@ -12,7 +11,7 @@ insightface_app = None swapper = None -def face_swap(p: processing.StableDiffusionProcessing, app, input_images: List[Image.Image], source_image: Image.Image, cache: bool): +def face_swap(p: processing.StableDiffusionProcessing, app, input_images: list[Image.Image], source_image: Image.Image, cache: bool): global swapper # pylint: disable=global-statement if swapper is None: import insightface.model_zoo diff --git a/modules/face/instantid_model.py b/modules/face/instantid_model.py index 8af9a2907..2511ec27a 100644 --- a/modules/face/instantid_model.py +++ b/modules/face/instantid_model.py @@ -14,7 +14,8 @@ import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any +from collections.abc import Callable import cv2 import numpy as np @@ -544,40 +545,40 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, image: PipelineImageInput = None, - height: Optional[int] = None, - width: Optional[int] = None, + height: int | None = None, + width: int | None = None, num_inference_steps: int = 50, guidance_scale: float = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - image_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + image_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", return_dict: bool = True, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, guess_mode: bool = False, - control_guidance_start: Union[float, List[float]] = 0.0, - control_guidance_end: Union[float, List[float]] = 1.0, - original_size: Tuple[int, int] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Tuple[int, int] = None, - negative_original_size: Optional[Tuple[int, int]] = None, - negative_crops_coords_top_left: Tuple[int, int] = (0, 0), - negative_target_size: Optional[Tuple[int, int]] = None, - clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = None, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = None, **kwargs, ): r""" @@ -890,7 +891,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) + for s, e in zip(control_guidance_start, control_guidance_end, strict=False) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) @@ -970,7 +971,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i], strict=False)] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): diff --git a/modules/face/photomaker_pipeline.py b/modules/face/photomaker_pipeline.py index 45006a7e1..191c4f352 100644 --- a/modules/face/photomaker_pipeline.py +++ b/modules/face/photomaker_pipeline.py @@ -1,7 +1,8 @@ ### original import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Union +from collections.abc import Callable import PIL import torch from transformers import CLIPImageProcessor @@ -26,8 +27,8 @@ from modules.face.photomaker_model_v2 import PhotoMakerIDEncoder_CLIPInsightface PipelineImageInput = Union[ PIL.Image.Image, torch.FloatTensor, - List[PIL.Image.Image], - List[torch.FloatTensor], + list[PIL.Image.Image], + list[torch.FloatTensor], ] @@ -49,10 +50,10 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, **kwargs, ): """ @@ -110,7 +111,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): @validate_hf_hub_args def load_photomaker_adapter( self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], weight_name: str, subfolder: str = '', trigger_word: str = 'img', @@ -214,21 +215,21 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): def encode_prompt_with_trigger_word( self, prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, + prompt_2: str | None = None, + device: torch.device | None = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, ### Added args num_id_images: int = 1, - class_tokens_mask: Optional[torch.LongTensor] = None, + class_tokens_mask: torch.LongTensor | None = None, ): device = device or self._execution_device @@ -273,7 +274,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): # pylint: disable=redefined-argument-from-local + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False): # pylint: disable=redefined-argument-from-local if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) @@ -362,7 +363,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) - uncond_tokens: List[str] + uncond_tokens: list[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" @@ -377,7 +378,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): # pylint: disable=redefined-argument-from-local + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders, strict=False): # pylint: disable=redefined-argument-from-local if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) @@ -444,49 +445,47 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, num_inference_steps: int = 50, - timesteps: List[int] = None, - sigmas: List[float] = None, - denoising_end: Optional[float] = None, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, guidance_scale: float = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - ip_adapter_image: Optional[PipelineImageInput] = None, - ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, - output_type: Optional[str] = "pil", + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", return_dict: bool = True, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, + cross_attention_kwargs: dict[str, Any] | None = None, guidance_rescale: float = 0.0, - original_size: Optional[Tuple[int, int]] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Optional[Tuple[int, int]] = None, - negative_original_size: Optional[Tuple[int, int]] = None, - negative_crops_coords_top_left: Tuple[int, int] = (0, 0), - negative_target_size: Optional[Tuple[int, int]] = None, - clip_skip: Optional[int] = None, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] | None = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = None, # Added parameters (for PhotoMaker) input_id_images: PipelineImageInput = None, start_merge_step: int = 10, - class_tokens_mask: Optional[torch.LongTensor] = None, - id_embeds: Optional[torch.FloatTensor] = None, - prompt_embeds_text_only: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None, + class_tokens_mask: torch.LongTensor | None = None, + id_embeds: torch.FloatTensor | None = None, + prompt_embeds_text_only: torch.FloatTensor | None = None, + pooled_prompt_embeds_text_only: torch.FloatTensor | None = None, **kwargs, ): r""" @@ -512,6 +511,8 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): `tuple`. When returning a tuple, the first element is a list with the generated images. """ + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["latents"] callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) diff --git a/modules/face/reswapper.py b/modules/face/reswapper.py index 77328a1ff..d688a51d2 100644 --- a/modules/face/reswapper.py +++ b/modules/face/reswapper.py @@ -1,4 +1,3 @@ -from typing import List import os import cv2 import torch @@ -43,8 +42,8 @@ def get_model(model_name: str): def reswapper( p: processing.StableDiffusionProcessing, app, - source_images: List[Image.Image], - target_images: List[Image.Image], + source_images: list[Image.Image], + target_images: list[Image.Image], model_name: str, original: bool, ): diff --git a/modules/face/reswapper_model.py b/modules/face/reswapper_model.py index de68d8566..606189bf8 100644 --- a/modules/face/reswapper_model.py +++ b/modules/face/reswapper_model.py @@ -6,7 +6,7 @@ import torch.nn.functional as F class ReSwapperModel(nn.Module): def __init__(self): - super(ReSwapperModel, self).__init__() + super().__init__() # self.pad = nn.ReflectionPad2d(3) # Encoder for target face @@ -87,7 +87,7 @@ class ReSwapperModel(nn.Module): class StyleBlock(nn.Module): def __init__(self, in_channels, out_channels, blockIndex): - super(StyleBlock, self).__init__() + super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0) self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0) self.style1 = nn.Linear(512, 2048) diff --git a/modules/files_cache.py b/modules/files_cache.py index d65e0f4f4..5ffc4d383 100644 --- a/modules/files_cache.py +++ b/modules/files_cache.py @@ -2,7 +2,8 @@ import itertools import os from collections import UserDict from dataclasses import dataclass, field -from typing import Callable, Dict, Iterator, List, Optional, Union +from typing import Union +from collections.abc import Callable, Iterator from installer import log @@ -10,19 +11,19 @@ do_cache_folders = os.environ.get('SD_NO_CACHE', None) is None class Directory: # forward declaration ... -FilePathList = List[str] +FilePathList = list[str] FilePathIterator = Iterator[str] -DirectoryPathList = List[str] +DirectoryPathList = list[str] DirectoryPathIterator = Iterator[str] -DirectoryList = List[Directory] +DirectoryList = list[Directory] DirectoryIterator = Iterator[Directory] -DirectoryCollection = Dict[str, Directory] +DirectoryCollection = dict[str, Directory] ExtensionFilter = Callable ExtensionList = list[str] RecursiveType = Union[bool,Callable] -def real_path(directory_path:str) -> Union[str, None]: +def real_path(directory_path:str) -> str | None: try: return os.path.abspath(os.path.expanduser(directory_path)) except Exception: @@ -52,7 +53,7 @@ class Directory(Directory): # pylint: disable=E0102 def clear(self) -> None: self._update(Directory.from_dict({ 'path': None, - 'mtime': float(), + 'mtime': 0.0, 'files': [], 'directories': [] })) @@ -125,7 +126,7 @@ def clean_directory(directory: Directory, /, recursive: RecursiveType=False) -> return is_clean -def get_directory(directory_or_path: str, /, fetch: bool=True) -> Union[Directory, None]: +def get_directory(directory_or_path: str, /, fetch: bool=True) -> Directory | None: if isinstance(directory_or_path, Directory): if directory_or_path.is_directory: return directory_or_path @@ -143,7 +144,7 @@ def get_directory(directory_or_path: str, /, fetch: bool=True) -> Union[Director return cache_folders[directory_or_path] if directory_or_path in cache_folders else None -def fetch_directory(directory_path: str) -> Union[Directory, None]: +def fetch_directory(directory_path: str) -> Directory | None: directory: Directory for directory in _walk(directory_path, recurse=False): return directory # The return is intentional, we get a generator, we only need the one @@ -255,7 +256,7 @@ def get_directories(*directory_paths: DirectoryPathList, fetch:bool=True, recurs return filter(bool, directories) -def directory_files(*directories_or_paths: Union[DirectoryPathList, DirectoryList], recursive: RecursiveType=True) -> FilePathIterator: +def directory_files(*directories_or_paths: DirectoryPathList | DirectoryList, recursive: RecursiveType=True) -> FilePathIterator: return itertools.chain.from_iterable( itertools.chain( directory_object.files, @@ -275,7 +276,7 @@ def directory_files(*directories_or_paths: Union[DirectoryPathList, DirectoryLis ) -def extension_filter(ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None) -> ExtensionFilter: +def extension_filter(ext_filter: ExtensionList | None=None, ext_blacklist: ExtensionList | None=None) -> ExtensionFilter: if ext_filter: ext_filter = [*map(str.upper, ext_filter)] if ext_blacklist: @@ -289,11 +290,11 @@ def not_hidden(filepath: str) -> bool: return not os.path.basename(filepath).startswith('.') -def filter_files(file_paths: FilePathList, ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None) -> FilePathIterator: +def filter_files(file_paths: FilePathList, ext_filter: ExtensionList | None=None, ext_blacklist: ExtensionList | None=None) -> FilePathIterator: return filter(extension_filter(ext_filter, ext_blacklist), file_paths) -def list_files(*directory_paths:DirectoryPathList, ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None, recursive:RecursiveType=True) -> FilePathIterator: +def list_files(*directory_paths:DirectoryPathList, ext_filter: ExtensionList | None=None, ext_blacklist: ExtensionList | None=None, recursive:RecursiveType=True) -> FilePathIterator: return filter_files(itertools.chain.from_iterable( directory_files(directory, recursive=recursive) for directory in get_directories(*directory_paths, recursive=recursive) diff --git a/modules/framepack/framepack_api.py b/modules/framepack/framepack_api.py index 0c2ecb25b..ad3bd0e70 100644 --- a/modules/framepack/framepack_api.py +++ b/modules/framepack/framepack_api.py @@ -1,4 +1,3 @@ -from typing import Optional, List from pydantic import BaseModel, Field # pylint: disable=no-name-in-module from fastapi.exceptions import HTTPException from modules import shared @@ -8,39 +7,39 @@ class ReqFramepack(BaseModel): variant: str = Field(default=None, title="Model variant", description="Model variant to use") prompt: str = Field(default=None, title="Prompt", description="Prompt for the model") init_image: str = Field(default=None, title="Initial image", description="Base64 encoded initial image") - end_image: Optional[str] = Field(default=None, title="End image", description="Base64 encoded end image") - start_weight: Optional[float] = Field(default=1.0, title="Start weight", description="Weight of the initial image") - end_weight: Optional[float] = Field(default=1.0, title="End weight", description="Weight of the end image") - vision_weight: Optional[float] = Field(default=1.0, title="Vision weight", description="Weight of the vision model") - system_prompt: Optional[str] = Field(default=None, title="System prompt", description="System prompt for the model") - optimized_prompt: Optional[bool] = Field(default=True, title="Optimized system prompt", description="Use optimized system prompt for the model") - section_prompt: Optional[str] = Field(default=None, title="Section prompt", description="Prompt for each section") - negative_prompt: Optional[str] = Field(default=None, title="Negative prompt", description="Negative prompt for the model") - styles: Optional[List[str]] = Field(default=None, title="Styles", description="Styles for the model") - seed: Optional[int] = Field(default=None, title="Seed", description="Seed for the model") - resolution: Optional[int] = Field(default=640, title="Resolution", description="Resolution of the image") - duration: Optional[float] = Field(default=4, title="Duration", description="Duration of the video in seconds") - latent_ws: Optional[int] = Field(default=9, title="Latent window size", description="Size of the latent window") - steps: Optional[int] = Field(default=25, title="Video steps", description="Number of steps for the video generation") - cfg_scale: Optional[float] = Field(default=1.0, title="CFG scale", description="CFG scale for the model") - cfg_distilled: Optional[float] = Field(default=10.0, title="Distilled CFG scale", description="Distilled CFG scale for the model") - cfg_rescale: Optional[float] = Field(default=0.0, title="CFG re-scale", description="CFG re-scale for the model") - shift: Optional[float] = Field(default=0, title="Sampler shift", description="Shift for the sampler") - use_teacache: Optional[bool] = Field(default=True, title="Enable TeaCache", description="Use TeaCache for the model") - use_cfgzero: Optional[bool] = Field(default=False, title="Enable CFGZero", description="Use CFGZero for the model") - mp4_fps: Optional[int] = Field(default=30, title="FPS", description="Frames per second for the video") - mp4_codec: Optional[str] = Field(default="libx264", title="Codec", description="Codec for the video") - mp4_sf: Optional[bool] = Field(default=False, title="Save SafeTensors", description="Save SafeTensors for the video") - mp4_video: Optional[bool] = Field(default=True, title="Save Video", description="Save video") - mp4_frames: Optional[bool] = Field(default=False, title="Save Frames", description="Save frames for the video") - mp4_opt: Optional[str] = Field(default="crf:16", title="Options", description="Options for the video codec") - mp4_ext: Optional[str] = Field(default="mp4", title="Format", description="Format for the video") - mp4_interpolate: Optional[int] = Field(default=0, title="Interpolation", description="Interpolation for the video") - attention: Optional[str] = Field(default="Default", title="Attention", description="Attention type for the model") - vae_type: Optional[str] = Field(default="Local", title="VAE", description="VAE type for the model") - vlm_enhance: Optional[bool] = Field(default=False, title="VLM enhance", description="Enable VLM enhance") - vlm_model: Optional[str] = Field(default=None, title="VLM model", description="VLM model to use") - vlm_system_prompt: Optional[str] = Field(default=None, title="VLM system prompt", description="System prompt for the VLM model") + end_image: str | None = Field(default=None, title="End image", description="Base64 encoded end image") + start_weight: float | None = Field(default=1.0, title="Start weight", description="Weight of the initial image") + end_weight: float | None = Field(default=1.0, title="End weight", description="Weight of the end image") + vision_weight: float | None = Field(default=1.0, title="Vision weight", description="Weight of the vision model") + system_prompt: str | None = Field(default=None, title="System prompt", description="System prompt for the model") + optimized_prompt: bool | None = Field(default=True, title="Optimized system prompt", description="Use optimized system prompt for the model") + section_prompt: str | None = Field(default=None, title="Section prompt", description="Prompt for each section") + negative_prompt: str | None = Field(default=None, title="Negative prompt", description="Negative prompt for the model") + styles: list[str] | None = Field(default=None, title="Styles", description="Styles for the model") + seed: int | None = Field(default=None, title="Seed", description="Seed for the model") + resolution: int | None = Field(default=640, title="Resolution", description="Resolution of the image") + duration: float | None = Field(default=4, title="Duration", description="Duration of the video in seconds") + latent_ws: int | None = Field(default=9, title="Latent window size", description="Size of the latent window") + steps: int | None = Field(default=25, title="Video steps", description="Number of steps for the video generation") + cfg_scale: float | None = Field(default=1.0, title="CFG scale", description="CFG scale for the model") + cfg_distilled: float | None = Field(default=10.0, title="Distilled CFG scale", description="Distilled CFG scale for the model") + cfg_rescale: float | None = Field(default=0.0, title="CFG re-scale", description="CFG re-scale for the model") + shift: float | None = Field(default=0, title="Sampler shift", description="Shift for the sampler") + use_teacache: bool | None = Field(default=True, title="Enable TeaCache", description="Use TeaCache for the model") + use_cfgzero: bool | None = Field(default=False, title="Enable CFGZero", description="Use CFGZero for the model") + mp4_fps: int | None = Field(default=30, title="FPS", description="Frames per second for the video") + mp4_codec: str | None = Field(default="libx264", title="Codec", description="Codec for the video") + mp4_sf: bool | None = Field(default=False, title="Save SafeTensors", description="Save SafeTensors for the video") + mp4_video: bool | None = Field(default=True, title="Save Video", description="Save video") + mp4_frames: bool | None = Field(default=False, title="Save Frames", description="Save frames for the video") + mp4_opt: str | None = Field(default="crf:16", title="Options", description="Options for the video codec") + mp4_ext: str | None = Field(default="mp4", title="Format", description="Format for the video") + mp4_interpolate: int | None = Field(default=0, title="Interpolation", description="Interpolation for the video") + attention: str | None = Field(default="Default", title="Attention", description="Attention type for the model") + vae_type: str | None = Field(default="Local", title="VAE", description="VAE type for the model") + vlm_enhance: bool | None = Field(default=False, title="VLM enhance", description="Enable VLM enhance") + vlm_model: str | None = Field(default=None, title="VLM model", description="VLM model to use") + vlm_system_prompt: str | None = Field(default=None, title="VLM system prompt", description="System prompt for the VLM model") class ResFramepack(BaseModel): diff --git a/modules/framepack/framepack_worker.py b/modules/framepack/framepack_worker.py index 345333ad9..9df245d59 100644 --- a/modules/framepack/framepack_worker.py +++ b/modules/framepack/framepack_worker.py @@ -43,8 +43,10 @@ def worker( mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate, vae_type, variant, - metadata:dict={}, + metadata:dict=None, ): + if metadata is None: + metadata = {} timer.process.reset() memstats.reset_stats() if stream is None or shared.state.interrupted or shared.state.skipped: diff --git a/modules/framepack/pipeline/hunyuan_video_packed.py b/modules/framepack/pipeline/hunyuan_video_packed.py index 0a3f8f62b..a852fbc10 100644 --- a/modules/framepack/pipeline/hunyuan_video_packed.py +++ b/modules/framepack/pipeline/hunyuan_video_packed.py @@ -1,4 +1,3 @@ -from typing import Optional, Tuple import torch import torch.nn as nn @@ -251,7 +250,7 @@ class CombinedTimestepTextProjEmbeddings(nn.Module): class HunyuanVideoAdaNorm(nn.Module): - def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + def __init__(self, in_features: int, out_features: int | None = None) -> None: super().__init__() out_features = out_features or 2 * in_features @@ -260,7 +259,7 @@ class HunyuanVideoAdaNorm(nn.Module): def forward( self, temb: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: temb = self.linear(self.nonlinearity(temb)) gate_msa, gate_mlp = temb.chunk(2, dim=-1) gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) @@ -298,7 +297,7 @@ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): self, hidden_states: torch.Tensor, temb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: norm_hidden_states = self.norm1(hidden_states) @@ -346,7 +345,7 @@ class HunyuanVideoIndividualTokenRefiner(nn.Module): self, hidden_states: torch.Tensor, temb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor | None = None, ) -> None: self_attn_mask = None if attention_mask is not None: @@ -396,7 +395,7 @@ class HunyuanVideoTokenRefiner(nn.Module): self, hidden_states: torch.Tensor, timestep: torch.LongTensor, - attention_mask: Optional[torch.LongTensor] = None, + attention_mask: torch.LongTensor | None = None, ) -> torch.Tensor: if attention_mask is None: pooled_projections = hidden_states.mean(dim=1) @@ -464,8 +463,8 @@ class AdaLayerNormZero(nn.Module): def forward( self, x: torch.Tensor, - emb: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = emb.unsqueeze(-2) emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) @@ -487,8 +486,8 @@ class AdaLayerNormZeroSingle(nn.Module): def forward( self, x: torch.Tensor, - emb: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = emb.unsqueeze(-2) emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) @@ -558,8 +557,8 @@ class HunyuanVideoSingleTransformerBlock(nn.Module): hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.shape[1] hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) @@ -636,9 +635,9 @@ class HunyuanVideoTransformerBlock(nn.Module): hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + attention_mask: torch.Tensor | None = None, + freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: # 1. Input normalization norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb) @@ -734,7 +733,7 @@ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterM text_embed_dim: int = 4096, pooled_projection_dim: int = 768, rope_theta: float = 256.0, - rope_axes_dim: Tuple[int] = (16, 56, 56), + rope_axes_dim: tuple[int] = (16, 56, 56), has_image_proj=False, image_proj_dim=1152, has_clean_x_embedder=False, diff --git a/modules/framepack/pipeline/utils.py b/modules/framepack/pipeline/utils.py index 9cd99571d..a14e1d7d9 100644 --- a/modules/framepack/pipeline/utils.py +++ b/modules/framepack/pipeline/utils.py @@ -102,14 +102,14 @@ def just_crop(image, w, h): def write_to_json(data, file_path): temp_file_path = file_path + ".tmp" - with open(temp_file_path, 'wt', encoding='utf-8') as temp_file: + with open(temp_file_path, 'w', encoding='utf-8') as temp_file: json.dump(data, temp_file, indent=4) os.replace(temp_file_path, file_path) return def read_from_json(file_path): - with open(file_path, 'rt', encoding='utf-8') as file: + with open(file_path, encoding='utf-8') as file: data = json.load(file) return data @@ -283,7 +283,7 @@ def add_tensors_with_padding(tensor1, tensor2): shape1 = tensor1.shape shape2 = tensor2.shape - new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2)) + new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2, strict=False)) padded_tensor1 = torch.zeros(new_shape) padded_tensor2 = torch.zeros(new_shape) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 63ffc545e..42f6aa864 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -5,7 +5,7 @@ import os from PIL import Image import gradio as gr from modules import shared, gr_tempdir, script_callbacks, images -from modules.infotext import parse, mapping, quote, unquote # pylint: disable=unused-import +from modules.infotext import parse, mapping # pylint: disable=unused-import type_of_gr_update = type(gr.update()) @@ -259,7 +259,7 @@ def connect_paste(button, local_paste_fields, input_comp, override_settings_comp from modules.paths import params_path if prompt is None or len(prompt.strip()) == 0: if os.path.exists(params_path): - with open(params_path, "r", encoding="utf8") as file: + with open(params_path, encoding="utf8") as file: prompt = file.read() shared.log.debug(f'Prompt parse: type="params" prompt="{prompt}"') else: diff --git a/modules/ggml/gguf_utils.py b/modules/ggml/gguf_utils.py index c6c937380..f3fdc2146 100644 --- a/modules/ggml/gguf_utils.py +++ b/modules/ggml/gguf_utils.py @@ -1,7 +1,8 @@ # Original: invokeai.backend.quantization.gguf.utils # Largely based on https://github.com/city96/ComfyUI-GGUF -from typing import Callable, Optional, Union +from typing import Union +from collections.abc import Callable import gguf import torch @@ -28,7 +29,7 @@ def get_scale_min(scales: torch.Tensor): # Legacy Quants # def dequantize_blocks_Q8_0( - blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None + blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None ) -> torch.Tensor: d, x = split_block_dims(blocks, 2) d = d.view(torch.float16).to(dtype) @@ -37,7 +38,7 @@ def dequantize_blocks_Q8_0( def dequantize_blocks_Q5_1( - blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None + blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None ) -> torch.Tensor: n_blocks = blocks.shape[0] @@ -58,7 +59,7 @@ def dequantize_blocks_Q5_1( def dequantize_blocks_Q5_0( - blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None + blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None ) -> torch.Tensor: n_blocks = blocks.shape[0] @@ -79,7 +80,7 @@ def dequantize_blocks_Q5_0( def dequantize_blocks_Q4_1( - blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None + blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None ) -> torch.Tensor: n_blocks = blocks.shape[0] @@ -96,7 +97,7 @@ def dequantize_blocks_Q4_1( def dequantize_blocks_Q4_0( - blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None + blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None ) -> torch.Tensor: n_blocks = blocks.shape[0] @@ -111,13 +112,13 @@ def dequantize_blocks_Q4_0( def dequantize_blocks_BF16( - blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None + blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None ) -> torch.Tensor: return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32) def dequantize_blocks_Q6_K( - blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None + blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None ) -> torch.Tensor: n_blocks = blocks.shape[0] @@ -147,7 +148,7 @@ def dequantize_blocks_Q6_K( def dequantize_blocks_Q5_K( - blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None + blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None ) -> torch.Tensor: n_blocks = blocks.shape[0] @@ -175,7 +176,7 @@ def dequantize_blocks_Q5_K( def dequantize_blocks_Q4_K( - blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None + blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None ) -> torch.Tensor: n_blocks = blocks.shape[0] @@ -197,7 +198,7 @@ def dequantize_blocks_Q4_K( def dequantize_blocks_Q3_K( - blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None + blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None ) -> torch.Tensor: n_blocks = blocks.shape[0] @@ -232,7 +233,7 @@ def dequantize_blocks_Q3_K( def dequantize_blocks_Q2_K( - blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None + blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None ) -> torch.Tensor: n_blocks = blocks.shape[0] @@ -254,7 +255,7 @@ def dequantize_blocks_Q2_K( DEQUANTIZE_FUNCTIONS: dict[ - gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, Optional[torch.dtype]], torch.Tensor] + gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, torch.dtype | None], torch.Tensor] ] = { gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, @@ -270,7 +271,7 @@ DEQUANTIZE_FUNCTIONS: dict[ } -def is_torch_compatible(tensor: Optional[torch.Tensor]): +def is_torch_compatible(tensor: torch.Tensor | None): return getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES @@ -279,7 +280,7 @@ def is_quantized(tensor: torch.Tensor): def dequantize( - data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: Optional[torch.dtype] = None + data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: torch.dtype | None = None ): """ Dequantize tensor back to usable shape/dtype diff --git a/modules/history.py b/modules/history.py index 9c0c9c0c6..2540aa7a0 100644 --- a/modules/history.py +++ b/modules/history.py @@ -9,8 +9,10 @@ import torch from modules import shared, devices -class Item(): - def __init__(self, latent, preview=None, info=None, ops=[]): +class Item: + def __init__(self, latent, preview=None, info=None, ops=None): + if ops is None: + ops = [] self.ts = datetime.datetime.now().replace(microsecond=0) self.name = self.ts.strftime('%Y-%m-%d %H:%M:%S') self.latent = latent.detach().clone().to(devices.cpu) @@ -20,7 +22,7 @@ class Item(): self.size = sys.getsizeof(self.latent.storage()) -class History(): +class History: def __init__(self): self.index = -1 self.latents = deque(maxlen=1024) @@ -58,7 +60,9 @@ class History(): return i return -1 - def add(self, latent, preview=None, info=None, ops=[]): + def add(self, latent, preview=None, info=None, ops=None): + if ops is None: + ops = [] shared.state.latent_history += 1 if shared.opts.latent_history == 0: return diff --git a/modules/image/grid.py b/modules/image/grid.py index 5827bdb95..2b6bc2ed1 100644 --- a/modules/image/grid.py +++ b/modules/image/grid.py @@ -171,7 +171,7 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N calc_img = Image.new("RGB", (1, 1), shared.opts.grid_background) calc_d = ImageDraw.Draw(calc_img) title_texts = [title] if title else [[GridAnnotation()]] - for texts, allowed_width in zip(x_texts + y_texts + title_texts, [width] * len(x_texts) + [pad_left] * len(y_texts) + [(width+margin)*cols]): + for texts, allowed_width in zip(x_texts + y_texts + title_texts, [width] * len(x_texts) + [pad_left] * len(y_texts) + [(width+margin)*cols], strict=False): items = [] + texts texts.clear() for line in items: diff --git a/modules/image/resize.py b/modules/image/resize.py index 029e2549e..a26c92714 100644 --- a/modules/image/resize.py +++ b/modules/image/resize.py @@ -1,4 +1,3 @@ -from typing import Union import sys import time import numpy as np @@ -8,7 +7,7 @@ from modules import shared, upscaler from modules.image import sharpfin -def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None): +def resize_image(resize_mode: int, im: Image.Image | torch.Tensor, width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None): upscaler_name = upscaler_name or shared.opts.upscaler_for_img2img def verify_image(image): @@ -34,7 +33,7 @@ def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width: im = vae_decode(latents, shared.sd_model, output_type='pil', vae_type='Tiny')[0] return im - def resize(im: Union[Image.Image, torch.Tensor], w, h): + def resize(im: Image.Image | torch.Tensor, w, h): w, h = int(w), int(h) if upscaler_name is None or upscaler_name == "None" or (hasattr(im, 'mode') and im.mode == 'L'): return sharpfin.resize(im, (w, h), linearize=False) # force for mask diff --git a/modules/image/sharpfin.py b/modules/image/sharpfin.py index 987e98bba..ecdebdf0a 100644 --- a/modules/image/sharpfin.py +++ b/modules/image/sharpfin.py @@ -22,7 +22,6 @@ _triton_ok = False def check_sharpfin(): global _sharpfin_checked, _sharpfin_ok, _triton_ok # pylint: disable=global-statement if not _sharpfin_checked: - from modules.sharpfin.functional import scale # pylint: disable=unused-import _sharpfin_ok = True try: from modules.sharpfin import TRITON_AVAILABLE diff --git a/modules/images.py b/modules/images.py index 070beb9a4..068141e77 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,10 +1,11 @@ -from modules.image.util import flatten, draw_text # pylint: disable=unused-import -from modules.image.save import save_image # pylint: disable=unused-import -from modules.image.convert import to_pil, to_tensor # pylint: disable=unused-import -from modules.image.metadata import read_info_from_image, image_data # pylint: disable=unused-import -from modules.image.resize import resize_image # pylint: disable=unused-import -from modules.image.sharpfin import resize # pylint: disable=unused-import -from modules.image.namegen import FilenameGenerator, get_next_sequence_number # pylint: disable=unused-import -from modules.image.watermark import set_watermark, get_watermark # pylint: disable=unused-import -from modules.image.grid import image_grid, get_grid_size, split_grid, combine_grid, check_grid_size, get_font, draw_grid_annotations, draw_prompt_matrix, GridAnnotation, Grid # pylint: disable=unused-import -from modules.video import save_video # pylint: disable=unused-import +from modules.image.metadata import image_data, read_info_from_image +from modules.image.save import save_image, sanitize_filename_part +from modules.image.resize import resize_image +from modules.image.grid import image_grid, check_grid_size, get_grid_size, draw_grid_annotations, draw_prompt_matrix + +__all__ = [ + 'image_data', 'read_info_from_image', + 'save_image', 'sanitize_filename_part', + 'resize_image', + 'image_grid', 'check_grid_size', 'get_grid_size', 'draw_grid_annotations', 'draw_prompt_matrix' +] diff --git a/modules/img2img.py b/modules/img2img.py index bc080f732..9b832e7a0 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -77,7 +77,7 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args) caption_file = os.path.splitext(image_file)[0] + '.txt' prompt_type='default' if os.path.exists(caption_file): - with open(caption_file, 'r', encoding='utf8') as f: + with open(caption_file, encoding='utf8') as f: p.prompt = f.read() prompt_type='file' else: diff --git a/modules/infotext.py b/modules/infotext.py index dbee9fbd9..75a955e2d 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -129,7 +129,7 @@ if __name__ == '__main__': import sys if len(sys.argv) > 1: if os.path.exists(sys.argv[1]): - with open(sys.argv[1], 'r', encoding='utf8') as f: + with open(sys.argv[1], encoding='utf8') as f: parse(f.read()) else: parse(sys.argv[1]) diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index 6d947fd3f..b32c1e558 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -1,3 +1,2 @@ # a1111 compatibility module: unused -from modules.infotext import parse as parse_generation_parameters # pylint: disable=unused-import diff --git a/modules/intel/openvino/__init__.py b/modules/intel/openvino/__init__.py index e14b26ec0..0491525ad 100644 --- a/modules/intel/openvino/__init__.py +++ b/modules/intel/openvino/__init__.py @@ -81,7 +81,9 @@ def warn_once(msg): warned = True class OpenVINOGraphModule(torch.nn.Module): - def __init__(self, gm, partition_id, use_python_fusion_cache, model_hash_str: str = None, file_name="", int_inputs=[]): + def __init__(self, gm, partition_id, use_python_fusion_cache, model_hash_str: str = None, file_name="", int_inputs=None): + if int_inputs is None: + int_inputs = [] super().__init__() self.gm = gm self.int_inputs = int_inputs @@ -192,7 +194,7 @@ def execute( elif executor == "strictly_openvino": return openvino_execute(gm, *args, executor_parameters=executor_parameters, file_name=file_name) - msg = "Received unexpected value for 'executor': {0}. Allowed values are: openvino, strictly_openvino.".format(executor) + msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: openvino, strictly_openvino." raise ValueError(msg) @@ -373,7 +375,7 @@ def openvino_execute(gm: GraphModule, *args, executor_parameters=None, partition ov_inputs = [] for arg in flat_args: if not isinstance(arg, int): - ov_inputs.append((arg.detach().cpu().numpy())) + ov_inputs.append(arg.detach().cpu().numpy()) res = req.infer(ov_inputs, share_inputs=True, share_outputs=True) @@ -423,7 +425,9 @@ def openvino_execute_partitioned(gm: GraphModule, *args, executor_parameters=Non return shared.compiled_model_state.partitioned_modules[signature][0](*ov_inputs) -def partition_graph(gm: GraphModule, use_python_fusion_cache: bool, model_hash_str: str = None, file_name="", int_inputs=[]): +def partition_graph(gm: GraphModule, use_python_fusion_cache: bool, model_hash_str: str = None, file_name="", int_inputs=None): + if int_inputs is None: + int_inputs = [] for node in gm.graph.nodes: if node.op == "call_module" and "fused_" in node.name: openvino_submodule = getattr(gm, node.name) @@ -509,7 +513,7 @@ def openvino_fx(subgraph, example_inputs, options=None): if os.path.isfile(maybe_fs_cached_name + ".xml") and os.path.isfile(maybe_fs_cached_name + ".bin"): example_inputs_reordered = [] if (os.path.isfile(maybe_fs_cached_name + ".txt")): - f = open(maybe_fs_cached_name + ".txt", "r") + f = open(maybe_fs_cached_name + ".txt") for input_data in example_inputs: shape = f.readline() if (str(input_data.size()) != shape): @@ -532,7 +536,7 @@ def openvino_fx(subgraph, example_inputs, options=None): if (shared.compiled_model_state.cn_model != [] and str(shared.compiled_model_state.cn_model) in maybe_fs_cached_name): args_reordered = [] if (os.path.isfile(maybe_fs_cached_name + ".txt")): - f = open(maybe_fs_cached_name + ".txt", "r") + f = open(maybe_fs_cached_name + ".txt") for input_data in args: shape = f.readline() if (str(input_data.size()) != shape): diff --git a/modules/ipadapter.py b/modules/ipadapter.py index f29fc0d77..97ac171ae 100644 --- a/modules/ipadapter.py +++ b/modules/ipadapter.py @@ -288,7 +288,19 @@ def parse_params(p: processing.StableDiffusionProcessing, adapters: list, adapte return adapter_images, adapter_masks, adapter_scales, adapter_crops, adapter_starts, adapter_ends -def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapter_scales=[1.0], adapter_crops=[False], adapter_starts=[0.0], adapter_ends=[1.0], adapter_images=[]): +def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=None, adapter_scales=None, adapter_crops=None, adapter_starts=None, adapter_ends=None, adapter_images=None): + if adapter_images is None: + adapter_images = [] + if adapter_ends is None: + adapter_ends = [1.0] + if adapter_starts is None: + adapter_starts = [0.0] + if adapter_crops is None: + adapter_crops = [False] + if adapter_scales is None: + adapter_scales = [1.0] + if adapter_names is None: + adapter_names = [] global adapters_loaded # pylint: disable=global-statement # overrides if hasattr(p, 'ip_adapter_names'): @@ -361,7 +373,7 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt if adapter_starts[i] > 0: adapter_scales[i] = 0.00 pipe.set_ip_adapter_scale(adapter_scales if len(adapter_scales) > 1 else adapter_scales[0]) - ip_str = [f'{os.path.splitext(adapter)[0]}:{scale}:{start}:{end}:{crop}' for adapter, scale, start, end, crop in zip(adapter_names, adapter_scales, adapter_starts, adapter_ends, adapter_crops)] + ip_str = [f'{os.path.splitext(adapter)[0]}:{scale}:{start}:{end}:{crop}' for adapter, scale, start, end, crop in zip(adapter_names, adapter_scales, adapter_starts, adapter_ends, adapter_crops, strict=False)] if hasattr(pipe, 'transformer') and 'Nunchaku' in pipe.transformer.__class__.__name__: if isinstance(repos, str): sd_models.clear_caches(full=True) diff --git a/modules/loader.py b/modules/loader.py index b34e01dd6..f00f39882 100644 --- a/modules/loader.py +++ b/modules/loader.py @@ -63,7 +63,7 @@ try: except Exception: pass try: - import torch.distributed.distributed_c10d as _c10d # pylint: disable=unused-import,ungrouped-imports + pass # pylint: disable=unused-import,ungrouped-imports except Exception: errors.log.warning('Loader: torch is not built with distributed support') @@ -73,7 +73,6 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvisi torchvision = None try: import torchvision # pylint: disable=W0611,C0411 - import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them # pylint: disable=W0611,C0411 except Exception as e: errors.log.error(f'Loader: torchvision=={torchvision.__version__ if "torchvision" in sys.modules else None} {e}') if '_no_nep' in str(e): @@ -100,7 +99,7 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__: timer.startup.record("torch") try: - import bitsandbytes # pylint: disable=W0611,C0411 + import bitsandbytes # pylint: disable=unused-import _bnb = True except Exception: _bnb = False @@ -132,7 +131,6 @@ except Exception as e: errors.log.warning(f'Torch onnxruntime: {e}') timer.startup.record("onnx") -from fastapi import FastAPI # pylint: disable=W0611,C0411 timer.startup.record("fastapi") import gradio # pylint: disable=W0611,C0411 @@ -161,17 +159,16 @@ except Exception as e: sys.exit(1) try: - import pillow_jxl # pylint: disable=W0611,C0411 + pass # pylint: disable=W0611,C0411 except Exception: pass -from PIL import Image # pylint: disable=W0611,C0411 timer.startup.record("pillow") import cv2 # pylint: disable=W0611,C0411 timer.startup.record("cv2") -class _tqdm_cls(): +class _tqdm_cls: def __call__(self, *args, **kwargs): bar_format = 'Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + '{desc}' + '\x1b[0m' return tqdm_lib.tqdm(*args, bar_format=bar_format, ncols=80, colour='#327fba', **kwargs) diff --git a/modules/localization.py b/modules/localization.py index e3cc19959..e6fc9fdbe 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -28,7 +28,7 @@ def localization_js(current_localization_name): data = {} if fn is not None: try: - with open(fn, "r", encoding="utf8") as file: + with open(fn, encoding="utf8") as file: data = json.load(file) except Exception as e: errors.log.error(f"Error loading localization from {fn}:") diff --git a/modules/lora/extra_networks_lora.py b/modules/lora/extra_networks_lora.py index cd23998a8..1dabaaa01 100644 --- a/modules/lora/extra_networks_lora.py +++ b/modules/lora/extra_networks_lora.py @@ -1,4 +1,3 @@ -from typing import List import os import re import numpy as np @@ -19,7 +18,7 @@ def get_stepwise(param, step, steps): # from https://github.com/cheald/sd-webui- return steps[0][0] steps = [[s[0], s[1] if len(s) == 2 else 1] for s in steps] # Add implicit 1s to any steps which don't have a weight steps.sort(key=lambda k: k[1]) # Sort by index - steps = [list(v) for v in zip(*steps)] + steps = [list(v) for v in zip(*steps, strict=False)] return steps def calculate_weight(m, step, max_steps, step_offset=2): @@ -170,10 +169,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): self.model = None self.errors = {} - def signature(self, names: List[str], te_multipliers: List, unet_multipliers: List): - return [f'{name}:{te}:{unet}' for name, te, unet in zip(names, te_multipliers, unet_multipliers)] + def signature(self, names: list[str], te_multipliers: list, unet_multipliers: list): + return [f'{name}:{te}:{unet}' for name, te, unet in zip(names, te_multipliers, unet_multipliers, strict=False)] - def changed(self, requested: List[str], include: List[str] = None, exclude: List[str] = None) -> bool: + def changed(self, requested: list[str], include: list[str] = None, exclude: list[str] = None) -> bool: if shared.opts.lora_force_reload: debug_log(f'Network check: type=LoRA requested={requested} status=forced') return True @@ -190,7 +189,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): sd_model.loaded_loras[key] = requested debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=changed') return True - for req, load in zip(requested, loaded): + for req, load in zip(requested, loaded, strict=False): if req != load: sd_model.loaded_loras[key] = requested debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=changed') @@ -198,7 +197,11 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=same') return False - def activate(self, p, params_list, step=0, include=[], exclude=[]): + def activate(self, p, params_list, step=0, include=None, exclude=None): + if exclude is None: + exclude = [] + if include is None: + include = [] self.errors.clear() if self.active: if self.model != shared.opts.sd_model_checkpoint: # reset if model changed diff --git a/modules/lora/lora_apply.py b/modules/lora/lora_apply.py index e79306c9f..5cf0a4d38 100644 --- a/modules/lora/lora_apply.py +++ b/modules/lora/lora_apply.py @@ -1,4 +1,3 @@ -from typing import Union import re import time import torch @@ -12,7 +11,7 @@ bnb = None re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") -def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, wanted_names: tuple): +def network_backup_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, network_layer_name: str, wanted_names: tuple): global bnb # pylint: disable=W0603 backup_size = 0 if len(l.loaded_networks) > 0 and network_layer_name is not None and any([net.modules.get(network_layer_name, None) for net in l.loaded_networks]): # noqa: C419 # pylint: disable=R1729 @@ -76,7 +75,7 @@ def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n return backup_size -def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, use_previous: bool = False): +def network_calc_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, network_layer_name: str, use_previous: bool = False): if shared.opts.diffusers_offload_mode == "none": try: self.to(devices.device) @@ -147,7 +146,7 @@ def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn. return batch_updown, batch_ex_bias -def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], model_weights: Union[None, torch.Tensor] = None, lora_weights: torch.Tensor = None, deactivate: bool = False, device: torch.device = None, bias: bool = False): +def network_add_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, model_weights: None | torch.Tensor = None, lora_weights: torch.Tensor = None, deactivate: bool = False, device: torch.device = None, bias: bool = False): if lora_weights is None: return if deactivate: @@ -239,7 +238,7 @@ def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.G del model_weights, lora_weights, new_weight, weight # required to avoid memory leak -def network_apply_direct(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, deactivate: bool = False, device: torch.device = devices.device): +def network_apply_direct(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, updown: torch.Tensor, ex_bias: torch.Tensor, deactivate: bool = False, device: torch.device = devices.device): weights_backup = getattr(self, "network_weights_backup", False) bias_backup = getattr(self, "network_bias_backup", False) if not isinstance(weights_backup, bool): # remove previous backup if we switched settings @@ -266,7 +265,7 @@ def network_apply_direct(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn. l.timer.apply += time.time() - t0 -def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, device: torch.device, deactivate: bool = False): +def network_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, updown: torch.Tensor, ex_bias: torch.Tensor, device: torch.device, deactivate: bool = False): weights_backup = getattr(self, "network_weights_backup", None) bias_backup = getattr(self, "network_bias_backup", None) if weights_backup is None and bias_backup is None: diff --git a/modules/lora/lora_common.py b/modules/lora/lora_common.py index a6b15ae13..7f171846b 100644 --- a/modules/lora/lora_common.py +++ b/modules/lora/lora_common.py @@ -1,4 +1,3 @@ -from typing import List import os from modules.lora import lora_timers from modules.lora import network_lora, network_hada, network_ia3, network_oft, network_lokr, network_full, network_norm, network_glora @@ -16,6 +15,6 @@ module_types = [ network_norm.ModuleTypeNorm(), network_glora.ModuleTypeGLora(), ] -loaded_networks: List = [] # no type due to circular import -previously_loaded_networks: List = [] # no type due to circular import +loaded_networks: list = [] # no type due to circular import +previously_loaded_networks: list = [] # no type due to circular import extra_network_lora = None # initialized in extra_networks.py diff --git a/modules/lora/lora_convert.py b/modules/lora/lora_convert.py index aaef92e43..f019c450e 100644 --- a/modules/lora/lora_convert.py +++ b/modules/lora/lora_convert.py @@ -1,7 +1,6 @@ import os import re import bisect -from typing import Dict import torch from modules import shared @@ -23,7 +22,7 @@ re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") re_compiled = {} -def make_unet_conversion_map() -> Dict[str, str]: +def make_unet_conversion_map() -> dict[str, str]: unet_conversion_map_layer = [] for i in range(4): # num_blocks is 3 in sdxl @@ -213,10 +212,10 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): ait_sd.update({k: down_weight for k in ait_down_keys}) # up_weight is split to each split - ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0), strict=False)}) # noqa: C416 # pylint: disable=unnecessary-comprehension else: # down_weight is chunked to each split - ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0), strict=False)}) # noqa: C416 # pylint: disable=unnecessary-comprehension # up_weight is sparse: only non-zero values are copied to each split i = 0 diff --git a/modules/lora/lora_diffusers.py b/modules/lora/lora_diffusers.py index eb1515ca0..02179fb21 100644 --- a/modules/lora/lora_diffusers.py +++ b/modules/lora/lora_diffusers.py @@ -1,4 +1,3 @@ -from typing import Union import os import time import diffusers @@ -50,7 +49,7 @@ def load_per_module(sd_model: diffusers.DiffusionPipeline, filename: str, adapte return adapter_name -def load_diffusers(name: str, network_on_disk: network.NetworkOnDisk, lora_scale:float=shared.opts.extra_networks_default_multiplier, lora_module=None) -> Union[network.Network, None]: +def load_diffusers(name: str, network_on_disk: network.NetworkOnDisk, lora_scale:float=shared.opts.extra_networks_default_multiplier, lora_module=None) -> network.Network | None: t0 = time.time() name = name.replace(".", "_") sd_model: diffusers.DiffusionPipeline = getattr(shared.sd_model, "pipe", shared.sd_model) diff --git a/modules/lora/lora_load.py b/modules/lora/lora_load.py index a836b5323..ff5659ea5 100644 --- a/modules/lora/lora_load.py +++ b/modules/lora/lora_load.py @@ -1,4 +1,3 @@ -from typing import Union import os import time import concurrent @@ -39,7 +38,7 @@ def lora_dump(lora, dct): f.write(line + "\n") -def load_safetensors(name, network_on_disk: network.NetworkOnDisk) -> Union[network.Network, None]: +def load_safetensors(name, network_on_disk: network.NetworkOnDisk) -> network.Network | None: if not shared.sd_loaded: return None @@ -241,7 +240,7 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non lora_diffusers.diffuser_scales.clear() t0 = time.time() - for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): + for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names, strict=False)): net = None if network_on_disk is not None: shorthash = getattr(network_on_disk, 'shorthash', '').lower() diff --git a/modules/lora/lora_nunchaku.py b/modules/lora/lora_nunchaku.py index de4773158..318e25398 100644 --- a/modules/lora/lora_nunchaku.py +++ b/modules/lora/lora_nunchaku.py @@ -10,7 +10,7 @@ def load_nunchaku(names, strengths): global previously_loaded # pylint: disable=global-statement strengths = [s[0] if isinstance(s, list) else s for s in strengths] networks = lora_load.gather_networks(names) - networks = [(network, strength) for network, strength in zip(networks, strengths) if network is not None and strength > 0] + networks = [(network, strength) for network, strength in zip(networks, strengths, strict=False) if network is not None and strength > 0] loras = [(network.filename, strength) for network, strength in networks] is_changed = loras != previously_loaded if not is_changed: diff --git a/modules/lora/lora_timers.py b/modules/lora/lora_timers.py index 30c35a728..6f3e48c33 100644 --- a/modules/lora/lora_timers.py +++ b/modules/lora/lora_timers.py @@ -1,4 +1,4 @@ -class Timer(): +class Timer: list: float = 0 load: float = 0 backup: float = 0 diff --git a/modules/lora/network.py b/modules/lora/network.py index b8a09913b..a959a3338 100644 --- a/modules/lora/network.py +++ b/modules/lora/network.py @@ -1,6 +1,5 @@ import os import enum -from typing import Union from collections import namedtuple from modules import sd_models, hashes, shared @@ -120,7 +119,7 @@ class NetworkOnDisk: if self.filename is not None: fn = os.path.splitext(self.filename)[0] + '.txt' if os.path.exists(fn): - with open(fn, "r", encoding="utf-8") as file: + with open(fn, encoding="utf-8") as file: return file.read() return None @@ -144,7 +143,7 @@ class Network: # LoraModule class ModuleType: - def create_module(self, net: Network, weights: NetworkWeights) -> Union[Network, None]: # pylint: disable=W0613 + def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: # pylint: disable=W0613 return None diff --git a/modules/lora/networks.py b/modules/lora/networks.py index 69df992cc..512a63698 100644 --- a/modules/lora/networks.py +++ b/modules/lora/networks.py @@ -11,7 +11,11 @@ applied_layers: list[str] = [] default_components = ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'text_encoder_4', 'unet', 'transformer', 'transformer_2'] -def network_activate(include=[], exclude=[]): +def network_activate(include=None, exclude=None): + if exclude is None: + exclude = [] + if include is None: + include = [] t0 = time.time() with limit_errors("network_activate"): sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) @@ -77,7 +81,11 @@ def network_activate(include=[], exclude=[]): sd_models.set_diffuser_offload(sd_model, op="model") -def network_deactivate(include=[], exclude=[]): +def network_deactivate(include=None, exclude=None): + if exclude is None: + exclude = [] + if include is None: + include = [] if not shared.opts.lora_fuse_native or shared.opts.lora_force_diffusers: return if len(l.previously_loaded_networks) == 0: diff --git a/modules/masking.py b/modules/masking.py index ea1844c19..92de3bb75 100644 --- a/modules/masking.py +++ b/modules/masking.py @@ -1,5 +1,4 @@ from types import SimpleNamespace -from typing import List import os import sys import time @@ -235,7 +234,7 @@ def run_segment(input_image: gr.Image, input_mask: np.ndarray): combined_mask = np.zeros(input_mask.shape, dtype='uint8') input_mask_size = np.count_nonzero(input_mask) debug(f'Segment SAM: {vars(opts)}') - for mask, score in zip(outputs['masks'], outputs['scores']): + for mask, score in zip(outputs['masks'], outputs['scores'], strict=False): mask = mask.astype('uint8') mask_size = np.count_nonzero(mask) if mask_size == 0: @@ -561,7 +560,7 @@ def create_segment_ui(): return controls -def bind_controls(image_controls: List[gr.Image], preview_image: gr.Image, output_image: gr.Image): +def bind_controls(image_controls: list[gr.Image], preview_image: gr.Image, output_image: gr.Image): for image_control in image_controls: btn_mask.click(run_mask, inputs=[image_control], outputs=[preview_image]) btn_lama.click(run_lama, inputs=[image_control], outputs=[output_image]) diff --git a/modules/memmon.py b/modules/memmon.py index 944d85d83..eb521f93f 100644 --- a/modules/memmon.py +++ b/modules/memmon.py @@ -2,7 +2,7 @@ from collections import defaultdict import torch -class MemUsageMonitor(): +class MemUsageMonitor: device = None disabled = False opts = None diff --git a/modules/memstats.py b/modules/memstats.py index c9bb13238..0d67da44d 100644 --- a/modules/memstats.py +++ b/modules/memstats.py @@ -24,7 +24,7 @@ def get_docker_limit(): if docker_limit is not None: return docker_limit try: - with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r', encoding='utf8') as f: + with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', encoding='utf8') as f: docker_limit = float(f.read()) except Exception: docker_limit = sys.float_info.max @@ -145,7 +145,9 @@ class Object: return f'{self.fn}.{self.name} type={self.type} size={self.size} ref={self.refcount}' -def get_objects(gcl={}, threshold:int=0): +def get_objects(gcl=None, threshold:int=0): + if gcl is None: + gcl = {} objects = [] seen = [] diff --git a/modules/merging/convert_sdxl.py b/modules/merging/convert_sdxl.py index 93fc71f5d..3238cfd35 100644 --- a/modules/merging/convert_sdxl.py +++ b/modules/merging/convert_sdxl.py @@ -260,7 +260,9 @@ def calculate_model_hash(state_dict): return func.hexdigest() -def convert(model_path:str, checkpoint_path:str, metadata:dict={}): +def convert(model_path:str, checkpoint_path:str, metadata:dict=None): + if metadata is None: + metadata = {} unet_path = os.path.join(model_path, "unet", "diffusion_pytorch_model.safetensors") vae_path = os.path.join(model_path, "vae", "diffusion_pytorch_model.safetensors") text_enc_path = os.path.join(model_path, "text_encoder", "model.safetensors") diff --git a/modules/merging/merge.py b/modules/merging/merge.py index d0e48bfba..ec493de2e 100644 --- a/modules/merging/merge.py +++ b/modules/merging/merge.py @@ -1,7 +1,6 @@ import os from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from typing import Dict, Optional, Tuple, Set import safetensors.torch import torch import modules.memstats @@ -37,7 +36,7 @@ KEY_POSITION_IDS = ".".join( ) -def fix_clip(model: Dict) -> Dict: +def fix_clip(model: dict) -> dict: if KEY_POSITION_IDS in model.keys(): model[KEY_POSITION_IDS] = torch.tensor( [list(range(MAX_TOKENS))], @@ -48,7 +47,7 @@ def fix_clip(model: Dict) -> Dict: return model -def prune_sd_model(model: Dict, keyset: Set) -> Dict: +def prune_sd_model(model: dict, keyset: set) -> dict: keys = list(model.keys()) for k in keys: if ( @@ -60,7 +59,7 @@ def prune_sd_model(model: Dict, keyset: Set) -> Dict: return model -def restore_sd_model(original_model: Dict, merged_model: Dict) -> Dict: +def restore_sd_model(original_model: dict, merged_model: dict) -> dict: for k in original_model: if k not in merged_model: merged_model[k] = original_model[k] @@ -72,11 +71,11 @@ def log_vram(txt=""): def load_thetas( - models: Dict[str, os.PathLike], + models: dict[str, os.PathLike], prune: bool, device: torch.device, precision: str, -) -> Dict: +) -> dict: from tensordict import TensorDict thetas = {k: TensorDict.from_dict(read_state_dict(m, "cpu")) for k, m in models.items()} if prune: @@ -95,7 +94,7 @@ def load_thetas( def merge_models( - models: Dict[str, os.PathLike], + models: dict[str, os.PathLike], merge_mode: str, precision: str = "fp16", weights_clip: bool = False, @@ -104,7 +103,7 @@ def merge_models( prune: bool = False, threads: int = 4, **kwargs, -) -> Dict: +) -> dict: thetas = load_thetas(models, prune, device, precision) # log.info(f'Merge start: models={models.values()} precision={precision} clip={weights_clip} rebasin={re_basin} prune={prune} threads={threads}') weight_matcher = WeightClass(thetas["model_a"], **kwargs) @@ -136,13 +135,13 @@ def merge_models( def un_prune_model( - merged: Dict, - thetas: Dict, - models: Dict, + merged: dict, + thetas: dict, + models: dict, device: torch.device, prune: bool, precision: str, -) -> Dict: +) -> dict: if prune: log.info("Merge restoring pruned keys") del thetas @@ -180,7 +179,7 @@ def un_prune_model( def simple_merge( - thetas: Dict[str, Dict], + thetas: dict[str, dict], weight_matcher: WeightClass, merge_mode: str, precision: str = "fp16", @@ -188,7 +187,7 @@ def simple_merge( device: torch.device = None, work_device: torch.device = None, threads: int = 4, -) -> Dict: +) -> dict: futures = [] import rich.progress as p with p.Progress(p.TextColumn('[cyan]{task.description}'), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TextColumn('[cyan]keys={task.fields[keys]}'), console=console) as progress: @@ -227,7 +226,7 @@ def simple_merge( def rebasin_merge( - thetas: Dict[str, os.PathLike], + thetas: dict[str, os.PathLike], weight_matcher: WeightClass, merge_mode: str, precision: str = "fp16", @@ -306,14 +305,14 @@ def simple_merge_key(progress, task, key, thetas, *args, **kwargs): def merge_key( # pylint: disable=inconsistent-return-statements key: str, - thetas: Dict, + thetas: dict, weight_matcher: WeightClass, merge_mode: str, precision: str = "fp16", weights_clip: bool = False, device: torch.device = None, work_device: torch.device = None, -) -> Optional[Tuple[str, Dict]]: +) -> tuple[str, dict] | None: if work_device is None: work_device = device @@ -376,11 +375,11 @@ def merge_key_context(*args, **kwargs): def get_merge_method_args( - current_bases: Dict, - thetas: Dict, + current_bases: dict, + thetas: dict, key: str, work_device: torch.device, -) -> Dict: +) -> dict: merge_method_args = { "a": thetas["model_a"][key].to(work_device), "b": thetas["model_b"][key].to(work_device), diff --git a/modules/merging/merge_methods.py b/modules/merging/merge_methods.py index ce196b60c..3e54d501f 100644 --- a/modules/merging/merge_methods.py +++ b/modules/merging/merge_methods.py @@ -1,5 +1,4 @@ import math -from typing import Tuple import torch from torch import Tensor @@ -151,7 +150,7 @@ def kth_abs_value(a: Tensor, k: int) -> Tensor: return torch.kthvalue(torch.abs(a.float()), k)[0] -def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]: +def ratio_to_region(width: float, offset: float, n: int) -> tuple[int, int, bool]: if width < 0: offset += width width = -width @@ -233,7 +232,7 @@ def ties_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: flo delta_filters = (signs == final_sign).float() res = torch.zeros_like(c, device=c.device) - for delta_filter, delta in zip(delta_filters, deltas): + for delta_filter, delta in zip(delta_filters, deltas, strict=False): res += delta_filter * delta param_count = torch.sum(delta_filters, dim=0) diff --git a/modules/merging/modules_sdxl.py b/modules/merging/modules_sdxl.py index 959ad36e5..529066fc7 100644 --- a/modules/merging/modules_sdxl.py +++ b/modules/merging/modules_sdxl.py @@ -206,7 +206,7 @@ def test_model(pipe: diffusers.StableDiffusionXLPipeline, fn: str, **kwargs): if not test.generate: return try: - generator = torch.Generator(devices.device).manual_seed(int(4242)) + generator = torch.Generator(devices.device).manual_seed(4242) args = { 'prompt': test.prompt, 'negative_prompt': test.negative, @@ -278,7 +278,7 @@ def save_model(pipe: diffusers.StableDiffusionXLPipeline): yield msg(f'pretrained={folder}') shared.log.info(f'Modules merge save: type=sdxl diffusers="{folder}"') pipe.save_pretrained(folder, safe_serialization=True, push_to_hub=False) - with open(os.path.join(folder, 'vae', 'config.json'), 'r', encoding='utf8') as f: + with open(os.path.join(folder, 'vae', 'config.json'), encoding='utf8') as f: vae_config = json.load(f) vae_config['force_upcast'] = False vae_config['scaling_factor'] = 0.13025 diff --git a/modules/mit_nunchaku.py b/modules/mit_nunchaku.py index be6f84564..6ba77007e 100644 --- a/modules/mit_nunchaku.py +++ b/modules/mit_nunchaku.py @@ -53,8 +53,6 @@ def install_nunchaku(): import os import sys import platform - import importlib - import importlib.metadata import torch python_ver = f'{sys.version_info.major}{sys.version_info.minor}' if python_ver not in ['311', '312', '313']: diff --git a/modules/modeldata.py b/modules/modeldata.py index 7e2164095..48172baa3 100644 --- a/modules/modeldata.py +++ b/modules/modeldata.py @@ -220,5 +220,13 @@ class Shared(sys.modules[__name__].__class__): model_type = 'unknown' return model_type + @property + def console(self): + try: + from installer import get_console + return get_console() + except ImportError: + return None + model_data = ModelData() diff --git a/modules/modelloader.py b/modules/modelloader.py index ea45bea82..c168d981b 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -4,13 +4,12 @@ import time import shutil import importlib import contextlib -from typing import Dict from urllib.parse import urlparse import huggingface_hub as hf from installer import install, log from modules import shared, errors, files_cache from modules.upscaler import Upscaler -from modules.paths import script_path, models_path +from modules import paths loggedin = None @@ -55,7 +54,7 @@ def hf_login(token=None): return True -def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config: Dict[str, str] = None, token = None, variant = None, revision = None, mirror = None, custom_pipeline = None): +def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config: dict[str, str] = None, token = None, variant = None, revision = None, mirror = None, custom_pipeline = None): if hub_id is None or len(hub_id) == 0: return None from diffusers import DiffusionPipeline @@ -117,7 +116,7 @@ def load_diffusers_models(clear=True): # t0 = time.time() place = shared.opts.diffusers_dir if place is None or len(place) == 0 or not os.path.isdir(place): - place = os.path.join(models_path, 'Diffusers') + place = os.path.join(paths.models_path, 'Diffusers') if clear: diffuser_repos.clear() already_found = [] @@ -382,25 +381,25 @@ def cleanup_models(): # This code could probably be more efficient if we used a tuple list or something to store the src/destinations # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler # somehow auto-register and just do these things... - root_path = script_path - src_path = models_path - dest_path = os.path.join(models_path, "Stable-diffusion") + root_path = paths.script_path + src_path = paths.models_path + dest_path = os.path.join(paths.models_path, "Stable-diffusion") # move_files(src_path, dest_path, ".ckpt") # move_files(src_path, dest_path, ".safetensors") src_path = os.path.join(root_path, "ESRGAN") - dest_path = os.path.join(models_path, "ESRGAN") + dest_path = os.path.join(paths.models_path, "ESRGAN") move_files(src_path, dest_path) - src_path = os.path.join(models_path, "BSRGAN") - dest_path = os.path.join(models_path, "ESRGAN") + src_path = os.path.join(paths.models_path, "BSRGAN") + dest_path = os.path.join(paths.models_path, "ESRGAN") move_files(src_path, dest_path, ".pth") src_path = os.path.join(root_path, "SwinIR") - dest_path = os.path.join(models_path, "SwinIR") + dest_path = os.path.join(paths.models_path, "SwinIR") move_files(src_path, dest_path) src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/") - dest_path = os.path.join(models_path, "LDSR") + dest_path = os.path.join(paths.models_path, "LDSR") move_files(src_path, dest_path) src_path = os.path.join(root_path, "SCUNet") - dest_path = os.path.join(models_path, "SCUNet") + dest_path = os.path.join(paths.models_path, "SCUNet") move_files(src_path, dest_path) @@ -430,7 +429,7 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): def load_upscalers(): # We can only do this 'magic' method to dynamically load upscalers if they are referenced, so we'll try to import any _model.py files before looking in __subclasses__ t0 = time.time() - modules_dir = os.path.join(shared.script_path, "modules", "postprocess") + modules_dir = os.path.join(paths.script_path, "modules", "postprocess") for file in os.listdir(modules_dir): if "_model.py" in file: model_name = file.replace("_model.py", "") diff --git a/modules/modelstats.py b/modules/modelstats.py index 2ab1a6b62..b1fc54d8d 100644 --- a/modules/modelstats.py +++ b/modules/modelstats.py @@ -28,7 +28,7 @@ def stat(fn: str): return size, mtime -class Module(): +class Module: name: str = '' cls: str = None device: str = None @@ -61,7 +61,7 @@ class Module(): return s -class Model(): +class Model: name: str = '' fn: str = '' type: str = '' diff --git a/modules/olive_script.py b/modules/olive_script.py index c881679ed..c194e7f42 100644 --- a/modules/olive_script.py +++ b/modules/olive_script.py @@ -1,17 +1,18 @@ import os -from typing import Type, Callable, TypeVar, Dict, Any +from typing import TypeVar, Any +from collections.abc import Callable import torch import diffusers from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPTextModelWithProjection class ENVStore: - __DESERIALIZER: Dict[Type, Callable[[str,], Any]] = { + __DESERIALIZER: dict[type, Callable[[str,], Any]] = { bool: lambda x: bool(int(x)), int: int, str: lambda x: x, } - __SERIALIZER: Dict[Type, Callable[[Any,], str]] = { + __SERIALIZER: dict[type, Callable[[Any,], str]] = { bool: lambda x: str(int(x)), int: str, str: lambda x: x, @@ -89,7 +90,7 @@ def get_loader_arguments(no_variant: bool = False): T = TypeVar("T") -def from_pretrained(cls: Type[T], pretrained_model_name_or_path: os.PathLike, *args, no_variant: bool = False, **kwargs) -> T: +def from_pretrained(cls: type[T], pretrained_model_name_or_path: os.PathLike, *args, no_variant: bool = False, **kwargs) -> T: pretrained_model_name_or_path = str(pretrained_model_name_or_path) if pretrained_model_name_or_path.endswith(".onnx"): cls = diffusers.OnnxRuntimeModel diff --git a/modules/onnx_impl/__init__.py b/modules/onnx_impl/__init__.py index f7ef1b8be..1decd1d41 100644 --- a/modules/onnx_impl/__init__.py +++ b/modules/onnx_impl/__init__.py @@ -16,7 +16,7 @@ except Exception as e: class DynamicSessionOptions(ort.SessionOptions): - config: Optional[Dict] = None + config: dict | None = None def __init__(self): super().__init__() @@ -28,7 +28,7 @@ class DynamicSessionOptions(ort.SessionOptions): return sess_options.copy() return DynamicSessionOptions() - def enable_static_dims(self, config: Dict): + def enable_static_dims(self, config: dict): self.config = config self.add_free_dimension_override_by_name("unet_sample_batch", config["hidden_batch_size"]) self.add_free_dimension_override_by_name("unet_sample_channels", 4) @@ -103,9 +103,9 @@ class OnnxRuntimeModel(TorchCompatibleModule, diffusers.OnnxRuntimeModel): class VAEConfig: DEFAULTS = { "scaling_factor": 0.18215 } - config: Dict + config: dict - def __init__(self, config: Dict): + def __init__(self, config: dict): self.config = config def __getattr__(self, key): diff --git a/modules/onnx_impl/execution_providers.py b/modules/onnx_impl/execution_providers.py index dd2622f1f..b220692de 100644 --- a/modules/onnx_impl/execution_providers.py +++ b/modules/onnx_impl/execution_providers.py @@ -1,6 +1,5 @@ import sys from enum import Enum -from typing import Tuple, List from installer import log from modules import devices @@ -33,7 +32,7 @@ TORCH_DEVICE_TO_EP = { try: import onnxruntime as ort - available_execution_providers: List[ExecutionProvider] = ort.get_available_providers() + available_execution_providers: list[ExecutionProvider] = ort.get_available_providers() except Exception as e: log.error(f'ONNX import error: {e}') available_execution_providers = [] @@ -90,7 +89,7 @@ def get_execution_provider_options(): return execution_provider_options -def get_provider() -> Tuple: +def get_provider() -> tuple: from modules.shared import opts return (opts.onnx_execution_provider, get_execution_provider_options(),) diff --git a/modules/onnx_impl/pipelines/__init__.py b/modules/onnx_impl/pipelines/__init__.py index 99682b1aa..9b57d6577 100644 --- a/modules/onnx_impl/pipelines/__init__.py +++ b/modules/onnx_impl/pipelines/__init__.py @@ -103,12 +103,12 @@ class OnnxRawPipeline(PipelineBase): path: os.PathLike original_filename: str - constructor: Type[PipelineBase] - init_dict: Dict[str, Tuple[str]] = {} + constructor: type[PipelineBase] + init_dict: dict[str, tuple[str]] = {} default_scheduler: Any = None # for Img2Img - def __init__(self, constructor: Type[PipelineBase], path: os.PathLike): # pylint: disable=super-init-not-called + def __init__(self, constructor: type[PipelineBase], path: os.PathLike): # pylint: disable=super-init-not-called self._is_sdxl = check_pipeline_sdxl(constructor) self.from_diffusers_cache = check_diffusers_cache(path) self.path = path @@ -150,7 +150,7 @@ class OnnxRawPipeline(PipelineBase): pipeline.scheduler = self.default_scheduler return pipeline - def convert(self, submodels: List[str], in_dir: os.PathLike, out_dir: os.PathLike): + def convert(self, submodels: list[str], in_dir: os.PathLike, out_dir: os.PathLike): install('onnx') # may not be installed yet, this performs check and installs as needed import onnx shutil.rmtree("cache", ignore_errors=True) @@ -218,7 +218,7 @@ class OnnxRawPipeline(PipelineBase): with open(os.path.join(out_dir, "model_index.json"), 'w', encoding="utf-8") as file: json.dump(model_index, file) - def run_olive(self, submodels: List[str], in_dir: os.PathLike, out_dir: os.PathLike): + def run_olive(self, submodels: list[str], in_dir: os.PathLike, out_dir: os.PathLike): from olive.model import ONNXModelHandler from olive.workflows import run as run_workflows @@ -235,8 +235,8 @@ class OnnxRawPipeline(PipelineBase): for submodel in submodels: log.info(f"\nProcessing {submodel}") - with open(os.path.join(sd_configs_path, "olive", 'sdxl' if self._is_sdxl else 'sd', f"{submodel}.json"), "r", encoding="utf-8") as config_file: - olive_config: Dict[str, Dict[str, Dict]] = json.load(config_file) + with open(os.path.join(sd_configs_path, "olive", 'sdxl' if self._is_sdxl else 'sd', f"{submodel}.json"), encoding="utf-8") as config_file: + olive_config: dict[str, dict[str, dict]] = json.load(config_file) for flow in olive_config["pass_flows"]: for i in range(len(flow)): @@ -257,7 +257,7 @@ class OnnxRawPipeline(PipelineBase): run_workflows(olive_config) - with open(os.path.join("footprints", f"{submodel}_{EP_TO_NAME[shared.opts.onnx_execution_provider]}_footprints.json"), "r", encoding="utf-8") as footprint_file: + with open(os.path.join("footprints", f"{submodel}_{EP_TO_NAME[shared.opts.onnx_execution_provider]}_footprints.json"), encoding="utf-8") as footprint_file: footprints = json.load(footprint_file) processor_final_pass_footprint = None for _, footprint in footprints.items(): diff --git a/modules/onnx_impl/pipelines/onnx_stable_diffusion_img2img_pipeline.py b/modules/onnx_impl/pipelines/onnx_stable_diffusion_img2img_pipeline.py index 6d8ea5946..82c9740a8 100644 --- a/modules/onnx_impl/pipelines/onnx_stable_diffusion_img2img_pipeline.py +++ b/modules/onnx_impl/pipelines/onnx_stable_diffusion_img2img_pipeline.py @@ -1,5 +1,6 @@ import inspect -from typing import Union, Optional, Callable, List, Any +from typing import Any +from collections.abc import Callable import numpy as np import torch import diffusers @@ -33,20 +34,20 @@ class OnnxStableDiffusionImg2ImgPipeline(diffusers.OnnxStableDiffusionImg2ImgPip def __call__( self, - prompt: Union[str, List[str]], + prompt: str | list[str], image: PipelineImageInput = None, strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - output_type: Optional[str] = "pil", + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + output_type: str | None = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback: Callable[[int, int, np.ndarray], None] | None = None, callback_steps: int = 1, ): # check inputs. Raise error if not correct diff --git a/modules/onnx_impl/pipelines/onnx_stable_diffusion_inpaint_pipeline.py b/modules/onnx_impl/pipelines/onnx_stable_diffusion_inpaint_pipeline.py index dccfb808d..e8ce33fc4 100644 --- a/modules/onnx_impl/pipelines/onnx_stable_diffusion_inpaint_pipeline.py +++ b/modules/onnx_impl/pipelines/onnx_stable_diffusion_inpaint_pipeline.py @@ -1,5 +1,6 @@ import inspect -from typing import Union, Optional, Callable, List, Any +from typing import Any +from collections.abc import Callable import numpy as np import torch import diffusers @@ -31,25 +32,25 @@ class OnnxStableDiffusionInpaintPipeline(diffusers.OnnxStableDiffusionInpaintPip @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], + prompt: str | list[str], image: PipelineImageInput, mask_image: PipelineImageInput, masked_image_latents: torch.FloatTensor = None, - height: Optional[int] = 512, - width: Optional[int] = 512, + height: int | None = 512, + width: int | None = 512, strength: float = 1.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[np.ndarray] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - output_type: Optional[str] = "pil", + generator: torch.Generator | list[torch.Generator] | None = None, + latents: np.ndarray | None = None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + output_type: str | None = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback: Callable[[int, int, np.ndarray], None] | None = None, callback_steps: int = 1, ): # check inputs. Raise error if not correct diff --git a/modules/onnx_impl/pipelines/onnx_stable_diffusion_pipeline.py b/modules/onnx_impl/pipelines/onnx_stable_diffusion_pipeline.py index 112241996..2b583e8f5 100644 --- a/modules/onnx_impl/pipelines/onnx_stable_diffusion_pipeline.py +++ b/modules/onnx_impl/pipelines/onnx_stable_diffusion_pipeline.py @@ -1,5 +1,6 @@ import inspect -from typing import Union, Optional, Callable, List, Any +from typing import Any +from collections.abc import Callable import numpy as np import torch import diffusers @@ -29,21 +30,21 @@ class OnnxStableDiffusionPipeline(diffusers.OnnxStableDiffusionPipeline, Callabl def __call__( self, - prompt: Union[str, List[str]] = None, - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[np.ndarray] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - output_type: Optional[str] = "pil", + prompt: str | list[str] = None, + height: int | None = 512, + width: int | None = 512, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: np.ndarray | None = None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + output_type: str | None = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback: Callable[[int, int, np.ndarray], None] | None = None, callback_steps: int = 1, ): # check inputs. Raise error if not correct diff --git a/modules/onnx_impl/pipelines/onnx_stable_diffusion_upscale_pipeline.py b/modules/onnx_impl/pipelines/onnx_stable_diffusion_upscale_pipeline.py index 5bdc09794..f575959ab 100644 --- a/modules/onnx_impl/pipelines/onnx_stable_diffusion_upscale_pipeline.py +++ b/modules/onnx_impl/pipelines/onnx_stable_diffusion_upscale_pipeline.py @@ -1,5 +1,6 @@ import inspect -from typing import Union, Optional, Callable, Any, List +from typing import Any +from collections.abc import Callable import torch import numpy as np import diffusers @@ -31,22 +32,22 @@ class OnnxStableDiffusionUpscalePipeline(diffusers.OnnxStableDiffusionUpscalePip def __call__( self, - prompt: Union[str, List[str]], + prompt: str | list[str], image: PipelineImageInput = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, noise_level: int = 20, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[np.ndarray] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - output_type: Optional[str] = "pil", + generator: torch.Generator | list[torch.Generator] | None = None, + latents: np.ndarray | None = None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + output_type: str | None = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, - callback_steps: Optional[int] = 1, + callback: Callable[[int, int, np.ndarray], None] | None = None, + callback_steps: int | None = 1, ): # 1. Check inputs self.check_inputs( diff --git a/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_img2img_pipeline.py b/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_img2img_pipeline.py index 2627ba074..7a30a9a99 100644 --- a/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_img2img_pipeline.py +++ b/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_img2img_pipeline.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict, Any +from typing import Any import numpy as np import torch import onnxruntime as ort @@ -17,16 +17,16 @@ class OnnxStableDiffusionXLImg2ImgPipeline(CallablePipelineBase, optimum.onnxrun vae_decoder: ort.InferenceSession, text_encoder: ort.InferenceSession, unet: ort.InferenceSession, - config: Dict[str, Any], + config: dict[str, Any], tokenizer: Any, scheduler: Any, feature_extractor = None, - vae_encoder: Optional[ort.InferenceSession] = None, - text_encoder_2: Optional[ort.InferenceSession] = None, + vae_encoder: ort.InferenceSession | None = None, + text_encoder_2: ort.InferenceSession | None = None, tokenizer_2: Any = None, - use_io_binding: Optional[bool] = None, + use_io_binding: bool | None = None, model_save_dir = None, - add_watermarker: Optional[bool] = None + add_watermarker: bool | None = None ): optimum.onnxruntime.ORTStableDiffusionXLImg2ImgPipeline.__init__(self, vae_decoder, text_encoder, unet, config, tokenizer, scheduler, feature_extractor, vae_encoder, text_encoder_2, tokenizer_2, use_io_binding, model_save_dir, add_watermarker) super().__init__() diff --git a/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_pipeline.py b/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_pipeline.py index 452e4f892..bbc541965 100644 --- a/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_pipeline.py +++ b/modules/onnx_impl/pipelines/onnx_stable_diffusion_xl_pipeline.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict, Any +from typing import Any import onnxruntime as ort import optimum.onnxruntime from modules.onnx_impl.pipelines import CallablePipelineBase @@ -14,16 +14,16 @@ class OnnxStableDiffusionXLPipeline(CallablePipelineBase, optimum.onnxruntime.OR vae_decoder: ort.InferenceSession, text_encoder: ort.InferenceSession, unet: ort.InferenceSession, - config: Dict[str, Any], + config: dict[str, Any], tokenizer: Any, scheduler: Any, feature_extractor: Any = None, - vae_encoder: Optional[ort.InferenceSession] = None, - text_encoder_2: Optional[ort.InferenceSession] = None, + vae_encoder: ort.InferenceSession | None = None, + text_encoder_2: ort.InferenceSession | None = None, tokenizer_2: Any = None, - use_io_binding: Optional[bool] = None, + use_io_binding: bool | None = None, model_save_dir = None, - add_watermarker: Optional[bool] = None + add_watermarker: bool | None = None ): optimum.onnxruntime.ORTStableDiffusionXLPipeline.__init__(self, vae_decoder, text_encoder, unet, config, tokenizer, scheduler, feature_extractor, vae_encoder, text_encoder_2, tokenizer_2, use_io_binding, model_save_dir, add_watermarker) super().__init__() diff --git a/modules/onnx_impl/pipelines/utils.py b/modules/onnx_impl/pipelines/utils.py index c389cac01..f6b980302 100644 --- a/modules/onnx_impl/pipelines/utils.py +++ b/modules/onnx_impl/pipelines/utils.py @@ -1,9 +1,8 @@ -from typing import Union, List import numpy as np import torch -def extract_generator_seed(generator: Union[torch.Generator, List[torch.Generator]]) -> List[int]: +def extract_generator_seed(generator: torch.Generator | list[torch.Generator]) -> list[int]: if isinstance(generator, list): generator = [g.seed() for g in generator] else: @@ -11,7 +10,7 @@ def extract_generator_seed(generator: Union[torch.Generator, List[torch.Generato return generator -def randn_tensor(shape, dtype: np.dtype, generator: Union[torch.Generator, List[torch.Generator], int, List[int]]): +def randn_tensor(shape, dtype: np.dtype, generator: torch.Generator | list[torch.Generator] | int | list[int]): if hasattr(generator, "seed") or (isinstance(generator, list) and hasattr(generator[0], "seed")): generator = extract_generator_seed(generator) if len(generator) == 1: @@ -25,8 +24,8 @@ def prepare_latents( height: int, width: int, dtype: np.dtype, - generator: Union[torch.Generator, List[torch.Generator]], - latents: Union[np.ndarray, None] = None, + generator: torch.Generator | list[torch.Generator], + latents: np.ndarray | None = None, num_channels_latents = 4, vae_scale_factor = 8, ): diff --git a/modules/onnx_impl/ui.py b/modules/onnx_impl/ui.py index 703392d82..0a8ca2d22 100644 --- a/modules/onnx_impl/ui.py +++ b/modules/onnx_impl/ui.py @@ -1,11 +1,10 @@ import os import json import shutil -from typing import Dict, List, Union import gradio as gr -def get_recursively(d: Union[Dict, List], *args): +def get_recursively(d: dict | list, *args): if len(args) == 0: return d return get_recursively(d.get(args[0]), *args[1:]) @@ -112,19 +111,19 @@ def create_ui(): with gr.TabItem("Stable Diffusion", id="sd"): sd_config_path = os.path.join(sd_configs_path, "olive", "sd") sd_submodels = os.listdir(sd_config_path) - sd_configs: Dict[str, Dict[str, Dict[str, Dict]]] = {} - sd_pass_config_components: Dict[str, Dict[str, Dict]] = {} + sd_configs: dict[str, dict[str, dict[str, dict]]] = {} + sd_pass_config_components: dict[str, dict[str, dict]] = {} with gr.Tabs(elem_id="tabs_sd_submodel"): def sd_create_change_listener(*args): - def listener(v: Dict): + def listener(v: dict): get_recursively(sd_configs, *args[:-1])[args[-1]] = v return listener for submodel in sd_submodels: - config: Dict = None + config: dict = None sd_pass_config_components[submodel] = {} - with open(os.path.join(sd_config_path, submodel), "r", encoding="utf-8") as file: + with open(os.path.join(sd_config_path, submodel), encoding="utf-8") as file: config = json.load(file) sd_configs[submodel] = config @@ -175,19 +174,19 @@ def create_ui(): with gr.TabItem("Stable Diffusion XL", id="sdxl"): sdxl_config_path = os.path.join(sd_configs_path, "olive", "sdxl") sdxl_submodels = os.listdir(sdxl_config_path) - sdxl_configs: Dict[str, Dict[str, Dict[str, Dict]]] = {} - sdxl_pass_config_components: Dict[str, Dict[str, Dict]] = {} + sdxl_configs: dict[str, dict[str, dict[str, dict]]] = {} + sdxl_pass_config_components: dict[str, dict[str, dict]] = {} with gr.Tabs(elem_id="tabs_sdxl_submodel"): def sdxl_create_change_listener(*args): - def listener(v: Dict): + def listener(v: dict): get_recursively(sdxl_configs, *args[:-1])[args[-1]] = v return listener for submodel in sdxl_submodels: - config: Dict = None + config: dict = None sdxl_pass_config_components[submodel] = {} - with open(os.path.join(sdxl_config_path, submodel), "r", encoding="utf-8") as file: + with open(os.path.join(sdxl_config_path, submodel), encoding="utf-8") as file: config = json.load(file) sdxl_configs[submodel] = config diff --git a/modules/onnx_impl/utils.py b/modules/onnx_impl/utils.py index 80b75cb4a..5d3f7c06e 100644 --- a/modules/onnx_impl/utils.py +++ b/modules/onnx_impl/utils.py @@ -1,12 +1,12 @@ import os import json import importlib -from typing import Type, Tuple, Union, List, Dict, Any +from typing import Any import torch import diffusers -def extract_device(args: List, kwargs: Dict): +def extract_device(args: list, kwargs: dict): device = kwargs.get("device", None) if device is None: @@ -42,7 +42,7 @@ def check_diffusers_cache(path: os.PathLike): return opts.diffusers_dir in os.path.abspath(path) -def check_pipeline_sdxl(cls: Type[diffusers.DiffusionPipeline]) -> bool: +def check_pipeline_sdxl(cls: type[diffusers.DiffusionPipeline]) -> bool: return 'XL' in cls.__name__ @@ -57,7 +57,7 @@ def check_cache_onnx(path: os.PathLike) -> bool: init_dict = None - with open(init_dict_path, "r", encoding="utf-8") as file: + with open(init_dict_path, encoding="utf-8") as file: init_dict = file.read() if "OnnxRuntimeModel" not in init_dict: @@ -66,15 +66,15 @@ def check_cache_onnx(path: os.PathLike) -> bool: return True -def load_init_dict(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike): - merged: Dict[str, Any] = {} +def load_init_dict(cls: type[diffusers.DiffusionPipeline], path: os.PathLike): + merged: dict[str, Any] = {} extracted = cls.extract_init_dict(diffusers.DiffusionPipeline.load_config(path)) for item in extracted: merged.update(item) merged = merged.items() - R: Dict[str, Tuple[str]] = {} + R: dict[str, tuple[str]] = {} for k, v in merged: if isinstance(v, list): @@ -85,7 +85,7 @@ def load_init_dict(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike): return R -def load_submodel(path: os.PathLike, is_sdxl: bool, submodel_name: str, item: List[Union[str, None]], **kwargs_ort): +def load_submodel(path: os.PathLike, is_sdxl: bool, submodel_name: str, item: list[str | None], **kwargs_ort): lib, atr = item if lib is None or atr is None: @@ -107,7 +107,7 @@ def load_submodel(path: os.PathLike, is_sdxl: bool, submodel_name: str, item: Li return attribute.from_pretrained(path) -def load_submodels(path: os.PathLike, is_sdxl: bool, init_dict: Dict[str, Type], **kwargs_ort): +def load_submodels(path: os.PathLike, is_sdxl: bool, init_dict: dict[str, type], **kwargs_ort): loaded = {} for k, v in init_dict.items(): @@ -122,14 +122,14 @@ def load_submodels(path: os.PathLike, is_sdxl: bool, init_dict: Dict[str, Type], return loaded -def load_pipeline(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike, **kwargs_ort) -> diffusers.DiffusionPipeline: +def load_pipeline(cls: type[diffusers.DiffusionPipeline], path: os.PathLike, **kwargs_ort) -> diffusers.DiffusionPipeline: if os.path.isdir(path): return cls(**patch_kwargs(cls, load_submodels(path, check_pipeline_sdxl(cls), load_init_dict(cls, path), **kwargs_ort))) else: return cls.from_single_file(path) -def patch_kwargs(cls: Type[diffusers.DiffusionPipeline], kwargs: Dict) -> Dict: +def patch_kwargs(cls: type[diffusers.DiffusionPipeline], kwargs: dict) -> dict: if cls == diffusers.OnnxStableDiffusionPipeline or cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline: kwargs["safety_checker"] = None kwargs["requires_safety_checker"] = False @@ -140,7 +140,7 @@ def patch_kwargs(cls: Type[diffusers.DiffusionPipeline], kwargs: Dict) -> Dict: return kwargs -def get_base_constructor(cls: Type[diffusers.DiffusionPipeline], is_refiner: bool): +def get_base_constructor(cls: type[diffusers.DiffusionPipeline], is_refiner: bool): if cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline: return diffusers.OnnxStableDiffusionPipeline @@ -153,8 +153,8 @@ def get_base_constructor(cls: Type[diffusers.DiffusionPipeline], is_refiner: boo def get_io_config(submodel: str, is_sdxl: bool): from modules.paths import sd_configs_path - with open(os.path.join(sd_configs_path, "olive", 'sdxl' if is_sdxl else 'sd', f"{submodel}.json"), "r", encoding="utf-8") as config_file: - io_config: Dict[str, Any] = json.load(config_file)["input_model"]["config"]["io_config"] + with open(os.path.join(sd_configs_path, "olive", 'sdxl' if is_sdxl else 'sd', f"{submodel}.json"), encoding="utf-8") as config_file: + io_config: dict[str, Any] = json.load(config_file)["input_model"]["config"]["io_config"] for axe in io_config["dynamic_axes"]: io_config["dynamic_axes"][axe] = { int(k): v for k, v in io_config["dynamic_axes"][axe].items() } diff --git a/modules/options_handler.py b/modules/options_handler.py index b087c0529..61ee1a29f 100644 --- a/modules/options_handler.py +++ b/modules/options_handler.py @@ -18,13 +18,15 @@ cmd_opts = cmd_args.parse_args() compatibility_opts = ['clip_skip', 'uni_pc_lower_order_final', 'uni_pc_order'] -class Options(): +class Options: data_labels: dict[str, OptionInfo | LegacyOption] data: dict[str, Any] typemap = {int: float} debug = os.environ.get('SD_CONFIG_DEBUG', None) is not None - def __init__(self, options_templates: dict[str, OptionInfo | LegacyOption] = {}, restricted_opts: set[str] | None = None, *, filename = ''): + def __init__(self, options_templates: dict[str, OptionInfo | LegacyOption] = None, restricted_opts: set[str] | None = None, *, filename = ''): + if options_templates is None: + options_templates = {} if restricted_opts is None: restricted_opts = set() super().__setattr__('data_labels', options_templates) @@ -48,21 +50,21 @@ class Options(): log.warning(f'Settings set: {key}={value} legacy') self.data[key] = value return - return super(Options, self).__setattr__(key, value) # pylint: disable=super-with-arguments + return super().__setattr__(key, value) # pylint: disable=super-with-arguments def get(self, item): if item in self.data: return self.data[item] if item in self.data_labels: return self.data_labels[item].default - return super(Options, self).__getattribute__(item) # pylint: disable=super-with-arguments + return super().__getattribute__(item) # pylint: disable=super-with-arguments def __getattr__(self, item): if item in self.data: return self.data[item] if item in self.data_labels: return self.data_labels[item].default - return super(Options, self).__getattribute__(item) # pylint: disable=super-with-arguments + return super().__getattribute__(item) # pylint: disable=super-with-arguments def set(self, key, value): """sets an option and calls its onchange callback, returning True if the option changed and False otherwise""" diff --git a/modules/patches.py b/modules/patches.py index f24a38293..655cf4534 100644 --- a/modules/patches.py +++ b/modules/patches.py @@ -1,5 +1,4 @@ from collections import defaultdict -from typing import Optional from modules.errors import log @@ -55,13 +54,13 @@ def original(key, obj, field): return originals[key].get(patch_key, None) -def patch_method(cls, key:Optional[str]=None): +def patch_method(cls, key:str | None=None): def decorator(func): patch(func.__module__ if key is None else key, cls, func.__name__, func) return decorator -def add_method(cls, key:Optional[str]=None): +def add_method(cls, key:str | None=None): def decorator(func): patch(func.__module__ if key is None else key, cls, func.__name__, func, True) return decorator diff --git a/modules/paths.py b/modules/paths.py index 63ed005c7..cd7b6e5d4 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -2,7 +2,6 @@ import os import sys import json -import shlex import argparse import tempfile from installer import log @@ -19,7 +18,7 @@ cli = parser.parse_known_args(argv)[0] config_path = cli.config if os.path.isabs(cli.config) else os.path.join(cli.data_dir, cli.config) try: - with open(config_path, 'r', encoding='utf8') as f: + with open(config_path, encoding='utf8') as f: config = json.load(f) except Exception: config = {} diff --git a/modules/paths_internal.py b/modules/paths_internal.py index a9dabdd0f..f304361aa 100644 --- a/modules/paths_internal.py +++ b/modules/paths_internal.py @@ -1,3 +1,2 @@ # no longer used, all paths are defined in paths.py -from modules.paths import modules_path, script_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, data_path, models_path, extensions_dir, extensions_builtin_dir # pylint: disable=unused-import diff --git a/modules/postprocess/aurasr_model.py b/modules/postprocess/aurasr_model.py index 9d77af93f..c2030d184 100644 --- a/modules/postprocess/aurasr_model.py +++ b/modules/postprocess/aurasr_model.py @@ -1,5 +1,4 @@ import torch -import diffusers from PIL import Image from modules import shared, devices from modules.upscaler import Upscaler, UpscalerData diff --git a/modules/postprocess/esrgan_model_arch.py b/modules/postprocess/esrgan_model_arch.py index bf9f0ac6e..e50b441b7 100644 --- a/modules/postprocess/esrgan_model_arch.py +++ b/modules/postprocess/esrgan_model_arch.py @@ -14,7 +14,7 @@ class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', finalact=None, gaussian_noise=False, plus=False): - super(RRDBNet, self).__init__() + super().__init__() n_upscale = int(math.log(upscale, 2)) if upscale == 3: n_upscale = 1 @@ -69,7 +69,7 @@ class RRDB(nn.Module): def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', spectral_norm=False, gaussian_noise=False, plus=False): - super(RRDB, self).__init__() + super().__init__() # This is for backwards compatibility with existing models if nr == 3: self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, @@ -111,7 +111,7 @@ class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', spectral_norm=False, gaussian_noise=False, plus=False): - super(ResidualDenseBlock_5C, self).__init__() + super().__init__() self.noise = GaussianNoise() if gaussian_noise else None self.conv1x1 = conv1x1(nf, gc) if plus else None @@ -185,7 +185,7 @@ class SRVGGNetCompact(nn.Module): """ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): - super(SRVGGNetCompact, self).__init__() + super().__init__() self.num_in_ch = num_in_ch self.num_out_ch = num_out_ch self.num_feat = num_feat @@ -245,7 +245,7 @@ class Upsample(nn.Module): """ def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None): - super(Upsample, self).__init__() + super().__init__() if isinstance(scale_factor, tuple): self.scale_factor = tuple(float(factor) for factor in scale_factor) else: @@ -354,7 +354,7 @@ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0): class Identity(nn.Module): def __init__(self, *kwargs): - super(Identity, self).__init__() + super().__init__() def forward(self, x, *kwargs): return x @@ -399,7 +399,7 @@ def get_valid_padding(kernel_size, dilation): class ShortcutBlock(nn.Module): """ Elementwise sum the output of a submodule to its input """ def __init__(self, submodule): - super(ShortcutBlock, self).__init__() + super().__init__() self.sub = submodule def forward(self, x): diff --git a/modules/postprocess/pixelart.py b/modules/postprocess/pixelart.py index 3295a0a21..00b9b9636 100644 --- a/modules/postprocess/pixelart.py +++ b/modules/postprocess/pixelart.py @@ -1,4 +1,3 @@ -from typing import List import math import torch @@ -225,8 +224,8 @@ class JPEGEncoder(ImageProcessingMixin, ConfigMixin): block_size: int = 16, cbcr_downscale: int = 2, norm: str = "ortho", - latents_std: List[float] = None, - latents_mean: List[float] = None, + latents_std: list[float] = None, + latents_mean: list[float] = None, ): self.block_size = block_size self.cbcr_downscale = cbcr_downscale diff --git a/modules/postprocess/realesrgan_model_arch.py b/modules/postprocess/realesrgan_model_arch.py index bfdfffad6..dd350e4d9 100644 --- a/modules/postprocess/realesrgan_model_arch.py +++ b/modules/postprocess/realesrgan_model_arch.py @@ -14,7 +14,7 @@ from modules.upscaler import compile_upscaler ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -class RealESRGANer(): +class RealESRGANer: """A helper class for upsampling images with RealESRGAN. Args: @@ -340,7 +340,7 @@ class SRVGGNetCompact(nn.Module): """ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): - super(SRVGGNetCompact, self).__init__() + super().__init__() self.num_in_ch = num_in_ch self.num_out_ch = num_out_ch self.num_feat = num_feat diff --git a/modules/postprocess/scunet_model_arch.py b/modules/postprocess/scunet_model_arch.py index b51a88062..2441e06e2 100644 --- a/modules/postprocess/scunet_model_arch.py +++ b/modules/postprocess/scunet_model_arch.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import numpy as np import torch import torch.nn as nn @@ -12,7 +11,7 @@ class WMSA(nn.Module): """ def __init__(self, input_dim, output_dim, head_dim, window_size, type): - super(WMSA, self).__init__() + super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.head_dim = head_dim @@ -103,7 +102,7 @@ class Block(nn.Module): def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): """ SwinTransformer Block """ - super(Block, self).__init__() + super().__init__() self.input_dim = input_dim self.output_dim = output_dim assert type in ['W', 'SW'] @@ -131,7 +130,7 @@ class ConvTransBlock(nn.Module): def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): """ SwinTransformer and Conv Block """ - super(ConvTransBlock, self).__init__() + super().__init__() self.conv_dim = conv_dim self.trans_dim = trans_dim self.head_dim = head_dim @@ -170,7 +169,7 @@ class ConvTransBlock(nn.Module): class SCUNet(nn.Module): # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256): def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256): - super(SCUNet, self).__init__() + super().__init__() if config is None: config = [2, 2, 2, 2, 2, 2, 2] self.config = config diff --git a/modules/postprocess/swinir_model.py b/modules/postprocess/swinir_model.py index 86cc2e77f..60a0d267f 100644 --- a/modules/postprocess/swinir_model.py +++ b/modules/postprocess/swinir_model.py @@ -4,7 +4,7 @@ from PIL import Image from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn from modules.postprocess.swinir_model_arch import SwinIR as net from modules.postprocess.swinir_model_arch_v2 import Swin2SR as net2 -from modules import devices, script_callbacks, shared +from modules import devices, shared from modules.upscaler import Upscaler, compile_upscaler diff --git a/modules/postprocess/swinir_model_arch.py b/modules/postprocess/swinir_model_arch.py index d5ae4dd32..4b306433d 100644 --- a/modules/postprocess/swinir_model_arch.py +++ b/modules/postprocess/swinir_model_arch.py @@ -232,7 +232,7 @@ class SwinTransformerBlock(nn.Module): mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill(attn_mask != 0, (-100.0)).masked_fill(attn_mask == 0, 0.0) return attn_mask @@ -442,7 +442,7 @@ class RSTB(nn.Module): mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() + super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -587,7 +587,7 @@ class Upsample(nn.Sequential): m.append(nn.PixelShuffle(3)) else: raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) + super().__init__(*m) class UpsampleOneStep(nn.Sequential): @@ -606,7 +606,7 @@ class UpsampleOneStep(nn.Sequential): m = [] m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) + super().__init__(*m) def flops(self): H, W = self.input_resolution @@ -649,7 +649,7 @@ class SwinIR(nn.Module): norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', **kwargs): - super(SwinIR, self).__init__() + super().__init__() num_in_ch = in_chans num_out_ch = in_chans num_feat = 64 diff --git a/modules/postprocess/swinir_model_arch_v2.py b/modules/postprocess/swinir_model_arch_v2.py index ca69e2969..d61e92668 100644 --- a/modules/postprocess/swinir_model_arch_v2.py +++ b/modules/postprocess/swinir_model_arch_v2.py @@ -260,7 +260,7 @@ class SwinTransformerBlock(nn.Module): mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill(attn_mask != 0, (-100.0)).masked_fill(attn_mask == 0, 0.0) return attn_mask @@ -518,7 +518,7 @@ class RSTB(nn.Module): mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() + super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -619,7 +619,7 @@ class Upsample(nn.Sequential): m.append(nn.PixelShuffle(3)) else: raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) + super().__init__(*m) class Upsample_hf(nn.Sequential): """Upsample module. @@ -640,7 +640,7 @@ class Upsample_hf(nn.Sequential): m.append(nn.PixelShuffle(3)) else: raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') - super(Upsample_hf, self).__init__(*m) + super().__init__(*m) class UpsampleOneStep(nn.Sequential): @@ -659,7 +659,7 @@ class UpsampleOneStep(nn.Sequential): m = [] m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) + super().__init__(*m) def flops(self): H, W = self.input_resolution @@ -702,7 +702,7 @@ class Swin2SR(nn.Module): norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', **kwargs): - super(Swin2SR, self).__init__() + super().__init__() num_in_ch = in_chans num_out_ch = in_chans num_feat = 64 diff --git a/modules/postprocess/yolo.py b/modules/postprocess/yolo.py index 61f5f5661..473f8ec3f 100644 --- a/modules/postprocess/yolo.py +++ b/modules/postprocess/yolo.py @@ -26,7 +26,9 @@ load_lock = threading.Lock() class YoloResult: - def __init__(self, cls: int, label: str, score: float, box: list[int], mask: Image.Image = None, item: Image.Image = None, width = 0, height = 0, args = {}): + def __init__(self, cls: int, label: str, score: float, box: list[int], mask: Image.Image = None, item: Image.Image = None, width = 0, height = 0, args = None): + if args is None: + args = {} self.cls = cls self.label = label self.score = score @@ -138,7 +140,7 @@ class YoloRestorer(Detailer): masks = prediction.masks.data.cpu().float().numpy() if prediction.masks is not None else [] if len(masks) < len(classes): masks = len(classes) * [None] - for score, box, cls, seg in zip(scores, boxes, classes, masks): + for score, box, cls, seg in zip(scores, boxes, classes, masks, strict=False): if seg is not None: try: seg = (255 * seg).astype(np.uint8) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index b624e3a7e..318750493 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -1,6 +1,5 @@ import os import tempfile -from typing import List from PIL import Image @@ -9,7 +8,7 @@ from modules.shared import opts from modules.paths import resolve_output_path -def run_postprocessing(extras_mode, image, image_folder: List[tempfile.NamedTemporaryFile], input_dir, output_dir, show_extras_results, *args, save_output: bool = True): +def run_postprocessing(extras_mode, image, image_folder: list[tempfile.NamedTemporaryFile], input_dir, output_dir, show_extras_results, *args, save_output: bool = True): devices.torch_gc() shared.state.begin('Extras') image_data = [] @@ -61,7 +60,7 @@ def run_postprocessing(extras_mode, image, image_folder: List[tempfile.NamedTemp else: outpath = resolve_output_path(opts.outdir_samples, opts.outdir_extras_samples) processed_images = [] - for image, name, ext in zip(image_data, image_names, image_ext): # pylint: disable=redefined-argument-from-local + for image, name, ext in zip(image_data, image_names, image_ext, strict=False): # pylint: disable=redefined-argument-from-local shared.log.debug(f'Process: image={image} {args}') info = '' if shared.state.interrupted: diff --git a/modules/processing.py b/modules/processing.py index 13f2c275c..8c47ce83a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -5,7 +5,13 @@ import numpy as np from PIL import Image, ImageOps from modules import shared, devices, errors, images, scripts_manager, memstats, script_callbacks, extra_networks, detailer, sd_models, sd_checkpoint, sd_vae, processing_helpers, timer from modules.sd_hijack_hypertile import context_hypertile_vae, context_hypertile_unet -from modules.processing_class import StableDiffusionProcessing, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, StableDiffusionProcessingControl, StableDiffusionProcessingVideo # pylint: disable=unused-import +from modules.processing_class import ( # pylint: disable=unused-import + StableDiffusionProcessing, + StableDiffusionProcessingTxt2Img, + StableDiffusionProcessingImg2Img, + StableDiffusionProcessingVideo, + StableDiffusionProcessingControl, +) from modules.processing_info import create_infotext from modules.modeldata import model_data @@ -433,7 +439,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: results = p.scripts.process_images(p) if results is not None: samples = results.images - for script_image, script_infotext in zip(results.images, results.infotexts): + for script_image, script_infotext in zip(results.images, results.infotexts, strict=False): output_images.append(script_image) infotexts.append(script_infotext) @@ -467,7 +473,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: output_binary = samples.bytes else: batch_images, batch_infotexts = process_samples(p, samples) - for batch_image, batch_infotext in zip(batch_images, batch_infotexts): + for batch_image, batch_infotext in zip(batch_images, batch_infotexts, strict=False): if batch_image is not None and batch_image not in output_images: output_images.append(batch_image) infotexts.append(batch_infotext) diff --git a/modules/processing_args.py b/modules/processing_args.py index 7a827ebea..97f4d69c9 100644 --- a/modules/processing_args.py +++ b/modules/processing_args.py @@ -1,4 +1,3 @@ -import typing import os import re import math @@ -9,7 +8,7 @@ import numpy as np from PIL import Image from modules import shared, sd_models, processing, processing_vae, processing_helpers, sd_hijack_hypertile, extra_networks, sd_vae from modules.processing_callbacks import diffusers_callback_legacy, diffusers_callback, set_callbacks_p -from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, get_generator, set_latents, apply_circular # pylint: disable=unused-import +from modules.processing_helpers import get_generator, apply_circular # pylint: disable=unused-import from modules.processing_prompt import set_prompt from modules.api import helpers @@ -185,7 +184,7 @@ def get_params(model): return possible -def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:typing.Optional[list]=None, negative_prompts_2:typing.Optional[list]=None, prompt_attention:typing.Optional[str]=None, desc:typing.Optional[str]='', **kwargs): +def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:list | None=None, negative_prompts_2:list | None=None, prompt_attention:str | None=None, desc:str | None='', **kwargs): t0 = time.time() shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) argsid = shared.state.begin('Params') diff --git a/modules/processing_callbacks.py b/modules/processing_callbacks.py index eed7f985f..ea5720eef 100644 --- a/modules/processing_callbacks.py +++ b/modules/processing_callbacks.py @@ -1,4 +1,3 @@ -import typing import os import time import torch @@ -33,7 +32,7 @@ def prompt_callback(step, kwargs): return kwargs -def diffusers_callback_legacy(step: int, timestep: int, latents: typing.Union[torch.FloatTensor, np.ndarray]): +def diffusers_callback_legacy(step: int, timestep: int, latents: torch.FloatTensor | np.ndarray): if p is None: return if isinstance(latents, np.ndarray): # latents from Onnx pipelines is ndarray. @@ -51,7 +50,9 @@ def diffusers_callback_legacy(step: int, timestep: int, latents: typing.Union[to time.sleep(0.1) -def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {}): +def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = None): + if kwargs is None: + kwargs = {} t0 = time.time() if devices.backend == "ipex": torch.xpu.synchronize(devices.device) diff --git a/modules/processing_class.py b/modules/processing_class.py index 26d942efa..93f2bf135 100644 --- a/modules/processing_class.py +++ b/modules/processing_class.py @@ -2,7 +2,7 @@ import os import sys import inspect import hashlib -from typing import Any, Dict, List +from typing import Any from dataclasses import dataclass, field import numpy as np from PIL import Image, ImageOps @@ -51,7 +51,7 @@ class StableDiffusionProcessing: pag_scale: float = 0.0, pag_adaptive: float = 0.5, # styles - styles: List[str] = [], + styles: list[str] = None, # vae tiling: bool = False, vae_type: str = 'Full', @@ -79,8 +79,8 @@ class StableDiffusionProcessing: hdr_color_picker: str = None, hdr_tint_ratio: float = 0, # img2img - init_images: list = [], - init_control: list = [], + init_images: list = None, + init_control: list = None, denoising_strength: float = 0.3, image_cfg_scale: float = None, initial_noise_multiplier: float = None, # pylint: disable=unused-argument # a1111 compatibility @@ -150,9 +150,9 @@ class StableDiffusionProcessing: # xyz flag xyz: bool = False, # scripts - script_args: list = [], + script_args: list = None, # overrides - override_settings: Dict[str, Any] = {}, + override_settings: dict[str, Any] = None, override_settings_restore_afterwards: bool = True, # metadata # extra_generation_params: Dict[Any, Any] = {}, @@ -161,6 +161,16 @@ class StableDiffusionProcessing: **kwargs, ): + if override_settings is None: + override_settings = {} + if script_args is None: + script_args = [] + if init_control is None: + init_control = [] + if init_images is None: + init_images = [] + if styles is None: + styles = [] for k, v in kwargs.items(): setattr(self, k, v) diff --git a/modules/processing_vae.py b/modules/processing_vae.py index 53f8c3751..c2cbffb57 100644 --- a/modules/processing_vae.py +++ b/modules/processing_vae.py @@ -365,7 +365,7 @@ def reprocess(gallery): shared.log.info(f'Reprocessing: latent={latent.shape}') reprocessed = vae_decode(latent, shared.sd_model, output_type='pil') outputs = [] - for i0, i1 in zip(gallery, reprocessed): + for i0, i1 in zip(gallery, reprocessed, strict=False): if isinstance(i1, np.ndarray): i1 = Image.fromarray(i1) fn = i0['name'] diff --git a/modules/progress.py b/modules/progress.py index f0fe52877..bbdbee553 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -2,7 +2,6 @@ import base64 import os import io import time -from typing import Union from pydantic import BaseModel, Field # pylint: disable=no-name-in-module import modules.shared as shared @@ -48,7 +47,7 @@ class ProgressRequest(BaseModel): class InternalProgressResponse(BaseModel): job: str = Field(default=None, title="Job name", description="Internal job name") - textinfo: Union[str|None] = Field(default=None, title="Info text", description="Info text used by WebUI.") + textinfo: str|None = Field(default=None, title="Info text", description="Info text used by WebUI.") # status fields active: bool = Field(title="Whether the task is being worked on right now") queued: bool = Field(title="Whether the task is in queue") @@ -62,10 +61,10 @@ class InternalProgressResponse(BaseModel): batch_count: int = Field(default=None, title="Total batches", description="Total number of batches") # calculated fields progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1") - eta: Union[float|None] = Field(default=None, title="ETA in secs") + eta: float|None = Field(default=None, title="ETA in secs") # image fields - live_preview: Union[str|None] = Field(default=None, title="Live preview image", description="Current live preview; a data: uri") - id_live_preview: Union[int|None] = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image") + live_preview: str|None = Field(default=None, title="Live preview image", description="Current live preview; a data: uri") + id_live_preview: int|None = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image") def api_progress(req: ProgressRequest): diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 57587e3b5..06df1a9e7 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -10,7 +10,6 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..')) import os import re from collections import namedtuple -from typing import List import lark import torch from compel import Compel @@ -181,7 +180,7 @@ def get_learned_conditioning(model, prompts, steps): res = [] prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) cache = {} - for prompt, prompt_schedule in zip(prompts, prompt_schedules): + for prompt, prompt_schedule in zip(prompts, prompt_schedules, strict=False): debug(f'Prompt schedule: {prompt_schedule}') cached = cache.get(prompt, None) if cached is not None: @@ -220,14 +219,14 @@ def get_multicond_prompt_list(prompts): class ComposableScheduledPromptConditioning: def __init__(self, schedules, weight=1.0): - self.schedules: List[ScheduledPromptConditioning] = schedules + self.schedules: list[ScheduledPromptConditioning] = schedules self.weight: float = weight class MulticondLearnedConditioning: def __init__(self, shape, batch): self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS - self.batch: List[List[ComposableScheduledPromptConditioning]] = batch + self.batch: list[list[ComposableScheduledPromptConditioning]] = batch def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: @@ -243,7 +242,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) -def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): +def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step): param = c[0][0].cond res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for i, cond_schedule in enumerate(c): diff --git a/modules/prompt_parser_diffusers.py b/modules/prompt_parser_diffusers.py index 25349e583..17f4821f0 100644 --- a/modules/prompt_parser_diffusers.py +++ b/modules/prompt_parser_diffusers.py @@ -1,7 +1,6 @@ import os import math import time -import typing from collections import OrderedDict import torch from compel.embeddings_provider import BaseTextualInversionManager, EmbeddingsProvider @@ -85,7 +84,7 @@ class PromptEmbedder: return seen_prompts = {} # per prompt in batch - for batchidx, (prompt, negative_prompt) in enumerate(zip(self.prompts, self.negative_prompts)): + for batchidx, (prompt, negative_prompt) in enumerate(zip(self.prompts, self.negative_prompts, strict=False)): self.prepare_schedule(prompt, negative_prompt) schedule_key = ( tuple(self.positive_schedule) if self.positive_schedule is not None else None, @@ -300,7 +299,7 @@ class PromptEmbedder: return None -def compel_hijack(self, token_ids: torch.Tensor, attention_mask: typing.Optional[torch.Tensor] = None) -> torch.Tensor: +def compel_hijack(self, token_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: needs_hidden_states = self.returned_embeddings_type != 1 text_encoder_output = self.text_encoder(token_ids, attention_mask, output_hidden_states=needs_hidden_states, return_dict=True) @@ -323,7 +322,7 @@ def compel_hijack(self, token_ids: torch.Tensor, attention_mask: typing.Optional return hidden_state -def sd3_compel_hijack(self, token_ids: torch.Tensor, attention_mask: typing.Optional[torch.Tensor] = None) -> torch.Tensor: +def sd3_compel_hijack(self, token_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: needs_hidden_states = True text_encoder_output = self.text_encoder(token_ids, attention_mask, output_hidden_states=needs_hidden_states, return_dict=True) clip_skip = int(self.returned_embeddings_type) @@ -353,10 +352,10 @@ class DiffusersTextualInversionManager(BaseTextualInversionManager): # code from # https://github.com/huggingface/diffusers/blob/705c592ea98ba4e288d837b9cba2767623c78603/src/diffusers/loaders.py - def maybe_convert_prompt(self, prompt: typing.Union[str, typing.List[str]], tokenizer: PreTrainedTokenizer): - prompts = [prompt] if not isinstance(prompt, typing.List) else prompt + def maybe_convert_prompt(self, prompt: str | list[str], tokenizer: PreTrainedTokenizer): + prompts = [prompt] if not isinstance(prompt, list) else prompt prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts] - if not isinstance(prompt, typing.List): + if not isinstance(prompt, list): return prompts[0] return prompts @@ -378,7 +377,7 @@ class DiffusersTextualInversionManager(BaseTextualInversionManager): debug(f'Prompt: convert="{prompt}"') return prompt - def expand_textual_inversion_token_ids_if_necessary(self, token_ids: typing.List[int]) -> typing.List[int]: + def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: if len(token_ids) == 0: return token_ids prompt = self.pipe.tokenizer.decode(token_ids) @@ -470,7 +469,7 @@ def get_prompts_with_weights(pipe, prompt: str): texts_and_weights = prompt_parser.parse_prompt_attention(prompt) if shared.opts.prompt_mean_norm: texts_and_weights = normalize_prompt(texts_and_weights) - texts, text_weights = zip(*texts_and_weights) + texts, text_weights = zip(*texts_and_weights, strict=False) avg_weight = 0 min_weight = 1 max_weight = 0 @@ -478,7 +477,7 @@ def get_prompts_with_weights(pipe, prompt: str): try: all_tokens = 0 - for text, weight in zip(texts, text_weights): + for text, weight in zip(texts, text_weights, strict=False): tokens = get_tokens(pipe, 'section', text) all_tokens += tokens avg_weight += tokens*weight @@ -627,8 +626,8 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c ps = 2 * [get_prompts_with_weights(pipe, prompt)] ns = 2 * [get_prompts_with_weights(pipe, neg_prompt)] - positives, positive_weights = zip(*ps) - negatives, negative_weights = zip(*ns) + positives, positive_weights = zip(*ps, strict=False) + negatives, negative_weights = zip(*ns, strict=False) if hasattr(pipe, "tokenizer_2") and not hasattr(pipe, "tokenizer"): positives.pop(0) positive_weights.pop(0) diff --git a/modules/ras/ras_attention.py b/modules/ras/ras_attention.py index 4989cc931..5be0db7de 100644 --- a/modules/ras/ras_attention.py +++ b/modules/ras/ras_attention.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional import math import torch import torch.nn.functional as F @@ -38,10 +37,10 @@ class RASLuminaAttnProcessor2_0: attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - query_rotary_emb: Optional[torch.Tensor] = None, - key_rotary_emb: Optional[torch.Tensor] = None, - base_sequence_length: Optional[int] = None, + attention_mask: torch.Tensor | None = None, + query_rotary_emb: torch.Tensor | None = None, + key_rotary_emb: torch.Tensor | None = None, + base_sequence_length: int | None = None, ) -> torch.Tensor: from diffusers.models.embeddings import apply_rotary_emb @@ -165,7 +164,7 @@ class RASJointAttnProcessor2_0: attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: torch.FloatTensor | None = None, *args, **kwargs, ) -> torch.FloatTensor: diff --git a/modules/ras/ras_forward.py b/modules/ras/ras_forward.py index 63c71428e..ef0f245ea 100644 --- a/modules/ras/ras_forward.py +++ b/modules/ras/ras_forward.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any import torch from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers @@ -25,11 +25,11 @@ def ras_forward( encoder_hidden_states: torch.FloatTensor = None, pooled_projections: torch.FloatTensor = None, timestep: torch.LongTensor = None, - block_controlnet_hidden_states: List = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + block_controlnet_hidden_states: list = None, + joint_attention_kwargs: dict[str, Any] | None = None, return_dict: bool = True, - skip_layers: Optional[List[int]] = None, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + skip_layers: list[int] | None = None, + ) -> torch.FloatTensor | Transformer2DModelOutput: """ The [`SD3Transformer2DModel`] forward method. diff --git a/modules/ras/ras_scheduler.py b/modules/ras/ras_scheduler.py index a5143a067..f2131e4c6 100644 --- a/modules/ras/ras_scheduler.py +++ b/modules/ras/ras_scheduler.py @@ -15,7 +15,6 @@ # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple, Union import torch from diffusers.configuration_utils import register_to_config from diffusers.utils import BaseOutput, logging @@ -66,10 +65,10 @@ class RASFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): num_train_timesteps: int = 1000, shift: float = 1.0, use_dynamic_shifting=False, - base_shift: Optional[float] = 0.5, - max_shift: Optional[float] = 1.15, - base_image_seq_len: Optional[int] = 256, - max_image_seq_len: Optional[int] = 4096, + base_shift: float | None = 0.5, + max_shift: float | None = 1.15, + base_image_seq_len: int | None = 256, + max_image_seq_len: int | None = 4096, invert_sigmas: bool = False, ): super().__init__(num_train_timesteps=num_train_timesteps, @@ -120,15 +119,15 @@ class RASFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def step( self, model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], + timestep: float | torch.FloatTensor, sample: torch.FloatTensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = float("inf"), s_noise: float = 1.0, - generator: Optional[torch.Generator] = None, + generator: torch.Generator | None = None, return_dict: bool = True, - ) -> Union[RASFlowMatchEulerDiscreteSchedulerOutput, Tuple]: + ) -> RASFlowMatchEulerDiscreteSchedulerOutput | tuple: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). diff --git a/modules/res4lyf/abnorsett_scheduler.py b/modules/res4lyf/abnorsett_scheduler.py index e2ba0a686..810b0deab 100644 --- a/modules/res4lyf/abnorsett_scheduler.py +++ b/modules/res4lyf/abnorsett_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import ClassVar, List, Literal, Optional, Tuple, Union +from typing import ClassVar, Literal import numpy as np import torch @@ -31,7 +31,7 @@ class ABNorsettScheduler(SchedulerMixin, ConfigMixin): Adams-Bashforth Norsett (ABNorsett) scheduler. """ - _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers] + _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config @@ -41,7 +41,7 @@ class ABNorsettScheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + trained_betas: np.ndarray | list[float] | None = None, prediction_type: str = "epsilon", variant: Literal["abnorsett_2m", "abnorsett_3m", "abnorsett_4m"] = "abnorsett_2m", use_analytic_solution: bool = True, @@ -87,23 +87,22 @@ class ABNorsettScheduler(SchedulerMixin, ConfigMixin): self.init_noise_sigma = 1.0 @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32): from .scheduler_utils import ( apply_shift, get_dynamic_shift, get_sigmas_beta, get_sigmas_exponential, - get_sigmas_flow, get_sigmas_karras, ) @@ -183,7 +182,7 @@ class ABNorsettScheduler(SchedulerMixin, ConfigMixin): from .scheduler_utils import add_noise_to_sample return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -195,10 +194,10 @@ class ABNorsettScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/bong_tangent_scheduler.py b/modules/res4lyf/bong_tangent_scheduler.py index a0c827218..d3b7eaa84 100644 --- a/modules/res4lyf/bong_tangent_scheduler.py +++ b/modules/res4lyf/bong_tangent_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import ClassVar, List, Optional, Tuple, Union +from typing import ClassVar import numpy as np import torch @@ -29,7 +29,7 @@ class BongTangentScheduler(SchedulerMixin, ConfigMixin): BongTangent scheduler using Exponential Integrator step. """ - _compatibles: ClassVar[List[str]] = [] + _compatibles: ClassVar[list[str]] = [] order = 1 @register_to_config @@ -86,17 +86,17 @@ class BongTangentScheduler(SchedulerMixin, ConfigMixin): self.init_noise_sigma = 1.0 @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: self._begin_index = begin_index - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -105,7 +105,7 @@ class BongTangentScheduler(SchedulerMixin, ConfigMixin): sample = sample / ((sigma**2 + 1) ** 0.5) return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32): from .scheduler_utils import ( apply_shift, get_dynamic_shift, @@ -210,7 +210,7 @@ class BongTangentScheduler(SchedulerMixin, ConfigMixin): from .scheduler_utils import add_noise_to_sample return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps) - def _get_bong_tangent_sigmas(self, steps: int, slope: float, pivot: int, start: float, end: float, dtype: torch.dtype = torch.float32) -> List[float]: + def _get_bong_tangent_sigmas(self, steps: int, slope: float, pivot: int, start: float, end: float, dtype: torch.dtype = torch.float32) -> list[float]: x = torch.arange(steps, dtype=dtype) def bong_fn(val): @@ -228,10 +228,10 @@ class BongTangentScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/common_sigma_scheduler.py b/modules/res4lyf/common_sigma_scheduler.py index 202d289af..bfe32a875 100644 --- a/modules/res4lyf/common_sigma_scheduler.py +++ b/modules/res4lyf/common_sigma_scheduler.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import ClassVar, List, Literal, Optional, Tuple, Union +from typing import ClassVar, Literal import numpy as np import torch @@ -30,7 +30,7 @@ class CommonSigmaScheduler(SchedulerMixin, ConfigMixin): Common Sigma scheduler using Exponential Integrator step. """ - _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers] + _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers] order: ClassVar[int] = 1 @register_to_config @@ -88,17 +88,17 @@ class CommonSigmaScheduler(SchedulerMixin, ConfigMixin): self._begin_index = None @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32): from .scheduler_utils import ( apply_shift, get_dynamic_shift, @@ -200,7 +200,7 @@ class CommonSigmaScheduler(SchedulerMixin, ConfigMixin): from .scheduler_utils import add_noise_to_sample return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -212,10 +212,10 @@ class CommonSigmaScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/deis_scheduler_alt.py b/modules/res4lyf/deis_scheduler_alt.py index 70c63cecf..bcb3a266e 100644 --- a/modules/res4lyf/deis_scheduler_alt.py +++ b/modules/res4lyf/deis_scheduler_alt.py @@ -1,4 +1,3 @@ -from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -33,16 +32,16 @@ class RESDEISMultistepScheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + trained_betas: np.ndarray | list[float] | None = None, prediction_type: str = "epsilon", use_karras_sigmas: bool = False, use_exponential_sigmas: bool = False, use_beta_sigmas: bool = False, use_flow_sigmas: bool = False, - sigma_min: Optional[float] = None, - sigma_max: Optional[float] = None, + sigma_min: float | None = None, + sigma_max: float | None = None, rho: float = 7.0, - shift: Optional[float] = None, + shift: float | None = None, base_shift: float = 0.5, max_shift: float = 1.15, use_dynamic_shifting: bool = False, @@ -87,8 +86,8 @@ class RESDEISMultistepScheduler(SchedulerMixin, ConfigMixin): def set_timesteps( self, num_inference_steps: int, - device: Union[str, torch.device] = None, - mu: Optional[float] = None, + device: str | torch.device = None, + mu: float | None = None, dtype: torch.dtype = torch.float32): self.num_inference_steps = num_inference_steps @@ -225,7 +224,7 @@ class RESDEISMultistepScheduler(SchedulerMixin, ConfigMixin): if self._step_index is None: self._step_index = self.index_for_timestep(timestep) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -236,10 +235,10 @@ class RESDEISMultistepScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/etdrk_scheduler.py b/modules/res4lyf/etdrk_scheduler.py index 07b624ff6..cc7f693fe 100644 --- a/modules/res4lyf/etdrk_scheduler.py +++ b/modules/res4lyf/etdrk_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import ClassVar, List, Literal, Optional, Tuple, Union +from typing import ClassVar, Literal import numpy as np import torch @@ -31,7 +31,7 @@ class ETDRKScheduler(SchedulerMixin, ConfigMixin): Exponential Time Differencing Runge-Kutta (ETDRK) scheduler. """ - _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers] + _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config @@ -41,7 +41,7 @@ class ETDRKScheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + trained_betas: np.ndarray | list[float] | None = None, prediction_type: str = "epsilon", variant: Literal["etdrk2_2s", "etdrk3_a_3s", "etdrk3_b_3s", "etdrk4_4s", "etdrk4_4s_alt"] = "etdrk4_4s", use_analytic_solution: bool = True, @@ -87,17 +87,17 @@ class ETDRKScheduler(SchedulerMixin, ConfigMixin): self.init_noise_sigma = 1.0 @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32): from .scheduler_utils import ( apply_shift, get_dynamic_shift, @@ -171,7 +171,7 @@ class ETDRKScheduler(SchedulerMixin, ConfigMixin): from .scheduler_utils import add_noise_to_sample return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -183,10 +183,10 @@ class ETDRKScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/gauss_legendre_scheduler.py b/modules/res4lyf/gauss_legendre_scheduler.py index 38db308b8..0cbb5ea03 100644 --- a/modules/res4lyf/gauss_legendre_scheduler.py +++ b/modules/res4lyf/gauss_legendre_scheduler.py @@ -1,4 +1,3 @@ -from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -22,17 +21,17 @@ class GaussLegendreScheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + trained_betas: np.ndarray | list[float] | None = None, prediction_type: str = "epsilon", variant: str = "gauss-legendre_2s", # 2s to 8s variants use_karras_sigmas: bool = False, use_exponential_sigmas: bool = False, use_beta_sigmas: bool = False, use_flow_sigmas: bool = False, - sigma_min: Optional[float] = None, - sigma_max: Optional[float] = None, + sigma_min: float | None = None, + sigma_max: float | None = None, rho: float = 7.0, - shift: Optional[float] = None, + shift: float | None = None, base_shift: float = 0.5, max_shift: float = 1.15, use_dynamic_shifting: bool = False, @@ -147,8 +146,8 @@ class GaussLegendreScheduler(SchedulerMixin, ConfigMixin): def set_timesteps( self, num_inference_steps: int, - device: Union[str, torch.device] = None, - mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + device: str | torch.device = None, + mu: float | None = None, dtype: torch.dtype = torch.float32): self.num_inference_steps = num_inference_steps # 1. Spacing @@ -248,7 +247,7 @@ class GaussLegendreScheduler(SchedulerMixin, ConfigMixin): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -259,10 +258,10 @@ class GaussLegendreScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/langevin_dynamics_scheduler.py b/modules/res4lyf/langevin_dynamics_scheduler.py index 8e3c2eb48..af7213b52 100644 --- a/modules/res4lyf/langevin_dynamics_scheduler.py +++ b/modules/res4lyf/langevin_dynamics_scheduler.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import ClassVar, List, Optional, Tuple, Union +from typing import ClassVar import numpy as np import torch @@ -30,7 +30,7 @@ class LangevinDynamicsScheduler(SchedulerMixin, ConfigMixin): Langevin Dynamics sigma scheduler using Exponential Integrator step. """ - _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers] + _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers] order: ClassVar[int] = 1 @register_to_config @@ -85,11 +85,11 @@ class LangevinDynamicsScheduler(SchedulerMixin, ConfigMixin): self._begin_index = None @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: @@ -98,9 +98,9 @@ class LangevinDynamicsScheduler(SchedulerMixin, ConfigMixin): def set_timesteps( self, num_inference_steps: int, - device: Union[str, torch.device] = None, - generator: Optional[torch.Generator] = None, - mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + device: str | torch.device = None, + generator: torch.Generator | None = None, + mu: float | None = None, dtype: torch.dtype = torch.float32): from .scheduler_utils import ( apply_shift, get_dynamic_shift, @@ -187,7 +187,7 @@ class LangevinDynamicsScheduler(SchedulerMixin, ConfigMixin): from .scheduler_utils import add_noise_to_sample return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -199,10 +199,10 @@ class LangevinDynamicsScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/lawson_scheduler.py b/modules/res4lyf/lawson_scheduler.py index 0af304eb2..3631024bf 100644 --- a/modules/res4lyf/lawson_scheduler.py +++ b/modules/res4lyf/lawson_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import ClassVar, List, Literal, Optional, Tuple, Union +from typing import ClassVar, Literal import numpy as np import torch @@ -29,7 +29,7 @@ class LawsonScheduler(SchedulerMixin, ConfigMixin): Lawson's integration method scheduler. """ - _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers] + _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config @@ -39,7 +39,7 @@ class LawsonScheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + trained_betas: np.ndarray | list[float] | None = None, prediction_type: str = "epsilon", variant: Literal["lawson2a_2s", "lawson2b_2s", "lawson4_4s"] = "lawson4_4s", use_analytic_solution: bool = True, @@ -85,17 +85,17 @@ class LawsonScheduler(SchedulerMixin, ConfigMixin): self.init_noise_sigma = 1.0 @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32): from .scheduler_utils import ( apply_shift, get_dynamic_shift, @@ -169,7 +169,7 @@ class LawsonScheduler(SchedulerMixin, ConfigMixin): from .scheduler_utils import add_noise_to_sample return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -181,10 +181,10 @@ class LawsonScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/linear_rk_scheduler.py b/modules/res4lyf/linear_rk_scheduler.py index 8e2a9aac1..955e4af8e 100644 --- a/modules/res4lyf/linear_rk_scheduler.py +++ b/modules/res4lyf/linear_rk_scheduler.py @@ -1,4 +1,3 @@ -from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -22,17 +21,17 @@ class LinearRKScheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + trained_betas: np.ndarray | list[float] | None = None, prediction_type: str = "epsilon", variant: str = "rk4", # euler, heun, rk2, rk3, rk4, ralston, midpoint use_karras_sigmas: bool = False, use_exponential_sigmas: bool = False, use_beta_sigmas: bool = False, use_flow_sigmas: bool = False, - sigma_min: Optional[float] = None, - sigma_max: Optional[float] = None, + sigma_min: float | None = None, + sigma_max: float | None = None, rho: float = 7.0, - shift: Optional[float] = None, + shift: float | None = None, base_shift: float = 0.5, max_shift: float = 1.15, use_dynamic_shifting: bool = False, @@ -103,8 +102,8 @@ class LinearRKScheduler(SchedulerMixin, ConfigMixin): def set_timesteps( self, num_inference_steps: int, - device: Union[str, torch.device] = None, - mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + device: str | torch.device = None, + mu: float | None = None, dtype: torch.dtype = torch.float32): self.num_inference_steps = num_inference_steps # 1. Spacing @@ -204,7 +203,7 @@ class LinearRKScheduler(SchedulerMixin, ConfigMixin): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -215,10 +214,10 @@ class LinearRKScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) a_mat, b_vec, c_vec = self._get_tableau() diff --git a/modules/res4lyf/lobatto_scheduler.py b/modules/res4lyf/lobatto_scheduler.py index 97d073e88..e1698b935 100644 --- a/modules/res4lyf/lobatto_scheduler.py +++ b/modules/res4lyf/lobatto_scheduler.py @@ -1,4 +1,3 @@ -from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -23,17 +22,17 @@ class LobattoScheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + trained_betas: np.ndarray | list[float] | None = None, prediction_type: str = "epsilon", variant: str = "lobatto_iiia_3s", # Available: iiia, iiib, iiic use_karras_sigmas: bool = False, use_exponential_sigmas: bool = False, use_beta_sigmas: bool = False, use_flow_sigmas: bool = False, - sigma_min: Optional[float] = None, - sigma_max: Optional[float] = None, + sigma_min: float | None = None, + sigma_max: float | None = None, rho: float = 7.0, - shift: Optional[float] = None, + shift: float | None = None, base_shift: float = 0.5, max_shift: float = 1.15, use_dynamic_shifting: bool = False, @@ -103,8 +102,8 @@ class LobattoScheduler(SchedulerMixin, ConfigMixin): def set_timesteps( self, num_inference_steps: int, - device: Union[str, torch.device] = None, - mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + device: str | torch.device = None, + mu: float | None = None, dtype: torch.dtype = torch.float32): self.num_inference_steps = num_inference_steps # 1. Spacing @@ -204,7 +203,7 @@ class LobattoScheduler(SchedulerMixin, ConfigMixin): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -215,10 +214,10 @@ class LobattoScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) a_mat, b_vec, c_vec = self._get_tableau() diff --git a/modules/res4lyf/pec_scheduler.py b/modules/res4lyf/pec_scheduler.py index f6df4f449..d5951b937 100644 --- a/modules/res4lyf/pec_scheduler.py +++ b/modules/res4lyf/pec_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import ClassVar, List, Literal, Optional, Tuple, Union +from typing import ClassVar, Literal import numpy as np import torch @@ -31,7 +31,7 @@ class PECScheduler(SchedulerMixin, ConfigMixin): Predictor-Corrector (PEC) scheduler. """ - _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers] + _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config @@ -41,7 +41,7 @@ class PECScheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + trained_betas: np.ndarray | list[float] | None = None, prediction_type: str = "epsilon", variant: Literal["pec423_2h2s", "pec433_2h3s"] = "pec423_2h2s", use_analytic_solution: bool = True, @@ -87,11 +87,11 @@ class PECScheduler(SchedulerMixin, ConfigMixin): self.init_noise_sigma = 1.0 @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: @@ -100,8 +100,8 @@ class PECScheduler(SchedulerMixin, ConfigMixin): def set_timesteps( self, num_inference_steps: int, - device: Union[str, torch.device] = None, - mu: Optional[float] = None, + device: str | torch.device = None, + mu: float | None = None, dtype: torch.dtype = torch.float32, ): from .scheduler_utils import ( @@ -177,7 +177,7 @@ class PECScheduler(SchedulerMixin, ConfigMixin): from .scheduler_utils import add_noise_to_sample return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -188,10 +188,10 @@ class PECScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/phi_functions.py b/modules/res4lyf/phi_functions.py index 7941f7c2a..ddd859585 100644 --- a/modules/res4lyf/phi_functions.py +++ b/modules/res4lyf/phi_functions.py @@ -13,7 +13,6 @@ # limitations under the License. import math -from typing import Dict, List, Tuple, Union import torch from mpmath import exp as mp_exp @@ -89,10 +88,10 @@ class Phi: Supports both standard torch-based and high-precision mpmath-based solutions. """ - def __init__(self, h: torch.Tensor, c: List[Union[float, mpf]], analytic_solution: bool = True): + def __init__(self, h: torch.Tensor, c: list[float | mpf], analytic_solution: bool = True): self.h = h self.c = c - self.cache: Dict[Tuple[int, int], Union[float, torch.Tensor]] = {} + self.cache: dict[tuple[int, int], float | torch.Tensor] = {} self.analytic_solution = analytic_solution if analytic_solution: @@ -102,7 +101,7 @@ class Phi: else: self.phi_f = phi_standard_torch - def __call__(self, j: int, i: int = -1) -> Union[float, torch.Tensor]: + def __call__(self, j: int, i: int = -1) -> float | torch.Tensor: if (j, i) in self.cache: return self.cache[(j, i)] diff --git a/modules/res4lyf/radau_iia_scheduler.py b/modules/res4lyf/radau_iia_scheduler.py index 2cd5d85e3..4d072205b 100644 --- a/modules/res4lyf/radau_iia_scheduler.py +++ b/modules/res4lyf/radau_iia_scheduler.py @@ -1,4 +1,3 @@ -from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -23,17 +22,17 @@ class RadauIIAScheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + trained_betas: np.ndarray | list[float] | None = None, prediction_type: str = "epsilon", variant: str = "radau_iia_3s", # 2s to 11s variants use_karras_sigmas: bool = False, use_exponential_sigmas: bool = False, use_beta_sigmas: bool = False, use_flow_sigmas: bool = False, - sigma_min: Optional[float] = None, - sigma_max: Optional[float] = None, + sigma_min: float | None = None, + sigma_max: float | None = None, rho: float = 7.0, - shift: Optional[float] = None, + shift: float | None = None, base_shift: float = 0.5, max_shift: float = 1.15, use_dynamic_shifting: bool = False, @@ -137,8 +136,8 @@ class RadauIIAScheduler(SchedulerMixin, ConfigMixin): def set_timesteps( self, num_inference_steps: int, - device: Union[str, torch.device] = None, - mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + device: str | torch.device = None, + mu: float | None = None, dtype: torch.dtype = torch.float32): self.num_inference_steps = num_inference_steps # 1. Spacing @@ -238,7 +237,7 @@ class RadauIIAScheduler(SchedulerMixin, ConfigMixin): return np.abs(schedule_timesteps - timestep).argmin().item() - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -257,10 +256,10 @@ class RadauIIAScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) a_mat, b_vec, c_vec = self._get_tableau() diff --git a/modules/res4lyf/res_multistep_scheduler.py b/modules/res4lyf/res_multistep_scheduler.py index e324408ee..081e83307 100644 --- a/modules/res4lyf/res_multistep_scheduler.py +++ b/modules/res4lyf/res_multistep_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import ClassVar, List, Literal, Optional, Tuple, Union +from typing import ClassVar, Literal import numpy as np import torch @@ -49,7 +49,7 @@ class RESMultistepScheduler(SchedulerMixin, ConfigMixin): Whether to use high-precision analytic solutions for phi functions. """ - _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers] + _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config @@ -102,17 +102,17 @@ class RESMultistepScheduler(SchedulerMixin, ConfigMixin): self.init_noise_sigma = 1.0 @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: self._begin_index = begin_index - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -120,13 +120,12 @@ class RESMultistepScheduler(SchedulerMixin, ConfigMixin): sigma = self.sigmas[self._step_index] return sample / ((sigma**2 + 1) ** 0.5) - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32): from .scheduler_utils import ( apply_shift, get_dynamic_shift, get_sigmas_beta, get_sigmas_exponential, - get_sigmas_flow, get_sigmas_karras, ) @@ -208,10 +207,10 @@ class RESMultistepScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/res_multistep_sde_scheduler.py b/modules/res4lyf/res_multistep_sde_scheduler.py index 8ed98688b..adc40f832 100644 --- a/modules/res4lyf/res_multistep_sde_scheduler.py +++ b/modules/res4lyf/res_multistep_sde_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import ClassVar, List, Literal, Optional, Tuple, Union +from typing import ClassVar, Literal import numpy as np import torch @@ -38,7 +38,7 @@ class RESMultistepSDEScheduler(SchedulerMixin, ConfigMixin): The amount of noise to add during sampling (0.0 for ODE, 1.0 for full SDE). """ - _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers] + _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config @@ -92,17 +92,17 @@ class RESMultistepSDEScheduler(SchedulerMixin, ConfigMixin): self.init_noise_sigma = 1.0 @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: self._begin_index = begin_index - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -111,7 +111,7 @@ class RESMultistepSDEScheduler(SchedulerMixin, ConfigMixin): sample = sample / ((sigma**2 + 1) ** 0.5) return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32): from .scheduler_utils import ( apply_shift, get_dynamic_shift, @@ -188,11 +188,11 @@ class RESMultistepSDEScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, - generator: Optional[torch.Generator] = None, + generator: torch.Generator | None = None, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/res_singlestep_scheduler.py b/modules/res4lyf/res_singlestep_scheduler.py index 29146029f..86d10fd24 100644 --- a/modules/res4lyf/res_singlestep_scheduler.py +++ b/modules/res4lyf/res_singlestep_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import ClassVar, List, Literal, Optional, Tuple, Union +from typing import ClassVar, Literal import numpy as np import torch @@ -29,7 +29,7 @@ class RESSinglestepScheduler(SchedulerMixin, ConfigMixin): RESSinglestepScheduler (Multistage Exponential Integrator) ported from RES4LYF. """ - _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers] + _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config @@ -78,17 +78,17 @@ class RESSinglestepScheduler(SchedulerMixin, ConfigMixin): self.init_noise_sigma = 1.0 @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: self._begin_index = begin_index - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -96,7 +96,7 @@ class RESSinglestepScheduler(SchedulerMixin, ConfigMixin): sigma = self.sigmas[self._step_index] return sample / ((sigma**2 + 1) ** 0.5) - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32): from .scheduler_utils import ( apply_shift, get_dynamic_shift, @@ -183,10 +183,10 @@ class RESSinglestepScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/res_singlestep_sde_scheduler.py b/modules/res4lyf/res_singlestep_sde_scheduler.py index ef7fea5b9..a83b5b403 100644 --- a/modules/res4lyf/res_singlestep_sde_scheduler.py +++ b/modules/res4lyf/res_singlestep_sde_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import ClassVar, List, Literal, Optional, Tuple, Union +from typing import ClassVar, Literal import numpy as np import torch @@ -30,7 +30,7 @@ class RESSinglestepSDEScheduler(SchedulerMixin, ConfigMixin): RESSinglestepSDEScheduler (Stochastic Multistage Exponential Integrator) ported from RES4LYF. """ - _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers] + _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config @@ -80,17 +80,17 @@ class RESSinglestepSDEScheduler(SchedulerMixin, ConfigMixin): self.init_noise_sigma = 1.0 @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: self._begin_index = begin_index - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -99,7 +99,7 @@ class RESSinglestepSDEScheduler(SchedulerMixin, ConfigMixin): sample = sample / ((sigma**2 + 1) ** 0.5) return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32): from .scheduler_utils import ( apply_shift, get_dynamic_shift, @@ -173,11 +173,11 @@ class RESSinglestepSDEScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, - generator: Optional[torch.Generator] = None, + generator: torch.Generator | None = None, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/res_unified_scheduler.py b/modules/res4lyf/res_unified_scheduler.py index 5aa619db6..061517f10 100644 --- a/modules/res4lyf/res_unified_scheduler.py +++ b/modules/res4lyf/res_unified_scheduler.py @@ -1,4 +1,4 @@ -from typing import ClassVar, List, Optional, Tuple, Union +from typing import ClassVar import numpy as np import torch @@ -15,7 +15,7 @@ class RESUnifiedScheduler(SchedulerMixin, ConfigMixin): Supports DEIS 1S, 2M, 3M """ - _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers] + _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers] order: ClassVar[int] = 1 @register_to_config @@ -74,17 +74,17 @@ class RESUnifiedScheduler(SchedulerMixin, ConfigMixin): self._step_index = None @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: self._begin_index = begin_index - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -92,13 +92,12 @@ class RESUnifiedScheduler(SchedulerMixin, ConfigMixin): sigma = self.sigmas[self._step_index] return sample / ((sigma**2 + 1) ** 0.5) - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32): from .scheduler_utils import ( apply_shift, get_dynamic_shift, get_sigmas_beta, get_sigmas_exponential, - get_sigmas_flow, get_sigmas_karras, ) @@ -236,10 +235,10 @@ class RESUnifiedScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/riemannian_flow_scheduler.py b/modules/res4lyf/riemannian_flow_scheduler.py index 926c31c46..2b2ada55d 100644 --- a/modules/res4lyf/riemannian_flow_scheduler.py +++ b/modules/res4lyf/riemannian_flow_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import ClassVar, List, Literal, Optional, Tuple, Union +from typing import ClassVar, Literal import numpy as np import torch @@ -29,7 +29,7 @@ class RiemannianFlowScheduler(SchedulerMixin, ConfigMixin): Riemannian Flow scheduler using Exponential Integrator step. """ - _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers] + _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers] order: ClassVar[int] = 1 @register_to_config @@ -84,17 +84,17 @@ class RiemannianFlowScheduler(SchedulerMixin, ConfigMixin): self._begin_index = None @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32): from .scheduler_utils import ( apply_shift, get_dynamic_shift, @@ -202,7 +202,7 @@ class RiemannianFlowScheduler(SchedulerMixin, ConfigMixin): from .scheduler_utils import add_noise_to_sample return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -214,10 +214,10 @@ class RiemannianFlowScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/rungekutta_44s_scheduler.py b/modules/res4lyf/rungekutta_44s_scheduler.py index be6efe9da..d18941e01 100644 --- a/modules/res4lyf/rungekutta_44s_scheduler.py +++ b/modules/res4lyf/rungekutta_44s_scheduler.py @@ -1,4 +1,3 @@ -from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -23,16 +22,16 @@ class RungeKutta44Scheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + trained_betas: np.ndarray | list[float] | None = None, prediction_type: str = "epsilon", use_karras_sigmas: bool = False, use_exponential_sigmas: bool = False, use_beta_sigmas: bool = False, use_flow_sigmas: bool = False, - sigma_min: Optional[float] = None, - sigma_max: Optional[float] = None, + sigma_min: float | None = None, + sigma_max: float | None = None, rho: float = 7.0, - shift: Optional[float] = None, + shift: float | None = None, base_shift: float = 0.5, max_shift: float = 1.15, use_dynamic_shifting: bool = False, @@ -69,7 +68,7 @@ class RungeKutta44Scheduler(SchedulerMixin, ConfigMixin): self._sigmas_cpu = None self._step_index = None - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32): self.num_inference_steps = num_inference_steps # 1. Base sigmas @@ -141,7 +140,7 @@ class RungeKutta44Scheduler(SchedulerMixin, ConfigMixin): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -152,10 +151,10 @@ class RungeKutta44Scheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/rungekutta_57s_scheduler.py b/modules/res4lyf/rungekutta_57s_scheduler.py index d3f6b2297..5d118bff7 100644 --- a/modules/res4lyf/rungekutta_57s_scheduler.py +++ b/modules/res4lyf/rungekutta_57s_scheduler.py @@ -1,4 +1,3 @@ -from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -21,16 +20,16 @@ class RungeKutta57Scheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + trained_betas: np.ndarray | list[float] | None = None, prediction_type: str = "epsilon", use_karras_sigmas: bool = False, use_exponential_sigmas: bool = False, use_beta_sigmas: bool = False, use_flow_sigmas: bool = False, - sigma_min: Optional[float] = None, - sigma_max: Optional[float] = None, + sigma_min: float | None = None, + sigma_max: float | None = None, rho: float = 7.0, - shift: Optional[float] = None, + shift: float | None = None, base_shift: float = 0.5, max_shift: float = 1.15, use_dynamic_shifting: bool = False, @@ -72,8 +71,8 @@ class RungeKutta57Scheduler(SchedulerMixin, ConfigMixin): def set_timesteps( self, num_inference_steps: int, - device: Union[str, torch.device] = None, - mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + device: str | torch.device = None, + mu: float | None = None, dtype: torch.dtype = torch.float32): self.num_inference_steps = num_inference_steps # 1. Spacing @@ -178,7 +177,7 @@ class RungeKutta57Scheduler(SchedulerMixin, ConfigMixin): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -189,10 +188,10 @@ class RungeKutta57Scheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/rungekutta_67s_scheduler.py b/modules/res4lyf/rungekutta_67s_scheduler.py index b2c13ad47..55af1b16c 100644 --- a/modules/res4lyf/rungekutta_67s_scheduler.py +++ b/modules/res4lyf/rungekutta_67s_scheduler.py @@ -1,4 +1,3 @@ -from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -22,16 +21,16 @@ class RungeKutta67Scheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + trained_betas: np.ndarray | list[float] | None = None, prediction_type: str = "epsilon", use_karras_sigmas: bool = False, use_exponential_sigmas: bool = False, use_beta_sigmas: bool = False, use_flow_sigmas: bool = False, - sigma_min: Optional[float] = None, - sigma_max: Optional[float] = None, + sigma_min: float | None = None, + sigma_max: float | None = None, rho: float = 7.0, - shift: Optional[float] = None, + shift: float | None = None, base_shift: float = 0.5, max_shift: float = 1.15, use_dynamic_shifting: bool = False, @@ -72,8 +71,8 @@ class RungeKutta67Scheduler(SchedulerMixin, ConfigMixin): def set_timesteps( self, num_inference_steps: int, - device: Union[str, torch.device] = None, - mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + device: str | torch.device = None, + mu: float | None = None, dtype: torch.dtype = torch.float32): self.num_inference_steps = num_inference_steps # 1. Spacing @@ -177,7 +176,7 @@ class RungeKutta67Scheduler(SchedulerMixin, ConfigMixin): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -188,10 +187,10 @@ class RungeKutta67Scheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/simple_exponential_scheduler.py b/modules/res4lyf/simple_exponential_scheduler.py index 52e678ca9..01a901e12 100644 --- a/modules/res4lyf/simple_exponential_scheduler.py +++ b/modules/res4lyf/simple_exponential_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import ClassVar, List, Optional, Tuple, Union +from typing import ClassVar import numpy as np import torch @@ -29,7 +29,7 @@ class SimpleExponentialScheduler(SchedulerMixin, ConfigMixin): Simple Exponential sigma scheduler using Exponential Integrator step. """ - _compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers] + _compatibles: ClassVar[list[str]] = [e.name for e in KarrasDiffusionSchedulers] order: ClassVar[int] = 1 @register_to_config @@ -85,17 +85,17 @@ class SimpleExponentialScheduler(SchedulerMixin, ConfigMixin): self._begin_index = None @property - def step_index(self) -> Optional[int]: + def step_index(self) -> int | None: return self._step_index @property - def begin_index(self) -> Optional[int]: + def begin_index(self) -> int | None: return self._begin_index def set_begin_index(self, begin_index: int = 0) -> None: self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + def set_timesteps(self, num_inference_steps: int, device: str | torch.device = None, mu: float | None = None, dtype: torch.dtype = torch.float32): from .scheduler_utils import ( apply_shift, get_dynamic_shift, @@ -152,7 +152,7 @@ class SimpleExponentialScheduler(SchedulerMixin, ConfigMixin): from .scheduler_utils import add_noise_to_sample return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -164,10 +164,10 @@ class SimpleExponentialScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: if self._step_index is None: self._init_step_index(timestep) diff --git a/modules/res4lyf/specialized_rk_scheduler.py b/modules/res4lyf/specialized_rk_scheduler.py index fa9b23a2e..33b6df815 100644 --- a/modules/res4lyf/specialized_rk_scheduler.py +++ b/modules/res4lyf/specialized_rk_scheduler.py @@ -1,4 +1,3 @@ -from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -23,17 +22,17 @@ class SpecializedRKScheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + trained_betas: np.ndarray | list[float] | None = None, prediction_type: str = "epsilon", variant: str = "ssprk3_3s", # ssprk3_3s, ssprk4_4s, tsi_7s, ralston_4s, bogacki-shampine_4s use_karras_sigmas: bool = False, use_exponential_sigmas: bool = False, use_beta_sigmas: bool = False, use_flow_sigmas: bool = False, - sigma_min: Optional[float] = None, - sigma_max: Optional[float] = None, + sigma_min: float | None = None, + sigma_max: float | None = None, rho: float = 7.0, - shift: Optional[float] = None, + shift: float | None = None, base_shift: float = 0.5, max_shift: float = 1.15, use_dynamic_shifting: bool = False, @@ -107,8 +106,8 @@ class SpecializedRKScheduler(SchedulerMixin, ConfigMixin): def set_timesteps( self, num_inference_steps: int, - device: Union[str, torch.device] = None, - mu: Optional[float] = None, dtype: torch.dtype = torch.float32): + device: str | torch.device = None, + mu: float | None = None, dtype: torch.dtype = torch.float32): self.num_inference_steps = num_inference_steps # 1. Spacing @@ -211,7 +210,7 @@ class SpecializedRKScheduler(SchedulerMixin, ConfigMixin): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + def scale_model_input(self, sample: torch.Tensor, timestep: float | torch.Tensor) -> torch.Tensor: if self._step_index is None: self._init_step_index(timestep) if self.config.prediction_type == "flow_prediction": @@ -222,10 +221,10 @@ class SpecializedRKScheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: Union[float, torch.Tensor], + timestep: float | torch.Tensor, sample: torch.Tensor, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> SchedulerOutput | tuple: self._init_step_index(timestep) a_mat, b_vec, c_vec = self._get_tableau() num_stages = len(c_vec) diff --git a/modules/rife/__init__.py b/modules/rife/__init__.py index c4db35645..ba2a66d8b 100644 --- a/modules/rife/__init__.py +++ b/modules/rife/__init__.py @@ -12,7 +12,7 @@ from torch.nn import functional as F from tqdm.rich import tqdm from modules.rife.ssim import ssim_matlab from modules.rife.model_rife import RifeModel -from modules import devices, shared +from modules import devices, shared, paths model_url = 'https://github.com/vladmandic/rife/raw/main/model/flownet-v46.pkl' @@ -23,7 +23,7 @@ def load(model_path: str = 'rife/flownet-v46.pkl'): global model # pylint: disable=global-statement if model is None: from modules import modelloader - model_dir = os.path.join(shared.models_path, 'RIFE') + model_dir = os.path.join(paths.models_path, 'RIFE') model_path = modelloader.load_file_from_url(url=model_url, model_dir=model_dir, file_name='flownet-v46.pkl') shared.log.debug(f'Video interpolate: model="{model_path}"') model = RifeModel() @@ -104,7 +104,7 @@ def interpolate(images: list, count: int = 2, scale: float = 1.0, pad: int = 1, else: output = execute(I0, I1, count-1) for mid in output: - mid = (((mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0))) + mid = ((mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0)) buffer.put(mid[:h, :w]) buffer.put(frame) pbar.update(1) diff --git a/modules/rife/loss.py b/modules/rife/loss.py index 8b6309006..f525ff443 100644 --- a/modules/rife/loss.py +++ b/modules/rife/loss.py @@ -8,7 +8,7 @@ from modules import devices class EPE(nn.Module): def __init__(self): - super(EPE, self).__init__() + super().__init__() def forward(self, flow, gt, loss_mask): loss_map = (flow - gt.detach()) ** 2 @@ -18,7 +18,7 @@ class EPE(nn.Module): class Ternary(nn.Module): def __init__(self): - super(Ternary, self).__init__() + super().__init__() patch_size = 7 out_channels = patch_size * patch_size self.w = np.eye(out_channels).reshape( @@ -56,7 +56,7 @@ class Ternary(nn.Module): class SOBEL(nn.Module): def __init__(self): - super(SOBEL, self).__init__() + super().__init__() self.kernelX = torch.tensor([ [1, 0, -1], [2, 0, -2], @@ -82,7 +82,7 @@ class SOBEL(nn.Module): class MeanShift(nn.Conv2d): def __init__(self, data_mean, data_std, data_range=1, norm=True): c = len(data_mean) - super(MeanShift, self).__init__(c, c, kernel_size=1) + super().__init__(c, c, kernel_size=1) std = torch.Tensor(data_std) self.weight.data = torch.eye(c).view(c, c, 1, 1) if norm: @@ -97,7 +97,7 @@ class MeanShift(nn.Conv2d): class VGGPerceptualLoss(torch.nn.Module): def __init__(self, rank=0): # pylint: disable=unused-argument - super(VGGPerceptualLoss, self).__init__() + super().__init__() pretrained = True self.vgg_pretrained_features = models.vgg19( pretrained=pretrained).features diff --git a/modules/rife/model_ifnet.py b/modules/rife/model_ifnet.py index 843430bee..df32b59a1 100644 --- a/modules/rife/model_ifnet.py +++ b/modules/rife/model_ifnet.py @@ -28,7 +28,7 @@ def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation= class ResConv(nn.Module): def __init__(self, c, dilation=1): - super(ResConv, self).__init__() + super().__init__() self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1\ ) self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) @@ -39,7 +39,7 @@ class ResConv(nn.Module): class IFBlock(nn.Module): def __init__(self, in_planes, c=64): - super(IFBlock, self).__init__() + super().__init__() self.conv0 = nn.Sequential( conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1), @@ -74,7 +74,7 @@ class IFBlock(nn.Module): class IFNet(nn.Module): def __init__(self): - super(IFNet, self).__init__() + super().__init__() self.block0 = IFBlock(7, c=192) self.block1 = IFBlock(8+4, c=128) self.block2 = IFBlock(8+4, c=96) @@ -82,7 +82,9 @@ class IFNet(nn.Module): # self.contextnet = Contextnet() # self.unet = Unet() - def forward( self, x, timestep=0.5, scale_list=[8, 4, 2, 1], training=False, fastmode=True, ensemble=False): # pylint: disable=dangerous-default-value, unused-argument + def forward( self, x, timestep=0.5, scale_list=None, training=False, fastmode=True, ensemble=False): # pylint: disable=dangerous-default-value, unused-argument + if scale_list is None: + scale_list = [8, 4, 2, 1] if training is False: channel = x.shape[1] // 2 img0 = x[:, :channel] diff --git a/modules/rife/refine.py b/modules/rife/refine.py index 5d77582cc..9ea076d1f 100644 --- a/modules/rife/refine.py +++ b/modules/rife/refine.py @@ -29,7 +29,7 @@ def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): # pylint: class Conv2(nn.Module): def __init__(self, in_planes, out_planes, stride=2): - super(Conv2, self).__init__() + super().__init__() self.conv1 = conv(in_planes, out_planes, 3, stride, 1) self.conv2 = conv(out_planes, out_planes, 3, 1, 1) @@ -41,7 +41,7 @@ class Conv2(nn.Module): class Contextnet(nn.Module): def __init__(self): - super(Contextnet, self).__init__() + super().__init__() self.conv1 = Conv2(3, c) self.conv2 = Conv2(c, 2*c) self.conv3 = Conv2(2*c, 4*c) @@ -65,7 +65,7 @@ class Contextnet(nn.Module): class Unet(nn.Module): def __init__(self): - super(Unet, self).__init__() + super().__init__() self.down0 = Conv2(17, 2*c) self.down1 = Conv2(4*c, 4*c) self.down2 = Conv2(8*c, 8*c) diff --git a/modules/rife/ssim.py b/modules/rife/ssim.py index e2261ca7a..8233ec8f3 100644 --- a/modules/rife/ssim.py +++ b/modules/rife/ssim.py @@ -142,7 +142,7 @@ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normal # Classes to re-use window class SSIM(torch.nn.Module): def __init__(self, window_size=11, size_average=True, val_range=None): - super(SSIM, self).__init__() + super().__init__() self.window_size = window_size self.size_average = size_average self.val_range = val_range @@ -165,7 +165,7 @@ class SSIM(torch.nn.Module): class MSSSIM(torch.nn.Module): def __init__(self, window_size=11, size_average=True, channel=3): - super(MSSSIM, self).__init__() + super().__init__() self.window_size = window_size self.size_average = size_average self.channel = channel diff --git a/modules/rocm.py b/modules/rocm.py index c0b8b8df1..42af4db75 100644 --- a/modules/rocm.py +++ b/modules/rocm.py @@ -5,14 +5,14 @@ import ctypes import shutil import subprocess from types import ModuleType -from typing import Union, overload, TYPE_CHECKING +from typing import overload, TYPE_CHECKING from enum import Enum from functools import wraps if TYPE_CHECKING: import torch -rocm_sdk: Union[ModuleType, None] = None +rocm_sdk: ModuleType | None = None def resolve_link(path_: str) -> str: @@ -27,7 +27,7 @@ def dirname(path_: str, r: int = 1) -> str: return path_ -def spawn(command: Union[str, list[str]], cwd: os.PathLike = '.') -> str: +def spawn(command: str | list[str], cwd: os.PathLike = '.') -> str: process = subprocess.run(command, cwd=cwd, shell=True, check=False, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) return process.stdout.decode(encoding="utf8", errors="ignore") @@ -116,7 +116,7 @@ class Agent: return self.name @property - def therock(self) -> Union[str, None]: + def therock(self) -> str | None: if (self.gfx_version & 0xFFF0) == 0x1200: return "v2/gfx120X-all" if (self.gfx_version & 0xFFF0) == 0x1100: @@ -141,7 +141,7 @@ class Agent: # return "gfx950-dcgpu" return None - def get_gfx_version(self) -> Union[str, None]: + def get_gfx_version(self) -> str | None: if self.gfx_version is None: return None if self.gfx_version >= 0x1100 and self.gfx_version < 0x1200: @@ -153,7 +153,7 @@ class Agent: return None -def find() -> Union[ROCmEnvironment, None]: +def find() -> ROCmEnvironment | None: hip_path = shutil.which("hipconfig") if hip_path is not None: return ROCmEnvironment(dirname(resolve_link(hip_path), 2)) @@ -364,7 +364,6 @@ else: # sys.platform != "win32" def rocm_init(): try: - import torch from installer import log from modules.devices import get_hip_agent @@ -377,10 +376,10 @@ else: # sys.platform != "win32" is_wsl: bool = os.environ.get('WSL_DISTRO_NAME', 'unknown' if spawn('wslpath -w /') else None) is not None -environment: Union[Environment, None] = None -blaslt_tensile_libpath: Union[str, None] = None +environment: Environment | None = None +blaslt_tensile_libpath: str | None = None is_installed: bool = False -version: Union[str, None] = None +version: str | None = None refresh() # amdgpu-arch.exe written in Python diff --git a/modules/rocm_triton_windows.py b/modules/rocm_triton_windows.py index 4bcbaff18..022e27b10 100644 --- a/modules/rocm_triton_windows.py +++ b/modules/rocm_triton_windows.py @@ -1,5 +1,4 @@ import sys -from typing import Union import torch from modules import shared, devices from modules.rocm import Agent @@ -58,7 +57,7 @@ if sys.platform == "win32": from modules import zluda return zluda.core.to_hip_stream(_cuda_getCurrentRawStream(device)) - def get_default_agent() -> Union[Agent, None]: + def get_default_agent() -> Agent | None: if shared.devices.has_rocm(): return devices.get_hip_agent() else: diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index a0c85a283..fad126ce1 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,7 +2,7 @@ import os import sys import time from collections import namedtuple -from typing import Optional, Dict, Any +from typing import Any from fastapi import FastAPI from gradio import Blocks import modules.errors as errors @@ -149,7 +149,7 @@ def clear_callbacks(): callback_list.clear() -def app_started_callback(demo: Optional[Blocks], app: FastAPI): +def app_started_callback(demo: Blocks | None, app: FastAPI): for c in callback_map['callbacks_app_started']: try: t0 = time.time() @@ -319,7 +319,7 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(e, c, 'image_grid') -def infotext_pasted_callback(infotext: str, params: Dict[str, Any]): +def infotext_pasted_callback(infotext: str, params: dict[str, Any]): for c in callback_map['callbacks_infotext_pasted']: try: t0 = time.time() diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py index a1ebc104e..2b1cbb847 100644 --- a/modules/scripts_auto_postprocessing.py +++ b/modules/scripts_auto_postprocessing.py @@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts_manager.Script): return self.postprocessing_controls.values() def postprocess_image(self, p, script_pp, *args): # pylint: disable=arguments-differ - args_dict = dict(zip(self.postprocessing_controls, args)) + args_dict = dict(zip(self.postprocessing_controls, args, strict=False)) pp = scripts_postprocessing.PostprocessedImage(script_pp.image) pp.info = {} self.script.process(pp, **args_dict) diff --git a/modules/scripts_manager.py b/modules/scripts_manager.py index 7ca38f5cd..1ac69bdf5 100644 --- a/modules/scripts_manager.py +++ b/modules/scripts_manager.py @@ -234,7 +234,7 @@ def list_scripts(scriptdirname, extension): else: priority = '9' if os.path.isfile(os.path.join(base, "..", ".priority")): - with open(os.path.join(base, "..", ".priority"), "r", encoding="utf-8") as f: + with open(os.path.join(base, "..", ".priority"), encoding="utf-8") as f: priority = priority + str(f.read().strip()) errors.log.debug(f'Script priority override: ${script.name}:{priority}') else: diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index e72e16d55..1bea4e219 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -4,7 +4,9 @@ from modules import errors, shared class PostprocessedImage: - def __init__(self, image, info = {}): + def __init__(self, image, info = None): + if info is None: + info = {} self.image = image self.info = info @@ -99,7 +101,7 @@ class ScriptPostprocessingRunner: jobid = shared.state.begin(script.name) script_args = args[script.args_from:script.args_to] process_args = {} - for (name, _component), value in zip(script.controls.items(), script_args): + for (name, _component), value in zip(script.controls.items(), script_args, strict=False): process_args[name] = value shared.log.debug(f'Process: script="{script.name}" args={process_args}') script.process(pp, **process_args) @@ -129,7 +131,7 @@ class ScriptPostprocessingRunner: jobid = shared.state.begin(script.name) script_args = args[script.args_from:script.args_to] process_args = {} - for (name, _component), value in zip(script.controls.items(), script_args): + for (name, _component), value in zip(script.controls.items(), script_args, strict=False): process_args[name] = value shared.log.debug(f'Postprocess: script={script.name} args={process_args}') script.postprocess(filenames, **process_args) diff --git a/modules/sd_checkpoint.py b/modules/sd_checkpoint.py index 874cb6729..c050811fe 100644 --- a/modules/sd_checkpoint.py +++ b/modules/sd_checkpoint.py @@ -149,11 +149,11 @@ def list_models(): checkpoint_info.register() if shared.cmd_opts.ckpt is not None: checkpoint_info = CheckpointInfo(shared.cmd_opts.ckpt) - if checkpoint_info.name is not None: + if checkpoint_info.name is not None and os.path.exists(checkpoint_info.filename): checkpoint_info.register() shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title - elif shared.cmd_opts.ckpt != shared.default_sd_model_file and shared.cmd_opts.ckpt is not None: - shared.log.warning(f'Load model: path="{shared.cmd_opts.ckpt}" not found') + elif shared.cmd_opts.ckpt != shared.default_sd_model_file: + shared.log.warning(f'Load model: path="{shared.cmd_opts.ckpt}" not found') shared.log.info(f'Available Models: safetensors="{shared.opts.ckpt_dir}":{len(safetensors_list)} diffusers="{shared.opts.diffusers_dir}":{len(diffusers_list)} reference={len(list(shared.reference_models))} items={len(checkpoints_list)} time={time.time()-t0:.2f}') checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename)) diff --git a/modules/sd_hijack_accelerate.py b/modules/sd_hijack_accelerate.py index 834038a41..fd8f4a3cb 100644 --- a/modules/sd_hijack_accelerate.py +++ b/modules/sd_hijack_accelerate.py @@ -1,4 +1,3 @@ -from typing import Optional, Union import time import torch import torch.nn as nn @@ -17,10 +16,10 @@ orig_torch_conv = torch.nn.modules.conv.Conv2d._conv_forward # pylint: disable=p def hijack_set_module_tensor( module: nn.Module, tensor_name: str, - device: Union[int, str, torch.device], - value: Optional[torch.Tensor] = None, - dtype: Optional[Union[str, torch.dtype]] = None, # pylint: disable=unused-argument - fp16_statistics: Optional[torch.HalfTensor] = None, # pylint: disable=unused-argument + device: int | str | torch.device, + value: torch.Tensor | None = None, + dtype: str | torch.dtype | None = None, # pylint: disable=unused-argument + fp16_statistics: torch.HalfTensor | None = None, # pylint: disable=unused-argument ): global tensor_to_timer # pylint: disable=global-statement if device == 'cpu': # override to load directly to gpu @@ -46,10 +45,10 @@ def hijack_set_module_tensor( def hijack_set_module_tensor_simple( module: nn.Module, tensor_name: str, - device: Union[int, str, torch.device], - value: Optional[torch.Tensor] = None, - dtype: Optional[Union[str, torch.dtype]] = None, # pylint: disable=unused-argument - fp16_statistics: Optional[torch.HalfTensor] = None, # pylint: disable=unused-argument + device: int | str | torch.device, + value: torch.Tensor | None = None, + dtype: str | torch.dtype | None = None, # pylint: disable=unused-argument + fp16_statistics: torch.HalfTensor | None = None, # pylint: disable=unused-argument ): global tensor_to_timer # pylint: disable=global-statement if device == 'cpu': # override to load directly to gpu diff --git a/modules/sd_hijack_dynamic_atten.py b/modules/sd_hijack_dynamic_atten.py index c3202ad16..6410b5a71 100644 --- a/modules/sd_hijack_dynamic_atten.py +++ b/modules/sd_hijack_dynamic_atten.py @@ -1,8 +1,6 @@ -from typing import Tuple, Optional from functools import cache, wraps import torch -from diffusers.utils import USE_PEFT_BACKEND # pylint: disable=unused-import from modules import shared, devices @@ -21,7 +19,7 @@ def find_split_size(original_size: int, slice_block_size: int, slice_rate: int = # Find slice sizes for SDPA @cache -def find_sdpa_slice_sizes(query_shape: Tuple[int], key_shape: Tuple[int], query_element_size: int, slice_rate: int = 2, trigger_rate: int = 3) -> Tuple[bool, int]: +def find_sdpa_slice_sizes(query_shape: tuple[int], key_shape: tuple[int], query_element_size: int, slice_rate: int = 2, trigger_rate: int = 3) -> tuple[bool, int]: batch_size, attn_heads, query_len, _ = query_shape _, _, key_len, _ = key_shape @@ -55,7 +53,7 @@ def find_sdpa_slice_sizes(query_shape: Tuple[int], key_shape: Tuple[int], query_ if devices.sdpa_pre_dyanmic_atten is None: devices.sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention @wraps(devices.sdpa_pre_dyanmic_atten) -def dynamic_scaled_dot_product_attention(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.FloatTensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: +def dynamic_scaled_dot_product_attention(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.FloatTensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: is_unsqueezed = False if query.dim() == 3: query = query.unsqueeze(0) diff --git a/modules/sd_hijack_hypertile.py b/modules/sd_hijack_hypertile.py index ec8f0c64f..f25ac8078 100644 --- a/modules/sd_hijack_hypertile.py +++ b/modules/sd_hijack_hypertile.py @@ -2,7 +2,7 @@ # based on: https://github.com/tfernd/HyperTile/tree/main/hyper_tile/utils.py + https://github.com/tfernd/HyperTile/tree/main/hyper_tile/hyper_tile.py from __future__ import annotations -from typing import Callable +from collections.abc import Callable from functools import wraps, cache from contextlib import contextmanager, nullcontext import random diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py index 179ebc78e..e20a247b3 100644 --- a/modules/sd_hijack_utils.py +++ b/modules/sd_hijack_utils.py @@ -2,7 +2,7 @@ import importlib class CondFunc: def __new__(cls, orig_func, sub_func, cond_func): - self = super(CondFunc, cls).__new__(cls) + self = super().__new__(cls) if isinstance(orig_func, str): func_path = orig_func.split('.') for i in range(len(func_path)-1, -1, -1): diff --git a/modules/sd_models.py b/modules/sd_models.py index a88acded9..6590ebf3c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,9 +14,9 @@ from installer import log from modules import timer, paths, shared, shared_items, modelloader, devices, script_callbacks, sd_vae, sd_unet, errors, sd_models_compile, sd_detect, model_quant, sd_hijack_te, sd_hijack_accelerate, sd_hijack_safetensors, attention from modules.memstats import memory_stats from modules.modeldata import model_data -from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, sd_metadata_file, checkpoints_list, checkpoint_titles, get_closest_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import +from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoint_titles, get_closest_checkpoint_match, update_model_hashes, write_metadata, checkpoints_list # pylint: disable=unused-import from modules.sd_offload import get_module_names, disable_offload, set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import -from modules.sd_models_utils import NoWatermark, get_signature, get_call, path_to_repo, patch_diffuser_config, convert_to_faketensors, read_state_dict, get_state_dict_from_checkpoint, apply_function_to_model # pylint: disable=unused-import +from modules.sd_models_utils import NoWatermark, get_signature, path_to_repo, apply_function_to_model, read_state_dict, get_state_dict_from_checkpoint # pylint: disable=unused-import model_dir = "Stable-diffusion" diff --git a/modules/sd_models_compile.py b/modules/sd_models_compile.py index 7c91b9bfc..564a64556 100644 --- a/modules/sd_models_compile.py +++ b/modules/sd_models_compile.py @@ -65,7 +65,6 @@ def ipex_optimize(sd_model, apply_to_components=True, op="Model"): def optimize_openvino(sd_model, clear_cache=True): try: - from modules.intel.openvino import openvino_fx # pylint: disable=unused-import if clear_cache and shared.compiled_model_state is not None: shared.compiled_model_state.compiled_cache.clear() shared.compiled_model_state.req_cache.clear() @@ -124,12 +123,10 @@ def compile_stablefast(sd_model): return sd_model config = sf.CompilationConfig.Default() try: - import xformers # pylint: disable=unused-import config.enable_xformers = True except Exception: pass try: - import triton # pylint: disable=unused-import config.enable_triton = True except Exception: pass @@ -196,7 +193,7 @@ def compile_torch(sd_model, apply_to_components=True, op="Model"): shared.compiled_model_state = CompiledModelState() return sd_model elif shared.opts.cuda_compile_backend == "migraphx": - import torch_migraphx # pylint: disable=unused-import + pass # pylint: disable=unused-import log_level = logging.WARNING if 'verbose' in shared.opts.cuda_compile_options else logging.CRITICAL # pylint: disable=protected-access if hasattr(torch, '_logging'): torch._logging.set_logs(dynamo=log_level, aot=log_level, inductor=log_level) # pylint: disable=protected-access diff --git a/modules/sd_models_utils.py b/modules/sd_models_utils.py index 05f44ef6c..f766270b7 100644 --- a/modules/sd_models_utils.py +++ b/modules/sd_models_utils.py @@ -8,8 +8,7 @@ import torch import safetensors.torch from modules import paths, shared, errors -from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closest_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import -from modules.sd_offload import disable_offload, set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import +from modules.sd_checkpoint import CheckpointInfo # pylint: disable=unused-import class NoWatermark: @@ -124,11 +123,11 @@ def patch_diffuser_config(sd_model, model_file): cfg_file = f'{model_file}_{k}.json' try: if os.path.exists(cfg_file): - with open(cfg_file, 'r', encoding='utf-8') as f: + with open(cfg_file, encoding='utf-8') as f: return json.load(f) cfg_file = f'{os.path.join(paths.sd_configs_path, os.path.basename(model_file))}_{k}.json' if os.path.exists(cfg_file): - with open(cfg_file, 'r', encoding='utf-8') as f: + with open(cfg_file, encoding='utf-8') as f: return json.load(f) except Exception: pass diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 8c1eb1ecd..c0e45f7a5 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,7 +1,6 @@ import os import copy from modules import shared -from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # pylint: disable=unused-import debug = shared.log.trace if os.environ.get('SD_SAMPLER_DEBUG', None) is not None else lambda *args, **kwargs: None diff --git a/modules/sd_te_remote.py b/modules/sd_te_remote.py index 264eeb145..cdc743e0d 100644 --- a/modules/sd_te_remote.py +++ b/modules/sd_te_remote.py @@ -1,4 +1,3 @@ -from typing import List, Optional, Union import os import time import json @@ -8,11 +7,11 @@ from modules import devices, errors def get_t5_prompt_embeds( - prompt: Union[str, List[str]] = None, + prompt: str | list[str] = None, num_images_per_prompt: int = 1, # pylint: disable=unused-argument max_sequence_length: int = 512, # pylint: disable=unused-argument - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ): device = device or devices.device dtype = dtype or devices.dtype diff --git a/modules/sdnq/dequantizer.py b/modules/sdnq/dequantizer.py index ff1036260..882af5802 100644 --- a/modules/sdnq/dequantizer.py +++ b/modules/sdnq/dequantizer.py @@ -1,6 +1,5 @@ # pylint: disable=redefined-builtin,no-member,protected-access -from typing import List, Tuple, Optional from dataclasses import dataclass import torch @@ -13,7 +12,7 @@ from .layers import SDNQLayer @devices.inference_context() -def dequantize_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False) -> torch.FloatTensor: +def dequantize_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None, skip_quantized_matmul: bool = False) -> torch.FloatTensor: result = torch.addcmul(zero_point, weight.to(dtype=scale.dtype), scale) if result_shape is not None: result = result.view(result_shape) @@ -34,7 +33,7 @@ def dequantize_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, ze @devices.inference_context() -def dequantize_symmetric(weight: torch.CharTensor, scale: torch.FloatTensor, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor: +def dequantize_symmetric(weight: torch.CharTensor, scale: torch.FloatTensor, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor: result = weight.to(dtype=scale.dtype).mul_(scale) if skip_quantized_matmul and not re_quantize_for_matmul: result.t_() @@ -57,7 +56,7 @@ def dequantize_symmetric(weight: torch.CharTensor, scale: torch.FloatTensor, svd @devices.inference_context() -def dequantize_symmetric_with_bias(weight: torch.CharTensor, scale: torch.FloatTensor, bias: torch.FloatTensor, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None) -> torch.FloatTensor: +def dequantize_symmetric_with_bias(weight: torch.CharTensor, scale: torch.FloatTensor, bias: torch.FloatTensor, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None) -> torch.FloatTensor: result = torch.addcmul(bias, weight.to(dtype=scale.dtype), scale) if result_shape is not None: result = result.view(result_shape) @@ -67,48 +66,48 @@ def dequantize_symmetric_with_bias(weight: torch.CharTensor, scale: torch.FloatT @devices.inference_context() -def dequantize_packed_int_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False) -> torch.FloatTensor: +def dequantize_packed_int_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None, skip_quantized_matmul: bool = False) -> torch.FloatTensor: return dequantize_asymmetric(unpack_int_asymetric(weight, shape, weights_dtype), scale, zero_point, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=result_shape, skip_quantized_matmul=skip_quantized_matmul) @devices.inference_context() -def dequantize_packed_int_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor: +def dequantize_packed_int_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor: return dequantize_symmetric(unpack_int_symetric(weight, shape, weights_dtype, dtype=scale.dtype), scale, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=result_shape, skip_quantized_matmul=skip_quantized_matmul, re_quantize_for_matmul=re_quantize_for_matmul) @devices.inference_context() -def dequantize_packed_float_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False) -> torch.FloatTensor: +def dequantize_packed_float_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None, skip_quantized_matmul: bool = False) -> torch.FloatTensor: return dequantize_asymmetric(unpack_float(weight, shape, weights_dtype), scale, zero_point, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=result_shape, skip_quantized_matmul=skip_quantized_matmul) @devices.inference_context() -def dequantize_packed_float_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None, dtype: Optional[torch.dtype] = None, result_shape: Optional[torch.Size] = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor: +def dequantize_packed_float_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None, dtype: torch.dtype | None = None, result_shape: torch.Size | None = None, skip_quantized_matmul: bool = False, re_quantize_for_matmul: bool = False) -> torch.FloatTensor: return dequantize_symmetric(unpack_float(weight, shape, weights_dtype), scale, svd_up=svd_up, svd_down=svd_down, dtype=dtype, result_shape=result_shape, skip_quantized_matmul=skip_quantized_matmul, re_quantize_for_matmul=re_quantize_for_matmul) @devices.inference_context() -def quantize_int_mm(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "int8") -> Tuple[torch.Tensor, torch.FloatTensor]: +def quantize_int_mm(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "int8") -> tuple[torch.Tensor, torch.FloatTensor]: scale = torch.amax(input.abs(), dim=dim, keepdims=True).div_(dtype_dict[matmul_dtype]["max"]) input = torch.div(input, scale).round_().clamp_(dtype_dict[matmul_dtype]["min"], dtype_dict[matmul_dtype]["max"]).to(dtype=dtype_dict[matmul_dtype]["torch_dtype"]) return input, scale @devices.inference_context() -def quantize_int_mm_sr(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "int8") -> Tuple[torch.Tensor, torch.FloatTensor]: +def quantize_int_mm_sr(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "int8") -> tuple[torch.Tensor, torch.FloatTensor]: scale = torch.amax(input.abs(), dim=dim, keepdims=True).div_(dtype_dict[matmul_dtype]["max"]) input = torch.div(input, scale).add_(torch.randn_like(input), alpha=0.1).round_().clamp_(dtype_dict[matmul_dtype]["min"], dtype_dict[matmul_dtype]["max"]).to(dtype=dtype_dict[matmul_dtype]["torch_dtype"]) return input, scale @devices.inference_context() -def quantize_fp_mm(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "float8_e4m3fn") -> Tuple[torch.Tensor, torch.FloatTensor]: +def quantize_fp_mm(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]: scale = torch.amax(input.abs(), dim=dim, keepdims=True).div_(dtype_dict[matmul_dtype]["max"]) input = torch.div(input, scale).nan_to_num_().clamp_(dtype_dict[matmul_dtype]["min"], dtype_dict[matmul_dtype]["max"]).to(dtype=dtype_dict[matmul_dtype]["torch_dtype"]) return input, scale @devices.inference_context() -def quantize_fp_mm_sr(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "float8_e4m3fn") -> Tuple[torch.Tensor, torch.FloatTensor]: +def quantize_fp_mm_sr(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]: mantissa_difference = 1 << (23 - dtype_dict[matmul_dtype]["mantissa"]) scale = torch.amax(input.abs(), dim=dim, keepdims=True).div_(dtype_dict[matmul_dtype]["max"]) input = torch.div(input, scale).to(dtype=torch.float32).view(dtype=torch.int32) @@ -118,7 +117,7 @@ def quantize_fp_mm_sr(input: torch.FloatTensor, dim: int = -1, matmul_dtype: str @devices.inference_context() -def re_quantize_int_mm(weight: torch.FloatTensor) -> Tuple[torch.Tensor, torch.FloatTensor]: +def re_quantize_int_mm(weight: torch.FloatTensor) -> tuple[torch.Tensor, torch.FloatTensor]: if weight.ndim > 2: # convs weight = weight.flatten(1,-1) if use_contiguous_mm: @@ -130,7 +129,7 @@ def re_quantize_int_mm(weight: torch.FloatTensor) -> Tuple[torch.Tensor, torch.F @devices.inference_context() -def re_quantize_fp_mm(weight: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> Tuple[torch.Tensor, torch.FloatTensor]: +def re_quantize_fp_mm(weight: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]: if weight.ndim > 2: # convs weight = weight.flatten(1,-1) weight, scale = quantize_fp_mm(weight.contiguous(), dim=-1, matmul_dtype=matmul_dtype) @@ -141,7 +140,7 @@ def re_quantize_fp_mm(weight: torch.FloatTensor, matmul_dtype: str = "float8_e4m @devices.inference_context() -def re_quantize_matmul_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, matmul_dtype: str, result_shape: Optional[torch.Size] = None, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]: +def re_quantize_matmul_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, matmul_dtype: str, result_shape: torch.Size | None = None, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None) -> tuple[torch.Tensor, torch.FloatTensor]: weight = dequantize_asymmetric(weight, scale, zero_point, svd_up=svd_up, svd_down=svd_down, dtype=scale.dtype, result_shape=result_shape) if dtype_dict[matmul_dtype]["is_integer"]: return re_quantize_int_mm(weight) @@ -150,7 +149,7 @@ def re_quantize_matmul_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTe @devices.inference_context() -def re_quantize_matmul_symmetric(weight: torch.CharTensor, scale: torch.FloatTensor, matmul_dtype: str, result_shape: Optional[torch.Size] = None, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]: +def re_quantize_matmul_symmetric(weight: torch.CharTensor, scale: torch.FloatTensor, matmul_dtype: str, result_shape: torch.Size | None = None, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None) -> tuple[torch.Tensor, torch.FloatTensor]: weight = dequantize_symmetric(weight, scale, svd_up=svd_up, svd_down=svd_down, dtype=scale.dtype, result_shape=result_shape) if dtype_dict[matmul_dtype]["is_integer"]: return re_quantize_int_mm(weight) @@ -159,22 +158,22 @@ def re_quantize_matmul_symmetric(weight: torch.CharTensor, scale: torch.FloatTen @devices.inference_context() -def re_quantize_matmul_packed_int_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: torch.Size, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]: +def re_quantize_matmul_packed_int_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: torch.Size, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None) -> tuple[torch.Tensor, torch.FloatTensor]: return re_quantize_matmul_asymmetric(unpack_int_asymetric(weight, shape, weights_dtype), scale, zero_point, matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=result_shape) @devices.inference_context() -def re_quantize_matmul_packed_int_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: Optional[torch.Size] = None, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]: +def re_quantize_matmul_packed_int_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: torch.Size | None = None, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None) -> tuple[torch.Tensor, torch.FloatTensor]: return re_quantize_matmul_symmetric(unpack_int_symetric(weight, shape, weights_dtype, dtype=scale.dtype), scale, matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=result_shape) @devices.inference_context() -def re_quantize_matmul_packed_float_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: torch.Size, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]: +def re_quantize_matmul_packed_float_asymmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, zero_point: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: torch.Size, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None) -> tuple[torch.Tensor, torch.FloatTensor]: return re_quantize_matmul_asymmetric(unpack_float(weight, shape, weights_dtype), scale, zero_point, matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=result_shape) @devices.inference_context() -def re_quantize_matmul_packed_float_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: Optional[torch.Size] = None, svd_up: Optional[torch.FloatTensor] = None, svd_down: Optional[torch.FloatTensor] = None) -> Tuple[torch.Tensor, torch.FloatTensor]: +def re_quantize_matmul_packed_float_symmetric(weight: torch.ByteTensor, scale: torch.FloatTensor, shape: torch.Size, weights_dtype: str, matmul_dtype: str, result_shape: torch.Size | None = None, svd_up: torch.FloatTensor | None = None, svd_down: torch.FloatTensor | None = None) -> tuple[torch.Tensor, torch.FloatTensor]: return re_quantize_matmul_symmetric(unpack_float(weight, shape, weights_dtype), scale, matmul_dtype, svd_up=svd_up, svd_down=svd_down, result_shape=result_shape) @@ -220,7 +219,7 @@ class SDNQDequantizer: result_dtype: torch.dtype result_shape: torch.Size original_shape: torch.Size - original_stride: List[int] + original_stride: list[int] quantized_weight_shape: torch.Size weights_dtype: str quantized_matmul_dtype: str @@ -241,7 +240,7 @@ class SDNQDequantizer: result_dtype: torch.dtype, result_shape: torch.Size, original_shape: torch.Size, - original_stride: List[int], + original_stride: list[int], quantized_weight_shape: torch.Size, weights_dtype: str, quantized_matmul_dtype: str, diff --git a/modules/sdnq/forward.py b/modules/sdnq/forward.py index 9fc99d9f9..2deccf3ac 100644 --- a/modules/sdnq/forward.py +++ b/modules/sdnq/forward.py @@ -1,6 +1,6 @@ # pylint: disable=protected-access -from typing import Callable +from collections.abc import Callable from .common import dtype_dict, conv_types, conv_transpose_types, use_tensorwise_fp8_matmul diff --git a/modules/sdnq/layers/conv/conv_fp16.py b/modules/sdnq/layers/conv/conv_fp16.py index 8b60767cc..a8f1c4460 100644 --- a/modules/sdnq/layers/conv/conv_fp16.py +++ b/modules/sdnq/layers/conv/conv_fp16.py @@ -1,6 +1,5 @@ # pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access -from typing import List import torch @@ -18,10 +17,10 @@ def conv_fp16_matmul( weight: torch.Tensor, scale: torch.FloatTensor, result_shape: torch.Size, - reversed_padding_repeated_twice: List[int], + reversed_padding_repeated_twice: list[int], padding_mode: str, conv_type: int, - groups: int, stride: List[int], - padding: List[int], dilation: List[int], + groups: int, stride: list[int], + padding: list[int], dilation: list[int], bias: torch.FloatTensor = None, svd_up: torch.FloatTensor = None, svd_down: torch.FloatTensor = None, diff --git a/modules/sdnq/layers/conv/conv_fp8.py b/modules/sdnq/layers/conv/conv_fp8.py index 994850fb1..a2b864381 100644 --- a/modules/sdnq/layers/conv/conv_fp8.py +++ b/modules/sdnq/layers/conv/conv_fp8.py @@ -1,6 +1,5 @@ # pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access -from typing import List import torch @@ -17,10 +16,10 @@ def conv_fp8_matmul( weight: torch.Tensor, scale: torch.FloatTensor, result_shape: torch.Size, - reversed_padding_repeated_twice: List[int], + reversed_padding_repeated_twice: list[int], padding_mode: str, conv_type: int, - groups: int, stride: List[int], - padding: List[int], dilation: List[int], + groups: int, stride: list[int], + padding: list[int], dilation: list[int], bias: torch.FloatTensor = None, svd_up: torch.FloatTensor = None, svd_down: torch.FloatTensor = None, diff --git a/modules/sdnq/layers/conv/conv_fp8_tensorwise.py b/modules/sdnq/layers/conv/conv_fp8_tensorwise.py index 9be958923..9fc388873 100644 --- a/modules/sdnq/layers/conv/conv_fp8_tensorwise.py +++ b/modules/sdnq/layers/conv/conv_fp8_tensorwise.py @@ -1,6 +1,5 @@ # pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access -from typing import List import torch @@ -18,10 +17,10 @@ def conv_fp8_matmul_tensorwise( weight: torch.Tensor, scale: torch.FloatTensor, result_shape: torch.Size, - reversed_padding_repeated_twice: List[int], + reversed_padding_repeated_twice: list[int], padding_mode: str, conv_type: int, - groups: int, stride: List[int], - padding: List[int], dilation: List[int], + groups: int, stride: list[int], + padding: list[int], dilation: list[int], bias: torch.FloatTensor = None, svd_up: torch.FloatTensor = None, svd_down: torch.FloatTensor = None, diff --git a/modules/sdnq/layers/conv/conv_int8.py b/modules/sdnq/layers/conv/conv_int8.py index 9777b3d9b..3e28c11ea 100644 --- a/modules/sdnq/layers/conv/conv_int8.py +++ b/modules/sdnq/layers/conv/conv_int8.py @@ -1,6 +1,5 @@ # pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access -from typing import List import torch @@ -18,10 +17,10 @@ def conv_int8_matmul( weight: torch.Tensor, scale: torch.FloatTensor, result_shape: torch.Size, - reversed_padding_repeated_twice: List[int], + reversed_padding_repeated_twice: list[int], padding_mode: str, conv_type: int, - groups: int, stride: List[int], - padding: List[int], dilation: List[int], + groups: int, stride: list[int], + padding: list[int], dilation: list[int], bias: torch.FloatTensor = None, svd_up: torch.FloatTensor = None, svd_down: torch.FloatTensor = None, diff --git a/modules/sdnq/layers/conv/forward.py b/modules/sdnq/layers/conv/forward.py index 2ed3d816f..74454d2d9 100644 --- a/modules/sdnq/layers/conv/forward.py +++ b/modules/sdnq/layers/conv/forward.py @@ -1,6 +1,5 @@ # pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access -from typing import Optional import torch @@ -78,16 +77,16 @@ def quantized_conv_forward(self, input) -> torch.FloatTensor: return self._conv_forward(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias) -def quantized_conv_transpose_1d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor: +def quantized_conv_transpose_1d_forward(self, input: torch.FloatTensor, output_size: list[int] | None = None) -> torch.FloatTensor: output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 1, self.dilation) return torch.nn.functional.conv_transpose1d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) -def quantized_conv_transpose_2d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor: +def quantized_conv_transpose_2d_forward(self, input: torch.FloatTensor, output_size: list[int] | None = None) -> torch.FloatTensor: output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 2, self.dilation) return torch.nn.functional.conv_transpose2d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) -def quantized_conv_transpose_3d_forward(self, input: torch.FloatTensor, output_size: Optional[list[int]] = None) -> torch.FloatTensor: +def quantized_conv_transpose_3d_forward(self, input: torch.FloatTensor, output_size: list[int] | None = None) -> torch.FloatTensor: output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 3, self.dilation) return torch.nn.functional.conv_transpose3d(input, self.sdnq_dequantizer(self.weight, self.scale, self.zero_point, self.svd_up, self.svd_down), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) diff --git a/modules/sdnq/layers/linear/forward.py b/modules/sdnq/layers/linear/forward.py index 7b3a169d9..be51a66ad 100644 --- a/modules/sdnq/layers/linear/forward.py +++ b/modules/sdnq/layers/linear/forward.py @@ -1,13 +1,12 @@ # pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access -from typing import Tuple import torch from ...common import use_contiguous_mm # noqa: TID252 -def check_mats(input: torch.Tensor, weight: torch.Tensor, allow_contiguous_mm: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: +def check_mats(input: torch.Tensor, weight: torch.Tensor, allow_contiguous_mm: bool = True) -> tuple[torch.Tensor, torch.Tensor]: input = input.contiguous() if allow_contiguous_mm and use_contiguous_mm: weight = weight.contiguous() diff --git a/modules/sdnq/layers/linear/linear_fp8.py b/modules/sdnq/layers/linear/linear_fp8.py index 169d318f9..80bf64b0e 100644 --- a/modules/sdnq/layers/linear/linear_fp8.py +++ b/modules/sdnq/layers/linear/linear_fp8.py @@ -1,6 +1,5 @@ # pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access -from typing import Tuple import torch @@ -11,7 +10,7 @@ from ...dequantizer import quantize_fp_mm # noqa: TID252 from .forward import check_mats -def quantize_fp_mm_input(input: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> Tuple[torch.Tensor, torch.FloatTensor]: +def quantize_fp_mm_input(input: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]: input = input.flatten(0,-2).to(dtype=torch.float32) input, input_scale = quantize_fp_mm(input, dim=-1, matmul_dtype=matmul_dtype) return input, input_scale diff --git a/modules/sdnq/layers/linear/linear_fp8_tensorwise.py b/modules/sdnq/layers/linear/linear_fp8_tensorwise.py index a5ea71c55..8b4954c35 100644 --- a/modules/sdnq/layers/linear/linear_fp8_tensorwise.py +++ b/modules/sdnq/layers/linear/linear_fp8_tensorwise.py @@ -1,6 +1,5 @@ # pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access -from typing import Tuple import torch @@ -11,7 +10,7 @@ from ...dequantizer import quantize_fp_mm, dequantize_symmetric, dequantize_symm from .forward import check_mats -def quantize_fp_mm_input_tensorwise(input: torch.FloatTensor, scale: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> Tuple[torch.Tensor, torch.FloatTensor]: +def quantize_fp_mm_input_tensorwise(input: torch.FloatTensor, scale: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]: input = input.flatten(0,-2).to(dtype=scale.dtype) input, input_scale = quantize_fp_mm(input, dim=-1, matmul_dtype=matmul_dtype) scale = torch.mul(input_scale, scale) diff --git a/modules/sdnq/layers/linear/linear_int8.py b/modules/sdnq/layers/linear/linear_int8.py index 2d26a6086..2a1213cb8 100644 --- a/modules/sdnq/layers/linear/linear_int8.py +++ b/modules/sdnq/layers/linear/linear_int8.py @@ -1,6 +1,5 @@ # pylint: disable=relative-beyond-top-level,redefined-builtin,protected-access -from typing import Tuple import torch @@ -11,7 +10,7 @@ from ...dequantizer import quantize_int_mm, dequantize_symmetric, dequantize_sym from .forward import check_mats -def quantize_int_mm_input(input: torch.FloatTensor, scale: torch.FloatTensor) -> Tuple[torch.CharTensor, torch.FloatTensor]: +def quantize_int_mm_input(input: torch.FloatTensor, scale: torch.FloatTensor) -> tuple[torch.CharTensor, torch.FloatTensor]: input = input.flatten(0,-2).to(dtype=scale.dtype) input, input_scale = quantize_int_mm(input, dim=-1) scale = torch.mul(input_scale, scale) diff --git a/modules/sdnq/loader.py b/modules/sdnq/loader.py index 91be08394..789fd6e96 100644 --- a/modules/sdnq/loader.py +++ b/modules/sdnq/loader.py @@ -72,14 +72,14 @@ def load_sdnq_model(model_path: str, model_cls: ModelMixin = None, file_name: st if model_config is None: if os.path.exists(model_config_path): - with open(model_config_path, "r", encoding="utf-8") as f: + with open(model_config_path, encoding="utf-8") as f: model_config = json.load(f) else: model_config = {} if quantization_config is None: if os.path.exists(quantization_config_path): - with open(quantization_config_path, "r", encoding="utf-8") as f: + with open(quantization_config_path, encoding="utf-8") as f: quantization_config = json.load(f) else: quantization_config = model_config.get("quantization_config", None) diff --git a/modules/sdnq/packed_int.py b/modules/sdnq/packed_int.py index 09a38efbc..0cee35309 100644 --- a/modules/sdnq/packed_int.py +++ b/modules/sdnq/packed_int.py @@ -1,6 +1,5 @@ # pylint: disable=redefined-builtin,no-member,protected-access -from typing import Optional import torch @@ -15,7 +14,7 @@ def pack_int_asymetric(tensor: torch.CharTensor, weights_dtype: str) -> torch.By return packed_int_function_dict[weights_dtype]["pack"](tensor.to(dtype=dtype_dict[weights_dtype]["storage_dtype"])) -def unpack_int_symetric(packed_tensor: torch.ByteTensor, shape: torch.Size, weights_dtype: str, dtype: Optional[torch.dtype] = None) -> torch.CharTensor: +def unpack_int_symetric(packed_tensor: torch.ByteTensor, shape: torch.Size, weights_dtype: str, dtype: torch.dtype | None = None) -> torch.CharTensor: if dtype is None: dtype = dtype_dict[weights_dtype]["torch_dtype"] return packed_int_function_dict[weights_dtype]["unpack"](packed_tensor, shape).to(dtype=dtype).add_(dtype_dict[weights_dtype]["min"]) diff --git a/modules/sdnq/quantizer.py b/modules/sdnq/quantizer.py index a88da98cc..7035adc88 100644 --- a/modules/sdnq/quantizer.py +++ b/modules/sdnq/quantizer.py @@ -1,6 +1,6 @@ # pylint: disable=redefined-builtin,no-member,protected-access -from typing import Dict, List, Tuple, Optional, Union +from typing import Union from dataclasses import dataclass from enum import Enum @@ -29,7 +29,7 @@ class QuantizationMethod(str, Enum): @devices.inference_context() -def get_scale_asymmetric(weight: torch.FloatTensor, reduction_axes: Union[int, List[int]], weights_dtype: str) -> Tuple[torch.FloatTensor, torch.FloatTensor]: +def get_scale_asymmetric(weight: torch.FloatTensor, reduction_axes: int | list[int], weights_dtype: str) -> tuple[torch.FloatTensor, torch.FloatTensor]: zero_point = torch.amin(weight, dim=reduction_axes, keepdims=True) scale = torch.amax(weight, dim=reduction_axes, keepdims=True).sub_(zero_point).div_(dtype_dict[weights_dtype]["max"] - dtype_dict[weights_dtype]["min"]) if dtype_dict[weights_dtype]["min"] != 0: @@ -38,12 +38,12 @@ def get_scale_asymmetric(weight: torch.FloatTensor, reduction_axes: Union[int, L @devices.inference_context() -def get_scale_symmetric(weight: torch.FloatTensor, reduction_axes: Union[int, List[int]], weights_dtype: str) -> torch.FloatTensor: +def get_scale_symmetric(weight: torch.FloatTensor, reduction_axes: int | list[int], weights_dtype: str) -> torch.FloatTensor: return torch.amax(weight.abs(), dim=reduction_axes, keepdims=True).div_(dtype_dict[weights_dtype]["max"]) @devices.inference_context() -def quantize_weight(weight: torch.FloatTensor, reduction_axes: Union[int, List[int]], weights_dtype: str, dtype: torch.dtype = None, use_stochastic_rounding: bool = False) -> Tuple[torch.Tensor, torch.FloatTensor, torch.FloatTensor]: +def quantize_weight(weight: torch.FloatTensor, reduction_axes: int | list[int], weights_dtype: str, dtype: torch.dtype = None, use_stochastic_rounding: bool = False) -> tuple[torch.Tensor, torch.FloatTensor, torch.FloatTensor]: weight = weight.to(dtype=torch.float32) if dtype_dict[weights_dtype]["is_unsigned"]: @@ -73,7 +73,7 @@ def quantize_weight(weight: torch.FloatTensor, reduction_axes: Union[int, List[i @devices.inference_context() -def apply_svdquant(weight: torch.FloatTensor, rank: int = 32, niter: int = 8, dtype: torch.dtype = None) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: +def apply_svdquant(weight: torch.FloatTensor, rank: int = 32, niter: int = 8, dtype: torch.dtype = None) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: reshape_weight = False if weight.ndim > 2: # convs reshape_weight = True @@ -102,7 +102,7 @@ def prepare_weight_for_matmul(weight: torch.Tensor) -> torch.Tensor: @devices.inference_context() -def prepare_svd_for_matmul(svd_up: torch.FloatTensor, svd_down: torch.FloatTensor, use_quantized_matmul: bool) -> Tuple[torch.FloatTensor, torch.FloatTensor]: +def prepare_svd_for_matmul(svd_up: torch.FloatTensor, svd_down: torch.FloatTensor, use_quantized_matmul: bool) -> tuple[torch.FloatTensor, torch.FloatTensor]: if svd_up is not None: if use_quantized_matmul: svd_up = prepare_weight_for_matmul(svd_up) @@ -113,7 +113,7 @@ def prepare_svd_for_matmul(svd_up: torch.FloatTensor, svd_down: torch.FloatTenso return svd_up, svd_down -def check_param_name_in(param_name: str, param_list: List[str]) -> str: +def check_param_name_in(param_name: str, param_list: list[str]) -> str: split_param_name = param_name.split(".") for param in param_list: if param.startswith("."): @@ -153,7 +153,7 @@ def get_quant_args_from_config(quantization_config: Union["SDNQConfig", dict]) - return quantization_config_dict -def get_minimum_dtype(weights_dtype: str, param_name: str, modules_dtype_dict: Dict[str, List[str]]): +def get_minimum_dtype(weights_dtype: str, param_name: str, modules_dtype_dict: dict[str, list[str]]): if len(modules_dtype_dict.keys()) > 0: for key, value in modules_dtype_dict.items(): if check_param_name_in(param_name, value) is not None: @@ -180,7 +180,7 @@ def get_minimum_dtype(weights_dtype: str, param_name: str, modules_dtype_dict: D return weights_dtype -def get_quant_kwargs(quant_kwargs: dict, modules_quant_config: Dict[str, dict]) -> dict: +def get_quant_kwargs(quant_kwargs: dict, modules_quant_config: dict[str, dict]) -> dict: param_key = check_param_name_in(quant_kwargs["param_name"], modules_quant_config.keys()) if param_key is not None: for key, value in modules_quant_config[param_key].items(): @@ -189,7 +189,7 @@ def get_quant_kwargs(quant_kwargs: dict, modules_quant_config: Dict[str, dict]) return quant_kwargs -def add_module_skip_keys(model, modules_to_not_convert: List[str] = None, modules_dtype_dict: Dict[str, List[str]] = None): +def add_module_skip_keys(model, modules_to_not_convert: list[str] = None, modules_dtype_dict: dict[str, list[str]] = None): if modules_to_not_convert is None: modules_to_not_convert = [] if modules_dtype_dict is None: @@ -552,7 +552,7 @@ def sdnq_quantize_layer(layer, weights_dtype="int8", quantized_matmul_dtype=None @devices.inference_context() -def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=False, non_blocking=False, modules_to_not_convert: List[str] = None, modules_dtype_dict: Dict[str, List[str]] = None, modules_quant_config: Dict[str, dict] = None, quantization_device=None, return_device=None, full_param_name=""): # pylint: disable=unused-argument +def apply_sdnq_to_module(model, weights_dtype="int8", quantized_matmul_dtype=None, torch_dtype=None, group_size=0, svd_rank=32, svd_steps=8, dynamic_loss_threshold=1e-2, use_svd=False, quant_conv=False, use_quantized_matmul=False, use_quantized_matmul_conv=False, use_dynamic_quantization=False, use_stochastic_rounding=False, dequantize_fp32=False, non_blocking=False, modules_to_not_convert: list[str] = None, modules_dtype_dict: dict[str, list[str]] = None, modules_quant_config: dict[str, dict] = None, quantization_device=None, return_device=None, full_param_name=""): # pylint: disable=unused-argument has_children = list(model.children()) if not has_children: return model, modules_to_not_convert, modules_dtype_dict @@ -648,11 +648,11 @@ def sdnq_post_load_quant( dequantize_fp32: bool = False, non_blocking: bool = False, add_skip_keys:bool = True, - quantization_device: Optional[torch.device] = None, - return_device: Optional[torch.device] = None, - modules_to_not_convert: Optional[List[str]] = None, - modules_dtype_dict: Optional[Dict[str, List[str]]] = None, - modules_quant_config: Optional[Dict[str, dict]] = None, + quantization_device: torch.device | None = None, + return_device: torch.device | None = None, + modules_to_not_convert: list[str] | None = None, + modules_dtype_dict: dict[str, list[str]] | None = None, + modules_quant_config: dict[str, dict] | None = None, ): if modules_to_not_convert is None: modules_to_not_convert = [] @@ -733,7 +733,7 @@ def sdnq_post_load_quant( return model -class SDNQQuantize(): +class SDNQQuantize: def __init__(self, hf_quantizer): self.hf_quantizer = hf_quantizer @@ -887,7 +887,7 @@ class SDNQQuantizer(DiffusersQuantizer, HfQuantizer): def get_quantize_ops(self): return SDNQQuantize(self) - def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]: max_memory = {key: val * 0.80 for key, val in max_memory.items()} return max_memory @@ -908,7 +908,7 @@ class SDNQQuantizer(DiffusersQuantizer, HfQuantizer): self, model, device_map, # pylint: disable=unused-argument - keep_in_fp32_modules: List[str] = None, + keep_in_fp32_modules: list[str] = None, **kwargs, # pylint: disable=unused-argument ): if self.pre_quantized: @@ -1067,11 +1067,11 @@ class SDNQConfig(QuantizationConfigMixin): dequantize_fp32: bool = False, non_blocking: bool = False, add_skip_keys: bool = True, - quantization_device: Optional[torch.device] = None, - return_device: Optional[torch.device] = None, - modules_to_not_convert: Optional[List[str]] = None, - modules_dtype_dict: Optional[Dict[str, List[str]]] = None, - modules_quant_config: Optional[Dict[str, dict]] = None, + quantization_device: torch.device | None = None, + return_device: torch.device | None = None, + modules_to_not_convert: list[str] | None = None, + modules_dtype_dict: dict[str, list[str]] | None = None, + modules_quant_config: dict[str, dict] | None = None, is_training: bool = False, **kwargs, # pylint: disable=unused-argument ): diff --git a/modules/server.py b/modules/server.py index 4b757a1d3..8f1229a73 100644 --- a/modules/server.py +++ b/modules/server.py @@ -41,9 +41,8 @@ class UvicornServer(uvicorn.Server): self.start() -class HypercornServer(): +class HypercornServer: def __init__(self, app: fastapi.FastAPI, listen = None, port = None, keyfile = None, certfile = None, loop = "auto", http = None): - import asyncio import hypercorn self.app: fastapi.FastAPI = app self.server: HypercornServer = None diff --git a/modules/shared.py b/modules/shared.py index c48b46a03..fc8b71a83 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -8,17 +8,16 @@ import contextlib from enum import Enum from typing import TYPE_CHECKING import gradio as gr -from installer import log, print_dict, console, get_version # pylint: disable=unused-import +from installer import log, print_dict # pylint: disable=unused-import log.debug('Initializing: shared module') import modules.memmon import modules.paths as paths -from modules.json_helpers import readfile, writefile # pylint: disable=W0611 -from modules.shared_helpers import listdir, walk_files, html_path, html, req, total_tqdm # pylint: disable=W0611 +from modules.json_helpers import readfile # pylint: disable=W0611 +from modules.shared_helpers import listdir, req # pylint: disable=W0611 from modules import errors, devices, shared_state, cmd_args, theme, history, files_cache from modules.shared_defaults import get_default_modes -from modules.paths import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611 -from modules.memstats import memory_stats, ram_stats # pylint: disable=unused-import +from modules.memstats import memory_stats # pylint: disable=unused-import log.debug('Initializing: pipelines') from modules import shared_items @@ -74,6 +73,9 @@ sdnq_quant_modes = ["int8", "int7", "int6", "uint5", "uint4", "uint3", "uint2", sdnq_matmul_modes = ["auto", "int8", "float8_e4m3fn", "float16"] default_hfcache_dir = os.environ.get("SD_HFCACHEDIR", None) or os.path.join(paths.models_path, 'huggingface') state = shared_state.State() +models_path = paths.models_path +script_path = paths.script_path +data_path = paths.data_path # early select backend @@ -120,6 +122,7 @@ def list_checkpoint_titles(): list_checkpoint_tiles = list_checkpoint_titles # alias for legacy typo +default_sd_model_file = paths.default_sd_model_file default_checkpoint = list_checkpoint_titles()[0] if len(list_checkpoint_titles()) > 0 else "model.safetensors" @@ -862,7 +865,6 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", devices.device) history = history.History() if devices.backend == "directml": directml_do_hijack() -from modules import sdnq # pylint: disable=unused-import # register to diffusers and transformers log.debug('Quantization: registered=SDNQ') try: diff --git a/modules/shared_state.py b/modules/shared_state.py index ebc412265..f7a086849 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -148,7 +148,9 @@ class State: return job return None - def history(self, op:str, task_id:str=None, results:list=[]): + def history(self, op:str, task_id:str=None, results:list=None): + if results is None: + results = [] job = { 'id': task_id or self.id, 'job': self.job.lower(), diff --git a/modules/styles.py b/modules/styles.py index 072b03390..bb555e616 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -5,14 +5,13 @@ import csv import json import time import random -from typing import Dict from modules import files_cache, shared, infotext, sd_models, sd_vae debug_enabled = os.environ.get('SD_STYLES_DEBUG', None) is not None -class Style(): +class Style: def __init__(self, name: str, desc: str = "", prompt: str = "", negative_prompt: str = "", extra: str = "", wildcards: str = "", filename: str = "", preview: str = "", mtime: float = 0): self.name = name self.description = desc @@ -50,7 +49,7 @@ def select_from_weighted_list(inner: str) -> str: return '' parts = [p.strip() for p in inner.split('|') if p.strip()] - weighted: Dict[str, float] = {} + weighted: dict[str, float] = {} unweighted = [] for p in parts: @@ -102,7 +101,7 @@ def select_from_weighted_list(inner: str) -> str: if total <= 0.0: return items[0][0] - names, weights = zip(*items) + names, weights = zip(*items, strict=False) return random.choices(names, weights=weights, k=1)[0] @@ -130,7 +129,11 @@ def apply_curly_braces_to_prompt(prompt, seed=-1): return prompt -def apply_file_wildcards(prompt, replaced = [], not_found = [], recursion=0, seed=-1): +def apply_file_wildcards(prompt, replaced = None, not_found = None, recursion=0, seed=-1): + if not_found is None: + not_found = [] + if replaced is None: + replaced = [] def check_wildcard_files(prompt, wildcard, files, file_only=True): trimmed = wildcard.replace('\\', os.path.sep).replace('/', os.path.sep).strip().lower() for file in files: @@ -141,7 +144,7 @@ def apply_file_wildcards(prompt, replaced = [], not_found = [], recursion=0, see paths.insert(0, os.path.splitext(file)[0].lower()) if (trimmed in paths) or (os.path.sep in trimmed and trimmed in paths[0]): try: - with open(file, 'r', encoding='utf-8') as f: + with open(file, encoding='utf-8') as f: lines = f.readlines() lines = [line.split('#')[0].strip('\n').strip() for line in lines] lines = [line for line in lines if len(line) > 0] @@ -317,7 +320,7 @@ class StyleDatabase: pass def load_style(self, fn, prefix=None): - with open(fn, 'r', encoding='utf-8') as f: + with open(fn, encoding='utf-8') as f: new_style = None try: all_styles = json.load(f) @@ -508,7 +511,7 @@ class StyleDatabase: def load_csv(self, legacy_file): if not os.path.isfile(legacy_file): return - with open(legacy_file, "r", encoding="utf-8-sig", newline='') as file: + with open(legacy_file, encoding="utf-8-sig", newline='') as file: reader = csv.DictReader(file, skipinitialspace=True) num = 0 for row in reader: diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index f7302ca8e..3b233cfcf 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -12,7 +12,7 @@ from functools import partial import math -from typing import Optional, NamedTuple, List +from typing import NamedTuple import torch from torch import Tensor from torch.utils.checkpoint import checkpoint @@ -97,10 +97,10 @@ def _query_chunk_attention( ) return summarize_chunk(query, key_chunk, value_chunk) - chunks: List[AttnChunk] = [ + chunks: list[AttnChunk] = [ chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) ] - acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) + acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks, strict=False))) chunk_values, chunk_weights, chunk_max = acc_chunk global_max, _ = torch.max(chunk_max, 0, keepdim=True) @@ -142,8 +142,8 @@ def efficient_dot_product_attention( key: Tensor, value: Tensor, query_chunk_size=1024, - kv_chunk_size: Optional[int] = None, - kv_chunk_size_min: Optional[int] = None, + kv_chunk_size: int | None = None, + kv_chunk_size_min: int | None = None, use_checkpoint=True, ): """Computes efficient dot-product attention given query, key, and value. diff --git a/modules/taesd/hybrid_small.py b/modules/taesd/hybrid_small.py index a59b0b4d7..8ca1135ab 100644 --- a/modules/taesd/hybrid_small.py +++ b/modules/taesd/hybrid_small.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -74,19 +73,19 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin): self, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str] = ("UpDecoderBlock2D",), - block_out_channels: Tuple[int] = (64,), - encoder_block_out_channels: Tuple[int] = None, - decoder_block_out_channels: Tuple[int] = None, + down_block_types: tuple[str] = ("DownEncoderBlock2D",), + up_block_types: tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: tuple[int] = (64,), + encoder_block_out_channels: tuple[int] = None, + decoder_block_out_channels: tuple[int] = None, layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, - latents_mean: Optional[Tuple[float]] = None, - latents_std: Optional[Tuple[float]] = None, + latents_mean: tuple[float] | None = None, + latents_std: tuple[float] | None = None, force_upcast: float = True, ): super().__init__() @@ -177,7 +176,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin): @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: + def attn_processors(self) -> dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with @@ -186,7 +185,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin): # set recursively processors = {} - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) @@ -201,7 +200,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin): return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]): r""" Sets the attention processor to use to compute attention. @@ -254,7 +253,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin): @apply_forward_hook def encode( self, x: torch.FloatTensor, return_dict: bool = True - ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: """ Encode a batch of images into latents. @@ -284,7 +283,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin): return AutoencoderKLOutput(latent_dist=posterior) - def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> DecoderOutput | torch.FloatTensor: if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): return self.tiled_decode(z, return_dict=return_dict) @@ -299,7 +298,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin): @apply_forward_hook def decode( self, z: torch.FloatTensor, return_dict: bool = True, generator=None - ) -> Union[DecoderOutput, torch.FloatTensor]: + ) -> DecoderOutput | torch.FloatTensor: """ Decode a batch of images. @@ -391,7 +390,7 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin): return AutoencoderKLOutput(latent_dist=posterior) - def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> DecoderOutput | torch.FloatTensor: r""" Decode a batch of images using a tiled decoder. @@ -444,8 +443,8 @@ class AutoencoderSmall(ModelMixin, ConfigMixin, FromOriginalModelMixin): sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: + generator: torch.Generator | None = None, + ) -> DecoderOutput | torch.FloatTensor: r""" Args: sample (`torch.FloatTensor`): Input sample. diff --git a/modules/textual_inversion.py b/modules/textual_inversion.py index 064d7d214..d01118e64 100644 --- a/modules/textual_inversion.py +++ b/modules/textual_inversion.py @@ -1,4 +1,3 @@ -from typing import List, Union import os import time import torch @@ -83,7 +82,7 @@ def get_text_encoders(): text_encoders = [] tokenizers = [] hidden_sizes = [] - for te, tok in zip(te_names, tokenizers_names): + for te, tok in zip(te_names, tokenizers_names, strict=False): text_encoder = getattr(pipe, te, None) if text_encoder is None: continue @@ -135,14 +134,14 @@ def insert_vectors(embedding, tokenizers, text_encoders, hiddensizes): this may cause collisions. """ with devices.inference_context(): - for vector, size in zip(embedding.vec, embedding.vector_sizes): + for vector, size in zip(embedding.vec, embedding.vector_sizes, strict=False): if size not in hiddensizes: continue idx = hiddensizes.index(size) unk_token_id = tokenizers[idx].convert_tokens_to_ids(tokenizers[idx].unk_token) if text_encoders[idx].get_input_embeddings().weight.data.shape[0] != len(tokenizers[idx]): text_encoders[idx].resize_token_embeddings(len(tokenizers[idx])) - for token, v in zip(embedding.tokens, vector.unbind()): + for token, v in zip(embedding.tokens, vector.unbind(), strict=False): token_id = tokenizers[idx].convert_tokens_to_ids(token) if token_id > unk_token_id: text_encoders[idx].get_input_embeddings().weight.data[token_id] = v @@ -254,7 +253,7 @@ class EmbeddingDatabase: self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True) return embedding - def load_diffusers_embedding(self, filename: Union[str, List[str]] = None, data: dict = None): + def load_diffusers_embedding(self, filename: str | list[str] = None, data: dict = None): """ File names take precidence over bundled embeddings passed as a dict. Bundled embeddings are automatically set to overwrite previous embeddings. diff --git a/modules/theme.py b/modules/theme.py index da6fa562e..0b384ac20 100644 --- a/modules/theme.py +++ b/modules/theme.py @@ -18,7 +18,7 @@ def refresh_themes(no_update=False): res = [] if os.path.exists(themes_file): try: - with open(themes_file, 'r', encoding='utf8') as f: + with open(themes_file, encoding='utf8') as f: res = json.load(f) except Exception: modules.shared.log.error('Exception loading UI themes') diff --git a/modules/todo/todo_merge.py b/modules/todo/todo_merge.py index 77840d6ed..cde8381fe 100644 --- a/modules/todo/todo_merge.py +++ b/modules/todo/todo_merge.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Callable +from collections.abc import Callable import math import torch import torch.nn.functional as F @@ -136,7 +136,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, sy: int, r: int, no_rand: bool = False, - generator: torch.Generator = None) -> Tuple[Callable, Callable]: + generator: torch.Generator = None) -> tuple[Callable, Callable]: """ Partitions the tokens into src and dst and merges r tokens from src to dst. Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. @@ -305,9 +305,9 @@ class TokenMergeAttentionProcessor: self, attn: Attention, hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: torch.FloatTensor | None = None, + attention_mask: torch.FloatTensor | None = None, + temb: torch.FloatTensor | None = None, scale: float = 1.0, ) -> torch.FloatTensor: residual = hidden_states diff --git a/modules/todo/todo_utils.py b/modules/todo/todo_utils.py index 34a24bb82..0077b5558 100644 --- a/modules/todo/todo_utils.py +++ b/modules/todo/todo_utils.py @@ -29,7 +29,9 @@ def remove_tome_patch(pipe: torch.nn.Module): if hasattr(m, "processor"): m.processor = AttnProcessor2_0() -def patch_attention_proc(unet, token_merge_args={}): +def patch_attention_proc(unet, token_merge_args=None): + if token_merge_args is None: + token_merge_args = {} unet._tome_info = { # pylint: disable=protected-access "size": None, "timestep": None, diff --git a/modules/ui.py b/modules/ui.py index 3fe74b32d..b1c0033b6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -2,7 +2,6 @@ import gradio as gr import gradio.routes import gradio.utils from modules import errors, timer, gr_hijack, shared, script_callbacks, ui_common, ui_symbols, ui_javascript, ui_sections, generation_parameters_copypaste, call_queue, scripts_manager -from modules.paths import script_path, data_path # pylint: disable=unused-import from modules.api import mime diff --git a/modules/ui_common.py b/modules/ui_common.py index 89dd2038a..740d44385 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -38,7 +38,9 @@ def update_generation_info(generation_info, html_info, img_index): return html_info, html_info -def plaintext_to_html(text, elem_classes=[]): +def plaintext_to_html(text, elem_classes=None): + if elem_classes is None: + elem_classes = [] res = f'

' + '
\n'.join([f"{html.escape(x)}" for x in text.split('\n')]) + '

' return res diff --git a/modules/ui_control.py b/modules/ui_control.py index 1049b5f54..2c1c74748 100644 --- a/modules/ui_control.py +++ b/modules/ui_control.py @@ -73,7 +73,7 @@ def return_controls(res, t: float = None): def get_units(*values): update = [] what = None - for c, v in zip(controls, values): + for c, v in zip(controls, values, strict=False): if isinstance(c, gr.Label): # unit type indicator what = c.value['label'] c.value = v diff --git a/modules/ui_docs.py b/modules/ui_docs.py index 08306e5ac..a8e824b30 100644 --- a/modules/ui_docs.py +++ b/modules/ui_docs.py @@ -5,7 +5,7 @@ from modules import ui_symbols, ui_components from installer import install, log -class Page(): +class Page: def __init__(self, fn, full: bool = True): self.fn = fn self.title = '' @@ -21,7 +21,7 @@ class Page(): try: self.title = ' ' + os.path.basename(self.fn).replace('.md', '').replace('-', ' ') + ' ' self.mtime = time.localtime(os.path.getmtime(self.fn)) - with open(self.fn, 'r', encoding='utf-8') as f: + with open(self.fn, encoding='utf-8') as f: content = f.read() self.size = len(content) self.lines = [line.strip().lower() + ' ' for line in content.splitlines() if len(line)>1] @@ -80,7 +80,7 @@ class Page(): log.error(f'Search docs: page="{self.fn}" does not exist') return f'page="{self.fn}" does not exist' try: - with open(self.fn, 'r', encoding='utf-8') as f: + with open(self.fn, encoding='utf-8') as f: content = f.read() return content except Exception as e: @@ -91,7 +91,7 @@ class Page(): return f'Page(title="{self.title.strip()}" fn="{self.fn}" mtime={self.mtime} h1={[h.strip() for h in self.h1]} h2={len(self.h2)} h3={len(self.h3)} lines={len(self.lines)} size={self.size})' -class Pages(): +class Pages: def __init__(self): self.time = time.time() self.size = 0 @@ -117,7 +117,7 @@ class Pages(): text = text.lower() scores = [page.search(text) for page in self.pages] mtimes = [page.mtime for page in self.pages] - found = sorted(zip(scores, mtimes, self.pages), key=lambda x: (x[0], x[1]), reverse=True) + found = sorted(zip(scores, mtimes, self.pages, strict=False), key=lambda x: (x[0], x[1]), reverse=True) found = [item for item in found if item[0] > 0] return [(item[0], item[2]) for item in found][:topk] except Exception as e: @@ -177,7 +177,7 @@ def search_docs(search_term): def get_github_page(page): try: - with open(os.path.join('wiki', f'{page}.md'), 'r', encoding='utf-8') as f: + with open(os.path.join('wiki', f'{page}.md'), encoding='utf-8') as f: content = f.read() log.debug(f'Search wiki: page="{page}" size={len(content)}') except Exception as e: @@ -230,7 +230,7 @@ def search_github(search_term): def create_ui_logs(): def get_changelog(): - with open('CHANGELOG.md', 'r', encoding='utf-8') as f: + with open('CHANGELOG.md', encoding='utf-8') as f: content = f.read() content = content.replace('# Change Log for SD.Next', ' ') return content diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 68204650e..063580f51 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -391,7 +391,7 @@ class ExtraNetworksPage: r = random.randint(100, 255) g = random.randint(100, 255) b = random.randint(100, 255) - return '#{:02x}{:02x}{:02x}'.format(r, g, b) # pylint: disable=consider-using-f-string + return f'#{r:02x}{g:02x}{b:02x}' # pylint: disable=consider-using-f-string try: onclick = f'cardClicked({item.get("prompt", None)})' @@ -515,7 +515,7 @@ class ExtraNetworksPage: fn = os.path.splitext(path)[0] + '.txt' if os.path.exists(fn): try: - with open(fn, "r", encoding="utf-8", errors="replace") as f: + with open(fn, encoding="utf-8", errors="replace") as f: txt = f.read() txt = re.sub('[<>]', '', txt) return txt @@ -588,7 +588,6 @@ def register_pages(): if shared.opts.diffusers_enable_embed: from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion register_page(ExtraNetworksPageTextualInversion()) - from modules.video_models.models_def import models # pylint: disable=unused-import def get_pages(title=None): @@ -1044,7 +1043,7 @@ def create_ui(container, button_parent, tabname, skip_indexing = False): params, text = get_last_args() if (not params) or (not text) or (len(text) == 0): if os.path.exists(paths.params_path): - with open(paths.params_path, "r", encoding="utf8") as file: + with open(paths.params_path, encoding="utf8") as file: text = file.read() else: text = '' @@ -1062,7 +1061,7 @@ def create_ui(container, button_parent, tabname, skip_indexing = False): params, text = get_last_args() if (not params) or (not text) or (len(text) == 0): if os.path.exists(paths.params_path): - with open(paths.params_path, "r", encoding="utf8") as file: + with open(paths.params_path, encoding="utf8") as file: text = file.read() else: text = '' diff --git a/modules/ui_img2img.py b/modules/ui_img2img.py index 5e651de91..7828bc8ab 100644 --- a/modules/ui_img2img.py +++ b/modules/ui_img2img.py @@ -57,7 +57,7 @@ def create_ui(): def add_copy_image_controls(tab_name, elem): with gr.Row(variant="compact", elem_id=f"img2img_copy_{tab_name}_row"): - for title, name in zip(['➠ Image', '➠ Inpaint', '➠ Sketch', '➠ Composite'], ['img2img', 'inpaint', 'sketch', 'composite']): + for title, name in zip(['➠ Image', '➠ Inpaint', '➠ Sketch', '➠ Composite'], ['img2img', 'inpaint', 'sketch', 'composite'], strict=False): if name == tab_name: gr.Button(title, elem_id=f'{tab_name}_copy_to_{name}', interactive=False) copy_image_destinations[name] = elem diff --git a/modules/ui_javascript.py b/modules/ui_javascript.py index dcbf14731..5c5f966a8 100644 --- a/modules/ui_javascript.py +++ b/modules/ui_javascript.py @@ -55,7 +55,7 @@ def html_body(): def html_login(): fn = os.path.join(script_path, "javascript", "login.js") - with open(fn, 'r', encoding='utf8') as f: + with open(fn, encoding='utf8') as f: inline = f.read() js = f'\n' return js @@ -110,11 +110,11 @@ def reload_javascript(): def template_response(*args, **kwargs): res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace(b'', f'{title}'.encode("utf8")) - res.body = res.body.replace(b'', f'{manifest}'.encode("utf8")) - res.body = res.body.replace(b'', f'{login}'.encode("utf8")) - res.body = res.body.replace(b'', f'{js}'.encode("utf8")) - res.body = res.body.replace(b'', f'{css}{body}'.encode("utf8")) + res.body = res.body.replace(b'', f'{title}'.encode()) + res.body = res.body.replace(b'', f'{manifest}'.encode()) + res.body = res.body.replace(b'', f'{login}'.encode()) + res.body = res.body.replace(b'', f'{js}'.encode()) + res.body = res.body.replace(b'', f'{css}{body}'.encode()) lines = res.body.decode("utf8").split('\n') for line in lines: if 'meta name="twitter:' in line: diff --git a/modules/ui_models.py b/modules/ui_models.py index def5ec3c8..34b9b9798 100644 --- a/modules/ui_models.py +++ b/modules/ui_models.py @@ -346,7 +346,7 @@ def create_ui(): preset = interpolate(presets, ratio) else: preset = presets[0] - preset = ['%.3f' % x if int(x) != x else str(x) for x in preset] # pylint: disable=consider-using-f-string + preset = [f'{x:.3f}' if int(x) != x else str(x) for x in preset] # pylint: disable=consider-using-f-string preset = [preset[0], ",".join(preset[1:13]), preset[13], ",".join(preset[14:])] return [gr.update(value=x) for x in preset] + [gr.update(selected=2)] @@ -498,7 +498,7 @@ def create_ui(): def civitai_download(model_urls, model_names, model_types, model_path, civit_token, model_output): from modules.civitai.download_civitai import download_civit_model - for model_url, model_name, model_type in zip(model_urls, model_names, model_types): + for model_url, model_name, model_type in zip(model_urls, model_names, model_types, strict=False): msg = f"

Initiating download

{model_name} | {model_type} | {model_url}

" yield msg + model_output download_civit_model(model_url, model_name, model_path, model_type, civit_token) diff --git a/modules/ui_models_load.py b/modules/ui_models_load.py index 58e0a0f58..d63a63d20 100644 --- a/modules/ui_models_load.py +++ b/modules/ui_models_load.py @@ -1,6 +1,5 @@ import os import re -import json # pylint: disable=unused-import import inspect import gradio as gr import torch @@ -101,7 +100,7 @@ def process_huggingface_url(url): return repo, subfolder, fn, download -class Component(): +class Component: def __init__(self, signature, name=None, cls=None, val=None, local=None, remote=None, typ=None, dtype=None, quant=False, loadable=None): self.id = len(components) + 1 self.name = signature.name if signature else name diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 345077ec9..247325b3f 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -119,7 +119,7 @@ def create_dirty_indicator(key, keys_to_reset, **kwargs): def run_settings(*args): changed = [] - for key, value, comp in zip(shared.opts.data_labels.keys(), args, components): + for key, value, comp in zip(shared.opts.data_labels.keys(), args, components, strict=False): if comp == dummy_component or value=='dummy': # or getattr(comp, 'visible', True) is False or key in hidden_list: # actual = shared.opts.data.get(key, None) # ensure the key is in data # default = shared.opts.data_labels[key].default @@ -173,7 +173,9 @@ def run_settings_single(value, key, progress=False): return get_value_for_setting(key), shared.opts.dumpjson() -def create_ui(disabled_tabs=[]): +def create_ui(disabled_tabs=None): + if disabled_tabs is None: + disabled_tabs = [] shared.log.debug('UI initialize: tab=settings') global text_settings # pylint: disable=global-statement text_settings = gr.Textbox(elem_id="settings_json", elem_classes=["settings_json"], value=lambda: shared.opts.dumpjson(), visible=False) diff --git a/modules/upscaler.py b/modules/upscaler.py index 0cdd7892c..10e816569 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -1,7 +1,7 @@ import os from abc import abstractmethod from PIL import Image -from modules import modelloader, shared +from modules import modelloader, shared, paths models = None @@ -39,14 +39,13 @@ class Upscaler: if self.user_path is not None and len(self.user_path) > 0 and not os.path.exists(self.user_path): shared.log.info(f'Upscaler create: folder="{self.user_path}"') if self.model_path is None and self.name: - self.model_path = os.path.join(shared.models_path, self.name) + self.model_path = os.path.join(paths.models_path, self.name) try: if self.model_path and create_dirs: os.makedirs(self.model_path, exist_ok=True) except Exception: pass try: - import cv2 # pylint: disable=unused-import self.can_tile = True except Exception: pass diff --git a/modules/vae/sd_vae_fal.py b/modules/vae/sd_vae_fal.py index bd482a779..0f0e5b3bb 100644 --- a/modules/vae/sd_vae_fal.py +++ b/modules/vae/sd_vae_fal.py @@ -49,17 +49,25 @@ class Flux2TinyAutoEncoder(ModelMixin, ConfigMixin): in_channels: int = 3, out_channels: int = 3, latent_channels: int = 128, - encoder_block_out_channels: list[int] = [64, 64, 64, 64], - decoder_block_out_channels: list[int] = [64, 64, 64, 64], + encoder_block_out_channels: list[int] = None, + decoder_block_out_channels: list[int] = None, act_fn: str = "silu", upsampling_scaling_factor: int = 2, - num_encoder_blocks: list[int] = [1, 3, 3, 3], - num_decoder_blocks: list[int] = [3, 3, 3, 1], + num_encoder_blocks: list[int] = None, + num_decoder_blocks: list[int] = None, latent_magnitude: float = 3.0, latent_shift: float = 0.5, force_upcast: bool = False, scaling_factor: float = 0.13025, ) -> None: + if num_decoder_blocks is None: + num_decoder_blocks = [3, 3, 3, 1] + if num_encoder_blocks is None: + num_encoder_blocks = [1, 3, 3, 3] + if decoder_block_out_channels is None: + decoder_block_out_channels = [64, 64, 64, 64] + if encoder_block_out_channels is None: + encoder_block_out_channels = [64, 64, 64, 64] super().__init__() self.tiny_vae = AutoencoderTiny( in_channels=in_channels, diff --git a/modules/vae/sd_vae_natten.py b/modules/vae/sd_vae_natten.py index 478e9b654..246816c8d 100644 --- a/modules/vae/sd_vae_natten.py +++ b/modules/vae/sd_vae_natten.py @@ -1,7 +1,6 @@ # copied from https://github.com/Birch-san/sdxl-play/blob/main/src/attn/natten_attn_processor.py import os -from typing import Optional from diffusers.models.attention import Attention import torch from torch.nn import Linear @@ -45,9 +44,9 @@ class NattenAttnProcessor: self, attn: Attention, hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: torch.FloatTensor | None = None, + attention_mask: torch.BoolTensor | None = None, + temb: torch.FloatTensor | None = None, ): import natten assert hasattr(attn, 'qkv'), "Did not find property qkv on attn. Expected you to fuse its q_proj, k_proj, v_proj weights and biases beforehand, and multiply attn.scale into the q weights and bias." diff --git a/modules/video_models/google_veo.py b/modules/video_models/google_veo.py index aebc3f22f..49893b260 100644 --- a/modules/video_models/google_veo.py +++ b/modules/video_models/google_veo.py @@ -43,7 +43,7 @@ def get_size_buckets(width: int, height: int) -> str: return closest_size, closest_aspect_ratio -class GoogleVeoVideoPipeline(): +class GoogleVeoVideoPipeline: def __init__(self, model_name: str): self.model = model_name self.client = None diff --git a/modules/video_models/models_def.py b/modules/video_models/models_def.py index 06b64fdd1..351d1c33a 100644 --- a/modules/video_models/models_def.py +++ b/modules/video_models/models_def.py @@ -6,7 +6,7 @@ from installer import log @dataclass -class Model(): +class Model: name: str url: str = '' repo: str = None diff --git a/modules/video_models/video_load.py b/modules/video_models/video_load.py index aae853a8a..e60463368 100644 --- a/modules/video_models/video_load.py +++ b/modules/video_models/video_load.py @@ -2,7 +2,6 @@ import os import sys import copy import time -import transformers # pylint: disable=unused-import import diffusers from modules import shared, errors, sd_models, sd_checkpoint, model_quant, devices, sd_hijack_te, sd_hijack_vae from modules.video_models import models_def, video_utils, video_overrides, video_cache diff --git a/modules/video_models/video_save.py b/modules/video_models/video_save.py index e137ca966..9f75df9c4 100644 --- a/modules/video_models/video_save.py +++ b/modules/video_models/video_save.py @@ -136,9 +136,11 @@ def atomic_save_video(filename: str, pix_fmt:str='yuv420p', options:str='', aac:int=24000, - metadata:dict={}, + metadata:dict=None, pbar=None, ): + if metadata is None: + metadata = {} av = check_av() if av is None or av is False: shared.log.error('Video: ffmpeg/av not available') @@ -205,9 +207,11 @@ def save_video( mp4_interpolate:int=0, # rife interpolation aac_sample_rate:int=24000, # audio sample rate stream=None, # async progress reporting stream - metadata:dict={}, # metadata for video + metadata:dict=None, # metadata for video pbar=None, # progress bar for video ): + if metadata is None: + metadata = {} output_video = None if binary is not None: diff --git a/modules/zluda.py b/modules/zluda.py index 7b85eec62..258958cc7 100644 --- a/modules/zluda.py +++ b/modules/zluda.py @@ -1,13 +1,11 @@ import sys -from typing import Union -from modules.zluda_installer import core, default_agent # pylint: disable=unused-import PLATFORM = sys.platform do_nothing = lambda _: None # pylint: disable=unnecessary-lambda-assignment -def test(device) -> Union[Exception, None]: +def test(device) -> Exception | None: import torch device = torch.device(device) try: diff --git a/modules/zluda_installer.py b/modules/zluda_installer.py index b5055b049..5326588eb 100644 --- a/modules/zluda_installer.py +++ b/modules/zluda_installer.py @@ -6,7 +6,6 @@ import ctypes import shutil import zipfile import urllib.request -from typing import Union from installer import args, log from modules import rocm @@ -23,7 +22,7 @@ HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll', 'rocsparse.dll', 'hipfft.dll', MIOpen_enabled = False path = os.path.abspath(os.environ.get('ZLUDA', '.zluda')) -default_agent: Union[rocm.Agent, None] = None +default_agent: rocm.Agent | None = None hipBLASLt_enabled = False diff --git a/webui.py b/webui.py index 399e0ff4e..e6a995533 100644 --- a/webui.py +++ b/webui.py @@ -90,7 +90,7 @@ def initialize(): timer.startup.record("te") modules.modelloader.cleanup_models() - modules.sd_models.setup_model() + modules.sd_checkpoint.setup_model() timer.startup.record("models") from modules.lora import lora_load