initial pyright lint

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/4663/head
Vladimir Mandic 2026-02-21 09:32:36 +01:00
parent 0c9c86c3f9
commit 9a63ec758a
17 changed files with 167 additions and 155 deletions

View File

@ -1,12 +1,12 @@
# Change Log for SD.Next # Change Log for SD.Next
## Update for 2026-02-20 ## Update for 2026-02-21
### Highlights for 2026-02-20 ### Highlights for 2026-02-21
TBD TBD
### Details for 2026-02-20 ### Details for 2026-02-21
- **Models** - **Models**
- [FireRed Image Edit](https://huggingface.co/FireRedTeam/FireRed-Image-Edit-1.0) - [FireRed Image Edit](https://huggingface.co/FireRedTeam/FireRed-Image-Edit-1.0)
@ -47,8 +47,7 @@ TBD
`clip, numba, skimage, torchsde, omegaconf, antlr, patch-ng, patch-ng, astunparse, addict, inflection, jsonmerge, kornia`, `clip, numba, skimage, torchsde, omegaconf, antlr, patch-ng, patch-ng, astunparse, addict, inflection, jsonmerge, kornia`,
`resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets, imp` `resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets, imp`
these are now installed on-demand when needed these are now installed on-demand when needed
- refactor to/from image/tensor logic, thanks @CalamitousFelicitousness - refactor to/from *image/tensor* logic, thanks @CalamitousFelicitousness
- switch to `pyproject.toml` for tool configs
- refactor reorganize `cli` scripts - refactor reorganize `cli` scripts
- refactor move tests to dedicated `/test/` - refactor move tests to dedicated `/test/`
- refactor all image handling to `modules/image/` - refactor all image handling to `modules/image/`
@ -67,8 +66,12 @@ TBD
- remove requirements: `clip`, `open-clip` - remove requirements: `clip`, `open-clip`
- remove `normalbae` pre-processor - remove `normalbae` pre-processor
- captioning part-2, thanks @CalamitousFelicitousness - captioning part-2, thanks @CalamitousFelicitousness
- update `lint` rules, thanks @awsr
- add new build of `insightface`, thanks @hameerabbasi - add new build of `insightface`, thanks @hameerabbasi
- **Checks**
- switch to `pyproject.toml` for tool configs
- update `lint` rules, thanks @awsr
- add `ty` to optional lint tooling
- add `pyright` to optional lint tooling
- **Fixes** - **Fixes**
- handle `clip` installer doing unwanted `setuptools` update - handle `clip` installer doing unwanted `setuptools` update
- cleanup for `uv` installer fallback - cleanup for `uv` installer fallback

View File

@ -79,6 +79,21 @@ inpainting_is_aggressive_raunet = False
playground_is_aggressive_raunet = False playground_is_aggressive_raunet = False
def _chunked_feed_forward(ff: torch.nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output
def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
# replace global self-attention with MSW-MSA # replace global self-attention with MSW-MSA
class transformer_block(block_class): class transformer_block(block_class):
@ -238,7 +253,7 @@ def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type
norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None: if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) # pylint: disable=undefined-variable ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else: else:
ff_output = self.ff(norm_hidden_states) ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero: if self.use_ada_layer_norm_zero:

View File

@ -10,6 +10,9 @@ from diffusers.models import ControlNetModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
controlnet_apply_steps_rate = 0.6
def make_diffusers_unet_2d_condition(block_class): def make_diffusers_unet_2d_condition(block_class):
class unet_2d_condition(block_class): class unet_2d_condition(block_class):

View File

@ -53,9 +53,6 @@ class TDDScheduler(DPMSolverSinglestepScheduler):
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model. # this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

View File

@ -63,10 +63,10 @@ if devices.backend != "ipex":
if torch.__version__.startswith("2.6"): if torch.__version__.startswith("2.6"):
from dataclasses import dataclass from dataclasses import dataclass
from torch.compiler import disable as disable_compile # pylint: disable=ungrouped-imports from torch.compiler import disable as disable_compile # pylint: disable=ungrouped-imports
import diffusers.models.autoencoders.autoencoder_kl # pylint: disable=ungrouped-imports from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution # pylint: disable=ungrouped-imports
@dataclass @dataclass
@disable_compile @disable_compile
class AutoencoderKLOutput(diffusers.utils.BaseOutput): class AutoencoderKLOutput(diffusers.utils.BaseOutput):
latent_dist: "DiagonalGaussianDistribution" # noqa: F821 latent_dist: DiagonalGaussianDistribution
diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput = AutoencoderKLOutput diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput = AutoencoderKLOutput

View File

@ -18,7 +18,7 @@ import torch
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from ...common.cache import Cache from ...common.cache import Cache
from ....rotary_embedding import RotaryEmbedding from ....rotary_embedding import RotaryEmbedding, apply_rotary_emb
class RotaryEmbeddingBase(nn.Module): class RotaryEmbeddingBase(nn.Module):

View File

@ -27,6 +27,10 @@
"ruff-win": "venv\\scripts\\activate && ruff check", "ruff-win": "venv\\scripts\\activate && ruff check",
"pylint": ". venv/bin/activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/ | grep -v '^*'", "pylint": ". venv/bin/activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/ | grep -v '^*'",
"pylint-win": "venv\\scripts\\activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/", "pylint-win": "venv\\scripts\\activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/",
"pyright": ". venv/bin/activate && pyright --threads 4",
"pyright-win": "venv\\scripts\\activate && pyright --threads 4",
"ty": ". venv/bin/activate && ty check",
"ty-win": "venv\\scripts\\activate && ty check",
"lint": "npm run format && npm run eslint && npm run eslint-ui && npm run ruff && npm run pylint", "lint": "npm run format && npm run eslint && npm run eslint-ui && npm run ruff && npm run pylint",
"lint-win": "npm run format-win && npm run eslint && npm run eslint-ui && npm run ruff-win && npm run pylint-win", "lint-win": "npm run format-win && npm run eslint && npm run eslint-ui && npm run ruff-win && npm run pylint-win",
"test": ". venv/bin/activate; python launch.py --debug --test", "test": ". venv/bin/activate; python launch.py --debug --test",

View File

@ -88,7 +88,7 @@ class Scheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
timestep: torch.long, timestep: torch.Tensor,
sample: torch.LongTensor, sample: torch.LongTensor,
starting_mask_ratio: int = 1, starting_mask_ratio: int = 1,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,

View File

@ -366,3 +366,42 @@ variables.redefining-builtins-modules=["six.moves","past.builtins","future.built
pythonVersion = "3.10" pythonVersion = "3.10"
pythonPlatform = "All" pythonPlatform = "All"
typeCheckingMode = "off" typeCheckingMode = "off"
venvPath = "./venv"
include = [
"*.py",
"modules/**/*.py",
"pipelines/**/*.py",
"scripts/**/*.py",
"extensions-builtin/**/*.py"
]
exclude = [
"**/.*",
".git/",
"**/node_modules",
"**/__pycache__",
"venv",
]
reportMissingImports = "none"
reportInvalidTypeForm = "none"
[tool.ty.environment]
python = "./venv/bin/python"
python-platform = "all"
python-version = "3.10"
[tool.ty.src]
include = [
"./extensions-builtin"
]
exclude = [
"venv/",
"*.git/",
]
[tool.ty.rules]
invalid-method-overrides = "ignore"
invalid-argument-types = "ignore"
unresolved-imports = "ignore"
unresolved-attributes = "ignore"
invalid-assignments = "ignore"
unsupported-operators = "ignore"

View File

@ -2,7 +2,6 @@ import io
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Tuple, Set, Iterable from typing import Any, Dict, Tuple, Set, Iterable
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
@ -19,7 +18,6 @@ __all__ = ['GlobalHeatMap', 'RawHeatMapCollection', 'WordHeatMap', 'ParsedHeatMa
def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None, color_normalize=True, ax=None, cmap='jet'): def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None, color_normalize=True, ax=None, cmap='jet'):
# type: (PIL.Image.Image | np.ndarray, torch.Tensor, str, Path, int, bool, plt.Axes) -> None
if ax is None: if ax is None:
plt.rcParams['font.size'] = 16 plt.rcParams['font.size'] = 16
plt.rcParams['figure.facecolor'] = 'black' plt.rcParams['figure.facecolor'] = 'black'
@ -76,7 +74,6 @@ class WordHeatMap:
return self.heatmap return self.heatmap
def plot_overlay(self, image, out_file=None, color_normalize=True, ax=None, cmap='jet', **expand_kwargs): def plot_overlay(self, image, out_file=None, color_normalize=True, ax=None, cmap='jet', **expand_kwargs):
# type: (PIL.Image.Image | np.ndarray, Path, bool, plt.Axes, Dict[str, Any]) -> None
return plot_overlay_heat_map( return plot_overlay_heat_map(
image, image,
self.expand_as(image, **expand_kwargs), self.expand_as(image, **expand_kwargs),

View File

@ -1369,7 +1369,7 @@ def init_attn_proc(unet, ip_adapter_tokens=16, use_lcm=False, use_adaln=True, us
hidden_size=hidden_size, hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
time_embedding_dim=1280, time_embedding_dim=1280,
) if hasattr(F, "scaled_dot_product_attention") else AdditiveKV_AttnProcessor() )
else: else:
attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
else: else:

View File

@ -2,6 +2,8 @@ import os
import torch import torch
from typing import List from typing import List
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
from utils import revise_state_dict
def is_torch2_available(): def is_torch2_available():
return hasattr(torch.nn.functional, "scaled_dot_product_attention") return hasattr(torch.nn.functional, "scaled_dot_product_attention")

View File

@ -1,12 +1,12 @@
import torch
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
import torch
from safetensors import safe_open from safetensors import safe_open
from .attention_processor import init_attn_proc
from .ip_adapter import MultiIPAdapterImageProjection
from .resampler import Resampler
from transformers import ( from transformers import (
AutoModel, AutoImageProcessor, AutoModel, AutoImageProcessor,
CLIPVisionModelWithProjection, CLIPImageProcessor) CLIPVisionModelWithProjection, CLIPImageProcessor)
from .attention_processor import init_attn_proc
from .ip_adapter import MultiIPAdapterImageProjection
from .resampler import Resampler
def init_adapter_in_unet( def init_adapter_in_unet(
@ -80,85 +80,84 @@ def load_adapter_to_pipe(
use_lcm=False, use_lcm=False,
use_adaln=True, use_adaln=True,
): ):
if not isinstance(pretrained_model_path_or_dict, dict):
if not isinstance(pretrained_model_path_or_dict, dict): if pretrained_model_path_or_dict.endswith(".safetensors"):
if pretrained_model_path_or_dict.endswith(".safetensors"): state_dict = {"image_proj": {}, "ip_adapter": {}}
state_dict = {"image_proj": {}, "ip_adapter": {}} with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f:
with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f: for key in f.keys():
for key in f.keys(): if key.startswith("image_proj."):
if key.startswith("image_proj."): state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) elif key.startswith("ip_adapter."):
elif key.startswith("ip_adapter."): state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device)
else: else:
state_dict = pretrained_model_path_or_dict state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device)
keys = list(state_dict.keys()) else:
if "image_proj" not in keys and "ip_adapter" not in keys: state_dict = pretrained_model_path_or_dict
state_dict = revise_state_dict(state_dict) keys = list(state_dict.keys())
if "image_proj" not in keys and "ip_adapter" not in keys:
state_dict = revise_state_dict(state_dict)
# load CLIP image encoder here if it has not been registered to the pipeline yet # load CLIP image encoder here if it has not been registered to the pipeline yet
if image_encoder_or_path is not None: if image_encoder_or_path is not None:
if isinstance(image_encoder_or_path, str): if isinstance(image_encoder_or_path, str):
feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path
image_encoder_or_path = ( image_encoder_or_path = (
CLIPVisionModelWithProjection.from_pretrained( CLIPVisionModelWithProjection.from_pretrained(
image_encoder_or_path image_encoder_or_path
) if use_clip_encoder else ) if use_clip_encoder else
AutoModel.from_pretrained(image_encoder_or_path) AutoModel.from_pretrained(image_encoder_or_path)
) )
if feature_extractor_or_path is not None: if feature_extractor_or_path is not None:
if isinstance(feature_extractor_or_path, str): if isinstance(feature_extractor_or_path, str):
feature_extractor_or_path = ( feature_extractor_or_path = (
CLIPImageProcessor() if use_clip_encoder else CLIPImageProcessor() if use_clip_encoder else
AutoImageProcessor.from_pretrained(feature_extractor_or_path) AutoImageProcessor.from_pretrained(feature_extractor_or_path)
) )
# create image encoder if it has not been registered to the pipeline yet # create image encoder if it has not been registered to the pipeline yet
if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None: if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None:
image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype) image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype)
pipe.register_modules(image_encoder=image_encoder) pipe.register_modules(image_encoder=image_encoder)
else: else:
image_encoder = pipe.image_encoder image_encoder = pipe.image_encoder
# create feature extractor if it has not been registered to the pipeline yet # create feature extractor if it has not been registered to the pipeline yet
if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None: if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None:
feature_extractor = feature_extractor_or_path feature_extractor = feature_extractor_or_path
pipe.register_modules(feature_extractor=feature_extractor) pipe.register_modules(feature_extractor=feature_extractor)
else: else:
feature_extractor = pipe.feature_extractor feature_extractor = pipe.feature_extractor
# load adapter into unet # load adapter into unet
unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet
attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln) attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
unet.set_attn_processor(attn_procs) unet.set_attn_processor(attn_procs)
image_proj_model = Resampler( image_proj_model = Resampler(
embedding_dim=image_encoder.config.hidden_size, embedding_dim=image_encoder.config.hidden_size,
output_dim=unet.config.cross_attention_dim, output_dim=unet.config.cross_attention_dim,
num_queries=adapter_tokens, num_queries=adapter_tokens,
) )
# Load pretrinaed model if needed. # Load pretrinaed model if needed.
if "ip_adapter" in state_dict.keys(): if "ip_adapter" in state_dict.keys():
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False) missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
for mk in missing: for mk in missing:
if "ln" not in mk: if "ln" not in mk:
raise ValueError(f"Missing keys in adapter_modules: {missing}") raise ValueError(f"Missing keys in adapter_modules: {missing}")
if "image_proj" in state_dict.keys(): if "image_proj" in state_dict.keys():
image_proj_model.load_state_dict(state_dict["image_proj"]) image_proj_model.load_state_dict(state_dict["image_proj"])
# convert IP-Adapter Image Projection layers to diffusers # convert IP-Adapter Image Projection layers to diffusers
image_projection_layers = [] image_projection_layers = []
image_projection_layers.append(image_proj_model) image_projection_layers.append(image_proj_model)
unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
# Adjust unet config to handle addtional ip hidden states. # Adjust unet config to handle addtional ip hidden states.
unet.config.encoder_hid_dim_type = "ip_image_proj" unet.config.encoder_hid_dim_type = "ip_image_proj"
unet.to(dtype=pipe.dtype, device=pipe.device) unet.to(dtype=pipe.dtype, device=pipe.device)
def revise_state_dict(old_state_dict_or_path, map_location="cpu"): def revise_state_dict(old_state_dict_or_path, map_location="cpu"):
@ -198,51 +197,3 @@ def encode_image(image_encoder, feature_extractor, image, device, num_images_per
image_embeds = image_encoder(image).last_hidden_state image_embeds = image_encoder(image).last_hidden_state
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds return image_embeds
def prepare_training_image_embeds(
image_encoder, feature_extractor,
ip_adapter_image, ip_adapter_image_embeds,
device, drop_rate, output_hidden_state, idx_to_replace=None
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
# if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers):
# raise ValueError(
# f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
# )
image_embeds = []
for single_ip_adapter_image in ip_adapter_image:
if idx_to_replace is None:
idx_to_replace = torch.rand(len(single_ip_adapter_image)) < drop_rate
zero_ip_adapter_image = torch.zeros_like(single_ip_adapter_image)
single_ip_adapter_image[idx_to_replace] = zero_ip_adapter_image[idx_to_replace]
single_image_embeds = encode_image(
image_encoder, feature_extractor, single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds], dim=1) # FIXME
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds

View File

@ -19,7 +19,7 @@ except Exception:
from .modified_resnet import ModifiedResNet from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel from .timm_model import TimmModel
from .eva_vit_model import EVAVisionTransformer from .eva_vit_model import EVAVisionTransformer
from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer from .transformer import LayerNorm, LayerNormFp32, QuickGELU, Attention, VisionTransformer, TextTransformer
try: try:
from apex.normalization import FusedLayerNorm from apex.normalization import FusedLayerNorm

View File

@ -17,6 +17,15 @@ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
from .utils import to_2tuple from .utils import to_2tuple
try:
import xformers
import xformers.ops as xops
XFORMERS_IS_AVAILBLE = True
except Exception:
XFORMERS_IS_AVAILBLE = False
xops = None
class LayerNormFp32(nn.LayerNorm): class LayerNormFp32(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -417,10 +426,7 @@ class CustomTransformer(nn.Module):
if k is None and v is None: if k is None and v is None:
k = v = q k = v = q
for r in self.resblocks: for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting(): q = r(q, k, v, attn_mask=attn_mask)
q = checkpoint(r, q, k, v, attn_mask)
else:
q = r(q, k, v, attn_mask=attn_mask)
return q return q
@ -494,10 +500,7 @@ class Transformer(nn.Module):
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
for r in self.resblocks: for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting(): x = r(x, attn_mask=attn_mask)
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x return x

View File

@ -1,3 +1,4 @@
from typing import List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from collections import OrderedDict from collections import OrderedDict
@ -46,7 +47,7 @@ def get_parameter_dtype(parameter: torch.nn.Module):
except StopIteration: except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5 # For torch.nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, torch.Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples return tuples
@ -76,7 +77,7 @@ class Downsample(nn.Module):
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
from torch.nn import MaxUnpool2d from torch.nn import MaxUnpool2d
self.op = MaxUnpool2d(dims, kernel_size=stride, stride=stride) self.op = MaxUnpool2d(kernel_size=stride, stride=stride)
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
@ -267,10 +268,7 @@ class Adapter_XL(nn.Module):
if t is not None: if t is not None:
if not torch.is_tensor(t): if not torch.is_tensor(t):
is_mps = x[0].device.type == "mps" is_mps = x[0].device.type == "mps"
if isinstance(timestep, float): dtype = torch.int32 if is_mps else torch.int64
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
t = torch.tensor([t], dtype=dtype, device=x[0].device) t = torch.tensor([t], dtype=dtype, device=x[0].device)
elif len(t.shape) == 0: elif len(t.shape) == 0:
t = t[None].to(x[0].device) t = t[None].to(x[0].device)

View File

@ -728,7 +728,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_bridge_residuals: Optional[Tuple[torch.Tensor]] = None, down_bridge_residuals: Optional[Tuple[torch.Tensor]] = None,
fusion_guidance_scale: Optional[torch.FloatTensor] = None, fusion_guidance_scale: Optional[torch.FloatTensor] = None,
fusion_type: Optional[str] = 'ADD', fusion_type: Optional[str] = 'ADD',
adapter: Optional = None adapter = None
) -> Union[UNet2DConditionOutput, Tuple]: ) -> Union[UNet2DConditionOutput, Tuple]:
r""" r"""
The [`UNet2DConditionModel`] forward method. The [`UNet2DConditionModel`] forward method.