diff --git a/CHANGELOG.md b/CHANGELOG.md index 46cf8803c..82b510e03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,12 @@ # Change Log for SD.Next -## Update for 2026-02-20 +## Update for 2026-02-21 -### Highlights for 2026-02-20 +### Highlights for 2026-02-21 TBD -### Details for 2026-02-20 +### Details for 2026-02-21 - **Models** - [FireRed Image Edit](https://huggingface.co/FireRedTeam/FireRed-Image-Edit-1.0) @@ -47,8 +47,7 @@ TBD `clip, numba, skimage, torchsde, omegaconf, antlr, patch-ng, patch-ng, astunparse, addict, inflection, jsonmerge, kornia`, `resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets, imp` these are now installed on-demand when needed - - refactor to/from image/tensor logic, thanks @CalamitousFelicitousness - - switch to `pyproject.toml` for tool configs + - refactor to/from *image/tensor* logic, thanks @CalamitousFelicitousness - refactor reorganize `cli` scripts - refactor move tests to dedicated `/test/` - refactor all image handling to `modules/image/` @@ -67,8 +66,12 @@ TBD - remove requirements: `clip`, `open-clip` - remove `normalbae` pre-processor - captioning part-2, thanks @CalamitousFelicitousness - - update `lint` rules, thanks @awsr - add new build of `insightface`, thanks @hameerabbasi +- **Checks** + - switch to `pyproject.toml` for tool configs + - update `lint` rules, thanks @awsr + - add `ty` to optional lint tooling + - add `pyright` to optional lint tooling - **Fixes** - handle `clip` installer doing unwanted `setuptools` update - cleanup for `uv` installer fallback diff --git a/modules/hidiffusion/hidiffusion.py b/modules/hidiffusion/hidiffusion.py index b00d19132..b0ddf3b58 100644 --- a/modules/hidiffusion/hidiffusion.py +++ b/modules/hidiffusion/hidiffusion.py @@ -79,6 +79,21 @@ inpainting_is_aggressive_raunet = False playground_is_aggressive_raunet = False +def _chunked_feed_forward(ff: torch.nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + return ff_output + + def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: # replace global self-attention with MSW-MSA class transformer_block(block_class): @@ -238,7 +253,7 @@ def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp if self._chunk_size is not None: - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) # pylint: disable=undefined-variable + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: diff --git a/modules/hidiffusion/hidiffusion_controlnet.py b/modules/hidiffusion/hidiffusion_controlnet.py index 7a81ab066..990ecf8da 100644 --- a/modules/hidiffusion/hidiffusion_controlnet.py +++ b/modules/hidiffusion/hidiffusion_controlnet.py @@ -10,6 +10,9 @@ from diffusers.models import ControlNetModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput +controlnet_apply_steps_rate = 0.6 + + def make_diffusers_unet_2d_condition(block_class): class unet_2d_condition(block_class): diff --git a/modules/schedulers/scheduler_tdd.py b/modules/schedulers/scheduler_tdd.py index 7dbeb7010..05b49f35c 100644 --- a/modules/schedulers/scheduler_tdd.py +++ b/modules/schedulers/scheduler_tdd.py @@ -53,9 +53,6 @@ class TDDScheduler(DPMSolverSinglestepScheduler): elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 9be31ccdf..ea5a28ec3 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -63,10 +63,10 @@ if devices.backend != "ipex": if torch.__version__.startswith("2.6"): from dataclasses import dataclass from torch.compiler import disable as disable_compile # pylint: disable=ungrouped-imports - import diffusers.models.autoencoders.autoencoder_kl # pylint: disable=ungrouped-imports + from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution # pylint: disable=ungrouped-imports @dataclass @disable_compile class AutoencoderKLOutput(diffusers.utils.BaseOutput): - latent_dist: "DiagonalGaussianDistribution" # noqa: F821 + latent_dist: DiagonalGaussianDistribution diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput = AutoencoderKLOutput diff --git a/modules/seedvr/src/models/dit/rope.py b/modules/seedvr/src/models/dit/rope.py index 35b91ea8b..2647c17c8 100644 --- a/modules/seedvr/src/models/dit/rope.py +++ b/modules/seedvr/src/models/dit/rope.py @@ -18,7 +18,7 @@ import torch from einops import rearrange from torch import nn from ...common.cache import Cache -from ....rotary_embedding import RotaryEmbedding +from ....rotary_embedding import RotaryEmbedding, apply_rotary_emb class RotaryEmbeddingBase(nn.Module): diff --git a/package.json b/package.json index 2f6cfe69c..e3e948c65 100644 --- a/package.json +++ b/package.json @@ -27,6 +27,10 @@ "ruff-win": "venv\\scripts\\activate && ruff check", "pylint": ". venv/bin/activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/ | grep -v '^*'", "pylint-win": "venv\\scripts\\activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/", + "pyright": ". venv/bin/activate && pyright --threads 4", + "pyright-win": "venv\\scripts\\activate && pyright --threads 4", + "ty": ". venv/bin/activate && ty check", + "ty-win": "venv\\scripts\\activate && ty check", "lint": "npm run format && npm run eslint && npm run eslint-ui && npm run ruff && npm run pylint", "lint-win": "npm run format-win && npm run eslint && npm run eslint-ui && npm run ruff-win && npm run pylint-win", "test": ". venv/bin/activate; python launch.py --debug --test", diff --git a/pipelines/meissonic/scheduler.py b/pipelines/meissonic/scheduler.py index 757469b4d..155076ea1 100644 --- a/pipelines/meissonic/scheduler.py +++ b/pipelines/meissonic/scheduler.py @@ -88,7 +88,7 @@ class Scheduler(SchedulerMixin, ConfigMixin): def step( self, model_output: torch.Tensor, - timestep: torch.long, + timestep: torch.Tensor, sample: torch.LongTensor, starting_mask_ratio: int = 1, generator: Optional[torch.Generator] = None, diff --git a/pyproject.toml b/pyproject.toml index 53c963c93..66cd892f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -366,3 +366,42 @@ variables.redefining-builtins-modules=["six.moves","past.builtins","future.built pythonVersion = "3.10" pythonPlatform = "All" typeCheckingMode = "off" +venvPath = "./venv" +include = [ + "*.py", + "modules/**/*.py", + "pipelines/**/*.py", + "scripts/**/*.py", + "extensions-builtin/**/*.py" + ] +exclude = [ + "**/.*", + ".git/", + "**/node_modules", + "**/__pycache__", + "venv", + ] +reportMissingImports = "none" +reportInvalidTypeForm = "none" + +[tool.ty.environment] +python = "./venv/bin/python" +python-platform = "all" +python-version = "3.10" + +[tool.ty.src] +include = [ + "./extensions-builtin" + ] +exclude = [ + "venv/", + "*.git/", + ] + +[tool.ty.rules] +invalid-method-overrides = "ignore" +invalid-argument-types = "ignore" +unresolved-imports = "ignore" +unresolved-attributes = "ignore" +invalid-assignments = "ignore" +unsupported-operators = "ignore" diff --git a/scripts/daam/heatmap.py b/scripts/daam/heatmap.py index 0f7a311f0..99378f05c 100644 --- a/scripts/daam/heatmap.py +++ b/scripts/daam/heatmap.py @@ -2,7 +2,6 @@ import io from collections import defaultdict from dataclasses import dataclass from functools import lru_cache -from pathlib import Path from typing import Any, Dict, Tuple, Set, Iterable from matplotlib import pyplot as plt @@ -19,7 +18,6 @@ __all__ = ['GlobalHeatMap', 'RawHeatMapCollection', 'WordHeatMap', 'ParsedHeatMa def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None, color_normalize=True, ax=None, cmap='jet'): - # type: (PIL.Image.Image | np.ndarray, torch.Tensor, str, Path, int, bool, plt.Axes) -> None if ax is None: plt.rcParams['font.size'] = 16 plt.rcParams['figure.facecolor'] = 'black' @@ -76,7 +74,6 @@ class WordHeatMap: return self.heatmap def plot_overlay(self, image, out_file=None, color_normalize=True, ax=None, cmap='jet', **expand_kwargs): - # type: (PIL.Image.Image | np.ndarray, Path, bool, plt.Axes, Dict[str, Any]) -> None return plot_overlay_heat_map( image, self.expand_as(image, **expand_kwargs), diff --git a/scripts/instantir/ip_adapter/attention_processor.py b/scripts/instantir/ip_adapter/attention_processor.py index ca57b8754..191f45526 100644 --- a/scripts/instantir/ip_adapter/attention_processor.py +++ b/scripts/instantir/ip_adapter/attention_processor.py @@ -1369,7 +1369,7 @@ def init_attn_proc(unet, ip_adapter_tokens=16, use_lcm=False, use_adaln=True, us hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, time_embedding_dim=1280, - ) if hasattr(F, "scaled_dot_product_attention") else AdditiveKV_AttnProcessor() + ) else: attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() else: diff --git a/scripts/instantir/ip_adapter/ip_adapter.py b/scripts/instantir/ip_adapter/ip_adapter.py index a244cd681..8a5df0a23 100644 --- a/scripts/instantir/ip_adapter/ip_adapter.py +++ b/scripts/instantir/ip_adapter/ip_adapter.py @@ -2,6 +2,8 @@ import os import torch from typing import List from collections import namedtuple, OrderedDict +from utils import revise_state_dict + def is_torch2_available(): return hasattr(torch.nn.functional, "scaled_dot_product_attention") diff --git a/scripts/instantir/ip_adapter/utils.py b/scripts/instantir/ip_adapter/utils.py index 07147c278..4f6d7075c 100644 --- a/scripts/instantir/ip_adapter/utils.py +++ b/scripts/instantir/ip_adapter/utils.py @@ -1,12 +1,12 @@ -import torch from collections import namedtuple, OrderedDict +import torch from safetensors import safe_open -from .attention_processor import init_attn_proc -from .ip_adapter import MultiIPAdapterImageProjection -from .resampler import Resampler from transformers import ( AutoModel, AutoImageProcessor, CLIPVisionModelWithProjection, CLIPImageProcessor) +from .attention_processor import init_attn_proc +from .ip_adapter import MultiIPAdapterImageProjection +from .resampler import Resampler def init_adapter_in_unet( @@ -80,85 +80,84 @@ def load_adapter_to_pipe( use_lcm=False, use_adaln=True, ): - - if not isinstance(pretrained_model_path_or_dict, dict): - if pretrained_model_path_or_dict.endswith(".safetensors"): - state_dict = {"image_proj": {}, "ip_adapter": {}} - with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f: - for key in f.keys(): - if key.startswith("image_proj."): - state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) - elif key.startswith("ip_adapter."): - state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) - else: - state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device) + if not isinstance(pretrained_model_path_or_dict, dict): + if pretrained_model_path_or_dict.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) else: - state_dict = pretrained_model_path_or_dict - keys = list(state_dict.keys()) - if "image_proj" not in keys and "ip_adapter" not in keys: - state_dict = revise_state_dict(state_dict) + state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device) + else: + state_dict = pretrained_model_path_or_dict + keys = list(state_dict.keys()) + if "image_proj" not in keys and "ip_adapter" not in keys: + state_dict = revise_state_dict(state_dict) - # load CLIP image encoder here if it has not been registered to the pipeline yet - if image_encoder_or_path is not None: - if isinstance(image_encoder_or_path, str): - feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path + # load CLIP image encoder here if it has not been registered to the pipeline yet + if image_encoder_or_path is not None: + if isinstance(image_encoder_or_path, str): + feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path - image_encoder_or_path = ( - CLIPVisionModelWithProjection.from_pretrained( - image_encoder_or_path - ) if use_clip_encoder else - AutoModel.from_pretrained(image_encoder_or_path) - ) + image_encoder_or_path = ( + CLIPVisionModelWithProjection.from_pretrained( + image_encoder_or_path + ) if use_clip_encoder else + AutoModel.from_pretrained(image_encoder_or_path) + ) - if feature_extractor_or_path is not None: - if isinstance(feature_extractor_or_path, str): - feature_extractor_or_path = ( - CLIPImageProcessor() if use_clip_encoder else - AutoImageProcessor.from_pretrained(feature_extractor_or_path) - ) + if feature_extractor_or_path is not None: + if isinstance(feature_extractor_or_path, str): + feature_extractor_or_path = ( + CLIPImageProcessor() if use_clip_encoder else + AutoImageProcessor.from_pretrained(feature_extractor_or_path) + ) - # create image encoder if it has not been registered to the pipeline yet - if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None: - image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype) - pipe.register_modules(image_encoder=image_encoder) - else: - image_encoder = pipe.image_encoder + # create image encoder if it has not been registered to the pipeline yet + if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None: + image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype) + pipe.register_modules(image_encoder=image_encoder) + else: + image_encoder = pipe.image_encoder - # create feature extractor if it has not been registered to the pipeline yet - if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None: - feature_extractor = feature_extractor_or_path - pipe.register_modules(feature_extractor=feature_extractor) - else: - feature_extractor = pipe.feature_extractor + # create feature extractor if it has not been registered to the pipeline yet + if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None: + feature_extractor = feature_extractor_or_path + pipe.register_modules(feature_extractor=feature_extractor) + else: + feature_extractor = pipe.feature_extractor - # load adapter into unet - unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet - attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln) - unet.set_attn_processor(attn_procs) - image_proj_model = Resampler( - embedding_dim=image_encoder.config.hidden_size, - output_dim=unet.config.cross_attention_dim, - num_queries=adapter_tokens, - ) + # load adapter into unet + unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet + attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln) + unet.set_attn_processor(attn_procs) + image_proj_model = Resampler( + embedding_dim=image_encoder.config.hidden_size, + output_dim=unet.config.cross_attention_dim, + num_queries=adapter_tokens, + ) - # Load pretrinaed model if needed. - if "ip_adapter" in state_dict.keys(): - adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) - missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False) - for mk in missing: - if "ln" not in mk: - raise ValueError(f"Missing keys in adapter_modules: {missing}") - if "image_proj" in state_dict.keys(): - image_proj_model.load_state_dict(state_dict["image_proj"]) + # Load pretrinaed model if needed. + if "ip_adapter" in state_dict.keys(): + adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) + missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False) + for mk in missing: + if "ln" not in mk: + raise ValueError(f"Missing keys in adapter_modules: {missing}") + if "image_proj" in state_dict.keys(): + image_proj_model.load_state_dict(state_dict["image_proj"]) - # convert IP-Adapter Image Projection layers to diffusers - image_projection_layers = [] - image_projection_layers.append(image_proj_model) - unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + # convert IP-Adapter Image Projection layers to diffusers + image_projection_layers = [] + image_projection_layers.append(image_proj_model) + unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) - # Adjust unet config to handle addtional ip hidden states. - unet.config.encoder_hid_dim_type = "ip_image_proj" - unet.to(dtype=pipe.dtype, device=pipe.device) + # Adjust unet config to handle addtional ip hidden states. + unet.config.encoder_hid_dim_type = "ip_image_proj" + unet.to(dtype=pipe.dtype, device=pipe.device) def revise_state_dict(old_state_dict_or_path, map_location="cpu"): @@ -198,51 +197,3 @@ def encode_image(image_encoder, feature_extractor, image, device, num_images_per image_embeds = image_encoder(image).last_hidden_state image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds - - -def prepare_training_image_embeds( - image_encoder, feature_extractor, - ip_adapter_image, ip_adapter_image_embeds, - device, drop_rate, output_hidden_state, idx_to_replace=None -): - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - # if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers): - # raise ValueError( - # f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - # ) - - image_embeds = [] - for single_ip_adapter_image in ip_adapter_image: - if idx_to_replace is None: - idx_to_replace = torch.rand(len(single_ip_adapter_image)) < drop_rate - zero_ip_adapter_image = torch.zeros_like(single_ip_adapter_image) - single_ip_adapter_image[idx_to_replace] = zero_ip_adapter_image[idx_to_replace] - single_image_embeds = encode_image( - image_encoder, feature_extractor, single_ip_adapter_image, device, 1, output_hidden_state - ) - single_image_embeds = torch.stack([single_image_embeds], dim=1) # FIXME - - image_embeds.append(single_image_embeds) - else: - repeat_dims = [1] - image_embeds = [] - for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - image_embeds.append(single_image_embeds) - - return image_embeds diff --git a/scripts/pulid/eva_clip/model.py b/scripts/pulid/eva_clip/model.py index 9e755d683..05b055794 100644 --- a/scripts/pulid/eva_clip/model.py +++ b/scripts/pulid/eva_clip/model.py @@ -19,7 +19,7 @@ except Exception: from .modified_resnet import ModifiedResNet from .timm_model import TimmModel from .eva_vit_model import EVAVisionTransformer -from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer +from .transformer import LayerNorm, LayerNormFp32, QuickGELU, Attention, VisionTransformer, TextTransformer try: from apex.normalization import FusedLayerNorm diff --git a/scripts/pulid/eva_clip/transformer.py b/scripts/pulid/eva_clip/transformer.py index b1d42de77..57a24ad2c 100644 --- a/scripts/pulid/eva_clip/transformer.py +++ b/scripts/pulid/eva_clip/transformer.py @@ -17,6 +17,15 @@ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast from .utils import to_2tuple +try: + import xformers + import xformers.ops as xops + XFORMERS_IS_AVAILBLE = True +except Exception: + XFORMERS_IS_AVAILBLE = False + xops = None + + class LayerNormFp32(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" def __init__(self, *args, **kwargs): @@ -417,10 +426,7 @@ class CustomTransformer(nn.Module): if k is None and v is None: k = v = q for r in self.resblocks: - if self.grad_checkpointing and not torch.jit.is_scripting(): - q = checkpoint(r, q, k, v, attn_mask) - else: - q = r(q, k, v, attn_mask=attn_mask) + q = r(q, k, v, attn_mask=attn_mask) return q @@ -494,10 +500,7 @@ class Transformer(nn.Module): def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): for r in self.resblocks: - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(r, x, attn_mask) - else: - x = r(x, attn_mask=attn_mask) + x = r(x, attn_mask=attn_mask) return x diff --git a/scripts/xadapter/adapter.py b/scripts/xadapter/adapter.py index 4096f71e7..4b3f9f625 100644 --- a/scripts/xadapter/adapter.py +++ b/scripts/xadapter/adapter.py @@ -1,3 +1,4 @@ +from typing import List, Tuple import torch import torch.nn as nn from collections import OrderedDict @@ -46,7 +47,7 @@ def get_parameter_dtype(parameter: torch.nn.Module): except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 - def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, torch.Tensor]]: tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples @@ -76,7 +77,7 @@ class Downsample(nn.Module): else: assert self.channels == self.out_channels from torch.nn import MaxUnpool2d - self.op = MaxUnpool2d(dims, kernel_size=stride, stride=stride) + self.op = MaxUnpool2d(kernel_size=stride, stride=stride) def forward(self, x): assert x.shape[1] == self.channels @@ -267,10 +268,7 @@ class Adapter_XL(nn.Module): if t is not None: if not torch.is_tensor(t): is_mps = x[0].device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if is_mps else torch.int64 t = torch.tensor([t], dtype=dtype, device=x[0].device) elif len(t.shape) == 0: t = t[None].to(x[0].device) diff --git a/scripts/xadapter/unet_adapter.py b/scripts/xadapter/unet_adapter.py index 5890c7749..7ee6af94e 100644 --- a/scripts/xadapter/unet_adapter.py +++ b/scripts/xadapter/unet_adapter.py @@ -728,7 +728,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) down_bridge_residuals: Optional[Tuple[torch.Tensor]] = None, fusion_guidance_scale: Optional[torch.FloatTensor] = None, fusion_type: Optional[str] = 'ADD', - adapter: Optional = None + adapter = None ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method.