mirror of https://github.com/vladmandic/automatic
initial pyright lint
Signed-off-by: Vladimir Mandic <mandic00@live.com>pull/4663/head
parent
0c9c86c3f9
commit
9a63ec758a
15
CHANGELOG.md
15
CHANGELOG.md
|
|
@ -1,12 +1,12 @@
|
|||
# Change Log for SD.Next
|
||||
|
||||
## Update for 2026-02-20
|
||||
## Update for 2026-02-21
|
||||
|
||||
### Highlights for 2026-02-20
|
||||
### Highlights for 2026-02-21
|
||||
|
||||
TBD
|
||||
|
||||
### Details for 2026-02-20
|
||||
### Details for 2026-02-21
|
||||
|
||||
- **Models**
|
||||
- [FireRed Image Edit](https://huggingface.co/FireRedTeam/FireRed-Image-Edit-1.0)
|
||||
|
|
@ -47,8 +47,7 @@ TBD
|
|||
`clip, numba, skimage, torchsde, omegaconf, antlr, patch-ng, patch-ng, astunparse, addict, inflection, jsonmerge, kornia`,
|
||||
`resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets, imp`
|
||||
these are now installed on-demand when needed
|
||||
- refactor to/from image/tensor logic, thanks @CalamitousFelicitousness
|
||||
- switch to `pyproject.toml` for tool configs
|
||||
- refactor to/from *image/tensor* logic, thanks @CalamitousFelicitousness
|
||||
- refactor reorganize `cli` scripts
|
||||
- refactor move tests to dedicated `/test/`
|
||||
- refactor all image handling to `modules/image/`
|
||||
|
|
@ -67,8 +66,12 @@ TBD
|
|||
- remove requirements: `clip`, `open-clip`
|
||||
- remove `normalbae` pre-processor
|
||||
- captioning part-2, thanks @CalamitousFelicitousness
|
||||
- update `lint` rules, thanks @awsr
|
||||
- add new build of `insightface`, thanks @hameerabbasi
|
||||
- **Checks**
|
||||
- switch to `pyproject.toml` for tool configs
|
||||
- update `lint` rules, thanks @awsr
|
||||
- add `ty` to optional lint tooling
|
||||
- add `pyright` to optional lint tooling
|
||||
- **Fixes**
|
||||
- handle `clip` installer doing unwanted `setuptools` update
|
||||
- cleanup for `uv` installer fallback
|
||||
|
|
|
|||
|
|
@ -79,6 +79,21 @@ inpainting_is_aggressive_raunet = False
|
|||
playground_is_aggressive_raunet = False
|
||||
|
||||
|
||||
def _chunked_feed_forward(ff: torch.nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||
)
|
||||
|
||||
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
||||
ff_output = torch.cat(
|
||||
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
||||
dim=chunk_dim,
|
||||
)
|
||||
return ff_output
|
||||
|
||||
|
||||
def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
|
||||
# replace global self-attention with MSW-MSA
|
||||
class transformer_block(block_class):
|
||||
|
|
@ -238,7 +253,7 @@ def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type
|
|||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
if self._chunk_size is not None:
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) # pylint: disable=undefined-variable
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
if self.use_ada_layer_norm_zero:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ from diffusers.models import ControlNetModel
|
|||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
||||
|
||||
|
||||
controlnet_apply_steps_rate = 0.6
|
||||
|
||||
|
||||
def make_diffusers_unet_2d_condition(block_class):
|
||||
|
||||
class unet_2d_condition(block_class):
|
||||
|
|
|
|||
|
|
@ -53,9 +53,6 @@ class TDDScheduler(DPMSolverSinglestepScheduler):
|
|||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
|
|
|
|||
|
|
@ -63,10 +63,10 @@ if devices.backend != "ipex":
|
|||
if torch.__version__.startswith("2.6"):
|
||||
from dataclasses import dataclass
|
||||
from torch.compiler import disable as disable_compile # pylint: disable=ungrouped-imports
|
||||
import diffusers.models.autoencoders.autoencoder_kl # pylint: disable=ungrouped-imports
|
||||
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution # pylint: disable=ungrouped-imports
|
||||
|
||||
@dataclass
|
||||
@disable_compile
|
||||
class AutoencoderKLOutput(diffusers.utils.BaseOutput):
|
||||
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
|
||||
latent_dist: DiagonalGaussianDistribution
|
||||
diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput = AutoencoderKLOutput
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import torch
|
|||
from einops import rearrange
|
||||
from torch import nn
|
||||
from ...common.cache import Cache
|
||||
from ....rotary_embedding import RotaryEmbedding
|
||||
from ....rotary_embedding import RotaryEmbedding, apply_rotary_emb
|
||||
|
||||
|
||||
class RotaryEmbeddingBase(nn.Module):
|
||||
|
|
|
|||
|
|
@ -27,6 +27,10 @@
|
|||
"ruff-win": "venv\\scripts\\activate && ruff check",
|
||||
"pylint": ". venv/bin/activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/ | grep -v '^*'",
|
||||
"pylint-win": "venv\\scripts\\activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/",
|
||||
"pyright": ". venv/bin/activate && pyright --threads 4",
|
||||
"pyright-win": "venv\\scripts\\activate && pyright --threads 4",
|
||||
"ty": ". venv/bin/activate && ty check",
|
||||
"ty-win": "venv\\scripts\\activate && ty check",
|
||||
"lint": "npm run format && npm run eslint && npm run eslint-ui && npm run ruff && npm run pylint",
|
||||
"lint-win": "npm run format-win && npm run eslint && npm run eslint-ui && npm run ruff-win && npm run pylint-win",
|
||||
"test": ". venv/bin/activate; python launch.py --debug --test",
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ class Scheduler(SchedulerMixin, ConfigMixin):
|
|||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: torch.long,
|
||||
timestep: torch.Tensor,
|
||||
sample: torch.LongTensor,
|
||||
starting_mask_ratio: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
|
|
|
|||
|
|
@ -366,3 +366,42 @@ variables.redefining-builtins-modules=["six.moves","past.builtins","future.built
|
|||
pythonVersion = "3.10"
|
||||
pythonPlatform = "All"
|
||||
typeCheckingMode = "off"
|
||||
venvPath = "./venv"
|
||||
include = [
|
||||
"*.py",
|
||||
"modules/**/*.py",
|
||||
"pipelines/**/*.py",
|
||||
"scripts/**/*.py",
|
||||
"extensions-builtin/**/*.py"
|
||||
]
|
||||
exclude = [
|
||||
"**/.*",
|
||||
".git/",
|
||||
"**/node_modules",
|
||||
"**/__pycache__",
|
||||
"venv",
|
||||
]
|
||||
reportMissingImports = "none"
|
||||
reportInvalidTypeForm = "none"
|
||||
|
||||
[tool.ty.environment]
|
||||
python = "./venv/bin/python"
|
||||
python-platform = "all"
|
||||
python-version = "3.10"
|
||||
|
||||
[tool.ty.src]
|
||||
include = [
|
||||
"./extensions-builtin"
|
||||
]
|
||||
exclude = [
|
||||
"venv/",
|
||||
"*.git/",
|
||||
]
|
||||
|
||||
[tool.ty.rules]
|
||||
invalid-method-overrides = "ignore"
|
||||
invalid-argument-types = "ignore"
|
||||
unresolved-imports = "ignore"
|
||||
unresolved-attributes = "ignore"
|
||||
invalid-assignments = "ignore"
|
||||
unsupported-operators = "ignore"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import io
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple, Set, Iterable
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
|
@ -19,7 +18,6 @@ __all__ = ['GlobalHeatMap', 'RawHeatMapCollection', 'WordHeatMap', 'ParsedHeatMa
|
|||
|
||||
|
||||
def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None, color_normalize=True, ax=None, cmap='jet'):
|
||||
# type: (PIL.Image.Image | np.ndarray, torch.Tensor, str, Path, int, bool, plt.Axes) -> None
|
||||
if ax is None:
|
||||
plt.rcParams['font.size'] = 16
|
||||
plt.rcParams['figure.facecolor'] = 'black'
|
||||
|
|
@ -76,7 +74,6 @@ class WordHeatMap:
|
|||
return self.heatmap
|
||||
|
||||
def plot_overlay(self, image, out_file=None, color_normalize=True, ax=None, cmap='jet', **expand_kwargs):
|
||||
# type: (PIL.Image.Image | np.ndarray, Path, bool, plt.Axes, Dict[str, Any]) -> None
|
||||
return plot_overlay_heat_map(
|
||||
image,
|
||||
self.expand_as(image, **expand_kwargs),
|
||||
|
|
|
|||
|
|
@ -1369,7 +1369,7 @@ def init_attn_proc(unet, ip_adapter_tokens=16, use_lcm=False, use_adaln=True, us
|
|||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
time_embedding_dim=1280,
|
||||
) if hasattr(F, "scaled_dot_product_attention") else AdditiveKV_AttnProcessor()
|
||||
)
|
||||
else:
|
||||
attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import os
|
|||
import torch
|
||||
from typing import List
|
||||
from collections import namedtuple, OrderedDict
|
||||
from utils import revise_state_dict
|
||||
|
||||
|
||||
def is_torch2_available():
|
||||
return hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import torch
|
||||
from collections import namedtuple, OrderedDict
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from .attention_processor import init_attn_proc
|
||||
from .ip_adapter import MultiIPAdapterImageProjection
|
||||
from .resampler import Resampler
|
||||
from transformers import (
|
||||
AutoModel, AutoImageProcessor,
|
||||
CLIPVisionModelWithProjection, CLIPImageProcessor)
|
||||
from .attention_processor import init_attn_proc
|
||||
from .ip_adapter import MultiIPAdapterImageProjection
|
||||
from .resampler import Resampler
|
||||
|
||||
|
||||
def init_adapter_in_unet(
|
||||
|
|
@ -80,85 +80,84 @@ def load_adapter_to_pipe(
|
|||
use_lcm=False,
|
||||
use_adaln=True,
|
||||
):
|
||||
|
||||
if not isinstance(pretrained_model_path_or_dict, dict):
|
||||
if pretrained_model_path_or_dict.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device)
|
||||
if not isinstance(pretrained_model_path_or_dict, dict):
|
||||
if pretrained_model_path_or_dict.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = pretrained_model_path_or_dict
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
state_dict = revise_state_dict(state_dict)
|
||||
state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device)
|
||||
else:
|
||||
state_dict = pretrained_model_path_or_dict
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
state_dict = revise_state_dict(state_dict)
|
||||
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if image_encoder_or_path is not None:
|
||||
if isinstance(image_encoder_or_path, str):
|
||||
feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if image_encoder_or_path is not None:
|
||||
if isinstance(image_encoder_or_path, str):
|
||||
feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path
|
||||
|
||||
image_encoder_or_path = (
|
||||
CLIPVisionModelWithProjection.from_pretrained(
|
||||
image_encoder_or_path
|
||||
) if use_clip_encoder else
|
||||
AutoModel.from_pretrained(image_encoder_or_path)
|
||||
)
|
||||
image_encoder_or_path = (
|
||||
CLIPVisionModelWithProjection.from_pretrained(
|
||||
image_encoder_or_path
|
||||
) if use_clip_encoder else
|
||||
AutoModel.from_pretrained(image_encoder_or_path)
|
||||
)
|
||||
|
||||
if feature_extractor_or_path is not None:
|
||||
if isinstance(feature_extractor_or_path, str):
|
||||
feature_extractor_or_path = (
|
||||
CLIPImageProcessor() if use_clip_encoder else
|
||||
AutoImageProcessor.from_pretrained(feature_extractor_or_path)
|
||||
)
|
||||
if feature_extractor_or_path is not None:
|
||||
if isinstance(feature_extractor_or_path, str):
|
||||
feature_extractor_or_path = (
|
||||
CLIPImageProcessor() if use_clip_encoder else
|
||||
AutoImageProcessor.from_pretrained(feature_extractor_or_path)
|
||||
)
|
||||
|
||||
# create image encoder if it has not been registered to the pipeline yet
|
||||
if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None:
|
||||
image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype)
|
||||
pipe.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
image_encoder = pipe.image_encoder
|
||||
# create image encoder if it has not been registered to the pipeline yet
|
||||
if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None:
|
||||
image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype)
|
||||
pipe.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
image_encoder = pipe.image_encoder
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None:
|
||||
feature_extractor = feature_extractor_or_path
|
||||
pipe.register_modules(feature_extractor=feature_extractor)
|
||||
else:
|
||||
feature_extractor = pipe.feature_extractor
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None:
|
||||
feature_extractor = feature_extractor_or_path
|
||||
pipe.register_modules(feature_extractor=feature_extractor)
|
||||
else:
|
||||
feature_extractor = pipe.feature_extractor
|
||||
|
||||
# load adapter into unet
|
||||
unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet
|
||||
attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
|
||||
unet.set_attn_processor(attn_procs)
|
||||
image_proj_model = Resampler(
|
||||
embedding_dim=image_encoder.config.hidden_size,
|
||||
output_dim=unet.config.cross_attention_dim,
|
||||
num_queries=adapter_tokens,
|
||||
)
|
||||
# load adapter into unet
|
||||
unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet
|
||||
attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
|
||||
unet.set_attn_processor(attn_procs)
|
||||
image_proj_model = Resampler(
|
||||
embedding_dim=image_encoder.config.hidden_size,
|
||||
output_dim=unet.config.cross_attention_dim,
|
||||
num_queries=adapter_tokens,
|
||||
)
|
||||
|
||||
# Load pretrinaed model if needed.
|
||||
if "ip_adapter" in state_dict.keys():
|
||||
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
||||
missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
|
||||
for mk in missing:
|
||||
if "ln" not in mk:
|
||||
raise ValueError(f"Missing keys in adapter_modules: {missing}")
|
||||
if "image_proj" in state_dict.keys():
|
||||
image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||
# Load pretrinaed model if needed.
|
||||
if "ip_adapter" in state_dict.keys():
|
||||
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
||||
missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
|
||||
for mk in missing:
|
||||
if "ln" not in mk:
|
||||
raise ValueError(f"Missing keys in adapter_modules: {missing}")
|
||||
if "image_proj" in state_dict.keys():
|
||||
image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||
|
||||
# convert IP-Adapter Image Projection layers to diffusers
|
||||
image_projection_layers = []
|
||||
image_projection_layers.append(image_proj_model)
|
||||
unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
# convert IP-Adapter Image Projection layers to diffusers
|
||||
image_projection_layers = []
|
||||
image_projection_layers.append(image_proj_model)
|
||||
unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
|
||||
# Adjust unet config to handle addtional ip hidden states.
|
||||
unet.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
unet.to(dtype=pipe.dtype, device=pipe.device)
|
||||
# Adjust unet config to handle addtional ip hidden states.
|
||||
unet.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
unet.to(dtype=pipe.dtype, device=pipe.device)
|
||||
|
||||
|
||||
def revise_state_dict(old_state_dict_or_path, map_location="cpu"):
|
||||
|
|
@ -198,51 +197,3 @@ def encode_image(image_encoder, feature_extractor, image, device, num_images_per
|
|||
image_embeds = image_encoder(image).last_hidden_state
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
return image_embeds
|
||||
|
||||
|
||||
def prepare_training_image_embeds(
|
||||
image_encoder, feature_extractor,
|
||||
ip_adapter_image, ip_adapter_image_embeds,
|
||||
device, drop_rate, output_hidden_state, idx_to_replace=None
|
||||
):
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
# if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers):
|
||||
# raise ValueError(
|
||||
# f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
# )
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image in ip_adapter_image:
|
||||
if idx_to_replace is None:
|
||||
idx_to_replace = torch.rand(len(single_ip_adapter_image)) < drop_rate
|
||||
zero_ip_adapter_image = torch.zeros_like(single_ip_adapter_image)
|
||||
single_ip_adapter_image[idx_to_replace] = zero_ip_adapter_image[idx_to_replace]
|
||||
single_image_embeds = encode_image(
|
||||
image_encoder, feature_extractor, single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds], dim=1) # FIXME
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
repeat_dims = [1]
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ except Exception:
|
|||
from .modified_resnet import ModifiedResNet
|
||||
from .timm_model import TimmModel
|
||||
from .eva_vit_model import EVAVisionTransformer
|
||||
from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
||||
from .transformer import LayerNorm, LayerNormFp32, QuickGELU, Attention, VisionTransformer, TextTransformer
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm
|
||||
|
|
|
|||
|
|
@ -17,6 +17,15 @@ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
|
|||
from .utils import to_2tuple
|
||||
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops as xops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except Exception:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
xops = None
|
||||
|
||||
|
||||
class LayerNormFp32(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -417,10 +426,7 @@ class CustomTransformer(nn.Module):
|
|||
if k is None and v is None:
|
||||
k = v = q
|
||||
for r in self.resblocks:
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
q = checkpoint(r, q, k, v, attn_mask)
|
||||
else:
|
||||
q = r(q, k, v, attn_mask=attn_mask)
|
||||
q = r(q, k, v, attn_mask=attn_mask)
|
||||
return q
|
||||
|
||||
|
||||
|
|
@ -494,10 +500,7 @@ class Transformer(nn.Module):
|
|||
|
||||
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
||||
for r in self.resblocks:
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from typing import List, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
|
|
@ -46,7 +47,7 @@ def get_parameter_dtype(parameter: torch.nn.Module):
|
|||
except StopIteration:
|
||||
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, torch.Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
|
|
@ -76,7 +77,7 @@ class Downsample(nn.Module):
|
|||
else:
|
||||
assert self.channels == self.out_channels
|
||||
from torch.nn import MaxUnpool2d
|
||||
self.op = MaxUnpool2d(dims, kernel_size=stride, stride=stride)
|
||||
self.op = MaxUnpool2d(kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
|
|
@ -267,10 +268,7 @@ class Adapter_XL(nn.Module):
|
|||
if t is not None:
|
||||
if not torch.is_tensor(t):
|
||||
is_mps = x[0].device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
t = torch.tensor([t], dtype=dtype, device=x[0].device)
|
||||
elif len(t.shape) == 0:
|
||||
t = t[None].to(x[0].device)
|
||||
|
|
|
|||
|
|
@ -728,7 +728,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|||
down_bridge_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
fusion_guidance_scale: Optional[torch.FloatTensor] = None,
|
||||
fusion_type: Optional[str] = 'ADD',
|
||||
adapter: Optional = None
|
||||
adapter = None
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet2DConditionModel`] forward method.
|
||||
|
|
|
|||
Loading…
Reference in New Issue