initial pyright lint

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/4663/head
Vladimir Mandic 2026-02-21 09:32:36 +01:00
parent 0c9c86c3f9
commit 9a63ec758a
17 changed files with 167 additions and 155 deletions

View File

@ -1,12 +1,12 @@
# Change Log for SD.Next
## Update for 2026-02-20
## Update for 2026-02-21
### Highlights for 2026-02-20
### Highlights for 2026-02-21
TBD
### Details for 2026-02-20
### Details for 2026-02-21
- **Models**
- [FireRed Image Edit](https://huggingface.co/FireRedTeam/FireRed-Image-Edit-1.0)
@ -47,8 +47,7 @@ TBD
`clip, numba, skimage, torchsde, omegaconf, antlr, patch-ng, patch-ng, astunparse, addict, inflection, jsonmerge, kornia`,
`resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets, imp`
these are now installed on-demand when needed
- refactor to/from image/tensor logic, thanks @CalamitousFelicitousness
- switch to `pyproject.toml` for tool configs
- refactor to/from *image/tensor* logic, thanks @CalamitousFelicitousness
- refactor reorganize `cli` scripts
- refactor move tests to dedicated `/test/`
- refactor all image handling to `modules/image/`
@ -67,8 +66,12 @@ TBD
- remove requirements: `clip`, `open-clip`
- remove `normalbae` pre-processor
- captioning part-2, thanks @CalamitousFelicitousness
- update `lint` rules, thanks @awsr
- add new build of `insightface`, thanks @hameerabbasi
- **Checks**
- switch to `pyproject.toml` for tool configs
- update `lint` rules, thanks @awsr
- add `ty` to optional lint tooling
- add `pyright` to optional lint tooling
- **Fixes**
- handle `clip` installer doing unwanted `setuptools` update
- cleanup for `uv` installer fallback

View File

@ -79,6 +79,21 @@ inpainting_is_aggressive_raunet = False
playground_is_aggressive_raunet = False
def _chunked_feed_forward(ff: torch.nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output
def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
# replace global self-attention with MSW-MSA
class transformer_block(block_class):
@ -238,7 +253,7 @@ def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) # pylint: disable=undefined-variable
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:

View File

@ -10,6 +10,9 @@ from diffusers.models import ControlNetModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
controlnet_apply_steps_rate = 0.6
def make_diffusers_unet_2d_condition(block_class):
class unet_2d_condition(block_class):

View File

@ -53,9 +53,6 @@ class TDDScheduler(DPMSolverSinglestepScheduler):
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

View File

@ -63,10 +63,10 @@ if devices.backend != "ipex":
if torch.__version__.startswith("2.6"):
from dataclasses import dataclass
from torch.compiler import disable as disable_compile # pylint: disable=ungrouped-imports
import diffusers.models.autoencoders.autoencoder_kl # pylint: disable=ungrouped-imports
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution # pylint: disable=ungrouped-imports
@dataclass
@disable_compile
class AutoencoderKLOutput(diffusers.utils.BaseOutput):
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
latent_dist: DiagonalGaussianDistribution
diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput = AutoencoderKLOutput

View File

@ -18,7 +18,7 @@ import torch
from einops import rearrange
from torch import nn
from ...common.cache import Cache
from ....rotary_embedding import RotaryEmbedding
from ....rotary_embedding import RotaryEmbedding, apply_rotary_emb
class RotaryEmbeddingBase(nn.Module):

View File

@ -27,6 +27,10 @@
"ruff-win": "venv\\scripts\\activate && ruff check",
"pylint": ". venv/bin/activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/ | grep -v '^*'",
"pylint-win": "venv\\scripts\\activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/",
"pyright": ". venv/bin/activate && pyright --threads 4",
"pyright-win": "venv\\scripts\\activate && pyright --threads 4",
"ty": ". venv/bin/activate && ty check",
"ty-win": "venv\\scripts\\activate && ty check",
"lint": "npm run format && npm run eslint && npm run eslint-ui && npm run ruff && npm run pylint",
"lint-win": "npm run format-win && npm run eslint && npm run eslint-ui && npm run ruff-win && npm run pylint-win",
"test": ". venv/bin/activate; python launch.py --debug --test",

View File

@ -88,7 +88,7 @@ class Scheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: torch.long,
timestep: torch.Tensor,
sample: torch.LongTensor,
starting_mask_ratio: int = 1,
generator: Optional[torch.Generator] = None,

View File

@ -366,3 +366,42 @@ variables.redefining-builtins-modules=["six.moves","past.builtins","future.built
pythonVersion = "3.10"
pythonPlatform = "All"
typeCheckingMode = "off"
venvPath = "./venv"
include = [
"*.py",
"modules/**/*.py",
"pipelines/**/*.py",
"scripts/**/*.py",
"extensions-builtin/**/*.py"
]
exclude = [
"**/.*",
".git/",
"**/node_modules",
"**/__pycache__",
"venv",
]
reportMissingImports = "none"
reportInvalidTypeForm = "none"
[tool.ty.environment]
python = "./venv/bin/python"
python-platform = "all"
python-version = "3.10"
[tool.ty.src]
include = [
"./extensions-builtin"
]
exclude = [
"venv/",
"*.git/",
]
[tool.ty.rules]
invalid-method-overrides = "ignore"
invalid-argument-types = "ignore"
unresolved-imports = "ignore"
unresolved-attributes = "ignore"
invalid-assignments = "ignore"
unsupported-operators = "ignore"

View File

@ -2,7 +2,6 @@ import io
from collections import defaultdict
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Tuple, Set, Iterable
from matplotlib import pyplot as plt
@ -19,7 +18,6 @@ __all__ = ['GlobalHeatMap', 'RawHeatMapCollection', 'WordHeatMap', 'ParsedHeatMa
def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None, color_normalize=True, ax=None, cmap='jet'):
# type: (PIL.Image.Image | np.ndarray, torch.Tensor, str, Path, int, bool, plt.Axes) -> None
if ax is None:
plt.rcParams['font.size'] = 16
plt.rcParams['figure.facecolor'] = 'black'
@ -76,7 +74,6 @@ class WordHeatMap:
return self.heatmap
def plot_overlay(self, image, out_file=None, color_normalize=True, ax=None, cmap='jet', **expand_kwargs):
# type: (PIL.Image.Image | np.ndarray, Path, bool, plt.Axes, Dict[str, Any]) -> None
return plot_overlay_heat_map(
image,
self.expand_as(image, **expand_kwargs),

View File

@ -1369,7 +1369,7 @@ def init_attn_proc(unet, ip_adapter_tokens=16, use_lcm=False, use_adaln=True, us
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
time_embedding_dim=1280,
) if hasattr(F, "scaled_dot_product_attention") else AdditiveKV_AttnProcessor()
)
else:
attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
else:

View File

@ -2,6 +2,8 @@ import os
import torch
from typing import List
from collections import namedtuple, OrderedDict
from utils import revise_state_dict
def is_torch2_available():
return hasattr(torch.nn.functional, "scaled_dot_product_attention")

View File

@ -1,12 +1,12 @@
import torch
from collections import namedtuple, OrderedDict
import torch
from safetensors import safe_open
from .attention_processor import init_attn_proc
from .ip_adapter import MultiIPAdapterImageProjection
from .resampler import Resampler
from transformers import (
AutoModel, AutoImageProcessor,
CLIPVisionModelWithProjection, CLIPImageProcessor)
from .attention_processor import init_attn_proc
from .ip_adapter import MultiIPAdapterImageProjection
from .resampler import Resampler
def init_adapter_in_unet(
@ -80,85 +80,84 @@ def load_adapter_to_pipe(
use_lcm=False,
use_adaln=True,
):
if not isinstance(pretrained_model_path_or_dict, dict):
if pretrained_model_path_or_dict.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device)
if not isinstance(pretrained_model_path_or_dict, dict):
if pretrained_model_path_or_dict.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = pretrained_model_path_or_dict
keys = list(state_dict.keys())
if "image_proj" not in keys and "ip_adapter" not in keys:
state_dict = revise_state_dict(state_dict)
state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device)
else:
state_dict = pretrained_model_path_or_dict
keys = list(state_dict.keys())
if "image_proj" not in keys and "ip_adapter" not in keys:
state_dict = revise_state_dict(state_dict)
# load CLIP image encoder here if it has not been registered to the pipeline yet
if image_encoder_or_path is not None:
if isinstance(image_encoder_or_path, str):
feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path
# load CLIP image encoder here if it has not been registered to the pipeline yet
if image_encoder_or_path is not None:
if isinstance(image_encoder_or_path, str):
feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path
image_encoder_or_path = (
CLIPVisionModelWithProjection.from_pretrained(
image_encoder_or_path
) if use_clip_encoder else
AutoModel.from_pretrained(image_encoder_or_path)
)
image_encoder_or_path = (
CLIPVisionModelWithProjection.from_pretrained(
image_encoder_or_path
) if use_clip_encoder else
AutoModel.from_pretrained(image_encoder_or_path)
)
if feature_extractor_or_path is not None:
if isinstance(feature_extractor_or_path, str):
feature_extractor_or_path = (
CLIPImageProcessor() if use_clip_encoder else
AutoImageProcessor.from_pretrained(feature_extractor_or_path)
)
if feature_extractor_or_path is not None:
if isinstance(feature_extractor_or_path, str):
feature_extractor_or_path = (
CLIPImageProcessor() if use_clip_encoder else
AutoImageProcessor.from_pretrained(feature_extractor_or_path)
)
# create image encoder if it has not been registered to the pipeline yet
if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None:
image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype)
pipe.register_modules(image_encoder=image_encoder)
else:
image_encoder = pipe.image_encoder
# create image encoder if it has not been registered to the pipeline yet
if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None:
image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype)
pipe.register_modules(image_encoder=image_encoder)
else:
image_encoder = pipe.image_encoder
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None:
feature_extractor = feature_extractor_or_path
pipe.register_modules(feature_extractor=feature_extractor)
else:
feature_extractor = pipe.feature_extractor
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None:
feature_extractor = feature_extractor_or_path
pipe.register_modules(feature_extractor=feature_extractor)
else:
feature_extractor = pipe.feature_extractor
# load adapter into unet
unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet
attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
unet.set_attn_processor(attn_procs)
image_proj_model = Resampler(
embedding_dim=image_encoder.config.hidden_size,
output_dim=unet.config.cross_attention_dim,
num_queries=adapter_tokens,
)
# load adapter into unet
unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet
attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
unet.set_attn_processor(attn_procs)
image_proj_model = Resampler(
embedding_dim=image_encoder.config.hidden_size,
output_dim=unet.config.cross_attention_dim,
num_queries=adapter_tokens,
)
# Load pretrinaed model if needed.
if "ip_adapter" in state_dict.keys():
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
for mk in missing:
if "ln" not in mk:
raise ValueError(f"Missing keys in adapter_modules: {missing}")
if "image_proj" in state_dict.keys():
image_proj_model.load_state_dict(state_dict["image_proj"])
# Load pretrinaed model if needed.
if "ip_adapter" in state_dict.keys():
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
for mk in missing:
if "ln" not in mk:
raise ValueError(f"Missing keys in adapter_modules: {missing}")
if "image_proj" in state_dict.keys():
image_proj_model.load_state_dict(state_dict["image_proj"])
# convert IP-Adapter Image Projection layers to diffusers
image_projection_layers = []
image_projection_layers.append(image_proj_model)
unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
# convert IP-Adapter Image Projection layers to diffusers
image_projection_layers = []
image_projection_layers.append(image_proj_model)
unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
# Adjust unet config to handle addtional ip hidden states.
unet.config.encoder_hid_dim_type = "ip_image_proj"
unet.to(dtype=pipe.dtype, device=pipe.device)
# Adjust unet config to handle addtional ip hidden states.
unet.config.encoder_hid_dim_type = "ip_image_proj"
unet.to(dtype=pipe.dtype, device=pipe.device)
def revise_state_dict(old_state_dict_or_path, map_location="cpu"):
@ -198,51 +197,3 @@ def encode_image(image_encoder, feature_extractor, image, device, num_images_per
image_embeds = image_encoder(image).last_hidden_state
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds
def prepare_training_image_embeds(
image_encoder, feature_extractor,
ip_adapter_image, ip_adapter_image_embeds,
device, drop_rate, output_hidden_state, idx_to_replace=None
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
# if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers):
# raise ValueError(
# f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
# )
image_embeds = []
for single_ip_adapter_image in ip_adapter_image:
if idx_to_replace is None:
idx_to_replace = torch.rand(len(single_ip_adapter_image)) < drop_rate
zero_ip_adapter_image = torch.zeros_like(single_ip_adapter_image)
single_ip_adapter_image[idx_to_replace] = zero_ip_adapter_image[idx_to_replace]
single_image_embeds = encode_image(
image_encoder, feature_extractor, single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds], dim=1) # FIXME
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds

View File

@ -19,7 +19,7 @@ except Exception:
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
from .eva_vit_model import EVAVisionTransformer
from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
from .transformer import LayerNorm, LayerNormFp32, QuickGELU, Attention, VisionTransformer, TextTransformer
try:
from apex.normalization import FusedLayerNorm

View File

@ -17,6 +17,15 @@ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
from .utils import to_2tuple
try:
import xformers
import xformers.ops as xops
XFORMERS_IS_AVAILBLE = True
except Exception:
XFORMERS_IS_AVAILBLE = False
xops = None
class LayerNormFp32(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
def __init__(self, *args, **kwargs):
@ -417,10 +426,7 @@ class CustomTransformer(nn.Module):
if k is None and v is None:
k = v = q
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
q = checkpoint(r, q, k, v, attn_mask)
else:
q = r(q, k, v, attn_mask=attn_mask)
q = r(q, k, v, attn_mask=attn_mask)
return q
@ -494,10 +500,7 @@ class Transformer(nn.Module):
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
x = r(x, attn_mask=attn_mask)
return x

View File

@ -1,3 +1,4 @@
from typing import List, Tuple
import torch
import torch.nn as nn
from collections import OrderedDict
@ -46,7 +47,7 @@ def get_parameter_dtype(parameter: torch.nn.Module):
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, torch.Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
@ -76,7 +77,7 @@ class Downsample(nn.Module):
else:
assert self.channels == self.out_channels
from torch.nn import MaxUnpool2d
self.op = MaxUnpool2d(dims, kernel_size=stride, stride=stride)
self.op = MaxUnpool2d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
@ -267,10 +268,7 @@ class Adapter_XL(nn.Module):
if t is not None:
if not torch.is_tensor(t):
is_mps = x[0].device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps else torch.int64
t = torch.tensor([t], dtype=dtype, device=x[0].device)
elif len(t.shape) == 0:
t = t[None].to(x[0].device)

View File

@ -728,7 +728,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_bridge_residuals: Optional[Tuple[torch.Tensor]] = None,
fusion_guidance_scale: Optional[torch.FloatTensor] = None,
fusion_type: Optional[str] = 'ADD',
adapter: Optional = None
adapter = None
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
The [`UNet2DConditionModel`] forward method.