mirror of https://github.com/vladmandic/automatic
modernize typing
parent
7aded79e8a
commit
bfe014f5da
|
|
@ -45,7 +45,7 @@ TBD
|
|||
see [docs](https://vladmandic.github.io/sdnext-docs/Python/) for details
|
||||
- remove hard-dependnecies:
|
||||
`clip, numba, skimage, torchsde, omegaconf, antlr, patch-ng, patch-ng, astunparse, addict, inflection, jsonmerge, kornia`,
|
||||
`resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets`
|
||||
`resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets, imp`
|
||||
these are now installed on-demand when needed
|
||||
- refactor: to/from image/tensor logic, thanks @CalamitousFelicitousness
|
||||
- refactor: switch to `pyproject.toml` for tool configs
|
||||
|
|
@ -57,6 +57,8 @@ TBD
|
|||
- refactor: unified command line parsing
|
||||
- refactor: launch use threads to async execute non-critical tasks
|
||||
- refactor: switch from deprecated `pkg_resources` to `importlib`
|
||||
- refactor: modernize typing and type annotations
|
||||
- refactor: improve `pydantic==2.x` compatibility
|
||||
- update `lint` rules, thanks @awsr
|
||||
- remove requirements: `clip`, `open-clip`
|
||||
- update `requirements`
|
||||
|
|
|
|||
51
installer.py
51
installer.py
|
|
@ -1,4 +1,4 @@
|
|||
from typing import overload, List, Optional
|
||||
from typing import overload
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
|
@ -106,10 +106,12 @@ def str_to_bool(val: str | bool | None) -> bool | None:
|
|||
return val
|
||||
|
||||
|
||||
def install_traceback(suppress: list = []):
|
||||
def install_traceback(suppress: list = None):
|
||||
from rich.traceback import install as traceback_install
|
||||
from rich.pretty import install as pretty_install
|
||||
|
||||
if suppress is None:
|
||||
suppress = []
|
||||
width = os.environ.get("SD_TRACEWIDTH", console.width if console else None)
|
||||
if width is not None:
|
||||
width = int(width)
|
||||
|
|
@ -133,7 +135,7 @@ def setup_logging():
|
|||
from functools import partial, partialmethod
|
||||
from logging.handlers import RotatingFileHandler
|
||||
try:
|
||||
import rich # pylint: disable=unused-import
|
||||
pass # pylint: disable=unused-import
|
||||
except Exception:
|
||||
log.error('Please restart SD.Next so changes take effect')
|
||||
sys.exit(1)
|
||||
|
|
@ -187,7 +189,7 @@ def setup_logging():
|
|||
_Segment = Segment
|
||||
left = _Segment(" " * self.left, style) if self.left else None
|
||||
right = [_Segment.line()]
|
||||
blank_line: Optional[List[Segment]] = None
|
||||
blank_line: list[Segment] | None = None
|
||||
if self.top:
|
||||
blank_line = [_Segment(f'{" " * width}\n', style)]
|
||||
yield from blank_line * self.top
|
||||
|
|
@ -215,8 +217,10 @@ def setup_logging():
|
|||
logging.Logger.trace = partialmethod(logging.Logger.log, logging.TRACE)
|
||||
logging.trace = partial(logging.log, logging.TRACE)
|
||||
|
||||
def exception_hook(e: Exception, suppress=[]):
|
||||
def exception_hook(e: Exception, suppress=None):
|
||||
from rich.traceback import Traceback
|
||||
if suppress is None:
|
||||
suppress = []
|
||||
tb = Traceback.from_exception(type(e), e, e.__traceback__, show_locals=False, max_frames=16, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=console.width)
|
||||
# print-to-console, does not get printed-to-file
|
||||
exc_type, exc_value, exc_traceback = sys.exc_info()
|
||||
|
|
@ -416,7 +420,7 @@ def uninstall(package, quiet = False):
|
|||
|
||||
|
||||
def run(cmd: str, arg: str):
|
||||
result = subprocess.run(f'"{cmd}" {arg}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
result = subprocess.run(f'"{cmd}" {arg}', shell=True, check=False, env=os.environ, capture_output=True)
|
||||
txt = result.stdout.decode(encoding="utf8", errors="ignore")
|
||||
if len(result.stderr) > 0:
|
||||
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
|
||||
|
|
@ -461,7 +465,7 @@ def pip(arg: str, ignore: bool = False, quiet: bool = True, uv = True):
|
|||
all_args = f'{pip_log}{arg} {env_args}'.strip()
|
||||
if not quiet:
|
||||
log.debug(f'Running: {pipCmd}="{all_args}"')
|
||||
result = subprocess.run(f'"{sys.executable}" -m {pipCmd} {all_args}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
result = subprocess.run(f'"{sys.executable}" -m {pipCmd} {all_args}', shell=True, check=False, env=os.environ, capture_output=True)
|
||||
txt = result.stdout.decode(encoding="utf8", errors="ignore")
|
||||
if len(result.stderr) > 0:
|
||||
if uv and result.returncode != 0:
|
||||
|
|
@ -509,7 +513,7 @@ def git(arg: str, folder: str = None, ignore: bool = False, optional: bool = Fal
|
|||
git_cmd = os.environ.get('GIT', "git")
|
||||
if git_cmd != "git":
|
||||
git_cmd = os.path.abspath(git_cmd)
|
||||
result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.')
|
||||
result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, capture_output=True, cwd=folder or '.')
|
||||
stdout = result.stdout.decode(encoding="utf8", errors="ignore")
|
||||
if len(result.stderr) > 0:
|
||||
stdout += ('\n' if len(stdout) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
|
||||
|
|
@ -639,7 +643,11 @@ def get_platform():
|
|||
|
||||
|
||||
# check python version
|
||||
def check_python(supported_minors=[], experimental_minors=[], reason=None):
|
||||
def check_python(supported_minors=None, experimental_minors=None, reason=None):
|
||||
if experimental_minors is None:
|
||||
experimental_minors = []
|
||||
if supported_minors is None:
|
||||
supported_minors = []
|
||||
if supported_minors is None or len(supported_minors) == 0:
|
||||
supported_minors = [10, 11, 12, 13]
|
||||
experimental_minors = [14]
|
||||
|
|
@ -911,8 +919,6 @@ def install_torch_addons():
|
|||
if 'xformers' in xformers_package:
|
||||
try:
|
||||
install(xformers_package, ignore=True, no_deps=True)
|
||||
import torch # pylint: disable=unused-import
|
||||
import xformers # pylint: disable=unused-import
|
||||
except Exception as e:
|
||||
log.debug(f'xFormers cannot install: {e}')
|
||||
elif not args.experimental and not args.use_xformers and opts.get('cross_attention_optimization', '') != 'xFormers':
|
||||
|
|
@ -1126,7 +1132,7 @@ def run_extension_installer(folder):
|
|||
if os.environ.get('PYTHONPATH', None) is not None:
|
||||
seperator = ';' if sys.platform == 'win32' else ':'
|
||||
env['PYTHONPATH'] += seperator + os.environ.get('PYTHONPATH', None)
|
||||
result = subprocess.run(f'"{sys.executable}" "{path_installer}"', shell=True, env=env, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder)
|
||||
result = subprocess.run(f'"{sys.executable}" "{path_installer}"', shell=True, env=env, check=False, capture_output=True, cwd=folder)
|
||||
txt = result.stdout.decode(encoding="utf8", errors="ignore")
|
||||
debug(f'Extension installer: file="{path_installer}" {txt}')
|
||||
if result.returncode != 0:
|
||||
|
|
@ -1265,7 +1271,7 @@ def ensure_base_requirements():
|
|||
local_log = logging.getLogger('sdnext.installer')
|
||||
global setuptools, distutils # pylint: disable=global-statement
|
||||
# python may ship with incompatible setuptools
|
||||
subprocess.run(f'"{sys.executable}" -m pip install setuptools=={setuptools_version}', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
subprocess.run(f'"{sys.executable}" -m pip install setuptools=={setuptools_version}', shell=True, check=False, env=os.environ, capture_output=True)
|
||||
# need to delete all references to modules to be able to reload them otherwise python will use cached version
|
||||
modules = [m for m in sys.modules if m.startswith('setuptools') or m.startswith('distutils')]
|
||||
for m in modules:
|
||||
|
|
@ -1399,7 +1405,7 @@ def install_requirements():
|
|||
log.info('Install: verifying requirements')
|
||||
if args.new:
|
||||
log.debug('Install: flag=new')
|
||||
with open('requirements.txt', 'r', encoding='utf8') as f:
|
||||
with open('requirements.txt', encoding='utf8') as f:
|
||||
lines = [line.strip() for line in f.readlines() if line.strip() != '' and not line.startswith('#') and line is not None]
|
||||
for line in lines:
|
||||
if not installed(line, quiet=True):
|
||||
|
|
@ -1495,20 +1501,20 @@ def get_version(force=False):
|
|||
t_start = time.time()
|
||||
if (version is None) or (version.get('branch', 'unknown') == 'unknown') or force:
|
||||
try:
|
||||
subprocess.run('git config log.showsignature false', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
subprocess.run('git config log.showsignature false', capture_output=True, shell=True, check=True)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', capture_output=True, shell=True, check=True)
|
||||
ver = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ' '
|
||||
commit, updated = ver.split(' ')
|
||||
version['commit'], version['updated'] = commit, updated
|
||||
except Exception as e:
|
||||
log.warning(f'Version: where=commit {e}')
|
||||
try:
|
||||
res = subprocess.run('git remote get-url origin', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
res = subprocess.run('git remote get-url origin', capture_output=True, shell=True, check=True)
|
||||
origin = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
res = subprocess.run('git rev-parse --abbrev-ref HEAD', capture_output=True, shell=True, check=True)
|
||||
branch_name = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
version['url'] = origin.replace('\n', '').removesuffix('.git') + '/tree/' + branch_name.replace('\n', '')
|
||||
version['branch'] = branch_name.replace('\n', '')
|
||||
|
|
@ -1520,7 +1526,7 @@ def get_version(force=False):
|
|||
try:
|
||||
if os.path.exists('extensions-builtin/sdnext-modernui'):
|
||||
os.chdir('extensions-builtin/sdnext-modernui')
|
||||
res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
res = subprocess.run('git rev-parse --abbrev-ref HEAD', capture_output=True, shell=True, check=True)
|
||||
branch_ui = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
branch_ui = 'dev' if 'dev' in branch_ui else 'main'
|
||||
version['ui'] = branch_ui
|
||||
|
|
@ -1536,7 +1542,7 @@ def get_version(force=False):
|
|||
version['kanvas'] = 'disabled'
|
||||
elif os.path.exists('extensions-builtin/sdnext-kanvas'):
|
||||
os.chdir('extensions-builtin/sdnext-kanvas')
|
||||
res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
res = subprocess.run('git rev-parse --abbrev-ref HEAD', capture_output=True, shell=True, check=True)
|
||||
branch_kanvas = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
branch_kanvas = 'dev' if 'dev' in branch_kanvas else 'main'
|
||||
version['kanvas'] = branch_kanvas
|
||||
|
|
@ -1723,7 +1729,7 @@ def check_timestamp():
|
|||
ok = True
|
||||
setup_time = -1
|
||||
version_time = -1
|
||||
with open(log_file, 'r', encoding='utf8') as f:
|
||||
with open(log_file, encoding='utf8') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
if 'Setup complete without errors' in line:
|
||||
|
|
@ -1752,7 +1758,6 @@ def check_timestamp():
|
|||
|
||||
|
||||
def add_args(parser):
|
||||
import argparse
|
||||
group_install = parser.add_argument_group('Install')
|
||||
group_install.add_argument('--quick', default=os.environ.get("SD_QUICK",False), action='store_true', help="Bypass version checks, default: %(default)s")
|
||||
group_install.add_argument('--reset', default=os.environ.get("SD_RESET",False), action='store_true', help="Reset main repository to latest version, default: %(default)s")
|
||||
|
|
@ -1832,7 +1837,7 @@ def read_options():
|
|||
t_start = time.time()
|
||||
global opts # pylint: disable=global-statement
|
||||
if os.path.isfile(args.config):
|
||||
with open(args.config, "r", encoding="utf8") as file:
|
||||
with open(args.config, encoding="utf8") as file:
|
||||
try:
|
||||
opts = json.load(file)
|
||||
if type(opts) is str:
|
||||
|
|
|
|||
18
launch.py
18
launch.py
|
|
@ -72,7 +72,7 @@ def get_custom_args():
|
|||
rec('args')
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def commit_hash(): # compatbility function
|
||||
global stored_commit_hash # pylint: disable=global-statement
|
||||
if stored_commit_hash is not None:
|
||||
|
|
@ -85,7 +85,7 @@ def commit_hash(): # compatbility function
|
|||
return stored_commit_hash
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compatbility function
|
||||
if desc is not None:
|
||||
installer.log.info(desc)
|
||||
|
|
@ -94,7 +94,7 @@ def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compat
|
|||
if result.returncode != 0:
|
||||
raise RuntimeError(f"""{errdesc or 'Error running command'} Command: {command} Error code: {result.returncode}""")
|
||||
return ''
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, check=False, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
|
||||
result = subprocess.run(command, capture_output=True, check=False, shell=True, env=os.environ if custom_env is None else custom_env)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"""{errdesc or 'Error running command'}: {command} code: {result.returncode}
|
||||
{result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''}
|
||||
|
|
@ -104,26 +104,26 @@ def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compat
|
|||
|
||||
|
||||
def check_run(command): # compatbility function
|
||||
result = subprocess.run(command, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
||||
result = subprocess.run(command, check=False, capture_output=True, shell=True)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def is_installed(pkg): # compatbility function
|
||||
return installer.installed(pkg)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def repo_dir(name): # compatbility function
|
||||
return os.path.join(script_path, dir_repos, name)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def run_python(code, desc=None, errdesc=None): # compatbility function
|
||||
return run(f'"{sys.executable}" -c "{code}"', desc, errdesc)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def run_pip(pkg, desc=None): # compatbility function
|
||||
forbidden = ['onnxruntime', 'opencv-python']
|
||||
if desc is None:
|
||||
|
|
@ -136,7 +136,7 @@ def run_pip(pkg, desc=None): # compatbility function
|
|||
return run(f'"{sys.executable}" -m pip {pkg} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def check_run_python(code): # compatbility function
|
||||
return check_run(f'"{sys.executable}" -c "{code}"')
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
|
|
@ -63,11 +63,11 @@ class StableCascadePriorPipelineOutput(BaseOutput):
|
|||
Text embeddings for the negative prompt.
|
||||
"""
|
||||
|
||||
image_embeddings: Union[torch.Tensor, np.ndarray]
|
||||
prompt_embeds: Union[torch.Tensor, np.ndarray]
|
||||
prompt_embeds_pooled: Union[torch.Tensor, np.ndarray]
|
||||
negative_prompt_embeds: Union[torch.Tensor, np.ndarray]
|
||||
negative_prompt_embeds_pooled: Union[torch.Tensor, np.ndarray]
|
||||
image_embeddings: torch.Tensor | np.ndarray
|
||||
prompt_embeds: torch.Tensor | np.ndarray
|
||||
prompt_embeds_pooled: torch.Tensor | np.ndarray
|
||||
negative_prompt_embeds: torch.Tensor | np.ndarray
|
||||
negative_prompt_embeds_pooled: torch.Tensor | np.ndarray
|
||||
|
||||
|
||||
class StableCascadePriorPipelineAPG(DiffusionPipeline):
|
||||
|
|
@ -109,8 +109,8 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
|
|||
prior: StableCascadeUNet,
|
||||
scheduler: DDPMWuerstchenScheduler,
|
||||
resolution_multiple: float = 42.67,
|
||||
feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
|
||||
feature_extractor: CLIPImageProcessor | None = None,
|
||||
image_encoder: CLIPVisionModelWithProjection | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
|
|
@ -151,10 +151,10 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
|
|||
do_classifier_free_guidance,
|
||||
prompt=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_pooled: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_pooled: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
prompt_embeds_pooled: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds_pooled: torch.Tensor | None = None,
|
||||
):
|
||||
if prompt_embeds is None:
|
||||
# get prompt text embeddings
|
||||
|
|
@ -196,7 +196,7 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
|
|||
prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if negative_prompt_embeds is None and do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
uncond_tokens: list[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
|
|
@ -367,26 +367,26 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
|
|||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None,
|
||||
prompt: str | list[str] | None = None,
|
||||
images: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image] = None,
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
num_inference_steps: int = 20,
|
||||
timesteps: List[float] = None,
|
||||
timesteps: list[float] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_pooled: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_pooled: Optional[torch.Tensor] = None,
|
||||
image_embeds: Optional[torch.Tensor] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pt",
|
||||
negative_prompt: str | list[str] | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
prompt_embeds_pooled: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds_pooled: torch.Tensor | None = None,
|
||||
image_embeds: torch.Tensor | None = None,
|
||||
num_images_per_prompt: int | None = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
output_type: str | None = "pt",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = None,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
|
@ -460,6 +460,8 @@ class StableCascadePriorPipelineAPG(DiffusionPipeline):
|
|||
"""
|
||||
|
||||
# 0. Define commonly used variables
|
||||
if callback_on_step_end_tensor_inputs is None:
|
||||
callback_on_step_end_tensor_inputs = ["latents"]
|
||||
device = self._execution_device
|
||||
dtype = next(self.prior.parameters()).dtype
|
||||
self._guidance_scale = guidance_scale
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@
|
|||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
import torch
|
||||
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
|
@ -28,7 +29,6 @@ from diffusers.utils import USE_PEFT_BACKEND, deprecate, is_invisible_watermark_
|
|||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from modules import apg
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
|
|
@ -76,10 +76,10 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
num_inference_steps: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
sigmas: list[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
|
@ -217,7 +217,7 @@ class StableDiffusionXLPipelineAPG(
|
|||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
add_watermarker: bool | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -248,18 +248,18 @@ class StableDiffusionXLPipelineAPG(
|
|||
def encode_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_2: Optional[str] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_2: str | None = None,
|
||||
device: torch.device | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt_2: Optional[str] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
negative_prompt: str | None = None,
|
||||
negative_prompt_2: str | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
pooled_prompt_embeds: torch.Tensor | None = None,
|
||||
negative_pooled_prompt_embeds: torch.Tensor | None = None,
|
||||
lora_scale: float | None = None,
|
||||
clip_skip: int | None = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
|
@ -343,7 +343,7 @@ class StableDiffusionXLPipelineAPG(
|
|||
# textual inversion: process multi-vector tokens if necessary
|
||||
prompt_embeds_list = []
|
||||
prompts = [prompt, prompt_2]
|
||||
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
||||
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False):
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
||||
|
||||
|
|
@ -396,7 +396,7 @@ class StableDiffusionXLPipelineAPG(
|
|||
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
||||
)
|
||||
|
||||
uncond_tokens: List[str]
|
||||
uncond_tokens: list[str]
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
|
|
@ -412,7 +412,7 @@ class StableDiffusionXLPipelineAPG(
|
|||
uncond_tokens = [negative_prompt, negative_prompt_2]
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
||||
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders, strict=False):
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
|
||||
|
||||
|
|
@ -521,7 +521,7 @@ class StableDiffusionXLPipelineAPG(
|
|||
)
|
||||
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers, strict=False
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
|
|
@ -793,42 +793,40 @@ class StableDiffusionXLPipelineAPG(
|
|||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
prompt: str | list[str] = None,
|
||||
prompt_2: str | list[str] | None = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
sigmas: List[float] = None,
|
||||
denoising_end: Optional[float] = None,
|
||||
timesteps: list[int] = None,
|
||||
sigmas: list[float] = None,
|
||||
denoising_end: float | None = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
negative_prompt: str | list[str] | None = None,
|
||||
negative_prompt_2: str | list[str] | None = None,
|
||||
num_images_per_prompt: int | None = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
pooled_prompt_embeds: torch.Tensor | None = None,
|
||||
negative_pooled_prompt_embeds: torch.Tensor | None = None,
|
||||
ip_adapter_image: PipelineImageInput | None = None,
|
||||
ip_adapter_image_embeds: list[torch.Tensor] | None = None,
|
||||
output_type: str | None = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: dict[str, Any] | None = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
negative_original_size: Optional[Tuple[int, int]] = None,
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
original_size: tuple[int, int] | None = None,
|
||||
crops_coords_top_left: tuple[int, int] = (0, 0),
|
||||
target_size: tuple[int, int] | None = None,
|
||||
negative_original_size: tuple[int, int] | None = None,
|
||||
negative_crops_coords_top_left: tuple[int, int] = (0, 0),
|
||||
negative_target_size: tuple[int, int] | None = None,
|
||||
clip_skip: int | None = None,
|
||||
callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
|
|
@ -976,6 +974,8 @@ class StableDiffusionXLPipelineAPG(
|
|||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if callback_on_step_end_tensor_inputs is None:
|
||||
callback_on_step_end_tensor_inputs = ["latents"]
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@
|
|||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
|
|
@ -71,10 +72,10 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|||
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
num_inference_steps: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
sigmas: list[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
|
@ -273,9 +274,9 @@ class StableDiffusionPipelineAPG(
|
|||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
lora_scale: float | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
|
|
@ -305,10 +306,10 @@ class StableDiffusionPipelineAPG(
|
|||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
lora_scale: float | None = None,
|
||||
clip_skip: int | None = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
|
@ -421,7 +422,7 @@ class StableDiffusionPipelineAPG(
|
|||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
uncond_tokens: list[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
|
|
@ -520,7 +521,7 @@ class StableDiffusionPipelineAPG(
|
|||
)
|
||||
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers, strict=False
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
|
|
@ -748,31 +749,29 @@ class StableDiffusionPipelineAPG(
|
|||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
prompt: str | list[str] = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
sigmas: List[float] = None,
|
||||
timesteps: list[int] = None,
|
||||
sigmas: list[float] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
negative_prompt: str | list[str] | None = None,
|
||||
num_images_per_prompt: int | None = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
ip_adapter_image: PipelineImageInput | None = None,
|
||||
ip_adapter_image_embeds: list[torch.Tensor] | None = None,
|
||||
output_type: str | None = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: dict[str, Any] | None = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
clip_skip: int | None = None,
|
||||
callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
|
|
@ -861,6 +860,8 @@ class StableDiffusionPipelineAPG(
|
|||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if callback_on_step_end_tensor_inputs is None:
|
||||
callback_on_step_end_tensor_inputs = ["latents"]
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import List, Optional
|
||||
from threading import Lock
|
||||
from secrets import compare_digest
|
||||
from fastapi import FastAPI, APIRouter, Depends, Request
|
||||
|
|
@ -19,7 +18,7 @@ class Api:
|
|||
user, password = auth.split(":")
|
||||
self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip()
|
||||
if shared.cmd_opts.auth_file:
|
||||
with open(shared.cmd_opts.auth_file, 'r', encoding="utf8") as file:
|
||||
with open(shared.cmd_opts.auth_file, encoding="utf8") as file:
|
||||
for line in file.readlines():
|
||||
user, password = line.split(":")
|
||||
self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip()
|
||||
|
|
@ -41,7 +40,7 @@ class Api:
|
|||
self.add_api_route("/js", server.get_js, methods=["GET"], auth=False)
|
||||
# server api
|
||||
self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str)
|
||||
self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=List[str])
|
||||
self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=list[str])
|
||||
self.add_api_route("/sdapi/v1/log", server.post_log, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"])
|
||||
self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"])
|
||||
|
|
@ -56,7 +55,7 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/options", server.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||
self.add_api_route("/sdapi/v1/options", server.set_config, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/cmd-flags", server.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||
self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu_status, methods=["GET"], response_model=List[models.ResGPU])
|
||||
self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu_status, methods=["GET"], response_model=list[models.ResGPU])
|
||||
|
||||
# core api using locking
|
||||
self.add_api_route("/sdapi/v1/txt2img", self.generate.post_text2img, methods=["POST"], response_model=models.ResTxt2Img)
|
||||
|
|
@ -71,21 +70,21 @@ class Api:
|
|||
|
||||
# api dealing with optional scripts
|
||||
self.add_api_route("/sdapi/v1/scripts", script.get_scripts_list, methods=["GET"], response_model=models.ResScripts)
|
||||
self.add_api_route("/sdapi/v1/script-info", script.get_script_info, methods=["GET"], response_model=List[models.ItemScript])
|
||||
self.add_api_route("/sdapi/v1/script-info", script.get_script_info, methods=["GET"], response_model=list[models.ItemScript])
|
||||
|
||||
# enumerator api
|
||||
self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=List[process.ItemPreprocess])
|
||||
self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=list[process.ItemPreprocess])
|
||||
self.add_api_route("/sdapi/v1/masking", self.process.get_mask, methods=["GET"], response_model=process.ItemMask)
|
||||
self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=List[models.ItemSampler])
|
||||
self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=List[models.ItemUpscaler])
|
||||
self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=List[models.ItemModel])
|
||||
self.add_api_route("/sdapi/v1/controlnets", endpoints.get_controlnets, methods=["GET"], response_model=List[str])
|
||||
self.add_api_route("/sdapi/v1/detailers", endpoints.get_detailers, methods=["GET"], response_model=List[models.ItemDetailer])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", endpoints.get_prompt_styles, methods=["GET"], response_model=List[models.ItemStyle])
|
||||
self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=list[models.ItemSampler])
|
||||
self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=list[models.ItemUpscaler])
|
||||
self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=list[models.ItemModel])
|
||||
self.add_api_route("/sdapi/v1/controlnets", endpoints.get_controlnets, methods=["GET"], response_model=list[str])
|
||||
self.add_api_route("/sdapi/v1/detailers", endpoints.get_detailers, methods=["GET"], response_model=list[models.ItemDetailer])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", endpoints.get_prompt_styles, methods=["GET"], response_model=list[models.ItemStyle])
|
||||
self.add_api_route("/sdapi/v1/embeddings", endpoints.get_embeddings, methods=["GET"], response_model=models.ResEmbeddings)
|
||||
self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=List[models.ItemVae])
|
||||
self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=List[models.ItemExtension])
|
||||
self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=List[models.ItemExtraNetwork])
|
||||
self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=list[models.ItemVae])
|
||||
self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=list[models.ItemExtension])
|
||||
self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=list[models.ItemExtraNetwork])
|
||||
|
||||
# functional api
|
||||
self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo)
|
||||
|
|
@ -96,7 +95,7 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/lock-checkpoint", endpoints.post_lock_checkpoint, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=List[str])
|
||||
self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=list[str])
|
||||
self.add_api_route("/sdapi/v1/latents", endpoints.post_latent_history, methods=["POST"], response_model=int)
|
||||
self.add_api_route("/sdapi/v1/modules", endpoints.get_modules, methods=["GET"])
|
||||
self.add_api_route("/sdapi/v1/sampler", endpoints.get_sampler, methods=["GET"], response_model=dict)
|
||||
|
|
@ -146,7 +145,7 @@ class Api:
|
|||
shared.log.error(f'API authentication: user="{credentials.username}"')
|
||||
raise HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"})
|
||||
|
||||
def get_session_start(self, req: Request, agent: Optional[str] = None):
|
||||
def get_session_start(self, req: Request, agent: str | None = None):
|
||||
token = req.cookies.get("access-token") or req.cookies.get("access-token-unsecure")
|
||||
user = self.app.tokens.get(token) if hasattr(self.app, 'tokens') else None
|
||||
shared.log.info(f'Browser session: user={user} client={req.client.host} agent={agent}')
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ Core processing logic is shared between direct and dispatch handlers via
|
|||
``do_openclip``, ``do_tagger``, and ``do_vqa`` functions to avoid duplication.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Union, Literal, Annotated
|
||||
from typing import Literal, Annotated
|
||||
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
|
||||
from fastapi.exceptions import HTTPException
|
||||
from modules import shared
|
||||
|
|
@ -49,21 +49,21 @@ class ReqCaption(BaseModel):
|
|||
mode: str = Field(default="best", title="Mode", description="Caption mode. 'best': Most thorough analysis, slowest but highest quality. 'fast': Quick caption with minimal flavor terms. 'classic': Standard captioning with balanced quality and speed. 'caption': BLIP caption only, no CLIP flavor matching. 'negative': Generate terms suitable for use as a negative prompt.")
|
||||
analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed image analysis breakdown (medium, artist, movement, trending, flavor) in addition to caption.")
|
||||
# Advanced settings (optional per-request overrides)
|
||||
max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum number of tokens in the generated caption.")
|
||||
chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up captioning but increase VRAM usage.")
|
||||
min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum number of descriptive tags (flavors) to keep in the final prompt.")
|
||||
max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum number of descriptive tags (flavors) to keep in the final prompt.")
|
||||
flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of the intermediate candidate pool when matching image features to descriptive tags. Higher values may improve quality but are slower.")
|
||||
num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Number of beams for beam search during caption generation. Higher values search more possibilities but are slower.")
|
||||
max_length: int | None = Field(default=None, title="Max Length", description="Maximum number of tokens in the generated caption.")
|
||||
chunk_size: int | None = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up captioning but increase VRAM usage.")
|
||||
min_flavors: int | None = Field(default=None, title="Min Flavors", description="Minimum number of descriptive tags (flavors) to keep in the final prompt.")
|
||||
max_flavors: int | None = Field(default=None, title="Max Flavors", description="Maximum number of descriptive tags (flavors) to keep in the final prompt.")
|
||||
flavor_count: int | None = Field(default=None, title="Intermediates", description="Size of the intermediate candidate pool when matching image features to descriptive tags. Higher values may improve quality but are slower.")
|
||||
num_beams: int | None = Field(default=None, title="Num Beams", description="Number of beams for beam search during caption generation. Higher values search more possibilities but are slower.")
|
||||
|
||||
class ResCaption(BaseModel):
|
||||
"""Response model for image captioning results."""
|
||||
caption: Optional[str] = Field(default=None, title="Caption", description="Generated caption/prompt describing the image content and style.")
|
||||
medium: Optional[str] = Field(default=None, title="Medium", description="Detected artistic medium (e.g., 'oil painting', 'digital art', 'photograph'). Only returned when analyze=True.")
|
||||
artist: Optional[str] = Field(default=None, title="Artist", description="Detected similar artist style (e.g., 'by greg rutkowski'). Only returned when analyze=True.")
|
||||
movement: Optional[str] = Field(default=None, title="Movement", description="Detected art movement (e.g., 'art nouveau', 'impressionism'). Only returned when analyze=True.")
|
||||
trending: Optional[str] = Field(default=None, title="Trending", description="Trending/platform tags (e.g., 'trending on artstation'). Only returned when analyze=True.")
|
||||
flavor: Optional[str] = Field(default=None, title="Flavor", description="Additional descriptive elements (e.g., 'cinematic lighting', 'highly detailed'). Only returned when analyze=True.")
|
||||
caption: str | None = Field(default=None, title="Caption", description="Generated caption/prompt describing the image content and style.")
|
||||
medium: str | None = Field(default=None, title="Medium", description="Detected artistic medium (e.g., 'oil painting', 'digital art', 'photograph'). Only returned when analyze=True.")
|
||||
artist: str | None = Field(default=None, title="Artist", description="Detected similar artist style (e.g., 'by greg rutkowski'). Only returned when analyze=True.")
|
||||
movement: str | None = Field(default=None, title="Movement", description="Detected art movement (e.g., 'art nouveau', 'impressionism'). Only returned when analyze=True.")
|
||||
trending: str | None = Field(default=None, title="Trending", description="Trending/platform tags (e.g., 'trending on artstation'). Only returned when analyze=True.")
|
||||
flavor: str | None = Field(default=None, title="Flavor", description="Additional descriptive elements (e.g., 'cinematic lighting', 'highly detailed'). Only returned when analyze=True.")
|
||||
|
||||
class ReqVQA(BaseModel):
|
||||
"""Request model for Vision-Language Model (VLM) captioning.
|
||||
|
|
@ -74,32 +74,32 @@ class ReqVQA(BaseModel):
|
|||
image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string containing the image data.")
|
||||
model: str = Field(default="Alibaba Qwen 2.5 VL 3B", title="Model", description="Select which model to use for Visual Language tasks. Use GET /sdapi/v1/vqa/models for full list. Models which support thinking mode are indicated in capabilities.")
|
||||
question: str = Field(default="describe the image", title="Question/Task", description="Task for the model to perform. Common tasks: 'Short Caption', 'Normal Caption', 'Long Caption'. Set to 'Use Prompt' to pass custom text via the prompt field. Florence-2 tasks: 'Object Detection', 'OCR (Read Text)', 'Phrase Grounding', 'Dense Region Caption', 'Region Proposal', 'OCR with Regions'. PromptGen tasks: 'Analyze', 'Generate Tags', 'Mixed Caption'. Moondream tasks: 'Point at...', 'Detect all...', 'Detect Gaze' (Moondream 2 only). Use GET /sdapi/v1/vqa/prompts?model=<name> to list tasks available for a specific model.")
|
||||
prompt: Optional[str] = Field(default=None, title="Prompt", description="Custom prompt text. Required when question is 'Use Prompt'. For 'Point at...' tasks, specify what to find (e.g., 'the red car'). For 'Detect all...' tasks, specify what to detect (e.g., 'faces').")
|
||||
prompt: str | None = Field(default=None, title="Prompt", description="Custom prompt text. Required when question is 'Use Prompt'. For 'Point at...' tasks, specify what to find (e.g., 'the red car'). For 'Detect all...' tasks, specify what to detect (e.g., 'faces').")
|
||||
system: str = Field(default="You are image captioning expert, creative, unbiased and uncensored.", title="System Prompt", description="System prompt controls behavior of the LLM. Processed first and persists throughout conversation. Has highest priority weighting and is always appended at the beginning of the sequence. Use for: Response formatting rules, role definition, style.")
|
||||
include_annotated: bool = Field(default=False, title="Include Annotated Image", description="If True and the task produces detection results (object detection, point detection, gaze), returns annotated image with bounding boxes/points drawn. Only applicable for detection tasks on models like Florence-2 and Moondream.")
|
||||
# LLM generation parameters (optional overrides)
|
||||
max_tokens: Optional[int] = Field(default=None, title="Max Tokens", description="Maximum number of tokens the model can generate in its response. The model is not aware of this limit during generation; it simply sets the hard limit for the length and will forcefully cut off the response when reached.")
|
||||
temperature: Optional[float] = Field(default=None, title="Temperature", description="Controls randomness in token selection. Lower values (e.g., 0.1) make outputs more focused and deterministic, always choosing high-probability tokens. Higher values (e.g., 0.9) increase creativity and diversity by allowing less probable tokens. Set to 0 for fully deterministic output.")
|
||||
top_k: Optional[int] = Field(default=None, title="Top-K", description="Limits token selection to the K most likely candidates at each step. Lower values (e.g., 40) make outputs more focused and predictable, while higher values allow more diverse choices. Set to 0 to disable.")
|
||||
top_p: Optional[float] = Field(default=None, title="Top-P", description="Selects tokens from the smallest set whose cumulative probability exceeds P (e.g., 0.9). Dynamically adapts the number of candidates based on model confidence; fewer options when certain, more when uncertain. Set to 1 to disable.")
|
||||
num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Maintains multiple candidate paths simultaneously and selects the overall best sequence. More thorough but much slower and less creative than random sampling. Generally not recommended; most modern VLMs perform better with sampling methods. Set to 1 to disable.")
|
||||
do_sample: Optional[bool] = Field(default=None, title="Use Samplers", description="Enable to use sampling (randomly selecting tokens based on sampling methods like Top-K or Top-P) or disable to use greedy decoding (selecting the most probable token at each step). Enabling makes outputs more diverse and creative but less deterministic.")
|
||||
thinking_mode: Optional[bool] = Field(default=None, title="Thinking Mode", description="Enables thinking/reasoning, allowing the model to take more time to generate responses. Can lead to more thoughtful and detailed answers but increases response time. Only works with models that support this feature.")
|
||||
prefill: Optional[str] = Field(default=None, title="Prefill Text", description="Pre-fills the start of the model's response to guide its output format or content by forcing it to continue the prefill text. Prefill is filtered out and does not appear in the final response unless keep_prefill is True. Leave empty to let the model generate from scratch.")
|
||||
keep_thinking: Optional[bool] = Field(default=None, title="Keep Thinking Trace", description="Include the model's reasoning process in the final output. Useful for understanding how the model arrived at its answer. Only works with models that support thinking mode.")
|
||||
keep_prefill: Optional[bool] = Field(default=None, title="Keep Prefill", description="Include the prefill text at the beginning of the final output. If disabled, the prefill text used to guide the model is removed from the result.")
|
||||
max_tokens: int | None = Field(default=None, title="Max Tokens", description="Maximum number of tokens the model can generate in its response. The model is not aware of this limit during generation; it simply sets the hard limit for the length and will forcefully cut off the response when reached.")
|
||||
temperature: float | None = Field(default=None, title="Temperature", description="Controls randomness in token selection. Lower values (e.g., 0.1) make outputs more focused and deterministic, always choosing high-probability tokens. Higher values (e.g., 0.9) increase creativity and diversity by allowing less probable tokens. Set to 0 for fully deterministic output.")
|
||||
top_k: int | None = Field(default=None, title="Top-K", description="Limits token selection to the K most likely candidates at each step. Lower values (e.g., 40) make outputs more focused and predictable, while higher values allow more diverse choices. Set to 0 to disable.")
|
||||
top_p: float | None = Field(default=None, title="Top-P", description="Selects tokens from the smallest set whose cumulative probability exceeds P (e.g., 0.9). Dynamically adapts the number of candidates based on model confidence; fewer options when certain, more when uncertain. Set to 1 to disable.")
|
||||
num_beams: int | None = Field(default=None, title="Num Beams", description="Maintains multiple candidate paths simultaneously and selects the overall best sequence. More thorough but much slower and less creative than random sampling. Generally not recommended; most modern VLMs perform better with sampling methods. Set to 1 to disable.")
|
||||
do_sample: bool | None = Field(default=None, title="Use Samplers", description="Enable to use sampling (randomly selecting tokens based on sampling methods like Top-K or Top-P) or disable to use greedy decoding (selecting the most probable token at each step). Enabling makes outputs more diverse and creative but less deterministic.")
|
||||
thinking_mode: bool | None = Field(default=None, title="Thinking Mode", description="Enables thinking/reasoning, allowing the model to take more time to generate responses. Can lead to more thoughtful and detailed answers but increases response time. Only works with models that support this feature.")
|
||||
prefill: str | None = Field(default=None, title="Prefill Text", description="Pre-fills the start of the model's response to guide its output format or content by forcing it to continue the prefill text. Prefill is filtered out and does not appear in the final response unless keep_prefill is True. Leave empty to let the model generate from scratch.")
|
||||
keep_thinking: bool | None = Field(default=None, title="Keep Thinking Trace", description="Include the model's reasoning process in the final output. Useful for understanding how the model arrived at its answer. Only works with models that support thinking mode.")
|
||||
keep_prefill: bool | None = Field(default=None, title="Keep Prefill", description="Include the prefill text at the beginning of the final output. If disabled, the prefill text used to guide the model is removed from the result.")
|
||||
|
||||
class ResVQA(BaseModel):
|
||||
"""Response model for VLM captioning results."""
|
||||
answer: Optional[str] = Field(default=None, title="Answer", description="Generated caption, answer, or analysis from the VLM. Format depends on the question/task type.")
|
||||
annotated_image: Optional[str] = Field(default=None, title="Annotated Image", description="Base64 encoded PNG image with detection results drawn (bounding boxes, points). Only returned when include_annotated=True and the task produces detection results.")
|
||||
answer: str | None = Field(default=None, title="Answer", description="Generated caption, answer, or analysis from the VLM. Format depends on the question/task type.")
|
||||
annotated_image: str | None = Field(default=None, title="Annotated Image", description="Base64 encoded PNG image with detection results drawn (bounding boxes, points). Only returned when include_annotated=True and the task produces detection results.")
|
||||
|
||||
class ItemVLMModel(BaseModel):
|
||||
"""VLM model information."""
|
||||
name: str = Field(title="Name", description="Display name of the model")
|
||||
repo: str = Field(title="Repository", description="HuggingFace repository ID")
|
||||
prompts: List[str] = Field(title="Prompts", description="Available prompts/tasks for this model")
|
||||
capabilities: List[str] = Field(title="Capabilities", description="Model capabilities. Possible values: 'caption' (image captioning), 'vqa' (visual question answering), 'detection' (object/point detection), 'ocr' (text recognition), 'thinking' (reasoning mode support).")
|
||||
prompts: list[str] = Field(title="Prompts", description="Available prompts/tasks for this model")
|
||||
capabilities: list[str] = Field(title="Capabilities", description="Model capabilities. Possible values: 'caption' (image captioning), 'vqa' (visual question answering), 'detection' (object/point detection), 'ocr' (text recognition), 'thinking' (reasoning mode support).")
|
||||
|
||||
class ResVLMPrompts(BaseModel):
|
||||
"""Available VLM prompts grouped by category.
|
||||
|
|
@ -107,12 +107,12 @@ class ResVLMPrompts(BaseModel):
|
|||
When called without ``model`` parameter, returns all prompt categories.
|
||||
When called with ``model``, returns only the ``available`` field with prompts for that model.
|
||||
"""
|
||||
common: Optional[List[str]] = Field(default=None, title="Common", description="Prompts available for all models: Use Prompt, Short/Normal/Long Caption.")
|
||||
florence: Optional[List[str]] = Field(default=None, title="Florence", description="Florence-2 base model tasks: Phrase Grounding, Object Detection, Dense Region Caption, Region Proposal, OCR (Read Text), OCR with Regions.")
|
||||
promptgen: Optional[List[str]] = Field(default=None, title="PromptGen", description="MiaoshouAI PromptGen fine-tune tasks: Analyze, Generate Tags, Mixed Caption, Mixed Caption+. Only available on PromptGen models.")
|
||||
moondream: Optional[List[str]] = Field(default=None, title="Moondream", description="Moondream 2 and 3 tasks: Point at..., Detect all...")
|
||||
moondream2_only: Optional[List[str]] = Field(default=None, title="Moondream 2 Only", description="Moondream 2 exclusive tasks: Detect Gaze. Not available in Moondream 3.")
|
||||
available: Optional[List[str]] = Field(default=None, title="Available", description="Populated only when filtering by model. Contains the combined list of prompts available for the specified model.")
|
||||
common: list[str] | None = Field(default=None, title="Common", description="Prompts available for all models: Use Prompt, Short/Normal/Long Caption.")
|
||||
florence: list[str] | None = Field(default=None, title="Florence", description="Florence-2 base model tasks: Phrase Grounding, Object Detection, Dense Region Caption, Region Proposal, OCR (Read Text), OCR with Regions.")
|
||||
promptgen: list[str] | None = Field(default=None, title="PromptGen", description="MiaoshouAI PromptGen fine-tune tasks: Analyze, Generate Tags, Mixed Caption, Mixed Caption+. Only available on PromptGen models.")
|
||||
moondream: list[str] | None = Field(default=None, title="Moondream", description="Moondream 2 and 3 tasks: Point at..., Detect all...")
|
||||
moondream2_only: list[str] | None = Field(default=None, title="Moondream 2 Only", description="Moondream 2 exclusive tasks: Detect Gaze. Not available in Moondream 3.")
|
||||
available: list[str] | None = Field(default=None, title="Available", description="Populated only when filtering by model. Contains the combined list of prompts available for the specified model.")
|
||||
|
||||
class ItemTaggerModel(BaseModel):
|
||||
"""Tagger model information."""
|
||||
|
|
@ -136,7 +136,7 @@ class ReqTagger(BaseModel):
|
|||
class ResTagger(BaseModel):
|
||||
"""Response model for image tagging results."""
|
||||
tags: str = Field(title="Tags", description="Comma-separated list of detected tags")
|
||||
scores: Optional[dict] = Field(default=None, title="Scores", description="Tag confidence scores (when show_scores=True)")
|
||||
scores: dict | None = Field(default=None, title="Scores", description="Tag confidence scores (when show_scores=True)")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -158,12 +158,12 @@ class ReqCaptionOpenCLIP(BaseModel):
|
|||
blip_model: str = Field(default="blip-large", title="Caption Model", description="BLIP model used to generate the initial image caption.")
|
||||
mode: str = Field(default="best", title="Mode", description="Caption mode: 'best' (highest quality, slowest), 'fast' (quick, fewer flavors), 'classic' (balanced), 'caption' (BLIP only, no CLIP matching), 'negative' (for negative prompts).")
|
||||
analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed breakdown (medium, artist, movement, trending, flavor).")
|
||||
max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum tokens in generated caption.")
|
||||
chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing flavors.")
|
||||
min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum descriptive tags to keep.")
|
||||
max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum descriptive tags to keep.")
|
||||
flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of intermediate candidate pool.")
|
||||
num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Beams for beam search during caption generation.")
|
||||
max_length: int | None = Field(default=None, title="Max Length", description="Maximum tokens in generated caption.")
|
||||
chunk_size: int | None = Field(default=None, title="Chunk Size", description="Batch size for processing flavors.")
|
||||
min_flavors: int | None = Field(default=None, title="Min Flavors", description="Minimum descriptive tags to keep.")
|
||||
max_flavors: int | None = Field(default=None, title="Max Flavors", description="Maximum descriptive tags to keep.")
|
||||
flavor_count: int | None = Field(default=None, title="Intermediates", description="Size of intermediate candidate pool.")
|
||||
num_beams: int | None = Field(default=None, title="Num Beams", description="Beams for beam search during caption generation.")
|
||||
|
||||
|
||||
class ReqCaptionTagger(BaseModel):
|
||||
|
|
@ -196,24 +196,24 @@ class ReqCaptionVLM(BaseModel):
|
|||
image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string.")
|
||||
model: str = Field(default="Alibaba Qwen 2.5 VL 3B", title="Model", description="VLM model to use. See GET /sdapi/v1/vqa/models for full list.")
|
||||
question: str = Field(default="describe the image", title="Question/Task", description="Task to perform: 'Short Caption', 'Normal Caption', 'Long Caption', 'Use Prompt' (custom text via prompt field). Model-specific tasks available via GET /sdapi/v1/vqa/prompts.")
|
||||
prompt: Optional[str] = Field(default=None, title="Prompt", description="Custom prompt text when question is 'Use Prompt'.")
|
||||
prompt: str | None = Field(default=None, title="Prompt", description="Custom prompt text when question is 'Use Prompt'.")
|
||||
system: str = Field(default="You are image captioning expert, creative, unbiased and uncensored.", title="System Prompt", description="System prompt for LLM behavior.")
|
||||
include_annotated: bool = Field(default=False, title="Include Annotated Image", description="Return annotated image for detection tasks.")
|
||||
max_tokens: Optional[int] = Field(default=None, title="Max Tokens", description="Maximum tokens in response.")
|
||||
temperature: Optional[float] = Field(default=None, title="Temperature", description="Randomness in token selection (0=deterministic, 0.9=creative).")
|
||||
top_k: Optional[int] = Field(default=None, title="Top-K", description="Limit to K most likely tokens per step.")
|
||||
top_p: Optional[float] = Field(default=None, title="Top-P", description="Nucleus sampling threshold.")
|
||||
num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Beam search width (1=disabled).")
|
||||
do_sample: Optional[bool] = Field(default=None, title="Use Samplers", description="Enable sampling vs greedy decoding.")
|
||||
thinking_mode: Optional[bool] = Field(default=None, title="Thinking Mode", description="Enable reasoning mode (supported models only).")
|
||||
prefill: Optional[str] = Field(default=None, title="Prefill Text", description="Pre-fill response start to guide output.")
|
||||
keep_thinking: Optional[bool] = Field(default=None, title="Keep Thinking Trace", description="Include reasoning in output.")
|
||||
keep_prefill: Optional[bool] = Field(default=None, title="Keep Prefill", description="Keep prefill text in final output.")
|
||||
max_tokens: int | None = Field(default=None, title="Max Tokens", description="Maximum tokens in response.")
|
||||
temperature: float | None = Field(default=None, title="Temperature", description="Randomness in token selection (0=deterministic, 0.9=creative).")
|
||||
top_k: int | None = Field(default=None, title="Top-K", description="Limit to K most likely tokens per step.")
|
||||
top_p: float | None = Field(default=None, title="Top-P", description="Nucleus sampling threshold.")
|
||||
num_beams: int | None = Field(default=None, title="Num Beams", description="Beam search width (1=disabled).")
|
||||
do_sample: bool | None = Field(default=None, title="Use Samplers", description="Enable sampling vs greedy decoding.")
|
||||
thinking_mode: bool | None = Field(default=None, title="Thinking Mode", description="Enable reasoning mode (supported models only).")
|
||||
prefill: str | None = Field(default=None, title="Prefill Text", description="Pre-fill response start to guide output.")
|
||||
keep_thinking: bool | None = Field(default=None, title="Keep Thinking Trace", description="Include reasoning in output.")
|
||||
keep_prefill: bool | None = Field(default=None, title="Keep Prefill", description="Keep prefill text in final output.")
|
||||
|
||||
|
||||
# Discriminated union for the dispatch endpoint
|
||||
ReqCaptionDispatch = Annotated[
|
||||
Union[ReqCaptionOpenCLIP, ReqCaptionTagger, ReqCaptionVLM],
|
||||
ReqCaptionOpenCLIP | ReqCaptionTagger | ReqCaptionVLM,
|
||||
Field(discriminator="backend")
|
||||
]
|
||||
|
||||
|
|
@ -226,18 +226,18 @@ class ResCaptionDispatch(BaseModel):
|
|||
# Common
|
||||
backend: str = Field(title="Backend", description="The backend that processed the request: 'openclip', 'tagger', or 'vlm'.")
|
||||
# OpenCLIP fields
|
||||
caption: Optional[str] = Field(default=None, title="Caption", description="Generated caption (OpenCLIP backend).")
|
||||
medium: Optional[str] = Field(default=None, title="Medium", description="Detected artistic medium (OpenCLIP with analyze=True).")
|
||||
artist: Optional[str] = Field(default=None, title="Artist", description="Detected artist style (OpenCLIP with analyze=True).")
|
||||
movement: Optional[str] = Field(default=None, title="Movement", description="Detected art movement (OpenCLIP with analyze=True).")
|
||||
trending: Optional[str] = Field(default=None, title="Trending", description="Trending tags (OpenCLIP with analyze=True).")
|
||||
flavor: Optional[str] = Field(default=None, title="Flavor", description="Flavor descriptors (OpenCLIP with analyze=True).")
|
||||
caption: str | None = Field(default=None, title="Caption", description="Generated caption (OpenCLIP backend).")
|
||||
medium: str | None = Field(default=None, title="Medium", description="Detected artistic medium (OpenCLIP with analyze=True).")
|
||||
artist: str | None = Field(default=None, title="Artist", description="Detected artist style (OpenCLIP with analyze=True).")
|
||||
movement: str | None = Field(default=None, title="Movement", description="Detected art movement (OpenCLIP with analyze=True).")
|
||||
trending: str | None = Field(default=None, title="Trending", description="Trending tags (OpenCLIP with analyze=True).")
|
||||
flavor: str | None = Field(default=None, title="Flavor", description="Flavor descriptors (OpenCLIP with analyze=True).")
|
||||
# Tagger fields
|
||||
tags: Optional[str] = Field(default=None, title="Tags", description="Comma-separated tags (Tagger backend).")
|
||||
scores: Optional[dict] = Field(default=None, title="Scores", description="Tag confidence scores (Tagger with show_scores=True).")
|
||||
tags: str | None = Field(default=None, title="Tags", description="Comma-separated tags (Tagger backend).")
|
||||
scores: dict | None = Field(default=None, title="Scores", description="Tag confidence scores (Tagger with show_scores=True).")
|
||||
# VLM fields
|
||||
answer: Optional[str] = Field(default=None, title="Answer", description="VLM response (VLM backend).")
|
||||
annotated_image: Optional[str] = Field(default=None, title="Annotated Image", description="Base64 annotated image (VLM with include_annotated=True).")
|
||||
answer: str | None = Field(default=None, title="Answer", description="VLM response (VLM backend).")
|
||||
annotated_image: str | None = Field(default=None, title="Annotated Image", description="Base64 annotated image (VLM with include_annotated=True).")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -596,7 +596,7 @@ def get_vqa_models():
|
|||
return models_list
|
||||
|
||||
|
||||
def get_vqa_prompts(model: Optional[str] = None):
|
||||
def get_vqa_prompts(model: str | None = None):
|
||||
"""
|
||||
List available prompts/tasks for VLM models.
|
||||
|
||||
|
|
@ -653,11 +653,11 @@ def get_tagger_models():
|
|||
|
||||
def register_api():
|
||||
from modules.shared import api
|
||||
api.add_api_route("/sdapi/v1/openclip", get_caption, methods=["GET"], response_model=List[str], tags=["Caption"])
|
||||
api.add_api_route("/sdapi/v1/openclip", get_caption, methods=["GET"], response_model=list[str], tags=["Caption"])
|
||||
api.add_api_route("/sdapi/v1/caption", post_caption_dispatch, methods=["POST"], response_model=ResCaptionDispatch, tags=["Caption"])
|
||||
api.add_api_route("/sdapi/v1/openclip", post_caption, methods=["POST"], response_model=ResCaption, tags=["Caption"])
|
||||
api.add_api_route("/sdapi/v1/vqa", post_vqa, methods=["POST"], response_model=ResVQA, tags=["Caption"])
|
||||
api.add_api_route("/sdapi/v1/vqa/models", get_vqa_models, methods=["GET"], response_model=List[ItemVLMModel], tags=["Caption"])
|
||||
api.add_api_route("/sdapi/v1/vqa/models", get_vqa_models, methods=["GET"], response_model=list[ItemVLMModel], tags=["Caption"])
|
||||
api.add_api_route("/sdapi/v1/vqa/prompts", get_vqa_prompts, methods=["GET"], response_model=ResVLMPrompts, tags=["Caption"])
|
||||
api.add_api_route("/sdapi/v1/tagger", post_tagger, methods=["POST"], response_model=ResTagger, tags=["Caption"])
|
||||
api.add_api_route("/sdapi/v1/tagger/models", get_tagger_models, methods=["GET"], response_model=List[ItemTaggerModel], tags=["Caption"])
|
||||
api.add_api_route("/sdapi/v1/tagger/models", get_tagger_models, methods=["GET"], response_model=list[ItemTaggerModel], tags=["Caption"])
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
from threading import Lock
|
||||
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
|
||||
from modules import errors, shared, processing_helpers
|
||||
|
|
@ -43,9 +43,9 @@ ReqControl = models.create_model_from_signature(
|
|||
{"key": "send_images", "type": bool, "default": True},
|
||||
{"key": "save_images", "type": bool, "default": False},
|
||||
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
||||
{"key": "ip_adapter", "type": Optional[List[models.ItemIPAdapter]], "default": None, "exclude": True},
|
||||
{"key": "ip_adapter", "type": Optional[list[models.ItemIPAdapter]], "default": None, "exclude": True},
|
||||
{"key": "face", "type": Optional[models.ItemFace], "default": None, "exclude": True},
|
||||
{"key": "control", "type": Optional[List[ItemControl]], "default": [], "exclude": True},
|
||||
{"key": "control", "type": Optional[list[ItemControl]], "default": [], "exclude": True},
|
||||
{"key": "xyz", "type": Optional[ItemXYZ], "default": None, "exclude": True},
|
||||
# {"key": "extra", "type": Optional[dict], "default": {}, "exclude": True},
|
||||
]
|
||||
|
|
@ -55,13 +55,13 @@ if not hasattr(ReqControl, "__config__"):
|
|||
|
||||
|
||||
class ResControl(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Images", description="")
|
||||
processed: List[str] = Field(default=None, title="Processed", description="")
|
||||
images: list[str] = Field(default=None, title="Images", description="")
|
||||
processed: list[str] = Field(default=None, title="Processed", description="")
|
||||
params: dict = Field(default={}, title="Settings", description="")
|
||||
info: str = Field(default="", title="Info", description="")
|
||||
|
||||
|
||||
class APIControl():
|
||||
class APIControl:
|
||||
def __init__(self, queue_lock: Lock):
|
||||
self.queue_lock = queue_lock
|
||||
self.default_script_arg = []
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Optional
|
||||
from modules import shared
|
||||
from modules.api import models, helpers
|
||||
|
||||
|
|
@ -43,7 +42,7 @@ def get_sd_models():
|
|||
checkpoints.append(model)
|
||||
return checkpoints
|
||||
|
||||
def get_controlnets(model_type: Optional[str] = None):
|
||||
def get_controlnets(model_type: str | None = None):
|
||||
from modules.control.units.controlnet import api_list_models
|
||||
return api_list_models(model_type)
|
||||
|
||||
|
|
@ -60,7 +59,7 @@ def get_embeddings():
|
|||
return models.ResEmbeddings(loaded=[], skipped=[])
|
||||
return models.ResEmbeddings(loaded=list(db.word_embeddings.keys()), skipped=list(db.skipped_embeddings.keys()))
|
||||
|
||||
def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, filename: Optional[str] = None, title: Optional[str] = None, fullname: Optional[str] = None, hash: Optional[str] = None): # pylint: disable=redefined-builtin
|
||||
def get_extra_networks(page: str | None = None, name: str | None = None, filename: str | None = None, title: str | None = None, fullname: str | None = None, hash: str | None = None): # pylint: disable=redefined-builtin
|
||||
res = []
|
||||
for pg in shared.extra_networks:
|
||||
if page is not None and pg.name != page.lower():
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import io
|
|||
import os
|
||||
import time
|
||||
import base64
|
||||
from typing import List, Union
|
||||
from urllib.parse import quote, unquote
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
|
|
@ -52,7 +51,7 @@ class ConnectionManager:
|
|||
debug(f'Browser WS disconnect: client={ws.client.host}')
|
||||
self.active.remove(ws)
|
||||
|
||||
async def send(self, ws: WebSocket, data: Union[str, dict, bytes]):
|
||||
async def send(self, ws: WebSocket, data: str | dict | bytes):
|
||||
# debug(f'Browser WS send: client={ws.client.host} data={type(data)}')
|
||||
if ws.client_state != WebSocketState.CONNECTED:
|
||||
return
|
||||
|
|
@ -65,7 +64,7 @@ class ConnectionManager:
|
|||
else:
|
||||
debug(f'Browser WS send: client={ws.client.host} data={type(data)} unknown')
|
||||
|
||||
async def broadcast(self, data: Union[str, dict, bytes]):
|
||||
async def broadcast(self, data: str | dict | bytes):
|
||||
for ws in self.active:
|
||||
await self.send(ws, data)
|
||||
|
||||
|
|
@ -206,7 +205,7 @@ def register_api(app: FastAPI): # register api
|
|||
shared.log.error(f'Gallery: {folder} {e}')
|
||||
return []
|
||||
|
||||
shared.api.add_api_route("/sdapi/v1/browser/folders", get_folders, methods=["GET"], response_model=List[str])
|
||||
shared.api.add_api_route("/sdapi/v1/browser/folders", get_folders, methods=["GET"], response_model=list[str])
|
||||
shared.api.add_api_route("/sdapi/v1/browser/thumb", get_thumb, methods=["GET"], response_model=dict)
|
||||
shared.api.add_api_route("/sdapi/v1/browser/files", ht_files, methods=["GET"], response_model=list)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from modules.paths import resolve_output_path
|
|||
errors.install()
|
||||
|
||||
|
||||
class APIGenerate():
|
||||
class APIGenerate:
|
||||
def __init__(self, queue_lock: Lock):
|
||||
self.queue_lock = queue_lock
|
||||
self.default_script_arg_txt2img = []
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import List
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
|
||||
|
|
@ -25,5 +24,5 @@ def post_refresh_loras():
|
|||
def register_api():
|
||||
from modules.shared import api
|
||||
api.add_api_route("/sdapi/v1/lora", get_lora, methods=["GET"], response_model=dict)
|
||||
api.add_api_route("/sdapi/v1/loras", get_loras, methods=["GET"], response_model=List[dict])
|
||||
api.add_api_route("/sdapi/v1/loras", get_loras, methods=["GET"], response_model=list[dict])
|
||||
api.add_api_route("/sdapi/v1/refresh-loras", post_refresh_loras, methods=["POST"])
|
||||
|
|
|
|||
|
|
@ -1,7 +1,15 @@
|
|||
import re
|
||||
import inspect
|
||||
from typing import Any, Optional, Dict, List, Type, Callable, Union
|
||||
from pydantic import BaseModel, Field, create_model # pylint: disable=no-name-in-module
|
||||
from typing import Any, Optional, Union
|
||||
from collections.abc import Callable
|
||||
import pydantic
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
try:
|
||||
from pydantic import ConfigDict
|
||||
PYDANTIC_V2 = True
|
||||
except ImportError:
|
||||
ConfigDict = None
|
||||
PYDANTIC_V2 = False
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
import modules.shared as shared
|
||||
|
||||
|
|
@ -41,8 +49,10 @@ class PydanticModelGenerator:
|
|||
model_name: str = None,
|
||||
class_instance = None,
|
||||
additional_fields = None,
|
||||
exclude_fields: List = [],
|
||||
exclude_fields: list = None,
|
||||
):
|
||||
if exclude_fields is None:
|
||||
exclude_fields = []
|
||||
def field_type_generator(_k, v):
|
||||
field_type = v.annotation
|
||||
return Optional[field_type]
|
||||
|
|
@ -80,12 +90,15 @@ class PydanticModelGenerator:
|
|||
|
||||
def generate_model(self):
|
||||
model_fields = { d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def }
|
||||
DynamicModel = create_model(self._model_name, **model_fields)
|
||||
try:
|
||||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
DynamicModel.__config__.allow_mutation = True
|
||||
except Exception:
|
||||
pass
|
||||
if PYDANTIC_V2:
|
||||
config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True)
|
||||
else:
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
orm_mode = True
|
||||
allow_population_by_field_name = True
|
||||
config = Config
|
||||
DynamicModel = create_model(self._model_name, __config__=config, **model_fields)
|
||||
return DynamicModel
|
||||
|
||||
### item classes
|
||||
|
|
@ -100,49 +113,49 @@ class ItemVae(BaseModel):
|
|||
|
||||
class ItemUpscaler(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
model_name: Optional[str] = Field(title="Model Name")
|
||||
model_path: Optional[str] = Field(title="Path")
|
||||
model_url: Optional[str] = Field(title="URL")
|
||||
scale: Optional[float] = Field(title="Scale")
|
||||
model_name: str | None = Field(title="Model Name")
|
||||
model_path: str | None = Field(title="Path")
|
||||
model_url: str | None = Field(title="URL")
|
||||
scale: float | None = Field(title="Scale")
|
||||
|
||||
class ItemModel(BaseModel):
|
||||
title: str = Field(title="Title")
|
||||
model_name: str = Field(title="Model Name")
|
||||
filename: str = Field(title="Filename")
|
||||
type: str = Field(title="Model type")
|
||||
sha256: Optional[str] = Field(title="SHA256 hash")
|
||||
hash: Optional[str] = Field(title="Short hash")
|
||||
config: Optional[str] = Field(title="Config file")
|
||||
sha256: str | None = Field(title="SHA256 hash")
|
||||
hash: str | None = Field(title="Short hash")
|
||||
config: str | None = Field(title="Config file")
|
||||
|
||||
class ItemHypernetwork(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
path: str | None = Field(title="Path")
|
||||
|
||||
class ItemDetailer(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
path: str | None = Field(title="Path")
|
||||
|
||||
class ItemGAN(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
scale: Optional[int] = Field(title="Scale")
|
||||
path: str | None = Field(title="Path")
|
||||
scale: int | None = Field(title="Scale")
|
||||
|
||||
class ItemStyle(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
prompt: Optional[str] = Field(title="Prompt")
|
||||
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
||||
extra: Optional[str] = Field(title="Extra")
|
||||
filename: Optional[str] = Field(title="Filename")
|
||||
preview: Optional[str] = Field(title="Preview")
|
||||
prompt: str | None = Field(title="Prompt")
|
||||
negative_prompt: str | None = Field(title="Negative Prompt")
|
||||
extra: str | None = Field(title="Extra")
|
||||
filename: str | None = Field(title="Filename")
|
||||
preview: str | None = Field(title="Preview")
|
||||
|
||||
class ItemExtraNetwork(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
type: str = Field(title="Type")
|
||||
title: Optional[str] = Field(title="Title")
|
||||
fullname: Optional[str] = Field(title="Fullname")
|
||||
filename: Optional[str] = Field(title="Filename")
|
||||
hash: Optional[str] = Field(title="Hash")
|
||||
preview: Optional[str] = Field(title="Preview image URL")
|
||||
title: str | None = Field(title="Title")
|
||||
fullname: str | None = Field(title="Fullname")
|
||||
filename: str | None = Field(title="Filename")
|
||||
hash: str | None = Field(title="Hash")
|
||||
preview: str | None = Field(title="Preview image URL")
|
||||
|
||||
class ItemArtist(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
|
|
@ -150,16 +163,16 @@ class ItemArtist(BaseModel):
|
|||
category: str = Field(title="Category")
|
||||
|
||||
class ItemEmbedding(BaseModel):
|
||||
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
|
||||
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
|
||||
step: int | None = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||
sd_checkpoint: str | None = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
|
||||
sd_checkpoint_name: str | None = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
|
||||
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
|
||||
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
|
||||
|
||||
class ItemIPAdapter(BaseModel):
|
||||
adapter: str = Field(title="Adapter", default="Base", description="IP adapter name")
|
||||
images: List[str] = Field(title="Image", default=[], description="IP adapter input images")
|
||||
masks: Optional[List[str]] = Field(title="Mask", default=[], description="IP adapter mask images")
|
||||
images: list[str] = Field(title="Image", default=[], description="IP adapter input images")
|
||||
masks: list[str] | None = Field(title="Mask", default=[], description="IP adapter mask images")
|
||||
scale: float = Field(title="Scale", default=0.5, ge=0, le=1, description="IP adapter scale")
|
||||
start: float = Field(title="Start", default=0.0, ge=0, le=1, description="IP adapter start step")
|
||||
end: float = Field(title="End", default=1.0, gt=0, le=1, description="IP adapter end step")
|
||||
|
|
@ -183,17 +196,17 @@ class ItemFace(BaseModel):
|
|||
|
||||
class ScriptArg(BaseModel):
|
||||
label: str = Field(default=None, title="Label", description="Name of the argument in UI")
|
||||
value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
|
||||
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
|
||||
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
|
||||
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
|
||||
choices: Optional[Any] = Field(default=None, title="Choices", description="Possible values for the argument")
|
||||
value: Any | None = Field(default=None, title="Value", description="Default value of the argument")
|
||||
minimum: Any | None = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
|
||||
maximum: Any | None = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
|
||||
step: Any | None = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
|
||||
choices: Any | None = Field(default=None, title="Choices", description="Possible values for the argument")
|
||||
|
||||
class ItemScript(BaseModel):
|
||||
name: str = Field(default=None, title="Name", description="Script name")
|
||||
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
|
||||
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
|
||||
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
|
||||
args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
|
||||
|
||||
class ItemExtension(BaseModel):
|
||||
name: str = Field(title="Name", description="Extension name")
|
||||
|
|
@ -201,13 +214,13 @@ class ItemExtension(BaseModel):
|
|||
branch: str = Field(default="uknnown", title="Branch", description="Extension Repository Branch")
|
||||
commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash")
|
||||
version: str = Field(title="Version", description="Extension Version")
|
||||
commit_date: Union[str, int] = Field(title="Commit Date", description="Extension Repository Commit Date")
|
||||
commit_date: str | int = Field(title="Commit Date", description="Extension Repository Commit Date")
|
||||
enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
|
||||
|
||||
class ItemScheduler(BaseModel):
|
||||
name: str = Field(title="Name", description="Scheduler name")
|
||||
cls: str = Field(title="Class", description="Scheduler class name")
|
||||
options: Dict[str, Any] = Field(title="Options", description="Dictionary of scheduler options")
|
||||
options: dict[str, Any] = Field(title="Options", description="Dictionary of scheduler options")
|
||||
|
||||
### request/response classes
|
||||
|
||||
|
|
@ -223,7 +236,7 @@ ReqTxt2Img = PydanticModelGenerator(
|
|||
{"key": "send_images", "type": bool, "default": True},
|
||||
{"key": "save_images", "type": bool, "default": False},
|
||||
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
||||
{"key": "ip_adapter", "type": Optional[List[ItemIPAdapter]], "default": None, "exclude": True},
|
||||
{"key": "ip_adapter", "type": Optional[list[ItemIPAdapter]], "default": None, "exclude": True},
|
||||
{"key": "face", "type": Optional[ItemFace], "default": None, "exclude": True},
|
||||
{"key": "extra", "type": Optional[dict], "default": {}, "exclude": True},
|
||||
]
|
||||
|
|
@ -233,7 +246,7 @@ if not hasattr(ReqTxt2Img, "__config__"):
|
|||
StableDiffusionTxt2ImgProcessingAPI = ReqTxt2Img
|
||||
|
||||
class ResTxt2Img(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated images in base64 format.")
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated images in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
|
|
@ -253,7 +266,7 @@ ReqImg2Img = PydanticModelGenerator(
|
|||
{"key": "send_images", "type": bool, "default": True},
|
||||
{"key": "save_images", "type": bool, "default": False},
|
||||
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
||||
{"key": "ip_adapter", "type": Optional[List[ItemIPAdapter]], "default": None, "exclude": True},
|
||||
{"key": "ip_adapter", "type": Optional[list[ItemIPAdapter]], "default": None, "exclude": True},
|
||||
{"key": "face_id", "type": Optional[ItemFace], "default": None, "exclude": True},
|
||||
{"key": "extra", "type": Optional[dict], "default": {}, "exclude": True},
|
||||
]
|
||||
|
|
@ -263,7 +276,7 @@ if not hasattr(ReqImg2Img, "__config__"):
|
|||
StableDiffusionImg2ImgProcessingAPI = ReqImg2Img
|
||||
|
||||
class ResImg2Img(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated images in base64 format.")
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated images in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
|
|
@ -289,9 +302,9 @@ class ResProcess(BaseModel):
|
|||
class ReqPromptEnhance(BaseModel):
|
||||
prompt: str = Field(title="Prompt", description="Prompt to enhance")
|
||||
type: str = Field(title="Type", default='text', description="Type of enhancement to perform")
|
||||
model: Optional[str] = Field(title="Model", default=None, description="Model to use for enhancement")
|
||||
system_prompt: Optional[str] = Field(title="System prompt", default=None, description="Model system prompt")
|
||||
image: Optional[str] = Field(title="Image", default=None, description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
model: str | None = Field(title="Model", default=None, description="Model to use for enhancement")
|
||||
system_prompt: str | None = Field(title="System prompt", default=None, description="Model system prompt")
|
||||
image: str | None = Field(title="Image", default=None, description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
seed: int = Field(title="Seed", default=-1, description="Seed used to generate the prompt")
|
||||
nsfw: bool = Field(title="NSFW", default=True, description="Should NSFW content be allowed?")
|
||||
|
||||
|
|
@ -306,10 +319,10 @@ class ResProcessImage(ResProcess):
|
|||
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
|
||||
class ReqProcessBatch(ReqProcess):
|
||||
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||
|
||||
class ResProcessBatch(ResProcess):
|
||||
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
|
||||
class ReqImageInfo(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded image")
|
||||
|
|
@ -325,38 +338,38 @@ class ReqGetLog(BaseModel):
|
|||
|
||||
|
||||
class ReqPostLog(BaseModel):
|
||||
message: Optional[str] = Field(default=None, title="Message", description="The info message to log")
|
||||
debug: Optional[str] = Field(default=None, title="Debug message", description="The debug message to log")
|
||||
error: Optional[str] = Field(default=None, title="Error message", description="The error message to log")
|
||||
message: str | None = Field(default=None, title="Message", description="The info message to log")
|
||||
debug: str | None = Field(default=None, title="Debug message", description="The debug message to log")
|
||||
error: str | None = Field(default=None, title="Error message", description="The error message to log")
|
||||
|
||||
class ReqHistory(BaseModel):
|
||||
id: Union[int, str, None] = Field(default=None, title="Task ID", description="Task ID")
|
||||
id: int | str | None = Field(default=None, title="Task ID", description="Task ID")
|
||||
|
||||
class ReqProgress(BaseModel):
|
||||
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
||||
|
||||
class ResProgress(BaseModel):
|
||||
id: Union[int, str, None] = Field(title="TaskID", description="Task ID")
|
||||
id: int | str | None = Field(title="TaskID", description="Task ID")
|
||||
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
|
||||
eta_relative: float = Field(title="ETA in secs")
|
||||
state: dict = Field(title="State", description="The current state snapshot")
|
||||
current_image: Optional[str] = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||
textinfo: Optional[str] = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
||||
current_image: str | None = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||
textinfo: str | None = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
||||
|
||||
class ResHistory(BaseModel):
|
||||
id: Union[int, str, None] = Field(title="ID", description="Task ID")
|
||||
id: int | str | None = Field(title="ID", description="Task ID")
|
||||
job: str = Field(title="Job", description="Job name")
|
||||
op: str = Field(title="Operation", description="Job state")
|
||||
timestamp: Union[float, None] = Field(title="Timestamp", description="Job timestamp")
|
||||
duration: Union[float, None] = Field(title="Duration", description="Job duration")
|
||||
outputs: List[str] = Field(title="Outputs", description="List of filenames")
|
||||
timestamp: float | None = Field(title="Timestamp", description="Job timestamp")
|
||||
duration: float | None = Field(title="Duration", description="Job duration")
|
||||
outputs: list[str] = Field(title="Outputs", description="List of filenames")
|
||||
|
||||
class ResStatus(BaseModel):
|
||||
status: str = Field(title="Status", description="Current status")
|
||||
task: str = Field(title="Task", description="Current job")
|
||||
timestamp: Optional[str] = Field(title="Timestamp", description="Timestamp of the current job")
|
||||
timestamp: str | None = Field(title="Timestamp", description="Timestamp of the current job")
|
||||
current: str = Field(title="Task", description="Current job")
|
||||
id: Union[int, str, None] = Field(title="ID", description="ID of the current task")
|
||||
id: int | str | None = Field(title="ID", description="ID of the current task")
|
||||
job: int = Field(title="Job", description="Current job")
|
||||
jobs: int = Field(title="Jobs", description="Total jobs")
|
||||
total: int = Field(title="Total Jobs", description="Total jobs")
|
||||
|
|
@ -364,9 +377,9 @@ class ResStatus(BaseModel):
|
|||
steps: int = Field(title="Steps", description="Total steps")
|
||||
queued: int = Field(title="Queued", description="Number of queued tasks")
|
||||
uptime: int = Field(title="Uptime", description="Uptime of the server")
|
||||
elapsed: Optional[float] = Field(default=None, title="Elapsed time")
|
||||
eta: Optional[float] = Field(default=None, title="ETA in secs")
|
||||
progress: Optional[float] = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
|
||||
elapsed: float | None = Field(default=None, title="Elapsed time")
|
||||
eta: float | None = Field(default=None, title="ETA in secs")
|
||||
progress: float | None = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
|
||||
|
||||
class ReqLatentHistory(BaseModel):
|
||||
name: str = Field(title="Name", description="Name of the history item to select")
|
||||
|
|
@ -392,7 +405,15 @@ for key, metadata in shared.opts.data_labels.items():
|
|||
else:
|
||||
fields.update({key: (Optional[optType], Field())})
|
||||
|
||||
OptionsModel = create_model("Options", **fields)
|
||||
if PYDANTIC_V2:
|
||||
config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True)
|
||||
else:
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
orm_mode = True
|
||||
allow_population_by_field_name = True
|
||||
config = Config
|
||||
OptionsModel = create_model("Options", __config__=config, **fields)
|
||||
|
||||
flags = {}
|
||||
_options = vars(shared.parser)['_option_string_actions']
|
||||
|
|
@ -404,7 +425,15 @@ for key in _options:
|
|||
_type = type(_options[key].default)
|
||||
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
|
||||
|
||||
FlagsModel = create_model("Flags", **flags)
|
||||
if PYDANTIC_V2:
|
||||
config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True)
|
||||
else:
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
orm_mode = True
|
||||
allow_population_by_field_name = True
|
||||
config = Config
|
||||
FlagsModel = create_model("Flags", __config__=config, **flags)
|
||||
|
||||
class ResEmbeddings(BaseModel):
|
||||
loaded: list = Field(default=None, title="loaded", description="List of loaded embeddings")
|
||||
|
|
@ -426,9 +455,13 @@ class ResGPU(BaseModel): # definition of http response
|
|||
|
||||
# helper function
|
||||
|
||||
def create_model_from_signature(func: Callable, model_name: str, base_model: Type[BaseModel] = BaseModel, additional_fields: List = [], exclude_fields: List[str] = []) -> type[BaseModel]:
|
||||
def create_model_from_signature(func: Callable, model_name: str, base_model: type[BaseModel] = BaseModel, additional_fields: list = None, exclude_fields: list[str] = None) -> type[BaseModel]:
|
||||
from PIL import Image
|
||||
|
||||
if exclude_fields is None:
|
||||
exclude_fields = []
|
||||
if additional_fields is None:
|
||||
additional_fields = []
|
||||
class Config:
|
||||
extra = 'allow'
|
||||
|
||||
|
|
@ -443,13 +476,13 @@ def create_model_from_signature(func: Callable, model_name: str, base_model: Typ
|
|||
defaults = (...,) * non_default_args + defaults
|
||||
keyword_only_params = {param: kwonlydefaults.get(param, Any) for param in kwonlyargs}
|
||||
for k, v in annotations.items():
|
||||
if v == List[Image.Image]:
|
||||
annotations[k] = List[str]
|
||||
if v == list[Image.Image]:
|
||||
annotations[k] = list[str]
|
||||
elif v == Image.Image:
|
||||
annotations[k] = str
|
||||
elif str(v) == 'typing.List[modules.control.unit.Unit]':
|
||||
annotations[k] = List[str]
|
||||
model_fields = {param: (annotations.get(param, Any), default) for param, default in zip(args, defaults)}
|
||||
annotations[k] = list[str]
|
||||
model_fields = {param: (annotations.get(param, Any), default) for param, default in zip(args, defaults, strict=False)}
|
||||
|
||||
for fld in additional_fields:
|
||||
model_def = ModelDef(
|
||||
|
|
@ -464,16 +497,21 @@ def create_model_from_signature(func: Callable, model_name: str, base_model: Typ
|
|||
if fld in model_fields:
|
||||
del model_fields[fld]
|
||||
|
||||
if PYDANTIC_V2:
|
||||
config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True, populate_by_name=True, extra='allow' if varkw else 'ignore')
|
||||
else:
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
orm_mode = True
|
||||
allow_population_by_field_name = True
|
||||
extra = 'allow' if varkw else 'ignore'
|
||||
config = Config
|
||||
|
||||
model = create_model(
|
||||
model_name,
|
||||
**model_fields,
|
||||
**keyword_only_params,
|
||||
__base__=base_model,
|
||||
__config__=config,
|
||||
**model_fields,
|
||||
**keyword_only_params,
|
||||
)
|
||||
try:
|
||||
model.__config__.allow_population_by_field_name = True
|
||||
model.__config__.allow_mutation = True
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Optional, List
|
||||
from threading import Lock
|
||||
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
|
||||
from fastapi.responses import JSONResponse
|
||||
|
|
@ -15,7 +14,7 @@ errors.install()
|
|||
class ReqPreprocess(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded image")
|
||||
model: str = Field(title="Model", description="The model to use for preprocessing")
|
||||
params: Optional[dict] = Field(default={}, title="Settings", description="Preprocessor settings")
|
||||
params: dict | None = Field(default={}, title="Settings", description="Preprocessor settings")
|
||||
|
||||
class ResPreprocess(BaseModel):
|
||||
model: str = Field(default='', title="Model", description="The processor model used")
|
||||
|
|
@ -24,20 +23,20 @@ class ResPreprocess(BaseModel):
|
|||
class ReqMask(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded image")
|
||||
type: str = Field(title="Mask type", description="Type of masking image to return")
|
||||
mask: Optional[str] = Field(title="Mask", description="If optional maks image is not provided auto-masking will be performed")
|
||||
model: Optional[str] = Field(title="Model", description="The model to use for preprocessing")
|
||||
params: Optional[dict] = Field(default={}, title="Settings", description="Preprocessor settings")
|
||||
mask: str | None = Field(title="Mask", description="If optional maks image is not provided auto-masking will be performed")
|
||||
model: str | None = Field(title="Model", description="The model to use for preprocessing")
|
||||
params: dict | None = Field(default={}, title="Settings", description="Preprocessor settings")
|
||||
|
||||
class ReqFace(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded image")
|
||||
model: Optional[str] = Field(title="Model", description="The model to use for detection")
|
||||
model: str | None = Field(title="Model", description="The model to use for detection")
|
||||
|
||||
class ResFace(BaseModel):
|
||||
classes: List[int] = Field(title="Class", description="The class of detected item")
|
||||
labels: List[str] = Field(title="Label", description="The label of detected item")
|
||||
boxes: List[List[int]] = Field(title="Box", description="The bounding box of detected item")
|
||||
images: List[str] = Field(title="Image", description="The base64 encoded images of detected faces")
|
||||
scores: List[float] = Field(title="Scores", description="The scores of the detected faces")
|
||||
classes: list[int] = Field(title="Class", description="The class of detected item")
|
||||
labels: list[str] = Field(title="Label", description="The label of detected item")
|
||||
boxes: list[list[int]] = Field(title="Box", description="The bounding box of detected item")
|
||||
images: list[str] = Field(title="Image", description="The base64 encoded images of detected faces")
|
||||
scores: list[float] = Field(title="Scores", description="The scores of the detected faces")
|
||||
|
||||
class ResMask(BaseModel):
|
||||
mask: str = Field(default='', title="Image", description="The processed image in base64 format")
|
||||
|
|
@ -47,13 +46,13 @@ class ItemPreprocess(BaseModel):
|
|||
params: dict = Field(title="Params")
|
||||
|
||||
class ItemMask(BaseModel):
|
||||
models: List[str] = Field(title="Models")
|
||||
colormaps: List[str] = Field(title="Color maps")
|
||||
models: list[str] = Field(title="Models")
|
||||
colormaps: list[str] = Field(title="Color maps")
|
||||
params: dict = Field(title="Params")
|
||||
types: List[str] = Field(title="Types")
|
||||
types: list[str] = Field(title="Types")
|
||||
|
||||
|
||||
class APIProcess():
|
||||
class APIProcess:
|
||||
def __init__(self, queue_lock: Lock):
|
||||
self.queue_lock = queue_lock
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Optional
|
||||
from fastapi.exceptions import HTTPException
|
||||
import gradio as gr
|
||||
from modules.api import models
|
||||
|
|
@ -36,7 +35,7 @@ def get_scripts_list():
|
|||
return models.ResScripts(txt2img = t2ilist, img2img = i2ilist, control = control)
|
||||
|
||||
|
||||
def get_script_info(script_name: Optional[str] = None):
|
||||
def get_script_info(script_name: str | None = None):
|
||||
res = []
|
||||
for script_list in [scripts_manager.scripts_txt2img.scripts, scripts_manager.scripts_img2img.scripts, scripts_manager.scripts_control.scripts]:
|
||||
for script in script_list:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from typing import List
|
||||
|
||||
|
||||
def xyz_grid_enum(option: str = "") -> List[dict]:
|
||||
def xyz_grid_enum(option: str = "") -> list[dict]:
|
||||
from scripts.xyz import xyz_grid_classes # pylint: disable=no-name-in-module
|
||||
options = []
|
||||
for x in xyz_grid_classes.axis_options:
|
||||
|
|
@ -23,4 +22,4 @@ def xyz_grid_enum(option: str = "") -> List[dict]:
|
|||
|
||||
def register_api():
|
||||
from modules.shared import api as api_instance
|
||||
api_instance.add_api_route("/sdapi/v1/xyz-grid", xyz_grid_enum, methods=["GET"], response_model=List[dict])
|
||||
api_instance.add_api_route("/sdapi/v1/xyz-grid", xyz_grid_enum, methods=["GET"], response_model=list[dict])
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Optional
|
||||
from functools import wraps
|
||||
import torch
|
||||
from modules import rocm
|
||||
|
|
@ -23,7 +22,7 @@ def set_triton_flash_attention(backend: str):
|
|||
from modules.flash_attn_triton_amd import interface_fa
|
||||
sdpa_pre_triton_flash_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(sdpa_pre_triton_flash_atten)
|
||||
def sdpa_triton_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
|
||||
def sdpa_triton_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
|
||||
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
|
||||
if scale is None:
|
||||
scale = query.shape[-1] ** (-0.5)
|
||||
|
|
@ -56,7 +55,7 @@ def set_flex_attention():
|
|||
|
||||
sdpa_pre_flex_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(sdpa_pre_flex_atten)
|
||||
def sdpa_flex_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: # pylint: disable=unused-argument
|
||||
def sdpa_flex_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: # pylint: disable=unused-argument
|
||||
score_mod = None
|
||||
block_mask = None
|
||||
if attn_mask is not None:
|
||||
|
|
@ -96,7 +95,7 @@ def set_ck_flash_attention(backend: str, device: torch.device):
|
|||
from flash_attn import flash_attn_func
|
||||
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(sdpa_pre_flash_atten)
|
||||
def sdpa_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
|
||||
def sdpa_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
|
||||
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
|
||||
is_unsqueezed = False
|
||||
if query.dim() == 3:
|
||||
|
|
@ -162,7 +161,7 @@ def set_sage_attention(backend: str, device: torch.device):
|
|||
|
||||
sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(sdpa_pre_sage_atten)
|
||||
def sdpa_sage_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
|
||||
def sdpa_sage_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
|
||||
if (query.shape[-1] in {128, 96, 64}) and (attn_mask is None) and (query.dtype != torch.float32):
|
||||
if enable_gqa:
|
||||
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
||||
|
|
|
|||
|
|
@ -373,7 +373,7 @@ class BasicLayer(nn.Module):
|
|||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, (-100.0)).masked_fill(attn_mask == 0, 0.0)
|
||||
|
||||
for blk in self.blocks:
|
||||
blk.H, blk.W = H, W
|
||||
|
|
@ -464,8 +464,8 @@ class SwinTransformer(nn.Module):
|
|||
patch_size=4,
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
depths=None,
|
||||
num_heads=None,
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
|
|
@ -479,6 +479,10 @@ class SwinTransformer(nn.Module):
|
|||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
use_checkpoint=False):
|
||||
if num_heads is None:
|
||||
num_heads = [3, 6, 12, 24]
|
||||
if depths is None:
|
||||
depths = [2, 2, 6, 2]
|
||||
super().__init__()
|
||||
|
||||
self.pretrain_img_size = pretrain_img_size
|
||||
|
|
@ -668,8 +672,10 @@ class PositionEmbeddingSine:
|
|||
|
||||
|
||||
class MCLM(nn.Module):
|
||||
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
||||
super(MCLM, self).__init__()
|
||||
def __init__(self, d_model, num_heads, pool_ratios=None):
|
||||
if pool_ratios is None:
|
||||
pool_ratios = [1, 4, 8]
|
||||
super().__init__()
|
||||
self.attention = nn.ModuleList([
|
||||
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
||||
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
||||
|
|
@ -739,7 +745,7 @@ class MCLM(nn.Module):
|
|||
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
|
||||
_g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
|
||||
outputs_re = []
|
||||
for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
|
||||
for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1), strict=False)):
|
||||
outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
|
||||
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
|
||||
|
||||
|
|
@ -760,8 +766,10 @@ class MCLM(nn.Module):
|
|||
|
||||
|
||||
class MCRM(nn.Module):
|
||||
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None): # pylint: disable=unused-argument
|
||||
super(MCRM, self).__init__()
|
||||
def __init__(self, d_model, num_heads, pool_ratios=None, h=None): # pylint: disable=unused-argument
|
||||
if pool_ratios is None:
|
||||
pool_ratios = [4, 8, 16]
|
||||
super().__init__()
|
||||
self.attention = nn.ModuleList([
|
||||
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
||||
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
||||
|
|
@ -1049,7 +1057,7 @@ class BEN_Base(nn.Module):
|
|||
"""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise IOError(f"Cannot open video: {video_path}")
|
||||
raise OSError(f"Cannot open video: {video_path}")
|
||||
|
||||
original_fps = cap.get(cv2.CAP_PROPFPS)
|
||||
original_fps = 30 if original_fps == 0 else original_fps
|
||||
|
|
@ -1225,7 +1233,7 @@ def add_audio_to_video(video_without_audio_path, original_video_path, output_pat
|
|||
'-of', 'csv=p=0',
|
||||
original_video_path
|
||||
]
|
||||
result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False)
|
||||
result = subprocess.run(probe_command, capture_output=True, text=True, check=False)
|
||||
|
||||
# result.stdout is empty if no audio stream found
|
||||
if not result.stdout.strip():
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import threading
|
|||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from modules import modelloader, devices, shared
|
||||
from modules import modelloader, devices, shared, paths
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
load_lock = threading.Lock()
|
||||
|
|
@ -18,7 +18,7 @@ class DeepDanbooru:
|
|||
with load_lock:
|
||||
if self.model is not None:
|
||||
return
|
||||
model_path = os.path.join(shared.models_path, "DeepDanbooru")
|
||||
model_path = os.path.join(paths.models_path, "DeepDanbooru")
|
||||
shared.log.debug(f'Caption load: module=DeepDanbooru folder="{model_path}"')
|
||||
files = modelloader.load_models(
|
||||
model_path=model_path,
|
||||
|
|
@ -96,7 +96,7 @@ class DeepDanbooru:
|
|||
x = torch.from_numpy(a).to(device=devices.device, dtype=devices.dtype)
|
||||
y = self.model(x)[0].detach().float().cpu().numpy()
|
||||
probability_dict = {}
|
||||
for current, probability in zip(self.model.tags, y):
|
||||
for current, probability in zip(self.model.tags, y, strict=False):
|
||||
if probability < general_threshold:
|
||||
continue
|
||||
if current.startswith("rating:") and not include_rating:
|
||||
|
|
|
|||
|
|
@ -671,4 +671,4 @@ class DeepDanbooruModel(nn.Module):
|
|||
|
||||
def load_state_dict(self, state_dict, **kwargs): # pylint: disable=arguments-differ,unused-argument
|
||||
self.tags = state_dict.get('tags', [])
|
||||
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'}) # pylint: disable=R1725
|
||||
super().load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'}) # pylint: disable=R1725
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ vl_chat_processor = None
|
|||
loaded_repo = None
|
||||
|
||||
|
||||
class fake_attrdict():
|
||||
class fake_attrdict:
|
||||
class AttrDict(dict): # dot notation access to dictionary attributes
|
||||
__getattr__ = dict.get
|
||||
__setattr__ = dict.__setitem__
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ Extra Options:
|
|||
"""
|
||||
|
||||
@dataclass
|
||||
class JoyOptions():
|
||||
class JoyOptions:
|
||||
repo: str = "fancyfeast/llama-joycaption-alpha-two-hf-llava"
|
||||
temp: float = 0.5
|
||||
top_k: float = 10
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import os
|
|||
import math
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torch.backends.cuda
|
||||
|
|
@ -126,7 +125,7 @@ class VisionModel(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def load_model(path: str) -> 'VisionModel':
|
||||
with open(Path(path) / 'config.json', 'r', encoding='utf8') as f:
|
||||
with open(Path(path) / 'config.json', encoding='utf8') as f:
|
||||
config = json.load(f)
|
||||
from safetensors.torch import load_file
|
||||
resume = load_file(Path(path) / 'model.safetensors', device='cpu')
|
||||
|
|
@ -244,7 +243,7 @@ class CLIPMlp(nn.Module):
|
|||
|
||||
class FastCLIPAttention2(nn.Module):
|
||||
"""Fast Attention module for CLIP-like. This is NOT a drop-in replacement for CLIPAttention, since it adds additional flexibility. Mainly uses xformers."""
|
||||
def __init__(self, hidden_size: int, out_dim: int, num_attention_heads: int, out_seq_len: Optional[int] = None, norm_qk: bool = False):
|
||||
def __init__(self, hidden_size: int, out_dim: int, num_attention_heads: int, out_seq_len: int | None = None, norm_qk: bool = False):
|
||||
super().__init__()
|
||||
self.out_seq_len = out_seq_len
|
||||
self.embed_dim = hidden_size
|
||||
|
|
@ -308,12 +307,12 @@ class FastCLIPEncoderLayer(nn.Module):
|
|||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
out_seq_len: Optional[int],
|
||||
out_seq_len: int | None,
|
||||
activation_cls = QuickGELUActivation,
|
||||
use_palm_alt: bool = False,
|
||||
norm_qk: bool = False,
|
||||
skip_init: Optional[float] = None,
|
||||
stochastic_depth: Optional[float] = None,
|
||||
skip_init: float | None = None,
|
||||
stochastic_depth: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_palm_alt = use_palm_alt
|
||||
|
|
@ -523,8 +522,8 @@ class CLIPLikeModel(VisionModel):
|
|||
norm_qk: bool = False,
|
||||
no_wd_bias: bool = False,
|
||||
use_gap_head: bool = False,
|
||||
skip_init: Optional[float] = None,
|
||||
stochastic_depth: Optional[float] = None,
|
||||
skip_init: float | None = None,
|
||||
stochastic_depth: float | None = None,
|
||||
):
|
||||
super().__init__(image_size, n_tags)
|
||||
out_dim = n_tags
|
||||
|
|
@ -939,7 +938,7 @@ class ViT(VisionModel):
|
|||
stochdepth_rate: float,
|
||||
use_sine: bool,
|
||||
loss_type: str,
|
||||
layerscale_init: Optional[float] = None,
|
||||
layerscale_init: float | None = None,
|
||||
head_mean_after: bool = False,
|
||||
cnn_stem: str = None,
|
||||
patch_dropout: float = 0.0,
|
||||
|
|
@ -1048,7 +1047,7 @@ def load():
|
|||
model = VisionModel.load_model(folder)
|
||||
model.to(dtype=devices.dtype)
|
||||
model.eval() # required: custom loader, not from_pretrained
|
||||
with open(os.path.join(folder, 'top_tags.txt'), 'r', encoding='utf8') as f:
|
||||
with open(os.path.join(folder, 'top_tags.txt'), encoding='utf8') as f:
|
||||
tags = [line.strip() for line in f.readlines() if line.strip()]
|
||||
shared.log.info(f'Caption: type=vlm model="JoyTag" repo="{MODEL_REPO}" tags={len(tags)}')
|
||||
sd_models.move_model(model, devices.device)
|
||||
|
|
|
|||
|
|
@ -330,11 +330,11 @@ def analyze_image(image, clip_model, blip_model):
|
|||
top_movements = ci.movements.rank(image_features, 5)
|
||||
top_trendings = ci.trendings.rank(image_features, 5)
|
||||
top_flavors = ci.flavors.rank(image_features, 5)
|
||||
medium_ranks = dict(sorted(zip(top_mediums, ci.similarities(image_features, top_mediums)), key=lambda x: x[1], reverse=True))
|
||||
artist_ranks = dict(sorted(zip(top_artists, ci.similarities(image_features, top_artists)), key=lambda x: x[1], reverse=True))
|
||||
movement_ranks = dict(sorted(zip(top_movements, ci.similarities(image_features, top_movements)), key=lambda x: x[1], reverse=True))
|
||||
trending_ranks = dict(sorted(zip(top_trendings, ci.similarities(image_features, top_trendings)), key=lambda x: x[1], reverse=True))
|
||||
flavor_ranks = dict(sorted(zip(top_flavors, ci.similarities(image_features, top_flavors)), key=lambda x: x[1], reverse=True))
|
||||
medium_ranks = dict(sorted(zip(top_mediums, ci.similarities(image_features, top_mediums), strict=False), key=lambda x: x[1], reverse=True))
|
||||
artist_ranks = dict(sorted(zip(top_artists, ci.similarities(image_features, top_artists), strict=False), key=lambda x: x[1], reverse=True))
|
||||
movement_ranks = dict(sorted(zip(top_movements, ci.similarities(image_features, top_movements), strict=False), key=lambda x: x[1], reverse=True))
|
||||
trending_ranks = dict(sorted(zip(top_trendings, ci.similarities(image_features, top_trendings), strict=False), key=lambda x: x[1], reverse=True))
|
||||
flavor_ranks = dict(sorted(zip(top_flavors, ci.similarities(image_features, top_flavors), strict=False), key=lambda x: x[1], reverse=True))
|
||||
shared.log.debug(f'CLIP analyze: complete time={time.time()-t0:.2f}')
|
||||
|
||||
# Format labels as text
|
||||
|
|
|
|||
|
|
@ -708,7 +708,7 @@ class VQA:
|
|||
debug(f'VQA caption: handler=qwen output_ids_shape={output_ids.shape}')
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):]
|
||||
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
|
||||
for input_ids, output_ids in zip(inputs.input_ids, output_ids, strict=False)
|
||||
]
|
||||
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
if debug_enabled:
|
||||
|
|
@ -887,7 +887,7 @@ class VQA:
|
|||
|
||||
def _ovis(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
||||
try:
|
||||
import flash_attn # pylint: disable=unused-import
|
||||
pass # pylint: disable=unused-import
|
||||
except Exception:
|
||||
shared.log.error(f'Caption: vlm="{repo}" flash-attn is not available')
|
||||
return ''
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ class WaifuDiffusionTagger:
|
|||
self.tags = []
|
||||
self.tag_categories = []
|
||||
|
||||
with open(csv_path, 'r', encoding='utf-8') as f:
|
||||
with open(csv_path, encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
self.tags.append(row['name'])
|
||||
|
|
@ -269,7 +269,7 @@ class WaifuDiffusionTagger:
|
|||
character_count = 0
|
||||
rating_count = 0
|
||||
|
||||
for i, (tag_name, prob) in enumerate(zip(self.tags, probs)):
|
||||
for i, (tag_name, prob) in enumerate(zip(self.tags, probs, strict=False)):
|
||||
category = self.tag_categories[i]
|
||||
tag_lower = tag_name.lower()
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ selected_model = None
|
|||
|
||||
|
||||
class CivitModel:
|
||||
def __init__(self, name, fn, sha = None, meta = {}):
|
||||
def __init__(self, name, fn, sha = None, meta = None):
|
||||
if meta is None:
|
||||
meta = {}
|
||||
self.name = name
|
||||
self.file = name
|
||||
self.id = meta.get('id', 0)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ full_html = False
|
|||
base_models = ['', 'AuraFlow', 'Chroma', 'CogVideoX', 'Flux.1 S', 'Flux.1 D', 'Flux.1 Krea', 'Flux.1 Kontext', 'Flux.2 D', 'HiDream', 'Hunyuan 1', 'Hunyuan Video', 'Illustrious', 'Kolors', 'LTXV', 'Lumina', 'Mochi', 'NoobAI', 'PixArt a', 'PixArt E', 'Pony', 'Pony V7', 'Qwen', 'SD 1.4', 'SD 1.5', 'SD 1.5 LCM', 'SD 1.5 Hyper', 'SD 2.0', 'SD 2.1', 'SDXL 1.0', 'SDXL Lightning', 'SDXL Hyper', 'Wan Video 1.3B t2v', 'Wan Video 14B t2v', 'Wan Video 14B i2v 480p', 'Wan Video 14B i2v 720p', 'Wan Video 2.2 TI2V-5B', 'Wan Video 2.2 I2V-A14B', 'Wan Video 2.2 T2V-A14B', 'Wan Video 2.5 T2V', 'Wan Video 2.5 I2V', 'ZImageTurbo', 'Other']
|
||||
|
||||
@dataclass
|
||||
class ModelImage():
|
||||
class ModelImage:
|
||||
def __init__(self, dct: dict):
|
||||
if isinstance(dct, str):
|
||||
dct = json.loads(dct)
|
||||
|
|
@ -26,7 +26,7 @@ class ModelImage():
|
|||
|
||||
|
||||
@dataclass
|
||||
class ModelFile():
|
||||
class ModelFile:
|
||||
def __init__(self, dct: dict):
|
||||
if isinstance(dct, str):
|
||||
dct = json.loads(dct)
|
||||
|
|
@ -43,7 +43,7 @@ class ModelFile():
|
|||
|
||||
|
||||
@dataclass
|
||||
class ModelVersion():
|
||||
class ModelVersion:
|
||||
def __init__(self, dct: dict):
|
||||
import bs4
|
||||
if isinstance(dct, str):
|
||||
|
|
@ -65,7 +65,7 @@ class ModelVersion():
|
|||
|
||||
|
||||
@dataclass
|
||||
class Model():
|
||||
class Model:
|
||||
def __init__(self, dct: dict):
|
||||
import bs4
|
||||
if isinstance(dct, str):
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ def main_args():
|
|||
|
||||
def compatibility_args():
|
||||
# removed args are added here as hidden in fixed format for compatbility reasons
|
||||
from modules.paths import data_path, models_path
|
||||
from modules.paths import data_path
|
||||
group_compat = parser.add_argument_group('Compatibility options')
|
||||
group_compat.add_argument('--backend', type=str, choices=['diffusers', 'original'], help=argparse.SUPPRESS)
|
||||
group_compat.add_argument("--allow-code", default=os.environ.get("SD_ALLOWCODE", False), action='store_true', help=argparse.SUPPRESS)
|
||||
|
|
|
|||
|
|
@ -46,11 +46,17 @@ def preprocess_image(
|
|||
input_mask:Image.Image = None,
|
||||
input_type:str = 0,
|
||||
unit_type:str = 'controlnet',
|
||||
active_process:list = [],
|
||||
active_model:list = [],
|
||||
selected_models:list = [],
|
||||
active_process:list = None,
|
||||
active_model:list = None,
|
||||
selected_models:list = None,
|
||||
has_models:bool = False,
|
||||
):
|
||||
if selected_models is None:
|
||||
selected_models = []
|
||||
if active_model is None:
|
||||
active_model = []
|
||||
if active_process is None:
|
||||
active_process = []
|
||||
t0 = time.time()
|
||||
jobid = shared.state.begin('Preprocess')
|
||||
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ def update_settings(*settings):
|
|||
update(['Depth Pro', 'params', 'color_map'], settings[28])
|
||||
|
||||
|
||||
class Processor():
|
||||
class Processor:
|
||||
def __init__(self, processor_id: str = None, resize = True):
|
||||
self.model = None
|
||||
self.processor_id = None
|
||||
|
|
@ -268,7 +268,9 @@ class Processor():
|
|||
display(e, 'Control Processor load')
|
||||
return f'Processor load filed: {processor_id}'
|
||||
|
||||
def __call__(self, image_input: Image, mode: str = 'RGB', width: int = 0, height: int = 0, resize_mode: int = 0, resize_name: str = 'None', scale_tab: int = 1, scale_by: float = 1.0, local_config: dict = {}):
|
||||
def __call__(self, image_input: Image, mode: str = 'RGB', width: int = 0, height: int = 0, resize_mode: int = 0, resize_name: str = 'None', scale_tab: int = 1, scale_by: float = 1.0, local_config: dict = None):
|
||||
if local_config is None:
|
||||
local_config = {}
|
||||
if self.override is not None:
|
||||
debug(f'Control Processor: id="{self.processor_id}" override={self.override}')
|
||||
width = image_input.width if image_input is not None else width
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
import sys
|
||||
from typing import List, Union
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from modules.control import util # helper functions
|
||||
|
|
@ -141,12 +140,12 @@ def set_pipe(p, has_models, unit_type, selected_models, active_model, active_str
|
|||
|
||||
|
||||
def check_active(p, unit_type, units):
|
||||
active_process: List[processors.Processor] = [] # all active preprocessors
|
||||
active_model: List[Union[controlnet.ControlNet, xs.ControlNetXS, t2iadapter.Adapter]] = [] # all active models
|
||||
active_strength: List[float] = [] # strength factors for all active models
|
||||
active_start: List[float] = [] # start step for all active models
|
||||
active_end: List[float] = [] # end step for all active models
|
||||
active_units: List[unit.Unit] = [] # all active units
|
||||
active_process: list[processors.Processor] = [] # all active preprocessors
|
||||
active_model: list[controlnet.ControlNet | xs.ControlNetXS | t2iadapter.Adapter] = [] # all active models
|
||||
active_strength: list[float] = [] # strength factors for all active models
|
||||
active_start: list[float] = [] # start step for all active models
|
||||
active_end: list[float] = [] # end step for all active models
|
||||
active_units: list[unit.Unit] = [] # all active units
|
||||
num_units = 0
|
||||
for u in units:
|
||||
if u.type != unit_type:
|
||||
|
|
@ -218,7 +217,7 @@ def check_active(p, unit_type, units):
|
|||
|
||||
def check_enabled(p, unit_type, units, active_model, active_strength, active_start, active_end):
|
||||
has_models = False
|
||||
selected_models: List[Union[controlnet.ControlNetModel, xs.ControlNetXSModel, t2iadapter.AdapterModel]] = None
|
||||
selected_models: list[controlnet.ControlNetModel | xs.ControlNetXSModel | t2iadapter.AdapterModel] = None
|
||||
control_conditioning = None
|
||||
control_guidance_start = None
|
||||
control_guidance_end = None
|
||||
|
|
@ -254,7 +253,7 @@ def control_set(kwargs):
|
|||
p_extra_args[k] = v
|
||||
|
||||
|
||||
def init_units(units: List[unit.Unit]):
|
||||
def init_units(units: list[unit.Unit]):
|
||||
for u in units:
|
||||
if not u.enabled:
|
||||
continue
|
||||
|
|
@ -271,9 +270,9 @@ def init_units(units: List[unit.Unit]):
|
|||
|
||||
|
||||
def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg
|
||||
units: List[unit.Unit] = [], inputs: List[Image.Image] = [], inits: List[Image.Image] = [], mask: Image.Image = None, unit_type: str = None, is_generator: bool = True,
|
||||
units: list[unit.Unit] = None, inputs: list[Image.Image] = None, inits: list[Image.Image] = None, mask: Image.Image = None, unit_type: str = None, is_generator: bool = True,
|
||||
input_type: int = 0,
|
||||
prompt: str = '', negative_prompt: str = '', styles: List[str] = [],
|
||||
prompt: str = '', negative_prompt: str = '', styles: list[str] = None,
|
||||
steps: int = 20, sampler_index: int = None,
|
||||
seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1,
|
||||
guidance_name: str = 'Default', guidance_scale: float = 6.0, guidance_rescale: float = 0.0, guidance_start: float = 0.0, guidance_stop: float = 1.0,
|
||||
|
|
@ -289,11 +288,23 @@ def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg
|
|||
enable_hr: bool = False, hr_sampler_index: int = None, hr_denoising_strength: float = 0.0, hr_resize_mode: int = 0, hr_resize_context: str = 'None', hr_upscaler: str = None, hr_force: bool = False, hr_second_pass_steps: int = 20,
|
||||
hr_scale: float = 1.0, hr_resize_x: int = 0, hr_resize_y: int = 0, refiner_steps: int = 5, refiner_start: float = 0.0, refiner_prompt: str = '', refiner_negative: str = '',
|
||||
video_skip_frames: int = 0, video_type: str = 'None', video_duration: float = 2.0, video_loop: bool = False, video_pad: int = 0, video_interpolate: int = 0,
|
||||
extra: dict = {},
|
||||
extra: dict = None,
|
||||
override_script_name: str = None,
|
||||
override_script_args = [],
|
||||
override_script_args = None,
|
||||
*input_script_args,
|
||||
):
|
||||
if override_script_args is None:
|
||||
override_script_args = []
|
||||
if extra is None:
|
||||
extra = {}
|
||||
if styles is None:
|
||||
styles = []
|
||||
if inits is None:
|
||||
inits = []
|
||||
if inputs is None:
|
||||
inputs = []
|
||||
if units is None:
|
||||
units = []
|
||||
global pipe, original_pipeline # pylint: disable=global-statement
|
||||
if 'refine' in state:
|
||||
enable_hr = True
|
||||
|
|
@ -303,7 +314,7 @@ def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg
|
|||
init_units(units)
|
||||
if inputs is None or (type(inputs) is list and len(inputs) == 0):
|
||||
inputs = [None]
|
||||
output_images: List[Image.Image] = [] # output images
|
||||
output_images: list[Image.Image] = [] # output images
|
||||
processed_image: Image.Image = None # last processed image
|
||||
if mask is not None and input_type == 0:
|
||||
input_type = 1 # inpaint always requires control_image
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Union
|
||||
from PIL import Image
|
||||
import gradio as gr
|
||||
from installer import log
|
||||
|
|
@ -7,7 +6,6 @@ from modules.control.units import controlnet
|
|||
from modules.control.units import xs
|
||||
from modules.control.units import lite
|
||||
from modules.control.units import t2iadapter
|
||||
from modules.control.units import reference # pylint: disable=unused-import
|
||||
|
||||
|
||||
default_device = None
|
||||
|
|
@ -16,7 +14,7 @@ unit_types = ['t2i adapter', 'controlnet', 'xs', 'lite', 'reference', 'ip']
|
|||
current = []
|
||||
|
||||
|
||||
class Unit(): # mashup of gradio controls and mapping to actual implementation classes
|
||||
class Unit: # mashup of gradio controls and mapping to actual implementation classes
|
||||
def update_choices(self, model_id=None):
|
||||
name = model_id or self.model_name
|
||||
if name == 'InstantX Union F1':
|
||||
|
|
@ -57,8 +55,10 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c
|
|||
control_mode = None,
|
||||
control_tile = None,
|
||||
result_txt = None,
|
||||
extra_controls: list = [],
|
||||
extra_controls: list = None,
|
||||
):
|
||||
if extra_controls is None:
|
||||
extra_controls = []
|
||||
self.model_id = model_id
|
||||
self.process_id = process_id
|
||||
self.controls = [gr.Label(value=unit_type, visible=False)] # separator
|
||||
|
|
@ -77,7 +77,7 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation c
|
|||
self.process_name = None
|
||||
self.process: processors.Processor = processors.Processor()
|
||||
self.adapter: t2iadapter.Adapter = None
|
||||
self.controlnet: Union[controlnet.ControlNet, xs.ControlNetXS] = None
|
||||
self.controlnet: controlnet.ControlNet | xs.ControlNetXS = None
|
||||
# map to input image
|
||||
self.override: Image = None
|
||||
# global settings but passed per-unit
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ def get_gpu_info():
|
|||
elif torch.cuda.is_available() and torch.version.cuda:
|
||||
try:
|
||||
import subprocess
|
||||
result = subprocess.run('nvidia-smi --query-gpu=driver_version --format=csv,noheader', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
result = subprocess.run('nvidia-smi --query-gpu=driver_version --format=csv,noheader', shell=True, check=False, env=os.environ, capture_output=True)
|
||||
version = result.stdout.decode(encoding="utf8", errors="ignore").strip()
|
||||
return version
|
||||
except Exception:
|
||||
|
|
@ -307,7 +307,7 @@ def set_cuda_tunable():
|
|||
lines={0}
|
||||
try:
|
||||
if os.path.exists(fn):
|
||||
with open(fn, 'r', encoding='utf8') as f:
|
||||
with open(fn, encoding='utf8') as f:
|
||||
lines = sum(1 for _line in f)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
class Generator(torch.Generator):
|
||||
def __init__(self, device: Optional[torch.device] = None):
|
||||
def __init__(self, device: torch.device | None = None):
|
||||
super().__init__("cpu")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import platform
|
||||
from typing import NamedTuple, Callable, Optional
|
||||
from typing import NamedTuple, Optional
|
||||
from collections.abc import Callable
|
||||
import torch
|
||||
from modules.errors import log
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
|
|
@ -86,8 +87,8 @@ def directml_do_hijack():
|
|||
|
||||
class OverrideItem(NamedTuple):
|
||||
value: str
|
||||
condition: Optional[Callable]
|
||||
message: Optional[str]
|
||||
condition: Callable | None
|
||||
message: str | None
|
||||
|
||||
|
||||
opts_override_table = {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import importlib
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ class autocast:
|
|||
|
||||
fast_dtype: torch.dtype = torch.float16
|
||||
prev_fast_dtype: torch.dtype
|
||||
def __init__(self, dtype: Optional[torch.dtype] = torch.float16):
|
||||
def __init__(self, dtype: torch.dtype | None = torch.float16):
|
||||
self.fast_dtype = dtype
|
||||
|
||||
def __enter__(self):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# pylint: disable=no-member,no-self-argument,no-method-argument
|
||||
from typing import Optional, Callable
|
||||
from collections.abc import Callable
|
||||
import torch
|
||||
import torch_directml # pylint: disable=import-error
|
||||
import modules.dml.amp as amp
|
||||
|
|
@ -9,17 +9,17 @@ from .Generator import Generator
|
|||
from .device_properties import DeviceProperties
|
||||
|
||||
|
||||
def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
|
||||
def amd_mem_get_info(device: rDevice | None=None) -> tuple[int, int]:
|
||||
from .memory_amd import AMDMemoryProvider
|
||||
return AMDMemoryProvider.mem_get_info(get_device(device).index)
|
||||
|
||||
|
||||
def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
|
||||
def pdh_mem_get_info(device: rDevice | None=None) -> tuple[int, int]:
|
||||
mem_info = DirectML.memory_provider.get_memory(get_device(device).index)
|
||||
return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"])
|
||||
|
||||
|
||||
def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument
|
||||
def mem_get_info(device: rDevice | None=None) -> tuple[int, int]: # pylint: disable=unused-argument
|
||||
return (8589934592, 8589934592)
|
||||
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ class DirectML:
|
|||
device = Device
|
||||
Generator = Generator
|
||||
|
||||
context_device: Optional[torch.device] = None
|
||||
context_device: torch.device | None = None
|
||||
|
||||
is_autocast_enabled = False
|
||||
autocast_gpu_dtype = torch.float16
|
||||
|
|
@ -41,7 +41,7 @@ class DirectML:
|
|||
def is_directml_device(device: torch.device) -> bool:
|
||||
return device.type == "privateuseone"
|
||||
|
||||
def has_float64_support(device: Optional[rDevice]=None) -> bool:
|
||||
def has_float64_support(device: rDevice | None=None) -> bool:
|
||||
return torch_directml.has_float64_support(get_device(device).index)
|
||||
|
||||
def device_count() -> int:
|
||||
|
|
@ -53,16 +53,16 @@ class DirectML:
|
|||
def default_device() -> torch.device:
|
||||
return torch_directml.device(torch_directml.default_device())
|
||||
|
||||
def get_device_string(device: Optional[rDevice]=None) -> str:
|
||||
def get_device_string(device: rDevice | None=None) -> str:
|
||||
return f"privateuseone:{get_device(device).index}"
|
||||
|
||||
def get_device_name(device: Optional[rDevice]=None) -> str:
|
||||
def get_device_name(device: rDevice | None=None) -> str:
|
||||
return torch_directml.device_name(get_device(device).index)
|
||||
|
||||
def get_device_properties(device: Optional[rDevice]=None) -> DeviceProperties:
|
||||
def get_device_properties(device: rDevice | None=None) -> DeviceProperties:
|
||||
return DeviceProperties(get_device(device))
|
||||
|
||||
def memory_stats(device: Optional[rDevice]=None):
|
||||
def memory_stats(device: rDevice | None=None):
|
||||
return {
|
||||
"num_ooms": 0,
|
||||
"num_alloc_retries": 0,
|
||||
|
|
@ -70,11 +70,11 @@ class DirectML:
|
|||
|
||||
mem_get_info: Callable = mem_get_info
|
||||
|
||||
def memory_allocated(device: Optional[rDevice]=None) -> int:
|
||||
def memory_allocated(device: rDevice | None=None) -> int:
|
||||
return sum(torch_directml.gpu_memory(get_device(device).index)) * (1 << 20)
|
||||
|
||||
def max_memory_allocated(device: Optional[rDevice]=None):
|
||||
def max_memory_allocated(device: rDevice | None=None):
|
||||
return DirectML.memory_allocated(device) # DirectML does not empty GPU memory
|
||||
|
||||
def reset_peak_memory_stats(device: Optional[rDevice]=None):
|
||||
def reset_peak_memory_stats(device: rDevice | None=None):
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Optional
|
||||
import torch
|
||||
from .utils import rDevice, get_device
|
||||
|
||||
|
|
@ -6,11 +5,11 @@ from .utils import rDevice, get_device
|
|||
class Device:
|
||||
idx: int
|
||||
|
||||
def __enter__(self, device: Optional[rDevice]=None):
|
||||
def __enter__(self, device: rDevice | None=None):
|
||||
torch.dml.context_device = get_device(device)
|
||||
self.idx = torch.dml.context_device.index
|
||||
|
||||
def __init__(self, device: Optional[rDevice]=None) -> torch.device: # pylint: disable=return-in-init
|
||||
def __init__(self, device: rDevice | None=None) -> torch.device: # pylint: disable=return-in-init
|
||||
self.idx = get_device(device).index
|
||||
|
||||
def __exit__(self, t, v, tb):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
from typing import Type
|
||||
import torch
|
||||
from modules.dml.hijack.utils import catch_nan
|
||||
|
||||
|
||||
def make_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
|
||||
def make_tome_block(block_class: type[torch.nn.Module]) -> type[torch.nn.Module]:
|
||||
class ToMeBlock(block_class):
|
||||
# Save for unpatching later
|
||||
_parent = block_class
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Optional
|
||||
import torch
|
||||
import transformers.models.clip.modeling_clip
|
||||
|
||||
|
|
@ -22,9 +21,9 @@ def _make_causal_mask(
|
|||
|
||||
def CLIPTextEmbeddings_forward(
|
||||
self: transformers.models.clip.modeling_clip.CLIPTextEmbeddings,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from modules.devices import dtype
|
||||
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
from modules.shared import log, opts
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from ctypes import CDLL, POINTER
|
||||
from ctypes.wintypes import LPCWSTR, LPDWORD, DWORD
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
from .structures import PDH_HQUERY, PDH_HCOUNTER, PPDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W
|
||||
from .defines import PDH_FUNCTION, PZZWSTR, DWORD_PTR
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
import torch
|
||||
|
||||
|
||||
rDevice = Union[torch.device, int]
|
||||
def get_device(device: Optional[rDevice]=None) -> torch.device:
|
||||
def get_device(device: rDevice | None=None) -> torch.device:
|
||||
if device is None:
|
||||
device = torch.dml.current_device()
|
||||
return torch.device(device)
|
||||
|
|
|
|||
|
|
@ -10,13 +10,17 @@ install_traceback()
|
|||
already_displayed = {}
|
||||
|
||||
|
||||
def install(suppress=[]):
|
||||
def install(suppress=None):
|
||||
if suppress is None:
|
||||
suppress = []
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
install_traceback(suppress=suppress)
|
||||
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s')
|
||||
|
||||
|
||||
def display(e: Exception, task: str, suppress=[]):
|
||||
def display(e: Exception, task: str, suppress=None):
|
||||
if suppress is None:
|
||||
suppress = []
|
||||
if isinstance(e, ErrorLimiterAbort):
|
||||
return
|
||||
log.critical(f"{task or 'error'}: {type(e).__name__}")
|
||||
|
|
@ -45,7 +49,9 @@ def run(code, task: str):
|
|||
display(e, task)
|
||||
|
||||
|
||||
def exception(suppress=[]):
|
||||
def exception(suppress=None):
|
||||
if suppress is None:
|
||||
suppress = []
|
||||
console = get_console()
|
||||
console.print_exception(show_locals=False, max_frames=16, extra_lines=2, suppress=suppress, theme="ansi_dark", word_wrap=False, width=min([console.width, 200]))
|
||||
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ class Extension:
|
|||
continue
|
||||
priority = '50'
|
||||
if os.path.isfile(os.path.join(dirpath, "..", ".priority")):
|
||||
with open(os.path.join(dirpath, "..", ".priority"), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(dirpath, "..", ".priority"), encoding="utf-8") as f:
|
||||
priority = str(f.read().strip())
|
||||
res.append(scripts_manager.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
|
||||
if priority != '50':
|
||||
|
|
|
|||
|
|
@ -73,8 +73,12 @@ def is_stepwise(en_obj):
|
|||
return any([len(str(x).split("@")) > 1 for x in all_args]) # noqa C419 # pylint: disable=use-a-generator
|
||||
|
||||
|
||||
def activate(p, extra_network_data=None, step=0, include=[], exclude=[]):
|
||||
def activate(p, extra_network_data=None, step=0, include=None, exclude=None):
|
||||
"""call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list"""
|
||||
if exclude is None:
|
||||
exclude = []
|
||||
if include is None:
|
||||
include = []
|
||||
if p.disable_extra_networks:
|
||||
return
|
||||
extra_network_data = extra_network_data or p.network_data
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def run_modelmerger(id_task, **kwargs): # pylint: disable=unused-argument
|
|||
from installer import install
|
||||
install('tensordict', quiet=True)
|
||||
try:
|
||||
from tensordict import TensorDict # pylint: disable=unused-import
|
||||
pass # pylint: disable=unused-import
|
||||
except Exception as e:
|
||||
shared.log.error(f"Merge: {e}")
|
||||
return [*[gr.update() for _ in range(4)], "tensordict not available"]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import List
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
|
|
@ -34,7 +33,7 @@ def hijack_load_ip_adapter(self):
|
|||
def face_id(
|
||||
p: processing.StableDiffusionProcessing,
|
||||
app,
|
||||
source_images: List[Image.Image],
|
||||
source_images: list[Image.Image],
|
||||
model: str,
|
||||
override: bool,
|
||||
cache: bool,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import List
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
|
@ -12,7 +11,7 @@ insightface_app = None
|
|||
swapper = None
|
||||
|
||||
|
||||
def face_swap(p: processing.StableDiffusionProcessing, app, input_images: List[Image.Image], source_image: Image.Image, cache: bool):
|
||||
def face_swap(p: processing.StableDiffusionProcessing, app, input_images: list[Image.Image], source_image: Image.Image, cache: bool):
|
||||
global swapper # pylint: disable=global-statement
|
||||
if swapper is None:
|
||||
import insightface.model_zoo
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@
|
|||
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
|
@ -544,40 +545,40 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
|
|||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
prompt: str | list[str] = None,
|
||||
prompt_2: str | list[str] | None = None,
|
||||
image: PipelineImageInput = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
negative_prompt: str | list[str] | None = None,
|
||||
negative_prompt_2: str | list[str] | None = None,
|
||||
num_images_per_prompt: int | None = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
image_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.FloatTensor | None = None,
|
||||
prompt_embeds: torch.FloatTensor | None = None,
|
||||
negative_prompt_embeds: torch.FloatTensor | None = None,
|
||||
pooled_prompt_embeds: torch.FloatTensor | None = None,
|
||||
negative_pooled_prompt_embeds: torch.FloatTensor | None = None,
|
||||
image_embeds: torch.FloatTensor | None = None,
|
||||
output_type: str | None = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
cross_attention_kwargs: dict[str, Any] | None = None,
|
||||
controlnet_conditioning_scale: float | list[float] = 1.0,
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
original_size: Tuple[int, int] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Tuple[int, int] = None,
|
||||
negative_original_size: Optional[Tuple[int, int]] = None,
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = None,
|
||||
control_guidance_start: float | list[float] = 0.0,
|
||||
control_guidance_end: float | list[float] = 1.0,
|
||||
original_size: tuple[int, int] = None,
|
||||
crops_coords_top_left: tuple[int, int] = (0, 0),
|
||||
target_size: tuple[int, int] = None,
|
||||
negative_original_size: tuple[int, int] | None = None,
|
||||
negative_crops_coords_top_left: tuple[int, int] = (0, 0),
|
||||
negative_target_size: tuple[int, int] | None = None,
|
||||
clip_skip: int | None = None,
|
||||
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
|
|
@ -890,7 +891,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
|
|||
for i in range(len(timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end, strict=False)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
|
|
@ -970,7 +971,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
|
|||
controlnet_added_cond_kwargs = added_cond_kwargs
|
||||
|
||||
if isinstance(controlnet_keep[i], list):
|
||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i], strict=False)]
|
||||
else:
|
||||
controlnet_cond_scale = controlnet_conditioning_scale
|
||||
if isinstance(controlnet_cond_scale, list):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
### original <https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/pipeline.py>
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Union
|
||||
from collections.abc import Callable
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor
|
||||
|
|
@ -26,8 +27,8 @@ from modules.face.photomaker_model_v2 import PhotoMakerIDEncoder_CLIPInsightface
|
|||
PipelineImageInput = Union[
|
||||
PIL.Image.Image,
|
||||
torch.FloatTensor,
|
||||
List[PIL.Image.Image],
|
||||
List[torch.FloatTensor],
|
||||
list[PIL.Image.Image],
|
||||
list[torch.FloatTensor],
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -49,10 +50,10 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
num_inference_steps: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
sigmas: list[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
|
@ -110,7 +111,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
@validate_hf_hub_args
|
||||
def load_photomaker_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
weight_name: str,
|
||||
subfolder: str = '',
|
||||
trigger_word: str = 'img',
|
||||
|
|
@ -214,21 +215,21 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
def encode_prompt_with_trigger_word(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_2: Optional[str] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_2: str | None = None,
|
||||
device: torch.device | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt_2: Optional[str] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
negative_prompt: str | None = None,
|
||||
negative_prompt_2: str | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
pooled_prompt_embeds: torch.Tensor | None = None,
|
||||
negative_pooled_prompt_embeds: torch.Tensor | None = None,
|
||||
lora_scale: float | None = None,
|
||||
clip_skip: int | None = None,
|
||||
### Added args
|
||||
num_id_images: int = 1,
|
||||
class_tokens_mask: Optional[torch.LongTensor] = None,
|
||||
class_tokens_mask: torch.LongTensor | None = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
|
|
@ -273,7 +274,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
# textual inversion: process multi-vector tokens if necessary
|
||||
prompt_embeds_list = []
|
||||
prompts = [prompt, prompt_2]
|
||||
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): # pylint: disable=redefined-argument-from-local
|
||||
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False): # pylint: disable=redefined-argument-from-local
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
||||
|
||||
|
|
@ -362,7 +363,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
||||
)
|
||||
|
||||
uncond_tokens: List[str]
|
||||
uncond_tokens: list[str]
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
|
|
@ -377,7 +378,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
uncond_tokens = [negative_prompt, negative_prompt_2]
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): # pylint: disable=redefined-argument-from-local
|
||||
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders, strict=False): # pylint: disable=redefined-argument-from-local
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
|
||||
|
||||
|
|
@ -444,49 +445,47 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
prompt: str | list[str] = None,
|
||||
prompt_2: str | list[str] | None = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
sigmas: List[float] = None,
|
||||
denoising_end: Optional[float] = None,
|
||||
timesteps: list[int] = None,
|
||||
sigmas: list[float] = None,
|
||||
denoising_end: float | None = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
negative_prompt: str | list[str] | None = None,
|
||||
negative_prompt_2: str | list[str] | None = None,
|
||||
num_images_per_prompt: int | None = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
pooled_prompt_embeds: torch.Tensor | None = None,
|
||||
negative_pooled_prompt_embeds: torch.Tensor | None = None,
|
||||
ip_adapter_image: PipelineImageInput | None = None,
|
||||
ip_adapter_image_embeds: list[torch.Tensor] | None = None,
|
||||
output_type: str | None = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: dict[str, Any] | None = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
negative_original_size: Optional[Tuple[int, int]] = None,
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
original_size: tuple[int, int] | None = None,
|
||||
crops_coords_top_left: tuple[int, int] = (0, 0),
|
||||
target_size: tuple[int, int] | None = None,
|
||||
negative_original_size: tuple[int, int] | None = None,
|
||||
negative_crops_coords_top_left: tuple[int, int] = (0, 0),
|
||||
negative_target_size: tuple[int, int] | None = None,
|
||||
clip_skip: int | None = None,
|
||||
callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = None,
|
||||
# Added parameters (for PhotoMaker)
|
||||
input_id_images: PipelineImageInput = None,
|
||||
start_merge_step: int = 10,
|
||||
class_tokens_mask: Optional[torch.LongTensor] = None,
|
||||
id_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
|
||||
class_tokens_mask: torch.LongTensor | None = None,
|
||||
id_embeds: torch.FloatTensor | None = None,
|
||||
prompt_embeds_text_only: torch.FloatTensor | None = None,
|
||||
pooled_prompt_embeds_text_only: torch.FloatTensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
|
|
@ -512,6 +511,8 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if callback_on_step_end_tensor_inputs is None:
|
||||
callback_on_step_end_tensor_inputs = ["latents"]
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import List
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
|
|
@ -43,8 +42,8 @@ def get_model(model_name: str):
|
|||
def reswapper(
|
||||
p: processing.StableDiffusionProcessing,
|
||||
app,
|
||||
source_images: List[Image.Image],
|
||||
target_images: List[Image.Image],
|
||||
source_images: list[Image.Image],
|
||||
target_images: list[Image.Image],
|
||||
model_name: str,
|
||||
original: bool,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|||
|
||||
class ReSwapperModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(ReSwapperModel, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
# self.pad = nn.ReflectionPad2d(3)
|
||||
# Encoder for target face
|
||||
|
|
@ -87,7 +87,7 @@ class ReSwapperModel(nn.Module):
|
|||
|
||||
class StyleBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, blockIndex):
|
||||
super(StyleBlock, self).__init__()
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0)
|
||||
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0)
|
||||
self.style1 = nn.Linear(512, 2048)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import itertools
|
|||
import os
|
||||
from collections import UserDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Union
|
||||
from typing import Union
|
||||
from collections.abc import Callable, Iterator
|
||||
from installer import log
|
||||
|
||||
|
||||
|
|
@ -10,19 +11,19 @@ do_cache_folders = os.environ.get('SD_NO_CACHE', None) is None
|
|||
class Directory: # forward declaration
|
||||
...
|
||||
|
||||
FilePathList = List[str]
|
||||
FilePathList = list[str]
|
||||
FilePathIterator = Iterator[str]
|
||||
DirectoryPathList = List[str]
|
||||
DirectoryPathList = list[str]
|
||||
DirectoryPathIterator = Iterator[str]
|
||||
DirectoryList = List[Directory]
|
||||
DirectoryList = list[Directory]
|
||||
DirectoryIterator = Iterator[Directory]
|
||||
DirectoryCollection = Dict[str, Directory]
|
||||
DirectoryCollection = dict[str, Directory]
|
||||
ExtensionFilter = Callable
|
||||
ExtensionList = list[str]
|
||||
RecursiveType = Union[bool,Callable]
|
||||
|
||||
|
||||
def real_path(directory_path:str) -> Union[str, None]:
|
||||
def real_path(directory_path:str) -> str | None:
|
||||
try:
|
||||
return os.path.abspath(os.path.expanduser(directory_path))
|
||||
except Exception:
|
||||
|
|
@ -52,7 +53,7 @@ class Directory(Directory): # pylint: disable=E0102
|
|||
def clear(self) -> None:
|
||||
self._update(Directory.from_dict({
|
||||
'path': None,
|
||||
'mtime': float(),
|
||||
'mtime': 0.0,
|
||||
'files': [],
|
||||
'directories': []
|
||||
}))
|
||||
|
|
@ -125,7 +126,7 @@ def clean_directory(directory: Directory, /, recursive: RecursiveType=False) ->
|
|||
return is_clean
|
||||
|
||||
|
||||
def get_directory(directory_or_path: str, /, fetch: bool=True) -> Union[Directory, None]:
|
||||
def get_directory(directory_or_path: str, /, fetch: bool=True) -> Directory | None:
|
||||
if isinstance(directory_or_path, Directory):
|
||||
if directory_or_path.is_directory:
|
||||
return directory_or_path
|
||||
|
|
@ -143,7 +144,7 @@ def get_directory(directory_or_path: str, /, fetch: bool=True) -> Union[Director
|
|||
return cache_folders[directory_or_path] if directory_or_path in cache_folders else None
|
||||
|
||||
|
||||
def fetch_directory(directory_path: str) -> Union[Directory, None]:
|
||||
def fetch_directory(directory_path: str) -> Directory | None:
|
||||
directory: Directory
|
||||
for directory in _walk(directory_path, recurse=False):
|
||||
return directory # The return is intentional, we get a generator, we only need the one
|
||||
|
|
@ -255,7 +256,7 @@ def get_directories(*directory_paths: DirectoryPathList, fetch:bool=True, recurs
|
|||
return filter(bool, directories)
|
||||
|
||||
|
||||
def directory_files(*directories_or_paths: Union[DirectoryPathList, DirectoryList], recursive: RecursiveType=True) -> FilePathIterator:
|
||||
def directory_files(*directories_or_paths: DirectoryPathList | DirectoryList, recursive: RecursiveType=True) -> FilePathIterator:
|
||||
return itertools.chain.from_iterable(
|
||||
itertools.chain(
|
||||
directory_object.files,
|
||||
|
|
@ -275,7 +276,7 @@ def directory_files(*directories_or_paths: Union[DirectoryPathList, DirectoryLis
|
|||
)
|
||||
|
||||
|
||||
def extension_filter(ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None) -> ExtensionFilter:
|
||||
def extension_filter(ext_filter: ExtensionList | None=None, ext_blacklist: ExtensionList | None=None) -> ExtensionFilter:
|
||||
if ext_filter:
|
||||
ext_filter = [*map(str.upper, ext_filter)]
|
||||
if ext_blacklist:
|
||||
|
|
@ -289,11 +290,11 @@ def not_hidden(filepath: str) -> bool:
|
|||
return not os.path.basename(filepath).startswith('.')
|
||||
|
||||
|
||||
def filter_files(file_paths: FilePathList, ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None) -> FilePathIterator:
|
||||
def filter_files(file_paths: FilePathList, ext_filter: ExtensionList | None=None, ext_blacklist: ExtensionList | None=None) -> FilePathIterator:
|
||||
return filter(extension_filter(ext_filter, ext_blacklist), file_paths)
|
||||
|
||||
|
||||
def list_files(*directory_paths:DirectoryPathList, ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None, recursive:RecursiveType=True) -> FilePathIterator:
|
||||
def list_files(*directory_paths:DirectoryPathList, ext_filter: ExtensionList | None=None, ext_blacklist: ExtensionList | None=None, recursive:RecursiveType=True) -> FilePathIterator:
|
||||
return filter_files(itertools.chain.from_iterable(
|
||||
directory_files(directory, recursive=recursive)
|
||||
for directory in get_directories(*directory_paths, recursive=recursive)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
|
||||
from fastapi.exceptions import HTTPException
|
||||
from modules import shared
|
||||
|
|
@ -8,39 +7,39 @@ class ReqFramepack(BaseModel):
|
|||
variant: str = Field(default=None, title="Model variant", description="Model variant to use")
|
||||
prompt: str = Field(default=None, title="Prompt", description="Prompt for the model")
|
||||
init_image: str = Field(default=None, title="Initial image", description="Base64 encoded initial image")
|
||||
end_image: Optional[str] = Field(default=None, title="End image", description="Base64 encoded end image")
|
||||
start_weight: Optional[float] = Field(default=1.0, title="Start weight", description="Weight of the initial image")
|
||||
end_weight: Optional[float] = Field(default=1.0, title="End weight", description="Weight of the end image")
|
||||
vision_weight: Optional[float] = Field(default=1.0, title="Vision weight", description="Weight of the vision model")
|
||||
system_prompt: Optional[str] = Field(default=None, title="System prompt", description="System prompt for the model")
|
||||
optimized_prompt: Optional[bool] = Field(default=True, title="Optimized system prompt", description="Use optimized system prompt for the model")
|
||||
section_prompt: Optional[str] = Field(default=None, title="Section prompt", description="Prompt for each section")
|
||||
negative_prompt: Optional[str] = Field(default=None, title="Negative prompt", description="Negative prompt for the model")
|
||||
styles: Optional[List[str]] = Field(default=None, title="Styles", description="Styles for the model")
|
||||
seed: Optional[int] = Field(default=None, title="Seed", description="Seed for the model")
|
||||
resolution: Optional[int] = Field(default=640, title="Resolution", description="Resolution of the image")
|
||||
duration: Optional[float] = Field(default=4, title="Duration", description="Duration of the video in seconds")
|
||||
latent_ws: Optional[int] = Field(default=9, title="Latent window size", description="Size of the latent window")
|
||||
steps: Optional[int] = Field(default=25, title="Video steps", description="Number of steps for the video generation")
|
||||
cfg_scale: Optional[float] = Field(default=1.0, title="CFG scale", description="CFG scale for the model")
|
||||
cfg_distilled: Optional[float] = Field(default=10.0, title="Distilled CFG scale", description="Distilled CFG scale for the model")
|
||||
cfg_rescale: Optional[float] = Field(default=0.0, title="CFG re-scale", description="CFG re-scale for the model")
|
||||
shift: Optional[float] = Field(default=0, title="Sampler shift", description="Shift for the sampler")
|
||||
use_teacache: Optional[bool] = Field(default=True, title="Enable TeaCache", description="Use TeaCache for the model")
|
||||
use_cfgzero: Optional[bool] = Field(default=False, title="Enable CFGZero", description="Use CFGZero for the model")
|
||||
mp4_fps: Optional[int] = Field(default=30, title="FPS", description="Frames per second for the video")
|
||||
mp4_codec: Optional[str] = Field(default="libx264", title="Codec", description="Codec for the video")
|
||||
mp4_sf: Optional[bool] = Field(default=False, title="Save SafeTensors", description="Save SafeTensors for the video")
|
||||
mp4_video: Optional[bool] = Field(default=True, title="Save Video", description="Save video")
|
||||
mp4_frames: Optional[bool] = Field(default=False, title="Save Frames", description="Save frames for the video")
|
||||
mp4_opt: Optional[str] = Field(default="crf:16", title="Options", description="Options for the video codec")
|
||||
mp4_ext: Optional[str] = Field(default="mp4", title="Format", description="Format for the video")
|
||||
mp4_interpolate: Optional[int] = Field(default=0, title="Interpolation", description="Interpolation for the video")
|
||||
attention: Optional[str] = Field(default="Default", title="Attention", description="Attention type for the model")
|
||||
vae_type: Optional[str] = Field(default="Local", title="VAE", description="VAE type for the model")
|
||||
vlm_enhance: Optional[bool] = Field(default=False, title="VLM enhance", description="Enable VLM enhance")
|
||||
vlm_model: Optional[str] = Field(default=None, title="VLM model", description="VLM model to use")
|
||||
vlm_system_prompt: Optional[str] = Field(default=None, title="VLM system prompt", description="System prompt for the VLM model")
|
||||
end_image: str | None = Field(default=None, title="End image", description="Base64 encoded end image")
|
||||
start_weight: float | None = Field(default=1.0, title="Start weight", description="Weight of the initial image")
|
||||
end_weight: float | None = Field(default=1.0, title="End weight", description="Weight of the end image")
|
||||
vision_weight: float | None = Field(default=1.0, title="Vision weight", description="Weight of the vision model")
|
||||
system_prompt: str | None = Field(default=None, title="System prompt", description="System prompt for the model")
|
||||
optimized_prompt: bool | None = Field(default=True, title="Optimized system prompt", description="Use optimized system prompt for the model")
|
||||
section_prompt: str | None = Field(default=None, title="Section prompt", description="Prompt for each section")
|
||||
negative_prompt: str | None = Field(default=None, title="Negative prompt", description="Negative prompt for the model")
|
||||
styles: list[str] | None = Field(default=None, title="Styles", description="Styles for the model")
|
||||
seed: int | None = Field(default=None, title="Seed", description="Seed for the model")
|
||||
resolution: int | None = Field(default=640, title="Resolution", description="Resolution of the image")
|
||||
duration: float | None = Field(default=4, title="Duration", description="Duration of the video in seconds")
|
||||
latent_ws: int | None = Field(default=9, title="Latent window size", description="Size of the latent window")
|
||||
steps: int | None = Field(default=25, title="Video steps", description="Number of steps for the video generation")
|
||||
cfg_scale: float | None = Field(default=1.0, title="CFG scale", description="CFG scale for the model")
|
||||
cfg_distilled: float | None = Field(default=10.0, title="Distilled CFG scale", description="Distilled CFG scale for the model")
|
||||
cfg_rescale: float | None = Field(default=0.0, title="CFG re-scale", description="CFG re-scale for the model")
|
||||
shift: float | None = Field(default=0, title="Sampler shift", description="Shift for the sampler")
|
||||
use_teacache: bool | None = Field(default=True, title="Enable TeaCache", description="Use TeaCache for the model")
|
||||
use_cfgzero: bool | None = Field(default=False, title="Enable CFGZero", description="Use CFGZero for the model")
|
||||
mp4_fps: int | None = Field(default=30, title="FPS", description="Frames per second for the video")
|
||||
mp4_codec: str | None = Field(default="libx264", title="Codec", description="Codec for the video")
|
||||
mp4_sf: bool | None = Field(default=False, title="Save SafeTensors", description="Save SafeTensors for the video")
|
||||
mp4_video: bool | None = Field(default=True, title="Save Video", description="Save video")
|
||||
mp4_frames: bool | None = Field(default=False, title="Save Frames", description="Save frames for the video")
|
||||
mp4_opt: str | None = Field(default="crf:16", title="Options", description="Options for the video codec")
|
||||
mp4_ext: str | None = Field(default="mp4", title="Format", description="Format for the video")
|
||||
mp4_interpolate: int | None = Field(default=0, title="Interpolation", description="Interpolation for the video")
|
||||
attention: str | None = Field(default="Default", title="Attention", description="Attention type for the model")
|
||||
vae_type: str | None = Field(default="Local", title="VAE", description="VAE type for the model")
|
||||
vlm_enhance: bool | None = Field(default=False, title="VLM enhance", description="Enable VLM enhance")
|
||||
vlm_model: str | None = Field(default=None, title="VLM model", description="VLM model to use")
|
||||
vlm_system_prompt: str | None = Field(default=None, title="VLM system prompt", description="System prompt for the VLM model")
|
||||
|
||||
|
||||
class ResFramepack(BaseModel):
|
||||
|
|
|
|||
|
|
@ -43,8 +43,10 @@ def worker(
|
|||
mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate,
|
||||
vae_type,
|
||||
variant,
|
||||
metadata:dict={},
|
||||
metadata:dict=None,
|
||||
):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
timer.process.reset()
|
||||
memstats.reset_stats()
|
||||
if stream is None or shared.state.interrupted or shared.state.skipped:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -251,7 +250,7 @@ class CombinedTimestepTextProjEmbeddings(nn.Module):
|
|||
|
||||
|
||||
class HunyuanVideoAdaNorm(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
|
||||
def __init__(self, in_features: int, out_features: int | None = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_features = out_features or 2 * in_features
|
||||
|
|
@ -260,7 +259,7 @@ class HunyuanVideoAdaNorm(nn.Module):
|
|||
|
||||
def forward(
|
||||
self, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
temb = self.linear(self.nonlinearity(temb))
|
||||
gate_msa, gate_mlp = temb.chunk(2, dim=-1)
|
||||
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
|
||||
|
|
@ -298,7 +297,7 @@ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
|
|||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
|
|
@ -346,7 +345,7 @@ class HunyuanVideoIndividualTokenRefiner(nn.Module):
|
|||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
self_attn_mask = None
|
||||
if attention_mask is not None:
|
||||
|
|
@ -396,7 +395,7 @@ class HunyuanVideoTokenRefiner(nn.Module):
|
|||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attention_mask is None:
|
||||
pooled_projections = hidden_states.mean(dim=1)
|
||||
|
|
@ -464,8 +463,8 @@ class AdaLayerNormZero(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = emb.unsqueeze(-2)
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
|
||||
|
|
@ -487,8 +486,8 @@ class AdaLayerNormZeroSingle(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = emb.unsqueeze(-2)
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
|
||||
|
|
@ -558,8 +557,8 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
|
|||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
|
@ -636,9 +635,9 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
|||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
|
||||
|
|
@ -734,7 +733,7 @@ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterM
|
|||
text_embed_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
||||
rope_axes_dim: tuple[int] = (16, 56, 56),
|
||||
has_image_proj=False,
|
||||
image_proj_dim=1152,
|
||||
has_clean_x_embedder=False,
|
||||
|
|
|
|||
|
|
@ -102,14 +102,14 @@ def just_crop(image, w, h):
|
|||
|
||||
def write_to_json(data, file_path):
|
||||
temp_file_path = file_path + ".tmp"
|
||||
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
|
||||
with open(temp_file_path, 'w', encoding='utf-8') as temp_file:
|
||||
json.dump(data, temp_file, indent=4)
|
||||
os.replace(temp_file_path, file_path)
|
||||
return
|
||||
|
||||
|
||||
def read_from_json(file_path):
|
||||
with open(file_path, 'rt', encoding='utf-8') as file:
|
||||
with open(file_path, encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
return data
|
||||
|
||||
|
|
@ -283,7 +283,7 @@ def add_tensors_with_padding(tensor1, tensor2):
|
|||
shape1 = tensor1.shape
|
||||
shape2 = tensor2.shape
|
||||
|
||||
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
|
||||
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2, strict=False))
|
||||
|
||||
padded_tensor1 = torch.zeros(new_shape)
|
||||
padded_tensor2 = torch.zeros(new_shape)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import os
|
|||
from PIL import Image
|
||||
import gradio as gr
|
||||
from modules import shared, gr_tempdir, script_callbacks, images
|
||||
from modules.infotext import parse, mapping, quote, unquote # pylint: disable=unused-import
|
||||
from modules.infotext import parse, mapping # pylint: disable=unused-import
|
||||
|
||||
|
||||
type_of_gr_update = type(gr.update())
|
||||
|
|
@ -259,7 +259,7 @@ def connect_paste(button, local_paste_fields, input_comp, override_settings_comp
|
|||
from modules.paths import params_path
|
||||
if prompt is None or len(prompt.strip()) == 0:
|
||||
if os.path.exists(params_path):
|
||||
with open(params_path, "r", encoding="utf8") as file:
|
||||
with open(params_path, encoding="utf8") as file:
|
||||
prompt = file.read()
|
||||
shared.log.debug(f'Prompt parse: type="params" prompt="{prompt}"')
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# Original: invokeai.backend.quantization.gguf.utils
|
||||
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
|
|
@ -28,7 +29,7 @@ def get_scale_min(scales: torch.Tensor):
|
|||
|
||||
# Legacy Quants #
|
||||
def dequantize_blocks_Q8_0(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
|
||||
) -> torch.Tensor:
|
||||
d, x = split_block_dims(blocks, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
|
|
@ -37,7 +38,7 @@ def dequantize_blocks_Q8_0(
|
|||
|
||||
|
||||
def dequantize_blocks_Q5_1(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
|
|
@ -58,7 +59,7 @@ def dequantize_blocks_Q5_1(
|
|||
|
||||
|
||||
def dequantize_blocks_Q5_0(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
|
|
@ -79,7 +80,7 @@ def dequantize_blocks_Q5_0(
|
|||
|
||||
|
||||
def dequantize_blocks_Q4_1(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
|
|
@ -96,7 +97,7 @@ def dequantize_blocks_Q4_1(
|
|||
|
||||
|
||||
def dequantize_blocks_Q4_0(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
|
|
@ -111,13 +112,13 @@ def dequantize_blocks_Q4_0(
|
|||
|
||||
|
||||
def dequantize_blocks_BF16(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
|
||||
) -> torch.Tensor:
|
||||
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
|
||||
|
||||
|
||||
def dequantize_blocks_Q6_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
|
|
@ -147,7 +148,7 @@ def dequantize_blocks_Q6_K(
|
|||
|
||||
|
||||
def dequantize_blocks_Q5_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
|
|
@ -175,7 +176,7 @@ def dequantize_blocks_Q5_K(
|
|||
|
||||
|
||||
def dequantize_blocks_Q4_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
|
|
@ -197,7 +198,7 @@ def dequantize_blocks_Q4_K(
|
|||
|
||||
|
||||
def dequantize_blocks_Q3_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
|
|
@ -232,7 +233,7 @@ def dequantize_blocks_Q3_K(
|
|||
|
||||
|
||||
def dequantize_blocks_Q2_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
|
|
@ -254,7 +255,7 @@ def dequantize_blocks_Q2_K(
|
|||
|
||||
|
||||
DEQUANTIZE_FUNCTIONS: dict[
|
||||
gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, Optional[torch.dtype]], torch.Tensor]
|
||||
gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, torch.dtype | None], torch.Tensor]
|
||||
] = {
|
||||
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
|
||||
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
|
||||
|
|
@ -270,7 +271,7 @@ DEQUANTIZE_FUNCTIONS: dict[
|
|||
}
|
||||
|
||||
|
||||
def is_torch_compatible(tensor: Optional[torch.Tensor]):
|
||||
def is_torch_compatible(tensor: torch.Tensor | None):
|
||||
return getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES
|
||||
|
||||
|
||||
|
|
@ -279,7 +280,7 @@ def is_quantized(tensor: torch.Tensor):
|
|||
|
||||
|
||||
def dequantize(
|
||||
data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: Optional[torch.dtype] = None
|
||||
data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: torch.dtype | None = None
|
||||
):
|
||||
"""
|
||||
Dequantize tensor back to usable shape/dtype
|
||||
|
|
|
|||
|
|
@ -9,8 +9,10 @@ import torch
|
|||
from modules import shared, devices
|
||||
|
||||
|
||||
class Item():
|
||||
def __init__(self, latent, preview=None, info=None, ops=[]):
|
||||
class Item:
|
||||
def __init__(self, latent, preview=None, info=None, ops=None):
|
||||
if ops is None:
|
||||
ops = []
|
||||
self.ts = datetime.datetime.now().replace(microsecond=0)
|
||||
self.name = self.ts.strftime('%Y-%m-%d %H:%M:%S')
|
||||
self.latent = latent.detach().clone().to(devices.cpu)
|
||||
|
|
@ -20,7 +22,7 @@ class Item():
|
|||
self.size = sys.getsizeof(self.latent.storage())
|
||||
|
||||
|
||||
class History():
|
||||
class History:
|
||||
def __init__(self):
|
||||
self.index = -1
|
||||
self.latents = deque(maxlen=1024)
|
||||
|
|
@ -58,7 +60,9 @@ class History():
|
|||
return i
|
||||
return -1
|
||||
|
||||
def add(self, latent, preview=None, info=None, ops=[]):
|
||||
def add(self, latent, preview=None, info=None, ops=None):
|
||||
if ops is None:
|
||||
ops = []
|
||||
shared.state.latent_history += 1
|
||||
if shared.opts.latent_history == 0:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
|
|||
calc_img = Image.new("RGB", (1, 1), shared.opts.grid_background)
|
||||
calc_d = ImageDraw.Draw(calc_img)
|
||||
title_texts = [title] if title else [[GridAnnotation()]]
|
||||
for texts, allowed_width in zip(x_texts + y_texts + title_texts, [width] * len(x_texts) + [pad_left] * len(y_texts) + [(width+margin)*cols]):
|
||||
for texts, allowed_width in zip(x_texts + y_texts + title_texts, [width] * len(x_texts) + [pad_left] * len(y_texts) + [(width+margin)*cols], strict=False):
|
||||
items = [] + texts
|
||||
texts.clear()
|
||||
for line in items:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Union
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
|
|
@ -8,7 +7,7 @@ from modules import shared, upscaler
|
|||
from modules.image import sharpfin
|
||||
|
||||
|
||||
def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None):
|
||||
def resize_image(resize_mode: int, im: Image.Image | torch.Tensor, width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None):
|
||||
upscaler_name = upscaler_name or shared.opts.upscaler_for_img2img
|
||||
|
||||
def verify_image(image):
|
||||
|
|
@ -34,7 +33,7 @@ def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width:
|
|||
im = vae_decode(latents, shared.sd_model, output_type='pil', vae_type='Tiny')[0]
|
||||
return im
|
||||
|
||||
def resize(im: Union[Image.Image, torch.Tensor], w, h):
|
||||
def resize(im: Image.Image | torch.Tensor, w, h):
|
||||
w, h = int(w), int(h)
|
||||
if upscaler_name is None or upscaler_name == "None" or (hasattr(im, 'mode') and im.mode == 'L'):
|
||||
return sharpfin.resize(im, (w, h), linearize=False) # force for mask
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ _triton_ok = False
|
|||
def check_sharpfin():
|
||||
global _sharpfin_checked, _sharpfin_ok, _triton_ok # pylint: disable=global-statement
|
||||
if not _sharpfin_checked:
|
||||
from modules.sharpfin.functional import scale # pylint: disable=unused-import
|
||||
_sharpfin_ok = True
|
||||
try:
|
||||
from modules.sharpfin import TRITON_AVAILABLE
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from modules.image.util import flatten, draw_text # pylint: disable=unused-import
|
||||
from modules.image.save import save_image # pylint: disable=unused-import
|
||||
from modules.image.convert import to_pil, to_tensor # pylint: disable=unused-import
|
||||
from modules.image.metadata import read_info_from_image, image_data # pylint: disable=unused-import
|
||||
from modules.image.resize import resize_image # pylint: disable=unused-import
|
||||
from modules.image.sharpfin import resize # pylint: disable=unused-import
|
||||
from modules.image.namegen import FilenameGenerator, get_next_sequence_number # pylint: disable=unused-import
|
||||
from modules.image.watermark import set_watermark, get_watermark # pylint: disable=unused-import
|
||||
from modules.image.grid import image_grid, get_grid_size, split_grid, combine_grid, check_grid_size, get_font, draw_grid_annotations, draw_prompt_matrix, GridAnnotation, Grid # pylint: disable=unused-import
|
||||
from modules.video import save_video # pylint: disable=unused-import
|
||||
from modules.image.metadata import image_data, read_info_from_image
|
||||
from modules.image.save import save_image, sanitize_filename_part
|
||||
from modules.image.resize import resize_image
|
||||
from modules.image.grid import image_grid, check_grid_size, get_grid_size, draw_grid_annotations, draw_prompt_matrix
|
||||
|
||||
__all__ = [
|
||||
'image_data', 'read_info_from_image',
|
||||
'save_image', 'sanitize_filename_part',
|
||||
'resize_image',
|
||||
'image_grid', 'check_grid_size', 'get_grid_size', 'draw_grid_annotations', 'draw_prompt_matrix'
|
||||
]
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args)
|
|||
caption_file = os.path.splitext(image_file)[0] + '.txt'
|
||||
prompt_type='default'
|
||||
if os.path.exists(caption_file):
|
||||
with open(caption_file, 'r', encoding='utf8') as f:
|
||||
with open(caption_file, encoding='utf8') as f:
|
||||
p.prompt = f.read()
|
||||
prompt_type='file'
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ if __name__ == '__main__':
|
|||
import sys
|
||||
if len(sys.argv) > 1:
|
||||
if os.path.exists(sys.argv[1]):
|
||||
with open(sys.argv[1], 'r', encoding='utf8') as f:
|
||||
with open(sys.argv[1], encoding='utf8') as f:
|
||||
parse(f.read())
|
||||
else:
|
||||
parse(sys.argv[1])
|
||||
|
|
|
|||
|
|
@ -1,3 +1,2 @@
|
|||
# a1111 compatibility module: unused
|
||||
|
||||
from modules.infotext import parse as parse_generation_parameters # pylint: disable=unused-import
|
||||
|
|
|
|||
|
|
@ -81,7 +81,9 @@ def warn_once(msg):
|
|||
warned = True
|
||||
|
||||
class OpenVINOGraphModule(torch.nn.Module):
|
||||
def __init__(self, gm, partition_id, use_python_fusion_cache, model_hash_str: str = None, file_name="", int_inputs=[]):
|
||||
def __init__(self, gm, partition_id, use_python_fusion_cache, model_hash_str: str = None, file_name="", int_inputs=None):
|
||||
if int_inputs is None:
|
||||
int_inputs = []
|
||||
super().__init__()
|
||||
self.gm = gm
|
||||
self.int_inputs = int_inputs
|
||||
|
|
@ -192,7 +194,7 @@ def execute(
|
|||
elif executor == "strictly_openvino":
|
||||
return openvino_execute(gm, *args, executor_parameters=executor_parameters, file_name=file_name)
|
||||
|
||||
msg = "Received unexpected value for 'executor': {0}. Allowed values are: openvino, strictly_openvino.".format(executor)
|
||||
msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: openvino, strictly_openvino."
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
|
|
@ -373,7 +375,7 @@ def openvino_execute(gm: GraphModule, *args, executor_parameters=None, partition
|
|||
ov_inputs = []
|
||||
for arg in flat_args:
|
||||
if not isinstance(arg, int):
|
||||
ov_inputs.append((arg.detach().cpu().numpy()))
|
||||
ov_inputs.append(arg.detach().cpu().numpy())
|
||||
|
||||
res = req.infer(ov_inputs, share_inputs=True, share_outputs=True)
|
||||
|
||||
|
|
@ -423,7 +425,9 @@ def openvino_execute_partitioned(gm: GraphModule, *args, executor_parameters=Non
|
|||
return shared.compiled_model_state.partitioned_modules[signature][0](*ov_inputs)
|
||||
|
||||
|
||||
def partition_graph(gm: GraphModule, use_python_fusion_cache: bool, model_hash_str: str = None, file_name="", int_inputs=[]):
|
||||
def partition_graph(gm: GraphModule, use_python_fusion_cache: bool, model_hash_str: str = None, file_name="", int_inputs=None):
|
||||
if int_inputs is None:
|
||||
int_inputs = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_module" and "fused_" in node.name:
|
||||
openvino_submodule = getattr(gm, node.name)
|
||||
|
|
@ -509,7 +513,7 @@ def openvino_fx(subgraph, example_inputs, options=None):
|
|||
if os.path.isfile(maybe_fs_cached_name + ".xml") and os.path.isfile(maybe_fs_cached_name + ".bin"):
|
||||
example_inputs_reordered = []
|
||||
if (os.path.isfile(maybe_fs_cached_name + ".txt")):
|
||||
f = open(maybe_fs_cached_name + ".txt", "r")
|
||||
f = open(maybe_fs_cached_name + ".txt")
|
||||
for input_data in example_inputs:
|
||||
shape = f.readline()
|
||||
if (str(input_data.size()) != shape):
|
||||
|
|
@ -532,7 +536,7 @@ def openvino_fx(subgraph, example_inputs, options=None):
|
|||
if (shared.compiled_model_state.cn_model != [] and str(shared.compiled_model_state.cn_model) in maybe_fs_cached_name):
|
||||
args_reordered = []
|
||||
if (os.path.isfile(maybe_fs_cached_name + ".txt")):
|
||||
f = open(maybe_fs_cached_name + ".txt", "r")
|
||||
f = open(maybe_fs_cached_name + ".txt")
|
||||
for input_data in args:
|
||||
shape = f.readline()
|
||||
if (str(input_data.size()) != shape):
|
||||
|
|
|
|||
|
|
@ -288,7 +288,19 @@ def parse_params(p: processing.StableDiffusionProcessing, adapters: list, adapte
|
|||
return adapter_images, adapter_masks, adapter_scales, adapter_crops, adapter_starts, adapter_ends
|
||||
|
||||
|
||||
def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapter_scales=[1.0], adapter_crops=[False], adapter_starts=[0.0], adapter_ends=[1.0], adapter_images=[]):
|
||||
def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=None, adapter_scales=None, adapter_crops=None, adapter_starts=None, adapter_ends=None, adapter_images=None):
|
||||
if adapter_images is None:
|
||||
adapter_images = []
|
||||
if adapter_ends is None:
|
||||
adapter_ends = [1.0]
|
||||
if adapter_starts is None:
|
||||
adapter_starts = [0.0]
|
||||
if adapter_crops is None:
|
||||
adapter_crops = [False]
|
||||
if adapter_scales is None:
|
||||
adapter_scales = [1.0]
|
||||
if adapter_names is None:
|
||||
adapter_names = []
|
||||
global adapters_loaded # pylint: disable=global-statement
|
||||
# overrides
|
||||
if hasattr(p, 'ip_adapter_names'):
|
||||
|
|
@ -361,7 +373,7 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt
|
|||
if adapter_starts[i] > 0:
|
||||
adapter_scales[i] = 0.00
|
||||
pipe.set_ip_adapter_scale(adapter_scales if len(adapter_scales) > 1 else adapter_scales[0])
|
||||
ip_str = [f'{os.path.splitext(adapter)[0]}:{scale}:{start}:{end}:{crop}' for adapter, scale, start, end, crop in zip(adapter_names, adapter_scales, adapter_starts, adapter_ends, adapter_crops)]
|
||||
ip_str = [f'{os.path.splitext(adapter)[0]}:{scale}:{start}:{end}:{crop}' for adapter, scale, start, end, crop in zip(adapter_names, adapter_scales, adapter_starts, adapter_ends, adapter_crops, strict=False)]
|
||||
if hasattr(pipe, 'transformer') and 'Nunchaku' in pipe.transformer.__class__.__name__:
|
||||
if isinstance(repos, str):
|
||||
sd_models.clear_caches(full=True)
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ try:
|
|||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import torch.distributed.distributed_c10d as _c10d # pylint: disable=unused-import,ungrouped-imports
|
||||
pass # pylint: disable=unused-import,ungrouped-imports
|
||||
except Exception:
|
||||
errors.log.warning('Loader: torch is not built with distributed support')
|
||||
|
||||
|
|
@ -73,7 +73,6 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvisi
|
|||
torchvision = None
|
||||
try:
|
||||
import torchvision # pylint: disable=W0611,C0411
|
||||
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them # pylint: disable=W0611,C0411
|
||||
except Exception as e:
|
||||
errors.log.error(f'Loader: torchvision=={torchvision.__version__ if "torchvision" in sys.modules else None} {e}')
|
||||
if '_no_nep' in str(e):
|
||||
|
|
@ -100,7 +99,7 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
|||
timer.startup.record("torch")
|
||||
|
||||
try:
|
||||
import bitsandbytes # pylint: disable=W0611,C0411
|
||||
import bitsandbytes # pylint: disable=unused-import
|
||||
_bnb = True
|
||||
except Exception:
|
||||
_bnb = False
|
||||
|
|
@ -132,7 +131,6 @@ except Exception as e:
|
|||
errors.log.warning(f'Torch onnxruntime: {e}')
|
||||
timer.startup.record("onnx")
|
||||
|
||||
from fastapi import FastAPI # pylint: disable=W0611,C0411
|
||||
timer.startup.record("fastapi")
|
||||
|
||||
import gradio # pylint: disable=W0611,C0411
|
||||
|
|
@ -161,17 +159,16 @@ except Exception as e:
|
|||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import pillow_jxl # pylint: disable=W0611,C0411
|
||||
pass # pylint: disable=W0611,C0411
|
||||
except Exception:
|
||||
pass
|
||||
from PIL import Image # pylint: disable=W0611,C0411
|
||||
timer.startup.record("pillow")
|
||||
|
||||
|
||||
import cv2 # pylint: disable=W0611,C0411
|
||||
timer.startup.record("cv2")
|
||||
|
||||
class _tqdm_cls():
|
||||
class _tqdm_cls:
|
||||
def __call__(self, *args, **kwargs):
|
||||
bar_format = 'Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + '{desc}' + '\x1b[0m'
|
||||
return tqdm_lib.tqdm(*args, bar_format=bar_format, ncols=80, colour='#327fba', **kwargs)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def localization_js(current_localization_name):
|
|||
data = {}
|
||||
if fn is not None:
|
||||
try:
|
||||
with open(fn, "r", encoding="utf8") as file:
|
||||
with open(fn, encoding="utf8") as file:
|
||||
data = json.load(file)
|
||||
except Exception as e:
|
||||
errors.log.error(f"Error loading localization from {fn}:")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import List
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
|
|
@ -19,7 +18,7 @@ def get_stepwise(param, step, steps): # from https://github.com/cheald/sd-webui-
|
|||
return steps[0][0]
|
||||
steps = [[s[0], s[1] if len(s) == 2 else 1] for s in steps] # Add implicit 1s to any steps which don't have a weight
|
||||
steps.sort(key=lambda k: k[1]) # Sort by index
|
||||
steps = [list(v) for v in zip(*steps)]
|
||||
steps = [list(v) for v in zip(*steps, strict=False)]
|
||||
return steps
|
||||
|
||||
def calculate_weight(m, step, max_steps, step_offset=2):
|
||||
|
|
@ -170,10 +169,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||
self.model = None
|
||||
self.errors = {}
|
||||
|
||||
def signature(self, names: List[str], te_multipliers: List, unet_multipliers: List):
|
||||
return [f'{name}:{te}:{unet}' for name, te, unet in zip(names, te_multipliers, unet_multipliers)]
|
||||
def signature(self, names: list[str], te_multipliers: list, unet_multipliers: list):
|
||||
return [f'{name}:{te}:{unet}' for name, te, unet in zip(names, te_multipliers, unet_multipliers, strict=False)]
|
||||
|
||||
def changed(self, requested: List[str], include: List[str] = None, exclude: List[str] = None) -> bool:
|
||||
def changed(self, requested: list[str], include: list[str] = None, exclude: list[str] = None) -> bool:
|
||||
if shared.opts.lora_force_reload:
|
||||
debug_log(f'Network check: type=LoRA requested={requested} status=forced')
|
||||
return True
|
||||
|
|
@ -190,7 +189,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||
sd_model.loaded_loras[key] = requested
|
||||
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=changed')
|
||||
return True
|
||||
for req, load in zip(requested, loaded):
|
||||
for req, load in zip(requested, loaded, strict=False):
|
||||
if req != load:
|
||||
sd_model.loaded_loras[key] = requested
|
||||
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=changed')
|
||||
|
|
@ -198,7 +197,11 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||
debug_log(f'Network check: type=LoRA key="{key}" requested={requested} loaded={loaded} status=same')
|
||||
return False
|
||||
|
||||
def activate(self, p, params_list, step=0, include=[], exclude=[]):
|
||||
def activate(self, p, params_list, step=0, include=None, exclude=None):
|
||||
if exclude is None:
|
||||
exclude = []
|
||||
if include is None:
|
||||
include = []
|
||||
self.errors.clear()
|
||||
if self.active:
|
||||
if self.model != shared.opts.sd_model_checkpoint: # reset if model changed
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Union
|
||||
import re
|
||||
import time
|
||||
import torch
|
||||
|
|
@ -12,7 +11,7 @@ bnb = None
|
|||
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
||||
|
||||
|
||||
def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, wanted_names: tuple):
|
||||
def network_backup_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, network_layer_name: str, wanted_names: tuple):
|
||||
global bnb # pylint: disable=W0603
|
||||
backup_size = 0
|
||||
if len(l.loaded_networks) > 0 and network_layer_name is not None and any([net.modules.get(network_layer_name, None) for net in l.loaded_networks]): # noqa: C419 # pylint: disable=R1729
|
||||
|
|
@ -76,7 +75,7 @@ def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n
|
|||
return backup_size
|
||||
|
||||
|
||||
def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, use_previous: bool = False):
|
||||
def network_calc_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, network_layer_name: str, use_previous: bool = False):
|
||||
if shared.opts.diffusers_offload_mode == "none":
|
||||
try:
|
||||
self.to(devices.device)
|
||||
|
|
@ -147,7 +146,7 @@ def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.
|
|||
return batch_updown, batch_ex_bias
|
||||
|
||||
|
||||
def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], model_weights: Union[None, torch.Tensor] = None, lora_weights: torch.Tensor = None, deactivate: bool = False, device: torch.device = None, bias: bool = False):
|
||||
def network_add_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, model_weights: None | torch.Tensor = None, lora_weights: torch.Tensor = None, deactivate: bool = False, device: torch.device = None, bias: bool = False):
|
||||
if lora_weights is None:
|
||||
return
|
||||
if deactivate:
|
||||
|
|
@ -239,7 +238,7 @@ def network_add_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.G
|
|||
del model_weights, lora_weights, new_weight, weight # required to avoid memory leak
|
||||
|
||||
|
||||
def network_apply_direct(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, deactivate: bool = False, device: torch.device = devices.device):
|
||||
def network_apply_direct(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, updown: torch.Tensor, ex_bias: torch.Tensor, deactivate: bool = False, device: torch.device = devices.device):
|
||||
weights_backup = getattr(self, "network_weights_backup", False)
|
||||
bias_backup = getattr(self, "network_bias_backup", False)
|
||||
if not isinstance(weights_backup, bool): # remove previous backup if we switched settings
|
||||
|
|
@ -266,7 +265,7 @@ def network_apply_direct(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.
|
|||
l.timer.apply += time.time() - t0
|
||||
|
||||
|
||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, device: torch.device, deactivate: bool = False):
|
||||
def network_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.GroupNorm | torch.nn.LayerNorm | diffusers.models.lora.LoRACompatibleLinear | diffusers.models.lora.LoRACompatibleConv, updown: torch.Tensor, ex_bias: torch.Tensor, device: torch.device, deactivate: bool = False):
|
||||
weights_backup = getattr(self, "network_weights_backup", None)
|
||||
bias_backup = getattr(self, "network_bias_backup", None)
|
||||
if weights_backup is None and bias_backup is None:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import List
|
||||
import os
|
||||
from modules.lora import lora_timers
|
||||
from modules.lora import network_lora, network_hada, network_ia3, network_oft, network_lokr, network_full, network_norm, network_glora
|
||||
|
|
@ -16,6 +15,6 @@ module_types = [
|
|||
network_norm.ModuleTypeNorm(),
|
||||
network_glora.ModuleTypeGLora(),
|
||||
]
|
||||
loaded_networks: List = [] # no type due to circular import
|
||||
previously_loaded_networks: List = [] # no type due to circular import
|
||||
loaded_networks: list = [] # no type due to circular import
|
||||
previously_loaded_networks: list = [] # no type due to circular import
|
||||
extra_network_lora = None # initialized in extra_networks.py
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
import re
|
||||
import bisect
|
||||
from typing import Dict
|
||||
import torch
|
||||
from modules import shared
|
||||
|
||||
|
|
@ -23,7 +22,7 @@ re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
|||
re_compiled = {}
|
||||
|
||||
|
||||
def make_unet_conversion_map() -> Dict[str, str]:
|
||||
def make_unet_conversion_map() -> dict[str, str]:
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(4): # num_blocks is 3 in sdxl
|
||||
|
|
@ -213,10 +212,10 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
|||
ait_sd.update({k: down_weight for k in ait_down_keys})
|
||||
|
||||
# up_weight is split to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension
|
||||
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0), strict=False)}) # noqa: C416 # pylint: disable=unnecessary-comprehension
|
||||
else:
|
||||
# down_weight is chunked to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension
|
||||
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0), strict=False)}) # noqa: C416 # pylint: disable=unnecessary-comprehension
|
||||
|
||||
# up_weight is sparse: only non-zero values are copied to each split
|
||||
i = 0
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Union
|
||||
import os
|
||||
import time
|
||||
import diffusers
|
||||
|
|
@ -50,7 +49,7 @@ def load_per_module(sd_model: diffusers.DiffusionPipeline, filename: str, adapte
|
|||
return adapter_name
|
||||
|
||||
|
||||
def load_diffusers(name: str, network_on_disk: network.NetworkOnDisk, lora_scale:float=shared.opts.extra_networks_default_multiplier, lora_module=None) -> Union[network.Network, None]:
|
||||
def load_diffusers(name: str, network_on_disk: network.NetworkOnDisk, lora_scale:float=shared.opts.extra_networks_default_multiplier, lora_module=None) -> network.Network | None:
|
||||
t0 = time.time()
|
||||
name = name.replace(".", "_")
|
||||
sd_model: diffusers.DiffusionPipeline = getattr(shared.sd_model, "pipe", shared.sd_model)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Union
|
||||
import os
|
||||
import time
|
||||
import concurrent
|
||||
|
|
@ -39,7 +38,7 @@ def lora_dump(lora, dct):
|
|||
f.write(line + "\n")
|
||||
|
||||
|
||||
def load_safetensors(name, network_on_disk: network.NetworkOnDisk) -> Union[network.Network, None]:
|
||||
def load_safetensors(name, network_on_disk: network.NetworkOnDisk) -> network.Network | None:
|
||||
if not shared.sd_loaded:
|
||||
return None
|
||||
|
||||
|
|
@ -241,7 +240,7 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
|
|||
lora_diffusers.diffuser_scales.clear()
|
||||
t0 = time.time()
|
||||
|
||||
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
|
||||
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names, strict=False)):
|
||||
net = None
|
||||
if network_on_disk is not None:
|
||||
shorthash = getattr(network_on_disk, 'shorthash', '').lower()
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ def load_nunchaku(names, strengths):
|
|||
global previously_loaded # pylint: disable=global-statement
|
||||
strengths = [s[0] if isinstance(s, list) else s for s in strengths]
|
||||
networks = lora_load.gather_networks(names)
|
||||
networks = [(network, strength) for network, strength in zip(networks, strengths) if network is not None and strength > 0]
|
||||
networks = [(network, strength) for network, strength in zip(networks, strengths, strict=False) if network is not None and strength > 0]
|
||||
loras = [(network.filename, strength) for network, strength in networks]
|
||||
is_changed = loras != previously_loaded
|
||||
if not is_changed:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
class Timer():
|
||||
class Timer:
|
||||
list: float = 0
|
||||
load: float = 0
|
||||
backup: float = 0
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
import enum
|
||||
from typing import Union
|
||||
from collections import namedtuple
|
||||
from modules import sd_models, hashes, shared
|
||||
|
||||
|
|
@ -120,7 +119,7 @@ class NetworkOnDisk:
|
|||
if self.filename is not None:
|
||||
fn = os.path.splitext(self.filename)[0] + '.txt'
|
||||
if os.path.exists(fn):
|
||||
with open(fn, "r", encoding="utf-8") as file:
|
||||
with open(fn, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
return None
|
||||
|
||||
|
|
@ -144,7 +143,7 @@ class Network: # LoraModule
|
|||
|
||||
|
||||
class ModuleType:
|
||||
def create_module(self, net: Network, weights: NetworkWeights) -> Union[Network, None]: # pylint: disable=W0613
|
||||
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: # pylint: disable=W0613
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,11 @@ applied_layers: list[str] = []
|
|||
default_components = ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'text_encoder_4', 'unet', 'transformer', 'transformer_2']
|
||||
|
||||
|
||||
def network_activate(include=[], exclude=[]):
|
||||
def network_activate(include=None, exclude=None):
|
||||
if exclude is None:
|
||||
exclude = []
|
||||
if include is None:
|
||||
include = []
|
||||
t0 = time.time()
|
||||
with limit_errors("network_activate"):
|
||||
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
|
||||
|
|
@ -77,7 +81,11 @@ def network_activate(include=[], exclude=[]):
|
|||
sd_models.set_diffuser_offload(sd_model, op="model")
|
||||
|
||||
|
||||
def network_deactivate(include=[], exclude=[]):
|
||||
def network_deactivate(include=None, exclude=None):
|
||||
if exclude is None:
|
||||
exclude = []
|
||||
if include is None:
|
||||
include = []
|
||||
if not shared.opts.lora_fuse_native or shared.opts.lora_force_diffusers:
|
||||
return
|
||||
if len(l.previously_loaded_networks) == 0:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from types import SimpleNamespace
|
||||
from typing import List
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
|
@ -235,7 +234,7 @@ def run_segment(input_image: gr.Image, input_mask: np.ndarray):
|
|||
combined_mask = np.zeros(input_mask.shape, dtype='uint8')
|
||||
input_mask_size = np.count_nonzero(input_mask)
|
||||
debug(f'Segment SAM: {vars(opts)}')
|
||||
for mask, score in zip(outputs['masks'], outputs['scores']):
|
||||
for mask, score in zip(outputs['masks'], outputs['scores'], strict=False):
|
||||
mask = mask.astype('uint8')
|
||||
mask_size = np.count_nonzero(mask)
|
||||
if mask_size == 0:
|
||||
|
|
@ -561,7 +560,7 @@ def create_segment_ui():
|
|||
return controls
|
||||
|
||||
|
||||
def bind_controls(image_controls: List[gr.Image], preview_image: gr.Image, output_image: gr.Image):
|
||||
def bind_controls(image_controls: list[gr.Image], preview_image: gr.Image, output_image: gr.Image):
|
||||
for image_control in image_controls:
|
||||
btn_mask.click(run_mask, inputs=[image_control], outputs=[preview_image])
|
||||
btn_lama.click(run_lama, inputs=[image_control], outputs=[output_image])
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from collections import defaultdict
|
|||
import torch
|
||||
|
||||
|
||||
class MemUsageMonitor():
|
||||
class MemUsageMonitor:
|
||||
device = None
|
||||
disabled = False
|
||||
opts = None
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ def get_docker_limit():
|
|||
if docker_limit is not None:
|
||||
return docker_limit
|
||||
try:
|
||||
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r', encoding='utf8') as f:
|
||||
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', encoding='utf8') as f:
|
||||
docker_limit = float(f.read())
|
||||
except Exception:
|
||||
docker_limit = sys.float_info.max
|
||||
|
|
@ -145,7 +145,9 @@ class Object:
|
|||
return f'{self.fn}.{self.name} type={self.type} size={self.size} ref={self.refcount}'
|
||||
|
||||
|
||||
def get_objects(gcl={}, threshold:int=0):
|
||||
def get_objects(gcl=None, threshold:int=0):
|
||||
if gcl is None:
|
||||
gcl = {}
|
||||
objects = []
|
||||
seen = []
|
||||
|
||||
|
|
|
|||
|
|
@ -260,7 +260,9 @@ def calculate_model_hash(state_dict):
|
|||
return func.hexdigest()
|
||||
|
||||
|
||||
def convert(model_path:str, checkpoint_path:str, metadata:dict={}):
|
||||
def convert(model_path:str, checkpoint_path:str, metadata:dict=None):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
unet_path = os.path.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||
vae_path = os.path.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||
text_enc_path = os.path.join(model_path, "text_encoder", "model.safetensors")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Optional, Tuple, Set
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import modules.memstats
|
||||
|
|
@ -37,7 +36,7 @@ KEY_POSITION_IDS = ".".join(
|
|||
)
|
||||
|
||||
|
||||
def fix_clip(model: Dict) -> Dict:
|
||||
def fix_clip(model: dict) -> dict:
|
||||
if KEY_POSITION_IDS in model.keys():
|
||||
model[KEY_POSITION_IDS] = torch.tensor(
|
||||
[list(range(MAX_TOKENS))],
|
||||
|
|
@ -48,7 +47,7 @@ def fix_clip(model: Dict) -> Dict:
|
|||
return model
|
||||
|
||||
|
||||
def prune_sd_model(model: Dict, keyset: Set) -> Dict:
|
||||
def prune_sd_model(model: dict, keyset: set) -> dict:
|
||||
keys = list(model.keys())
|
||||
for k in keys:
|
||||
if (
|
||||
|
|
@ -60,7 +59,7 @@ def prune_sd_model(model: Dict, keyset: Set) -> Dict:
|
|||
return model
|
||||
|
||||
|
||||
def restore_sd_model(original_model: Dict, merged_model: Dict) -> Dict:
|
||||
def restore_sd_model(original_model: dict, merged_model: dict) -> dict:
|
||||
for k in original_model:
|
||||
if k not in merged_model:
|
||||
merged_model[k] = original_model[k]
|
||||
|
|
@ -72,11 +71,11 @@ def log_vram(txt=""):
|
|||
|
||||
|
||||
def load_thetas(
|
||||
models: Dict[str, os.PathLike],
|
||||
models: dict[str, os.PathLike],
|
||||
prune: bool,
|
||||
device: torch.device,
|
||||
precision: str,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
from tensordict import TensorDict
|
||||
thetas = {k: TensorDict.from_dict(read_state_dict(m, "cpu")) for k, m in models.items()}
|
||||
if prune:
|
||||
|
|
@ -95,7 +94,7 @@ def load_thetas(
|
|||
|
||||
|
||||
def merge_models(
|
||||
models: Dict[str, os.PathLike],
|
||||
models: dict[str, os.PathLike],
|
||||
merge_mode: str,
|
||||
precision: str = "fp16",
|
||||
weights_clip: bool = False,
|
||||
|
|
@ -104,7 +103,7 @@ def merge_models(
|
|||
prune: bool = False,
|
||||
threads: int = 4,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
thetas = load_thetas(models, prune, device, precision)
|
||||
# log.info(f'Merge start: models={models.values()} precision={precision} clip={weights_clip} rebasin={re_basin} prune={prune} threads={threads}')
|
||||
weight_matcher = WeightClass(thetas["model_a"], **kwargs)
|
||||
|
|
@ -136,13 +135,13 @@ def merge_models(
|
|||
|
||||
|
||||
def un_prune_model(
|
||||
merged: Dict,
|
||||
thetas: Dict,
|
||||
models: Dict,
|
||||
merged: dict,
|
||||
thetas: dict,
|
||||
models: dict,
|
||||
device: torch.device,
|
||||
prune: bool,
|
||||
precision: str,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
if prune:
|
||||
log.info("Merge restoring pruned keys")
|
||||
del thetas
|
||||
|
|
@ -180,7 +179,7 @@ def un_prune_model(
|
|||
|
||||
|
||||
def simple_merge(
|
||||
thetas: Dict[str, Dict],
|
||||
thetas: dict[str, dict],
|
||||
weight_matcher: WeightClass,
|
||||
merge_mode: str,
|
||||
precision: str = "fp16",
|
||||
|
|
@ -188,7 +187,7 @@ def simple_merge(
|
|||
device: torch.device = None,
|
||||
work_device: torch.device = None,
|
||||
threads: int = 4,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
futures = []
|
||||
import rich.progress as p
|
||||
with p.Progress(p.TextColumn('[cyan]{task.description}'), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TextColumn('[cyan]keys={task.fields[keys]}'), console=console) as progress:
|
||||
|
|
@ -227,7 +226,7 @@ def simple_merge(
|
|||
|
||||
|
||||
def rebasin_merge(
|
||||
thetas: Dict[str, os.PathLike],
|
||||
thetas: dict[str, os.PathLike],
|
||||
weight_matcher: WeightClass,
|
||||
merge_mode: str,
|
||||
precision: str = "fp16",
|
||||
|
|
@ -306,14 +305,14 @@ def simple_merge_key(progress, task, key, thetas, *args, **kwargs):
|
|||
|
||||
def merge_key( # pylint: disable=inconsistent-return-statements
|
||||
key: str,
|
||||
thetas: Dict,
|
||||
thetas: dict,
|
||||
weight_matcher: WeightClass,
|
||||
merge_mode: str,
|
||||
precision: str = "fp16",
|
||||
weights_clip: bool = False,
|
||||
device: torch.device = None,
|
||||
work_device: torch.device = None,
|
||||
) -> Optional[Tuple[str, Dict]]:
|
||||
) -> tuple[str, dict] | None:
|
||||
if work_device is None:
|
||||
work_device = device
|
||||
|
||||
|
|
@ -376,11 +375,11 @@ def merge_key_context(*args, **kwargs):
|
|||
|
||||
|
||||
def get_merge_method_args(
|
||||
current_bases: Dict,
|
||||
thetas: Dict,
|
||||
current_bases: dict,
|
||||
thetas: dict,
|
||||
key: str,
|
||||
work_device: torch.device,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
merge_method_args = {
|
||||
"a": thetas["model_a"][key].to(work_device),
|
||||
"b": thetas["model_b"][key].to(work_device),
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -151,7 +150,7 @@ def kth_abs_value(a: Tensor, k: int) -> Tensor:
|
|||
return torch.kthvalue(torch.abs(a.float()), k)[0]
|
||||
|
||||
|
||||
def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]:
|
||||
def ratio_to_region(width: float, offset: float, n: int) -> tuple[int, int, bool]:
|
||||
if width < 0:
|
||||
offset += width
|
||||
width = -width
|
||||
|
|
@ -233,7 +232,7 @@ def ties_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: flo
|
|||
delta_filters = (signs == final_sign).float()
|
||||
|
||||
res = torch.zeros_like(c, device=c.device)
|
||||
for delta_filter, delta in zip(delta_filters, deltas):
|
||||
for delta_filter, delta in zip(delta_filters, deltas, strict=False):
|
||||
res += delta_filter * delta
|
||||
|
||||
param_count = torch.sum(delta_filters, dim=0)
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ def test_model(pipe: diffusers.StableDiffusionXLPipeline, fn: str, **kwargs):
|
|||
if not test.generate:
|
||||
return
|
||||
try:
|
||||
generator = torch.Generator(devices.device).manual_seed(int(4242))
|
||||
generator = torch.Generator(devices.device).manual_seed(4242)
|
||||
args = {
|
||||
'prompt': test.prompt,
|
||||
'negative_prompt': test.negative,
|
||||
|
|
@ -278,7 +278,7 @@ def save_model(pipe: diffusers.StableDiffusionXLPipeline):
|
|||
yield msg(f'pretrained={folder}')
|
||||
shared.log.info(f'Modules merge save: type=sdxl diffusers="{folder}"')
|
||||
pipe.save_pretrained(folder, safe_serialization=True, push_to_hub=False)
|
||||
with open(os.path.join(folder, 'vae', 'config.json'), 'r', encoding='utf8') as f:
|
||||
with open(os.path.join(folder, 'vae', 'config.json'), encoding='utf8') as f:
|
||||
vae_config = json.load(f)
|
||||
vae_config['force_upcast'] = False
|
||||
vae_config['scaling_factor'] = 0.13025
|
||||
|
|
|
|||
|
|
@ -53,8 +53,6 @@ def install_nunchaku():
|
|||
import os
|
||||
import sys
|
||||
import platform
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import torch
|
||||
python_ver = f'{sys.version_info.major}{sys.version_info.minor}'
|
||||
if python_ver not in ['311', '312', '313']:
|
||||
|
|
|
|||
|
|
@ -220,5 +220,13 @@ class Shared(sys.modules[__name__].__class__):
|
|||
model_type = 'unknown'
|
||||
return model_type
|
||||
|
||||
@property
|
||||
def console(self):
|
||||
try:
|
||||
from installer import get_console
|
||||
return get_console()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
model_data = ModelData()
|
||||
|
|
|
|||
|
|
@ -4,13 +4,12 @@ import time
|
|||
import shutil
|
||||
import importlib
|
||||
import contextlib
|
||||
from typing import Dict
|
||||
from urllib.parse import urlparse
|
||||
import huggingface_hub as hf
|
||||
from installer import install, log
|
||||
from modules import shared, errors, files_cache
|
||||
from modules.upscaler import Upscaler
|
||||
from modules.paths import script_path, models_path
|
||||
from modules import paths
|
||||
|
||||
|
||||
loggedin = None
|
||||
|
|
@ -55,7 +54,7 @@ def hf_login(token=None):
|
|||
return True
|
||||
|
||||
|
||||
def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config: Dict[str, str] = None, token = None, variant = None, revision = None, mirror = None, custom_pipeline = None):
|
||||
def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config: dict[str, str] = None, token = None, variant = None, revision = None, mirror = None, custom_pipeline = None):
|
||||
if hub_id is None or len(hub_id) == 0:
|
||||
return None
|
||||
from diffusers import DiffusionPipeline
|
||||
|
|
@ -117,7 +116,7 @@ def load_diffusers_models(clear=True):
|
|||
# t0 = time.time()
|
||||
place = shared.opts.diffusers_dir
|
||||
if place is None or len(place) == 0 or not os.path.isdir(place):
|
||||
place = os.path.join(models_path, 'Diffusers')
|
||||
place = os.path.join(paths.models_path, 'Diffusers')
|
||||
if clear:
|
||||
diffuser_repos.clear()
|
||||
already_found = []
|
||||
|
|
@ -382,25 +381,25 @@ def cleanup_models():
|
|||
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
|
||||
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
|
||||
# somehow auto-register and just do these things...
|
||||
root_path = script_path
|
||||
src_path = models_path
|
||||
dest_path = os.path.join(models_path, "Stable-diffusion")
|
||||
root_path = paths.script_path
|
||||
src_path = paths.models_path
|
||||
dest_path = os.path.join(paths.models_path, "Stable-diffusion")
|
||||
# move_files(src_path, dest_path, ".ckpt")
|
||||
# move_files(src_path, dest_path, ".safetensors")
|
||||
src_path = os.path.join(root_path, "ESRGAN")
|
||||
dest_path = os.path.join(models_path, "ESRGAN")
|
||||
dest_path = os.path.join(paths.models_path, "ESRGAN")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(models_path, "BSRGAN")
|
||||
dest_path = os.path.join(models_path, "ESRGAN")
|
||||
src_path = os.path.join(paths.models_path, "BSRGAN")
|
||||
dest_path = os.path.join(paths.models_path, "ESRGAN")
|
||||
move_files(src_path, dest_path, ".pth")
|
||||
src_path = os.path.join(root_path, "SwinIR")
|
||||
dest_path = os.path.join(models_path, "SwinIR")
|
||||
dest_path = os.path.join(paths.models_path, "SwinIR")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
|
||||
dest_path = os.path.join(models_path, "LDSR")
|
||||
dest_path = os.path.join(paths.models_path, "LDSR")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(root_path, "SCUNet")
|
||||
dest_path = os.path.join(models_path, "SCUNet")
|
||||
dest_path = os.path.join(paths.models_path, "SCUNet")
|
||||
move_files(src_path, dest_path)
|
||||
|
||||
|
||||
|
|
@ -430,7 +429,7 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
|||
def load_upscalers():
|
||||
# We can only do this 'magic' method to dynamically load upscalers if they are referenced, so we'll try to import any _model.py files before looking in __subclasses__
|
||||
t0 = time.time()
|
||||
modules_dir = os.path.join(shared.script_path, "modules", "postprocess")
|
||||
modules_dir = os.path.join(paths.script_path, "modules", "postprocess")
|
||||
for file in os.listdir(modules_dir):
|
||||
if "_model.py" in file:
|
||||
model_name = file.replace("_model.py", "")
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def stat(fn: str):
|
|||
return size, mtime
|
||||
|
||||
|
||||
class Module():
|
||||
class Module:
|
||||
name: str = ''
|
||||
cls: str = None
|
||||
device: str = None
|
||||
|
|
@ -61,7 +61,7 @@ class Module():
|
|||
return s
|
||||
|
||||
|
||||
class Model():
|
||||
class Model:
|
||||
name: str = ''
|
||||
fn: str = ''
|
||||
type: str = ''
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
import os
|
||||
from typing import Type, Callable, TypeVar, Dict, Any
|
||||
from typing import TypeVar, Any
|
||||
from collections.abc import Callable
|
||||
import torch
|
||||
import diffusers
|
||||
from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
|
||||
class ENVStore:
|
||||
__DESERIALIZER: Dict[Type, Callable[[str,], Any]] = {
|
||||
__DESERIALIZER: dict[type, Callable[[str,], Any]] = {
|
||||
bool: lambda x: bool(int(x)),
|
||||
int: int,
|
||||
str: lambda x: x,
|
||||
}
|
||||
__SERIALIZER: Dict[Type, Callable[[Any,], str]] = {
|
||||
__SERIALIZER: dict[type, Callable[[Any,], str]] = {
|
||||
bool: lambda x: str(int(x)),
|
||||
int: str,
|
||||
str: lambda x: x,
|
||||
|
|
@ -89,7 +90,7 @@ def get_loader_arguments(no_variant: bool = False):
|
|||
|
||||
|
||||
T = TypeVar("T")
|
||||
def from_pretrained(cls: Type[T], pretrained_model_name_or_path: os.PathLike, *args, no_variant: bool = False, **kwargs) -> T:
|
||||
def from_pretrained(cls: type[T], pretrained_model_name_or_path: os.PathLike, *args, no_variant: bool = False, **kwargs) -> T:
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if pretrained_model_name_or_path.endswith(".onnx"):
|
||||
cls = diffusers.OnnxRuntimeModel
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ except Exception as e:
|
|||
|
||||
|
||||
class DynamicSessionOptions(ort.SessionOptions):
|
||||
config: Optional[Dict] = None
|
||||
config: dict | None = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -28,7 +28,7 @@ class DynamicSessionOptions(ort.SessionOptions):
|
|||
return sess_options.copy()
|
||||
return DynamicSessionOptions()
|
||||
|
||||
def enable_static_dims(self, config: Dict):
|
||||
def enable_static_dims(self, config: dict):
|
||||
self.config = config
|
||||
self.add_free_dimension_override_by_name("unet_sample_batch", config["hidden_batch_size"])
|
||||
self.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||
|
|
@ -103,9 +103,9 @@ class OnnxRuntimeModel(TorchCompatibleModule, diffusers.OnnxRuntimeModel):
|
|||
|
||||
class VAEConfig:
|
||||
DEFAULTS = { "scaling_factor": 0.18215 }
|
||||
config: Dict
|
||||
config: dict
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
|
||||
def __getattr__(self, key):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import sys
|
||||
from enum import Enum
|
||||
from typing import Tuple, List
|
||||
from installer import log
|
||||
from modules import devices
|
||||
|
||||
|
|
@ -33,7 +32,7 @@ TORCH_DEVICE_TO_EP = {
|
|||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
available_execution_providers: List[ExecutionProvider] = ort.get_available_providers()
|
||||
available_execution_providers: list[ExecutionProvider] = ort.get_available_providers()
|
||||
except Exception as e:
|
||||
log.error(f'ONNX import error: {e}')
|
||||
available_execution_providers = []
|
||||
|
|
@ -90,7 +89,7 @@ def get_execution_provider_options():
|
|||
return execution_provider_options
|
||||
|
||||
|
||||
def get_provider() -> Tuple:
|
||||
def get_provider() -> tuple:
|
||||
from modules.shared import opts
|
||||
return (opts.onnx_execution_provider, get_execution_provider_options(),)
|
||||
|
||||
|
|
|
|||
|
|
@ -103,12 +103,12 @@ class OnnxRawPipeline(PipelineBase):
|
|||
path: os.PathLike
|
||||
original_filename: str
|
||||
|
||||
constructor: Type[PipelineBase]
|
||||
init_dict: Dict[str, Tuple[str]] = {}
|
||||
constructor: type[PipelineBase]
|
||||
init_dict: dict[str, tuple[str]] = {}
|
||||
|
||||
default_scheduler: Any = None # for Img2Img
|
||||
|
||||
def __init__(self, constructor: Type[PipelineBase], path: os.PathLike): # pylint: disable=super-init-not-called
|
||||
def __init__(self, constructor: type[PipelineBase], path: os.PathLike): # pylint: disable=super-init-not-called
|
||||
self._is_sdxl = check_pipeline_sdxl(constructor)
|
||||
self.from_diffusers_cache = check_diffusers_cache(path)
|
||||
self.path = path
|
||||
|
|
@ -150,7 +150,7 @@ class OnnxRawPipeline(PipelineBase):
|
|||
pipeline.scheduler = self.default_scheduler
|
||||
return pipeline
|
||||
|
||||
def convert(self, submodels: List[str], in_dir: os.PathLike, out_dir: os.PathLike):
|
||||
def convert(self, submodels: list[str], in_dir: os.PathLike, out_dir: os.PathLike):
|
||||
install('onnx') # may not be installed yet, this performs check and installs as needed
|
||||
import onnx
|
||||
shutil.rmtree("cache", ignore_errors=True)
|
||||
|
|
@ -218,7 +218,7 @@ class OnnxRawPipeline(PipelineBase):
|
|||
with open(os.path.join(out_dir, "model_index.json"), 'w', encoding="utf-8") as file:
|
||||
json.dump(model_index, file)
|
||||
|
||||
def run_olive(self, submodels: List[str], in_dir: os.PathLike, out_dir: os.PathLike):
|
||||
def run_olive(self, submodels: list[str], in_dir: os.PathLike, out_dir: os.PathLike):
|
||||
from olive.model import ONNXModelHandler
|
||||
from olive.workflows import run as run_workflows
|
||||
|
||||
|
|
@ -235,8 +235,8 @@ class OnnxRawPipeline(PipelineBase):
|
|||
for submodel in submodels:
|
||||
log.info(f"\nProcessing {submodel}")
|
||||
|
||||
with open(os.path.join(sd_configs_path, "olive", 'sdxl' if self._is_sdxl else 'sd', f"{submodel}.json"), "r", encoding="utf-8") as config_file:
|
||||
olive_config: Dict[str, Dict[str, Dict]] = json.load(config_file)
|
||||
with open(os.path.join(sd_configs_path, "olive", 'sdxl' if self._is_sdxl else 'sd', f"{submodel}.json"), encoding="utf-8") as config_file:
|
||||
olive_config: dict[str, dict[str, dict]] = json.load(config_file)
|
||||
|
||||
for flow in olive_config["pass_flows"]:
|
||||
for i in range(len(flow)):
|
||||
|
|
@ -257,7 +257,7 @@ class OnnxRawPipeline(PipelineBase):
|
|||
|
||||
run_workflows(olive_config)
|
||||
|
||||
with open(os.path.join("footprints", f"{submodel}_{EP_TO_NAME[shared.opts.onnx_execution_provider]}_footprints.json"), "r", encoding="utf-8") as footprint_file:
|
||||
with open(os.path.join("footprints", f"{submodel}_{EP_TO_NAME[shared.opts.onnx_execution_provider]}_footprints.json"), encoding="utf-8") as footprint_file:
|
||||
footprints = json.load(footprint_file)
|
||||
processor_final_pass_footprint = None
|
||||
for _, footprint in footprints.items():
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import inspect
|
||||
from typing import Union, Optional, Callable, List, Any
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
import numpy as np
|
||||
import torch
|
||||
import diffusers
|
||||
|
|
@ -33,20 +34,20 @@ class OnnxStableDiffusionImg2ImgPipeline(diffusers.OnnxStableDiffusionImg2ImgPip
|
|||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt: str | list[str],
|
||||
image: PipelineImageInput = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
num_inference_steps: int | None = 50,
|
||||
guidance_scale: float | None = 7.5,
|
||||
negative_prompt: str | list[str] | None = None,
|
||||
num_images_per_prompt: int | None = 1,
|
||||
eta: float | None = 0.0,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
prompt_embeds: np.ndarray | None = None,
|
||||
negative_prompt_embeds: np.ndarray | None = None,
|
||||
output_type: str | None = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback: Callable[[int, int, np.ndarray], None] | None = None,
|
||||
callback_steps: int = 1,
|
||||
):
|
||||
# check inputs. Raise error if not correct
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue