mirror of https://github.com/vladmandic/automatic
initial pyright lint
Signed-off-by: Vladimir Mandic <mandic00@live.com>pull/4663/head
parent
0c9c86c3f9
commit
9a63ec758a
15
CHANGELOG.md
15
CHANGELOG.md
|
|
@ -1,12 +1,12 @@
|
||||||
# Change Log for SD.Next
|
# Change Log for SD.Next
|
||||||
|
|
||||||
## Update for 2026-02-20
|
## Update for 2026-02-21
|
||||||
|
|
||||||
### Highlights for 2026-02-20
|
### Highlights for 2026-02-21
|
||||||
|
|
||||||
TBD
|
TBD
|
||||||
|
|
||||||
### Details for 2026-02-20
|
### Details for 2026-02-21
|
||||||
|
|
||||||
- **Models**
|
- **Models**
|
||||||
- [FireRed Image Edit](https://huggingface.co/FireRedTeam/FireRed-Image-Edit-1.0)
|
- [FireRed Image Edit](https://huggingface.co/FireRedTeam/FireRed-Image-Edit-1.0)
|
||||||
|
|
@ -47,8 +47,7 @@ TBD
|
||||||
`clip, numba, skimage, torchsde, omegaconf, antlr, patch-ng, patch-ng, astunparse, addict, inflection, jsonmerge, kornia`,
|
`clip, numba, skimage, torchsde, omegaconf, antlr, patch-ng, patch-ng, astunparse, addict, inflection, jsonmerge, kornia`,
|
||||||
`resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets, imp`
|
`resize-right, voluptuous, yapf, sqlalchemy, invisible-watermark, pi-heif, ftfy, blendmodes, PyWavelets, imp`
|
||||||
these are now installed on-demand when needed
|
these are now installed on-demand when needed
|
||||||
- refactor to/from image/tensor logic, thanks @CalamitousFelicitousness
|
- refactor to/from *image/tensor* logic, thanks @CalamitousFelicitousness
|
||||||
- switch to `pyproject.toml` for tool configs
|
|
||||||
- refactor reorganize `cli` scripts
|
- refactor reorganize `cli` scripts
|
||||||
- refactor move tests to dedicated `/test/`
|
- refactor move tests to dedicated `/test/`
|
||||||
- refactor all image handling to `modules/image/`
|
- refactor all image handling to `modules/image/`
|
||||||
|
|
@ -67,8 +66,12 @@ TBD
|
||||||
- remove requirements: `clip`, `open-clip`
|
- remove requirements: `clip`, `open-clip`
|
||||||
- remove `normalbae` pre-processor
|
- remove `normalbae` pre-processor
|
||||||
- captioning part-2, thanks @CalamitousFelicitousness
|
- captioning part-2, thanks @CalamitousFelicitousness
|
||||||
- update `lint` rules, thanks @awsr
|
|
||||||
- add new build of `insightface`, thanks @hameerabbasi
|
- add new build of `insightface`, thanks @hameerabbasi
|
||||||
|
- **Checks**
|
||||||
|
- switch to `pyproject.toml` for tool configs
|
||||||
|
- update `lint` rules, thanks @awsr
|
||||||
|
- add `ty` to optional lint tooling
|
||||||
|
- add `pyright` to optional lint tooling
|
||||||
- **Fixes**
|
- **Fixes**
|
||||||
- handle `clip` installer doing unwanted `setuptools` update
|
- handle `clip` installer doing unwanted `setuptools` update
|
||||||
- cleanup for `uv` installer fallback
|
- cleanup for `uv` installer fallback
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,21 @@ inpainting_is_aggressive_raunet = False
|
||||||
playground_is_aggressive_raunet = False
|
playground_is_aggressive_raunet = False
|
||||||
|
|
||||||
|
|
||||||
|
def _chunked_feed_forward(ff: torch.nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
||||||
|
# "feed_forward_chunk_size" can be used to save memory
|
||||||
|
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||||
|
)
|
||||||
|
|
||||||
|
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
||||||
|
ff_output = torch.cat(
|
||||||
|
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
||||||
|
dim=chunk_dim,
|
||||||
|
)
|
||||||
|
return ff_output
|
||||||
|
|
||||||
|
|
||||||
def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
|
def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
|
||||||
# replace global self-attention with MSW-MSA
|
# replace global self-attention with MSW-MSA
|
||||||
class transformer_block(block_class):
|
class transformer_block(block_class):
|
||||||
|
|
@ -238,7 +253,7 @@ def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type
|
||||||
norm_hidden_states = self.norm2(hidden_states)
|
norm_hidden_states = self.norm2(hidden_states)
|
||||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||||
if self._chunk_size is not None:
|
if self._chunk_size is not None:
|
||||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) # pylint: disable=undefined-variable
|
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||||
else:
|
else:
|
||||||
ff_output = self.ff(norm_hidden_states)
|
ff_output = self.ff(norm_hidden_states)
|
||||||
if self.use_ada_layer_norm_zero:
|
if self.use_ada_layer_norm_zero:
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,9 @@ from diffusers.models import ControlNetModel
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
||||||
|
|
||||||
|
|
||||||
|
controlnet_apply_steps_rate = 0.6
|
||||||
|
|
||||||
|
|
||||||
def make_diffusers_unet_2d_condition(block_class):
|
def make_diffusers_unet_2d_condition(block_class):
|
||||||
|
|
||||||
class unet_2d_condition(block_class):
|
class unet_2d_condition(block_class):
|
||||||
|
|
|
||||||
|
|
@ -53,9 +53,6 @@ class TDDScheduler(DPMSolverSinglestepScheduler):
|
||||||
elif beta_schedule == "scaled_linear":
|
elif beta_schedule == "scaled_linear":
|
||||||
# this schedule is very specific to the latent diffusion model.
|
# this schedule is very specific to the latent diffusion model.
|
||||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||||
elif beta_schedule == "squaredcos_cap_v2":
|
|
||||||
# Glide cosine schedule
|
|
||||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -63,10 +63,10 @@ if devices.backend != "ipex":
|
||||||
if torch.__version__.startswith("2.6"):
|
if torch.__version__.startswith("2.6"):
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from torch.compiler import disable as disable_compile # pylint: disable=ungrouped-imports
|
from torch.compiler import disable as disable_compile # pylint: disable=ungrouped-imports
|
||||||
import diffusers.models.autoencoders.autoencoder_kl # pylint: disable=ungrouped-imports
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution # pylint: disable=ungrouped-imports
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@disable_compile
|
@disable_compile
|
||||||
class AutoencoderKLOutput(diffusers.utils.BaseOutput):
|
class AutoencoderKLOutput(diffusers.utils.BaseOutput):
|
||||||
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
|
latent_dist: DiagonalGaussianDistribution
|
||||||
diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput = AutoencoderKLOutput
|
diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput = AutoencoderKLOutput
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from ...common.cache import Cache
|
from ...common.cache import Cache
|
||||||
from ....rotary_embedding import RotaryEmbedding
|
from ....rotary_embedding import RotaryEmbedding, apply_rotary_emb
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbeddingBase(nn.Module):
|
class RotaryEmbeddingBase(nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,10 @@
|
||||||
"ruff-win": "venv\\scripts\\activate && ruff check",
|
"ruff-win": "venv\\scripts\\activate && ruff check",
|
||||||
"pylint": ". venv/bin/activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/ | grep -v '^*'",
|
"pylint": ". venv/bin/activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/ | grep -v '^*'",
|
||||||
"pylint-win": "venv\\scripts\\activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/",
|
"pylint-win": "venv\\scripts\\activate && pylint *.py modules/ pipelines/ scripts/ extensions-builtin/",
|
||||||
|
"pyright": ". venv/bin/activate && pyright --threads 4",
|
||||||
|
"pyright-win": "venv\\scripts\\activate && pyright --threads 4",
|
||||||
|
"ty": ". venv/bin/activate && ty check",
|
||||||
|
"ty-win": "venv\\scripts\\activate && ty check",
|
||||||
"lint": "npm run format && npm run eslint && npm run eslint-ui && npm run ruff && npm run pylint",
|
"lint": "npm run format && npm run eslint && npm run eslint-ui && npm run ruff && npm run pylint",
|
||||||
"lint-win": "npm run format-win && npm run eslint && npm run eslint-ui && npm run ruff-win && npm run pylint-win",
|
"lint-win": "npm run format-win && npm run eslint && npm run eslint-ui && npm run ruff-win && npm run pylint-win",
|
||||||
"test": ". venv/bin/activate; python launch.py --debug --test",
|
"test": ". venv/bin/activate; python launch.py --debug --test",
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ class Scheduler(SchedulerMixin, ConfigMixin):
|
||||||
def step(
|
def step(
|
||||||
self,
|
self,
|
||||||
model_output: torch.Tensor,
|
model_output: torch.Tensor,
|
||||||
timestep: torch.long,
|
timestep: torch.Tensor,
|
||||||
sample: torch.LongTensor,
|
sample: torch.LongTensor,
|
||||||
starting_mask_ratio: int = 1,
|
starting_mask_ratio: int = 1,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator: Optional[torch.Generator] = None,
|
||||||
|
|
|
||||||
|
|
@ -366,3 +366,42 @@ variables.redefining-builtins-modules=["six.moves","past.builtins","future.built
|
||||||
pythonVersion = "3.10"
|
pythonVersion = "3.10"
|
||||||
pythonPlatform = "All"
|
pythonPlatform = "All"
|
||||||
typeCheckingMode = "off"
|
typeCheckingMode = "off"
|
||||||
|
venvPath = "./venv"
|
||||||
|
include = [
|
||||||
|
"*.py",
|
||||||
|
"modules/**/*.py",
|
||||||
|
"pipelines/**/*.py",
|
||||||
|
"scripts/**/*.py",
|
||||||
|
"extensions-builtin/**/*.py"
|
||||||
|
]
|
||||||
|
exclude = [
|
||||||
|
"**/.*",
|
||||||
|
".git/",
|
||||||
|
"**/node_modules",
|
||||||
|
"**/__pycache__",
|
||||||
|
"venv",
|
||||||
|
]
|
||||||
|
reportMissingImports = "none"
|
||||||
|
reportInvalidTypeForm = "none"
|
||||||
|
|
||||||
|
[tool.ty.environment]
|
||||||
|
python = "./venv/bin/python"
|
||||||
|
python-platform = "all"
|
||||||
|
python-version = "3.10"
|
||||||
|
|
||||||
|
[tool.ty.src]
|
||||||
|
include = [
|
||||||
|
"./extensions-builtin"
|
||||||
|
]
|
||||||
|
exclude = [
|
||||||
|
"venv/",
|
||||||
|
"*.git/",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ty.rules]
|
||||||
|
invalid-method-overrides = "ignore"
|
||||||
|
invalid-argument-types = "ignore"
|
||||||
|
unresolved-imports = "ignore"
|
||||||
|
unresolved-attributes = "ignore"
|
||||||
|
invalid-assignments = "ignore"
|
||||||
|
unsupported-operators = "ignore"
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ import io
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Tuple, Set, Iterable
|
from typing import Any, Dict, Tuple, Set, Iterable
|
||||||
|
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
@ -19,7 +18,6 @@ __all__ = ['GlobalHeatMap', 'RawHeatMapCollection', 'WordHeatMap', 'ParsedHeatMa
|
||||||
|
|
||||||
|
|
||||||
def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None, color_normalize=True, ax=None, cmap='jet'):
|
def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None, color_normalize=True, ax=None, cmap='jet'):
|
||||||
# type: (PIL.Image.Image | np.ndarray, torch.Tensor, str, Path, int, bool, plt.Axes) -> None
|
|
||||||
if ax is None:
|
if ax is None:
|
||||||
plt.rcParams['font.size'] = 16
|
plt.rcParams['font.size'] = 16
|
||||||
plt.rcParams['figure.facecolor'] = 'black'
|
plt.rcParams['figure.facecolor'] = 'black'
|
||||||
|
|
@ -76,7 +74,6 @@ class WordHeatMap:
|
||||||
return self.heatmap
|
return self.heatmap
|
||||||
|
|
||||||
def plot_overlay(self, image, out_file=None, color_normalize=True, ax=None, cmap='jet', **expand_kwargs):
|
def plot_overlay(self, image, out_file=None, color_normalize=True, ax=None, cmap='jet', **expand_kwargs):
|
||||||
# type: (PIL.Image.Image | np.ndarray, Path, bool, plt.Axes, Dict[str, Any]) -> None
|
|
||||||
return plot_overlay_heat_map(
|
return plot_overlay_heat_map(
|
||||||
image,
|
image,
|
||||||
self.expand_as(image, **expand_kwargs),
|
self.expand_as(image, **expand_kwargs),
|
||||||
|
|
|
||||||
|
|
@ -1369,7 +1369,7 @@ def init_attn_proc(unet, ip_adapter_tokens=16, use_lcm=False, use_adaln=True, us
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
time_embedding_dim=1280,
|
time_embedding_dim=1280,
|
||||||
) if hasattr(F, "scaled_dot_product_attention") else AdditiveKV_AttnProcessor()
|
)
|
||||||
else:
|
else:
|
||||||
attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ import os
|
||||||
import torch
|
import torch
|
||||||
from typing import List
|
from typing import List
|
||||||
from collections import namedtuple, OrderedDict
|
from collections import namedtuple, OrderedDict
|
||||||
|
from utils import revise_state_dict
|
||||||
|
|
||||||
|
|
||||||
def is_torch2_available():
|
def is_torch2_available():
|
||||||
return hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
return hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
import torch
|
|
||||||
from collections import namedtuple, OrderedDict
|
from collections import namedtuple, OrderedDict
|
||||||
|
import torch
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from .attention_processor import init_attn_proc
|
|
||||||
from .ip_adapter import MultiIPAdapterImageProjection
|
|
||||||
from .resampler import Resampler
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModel, AutoImageProcessor,
|
AutoModel, AutoImageProcessor,
|
||||||
CLIPVisionModelWithProjection, CLIPImageProcessor)
|
CLIPVisionModelWithProjection, CLIPImageProcessor)
|
||||||
|
from .attention_processor import init_attn_proc
|
||||||
|
from .ip_adapter import MultiIPAdapterImageProjection
|
||||||
|
from .resampler import Resampler
|
||||||
|
|
||||||
|
|
||||||
def init_adapter_in_unet(
|
def init_adapter_in_unet(
|
||||||
|
|
@ -80,85 +80,84 @@ def load_adapter_to_pipe(
|
||||||
use_lcm=False,
|
use_lcm=False,
|
||||||
use_adaln=True,
|
use_adaln=True,
|
||||||
):
|
):
|
||||||
|
if not isinstance(pretrained_model_path_or_dict, dict):
|
||||||
if not isinstance(pretrained_model_path_or_dict, dict):
|
if pretrained_model_path_or_dict.endswith(".safetensors"):
|
||||||
if pretrained_model_path_or_dict.endswith(".safetensors"):
|
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f:
|
||||||
with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f:
|
for key in f.keys():
|
||||||
for key in f.keys():
|
if key.startswith("image_proj."):
|
||||||
if key.startswith("image_proj."):
|
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
elif key.startswith("ip_adapter."):
|
||||||
elif key.startswith("ip_adapter."):
|
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
|
||||||
else:
|
|
||||||
state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device)
|
|
||||||
else:
|
else:
|
||||||
state_dict = pretrained_model_path_or_dict
|
state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device)
|
||||||
keys = list(state_dict.keys())
|
else:
|
||||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
state_dict = pretrained_model_path_or_dict
|
||||||
state_dict = revise_state_dict(state_dict)
|
keys = list(state_dict.keys())
|
||||||
|
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||||
|
state_dict = revise_state_dict(state_dict)
|
||||||
|
|
||||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||||
if image_encoder_or_path is not None:
|
if image_encoder_or_path is not None:
|
||||||
if isinstance(image_encoder_or_path, str):
|
if isinstance(image_encoder_or_path, str):
|
||||||
feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path
|
feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path
|
||||||
|
|
||||||
image_encoder_or_path = (
|
image_encoder_or_path = (
|
||||||
CLIPVisionModelWithProjection.from_pretrained(
|
CLIPVisionModelWithProjection.from_pretrained(
|
||||||
image_encoder_or_path
|
image_encoder_or_path
|
||||||
) if use_clip_encoder else
|
) if use_clip_encoder else
|
||||||
AutoModel.from_pretrained(image_encoder_or_path)
|
AutoModel.from_pretrained(image_encoder_or_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
if feature_extractor_or_path is not None:
|
if feature_extractor_or_path is not None:
|
||||||
if isinstance(feature_extractor_or_path, str):
|
if isinstance(feature_extractor_or_path, str):
|
||||||
feature_extractor_or_path = (
|
feature_extractor_or_path = (
|
||||||
CLIPImageProcessor() if use_clip_encoder else
|
CLIPImageProcessor() if use_clip_encoder else
|
||||||
AutoImageProcessor.from_pretrained(feature_extractor_or_path)
|
AutoImageProcessor.from_pretrained(feature_extractor_or_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
# create image encoder if it has not been registered to the pipeline yet
|
# create image encoder if it has not been registered to the pipeline yet
|
||||||
if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None:
|
if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None:
|
||||||
image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype)
|
image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype)
|
||||||
pipe.register_modules(image_encoder=image_encoder)
|
pipe.register_modules(image_encoder=image_encoder)
|
||||||
else:
|
else:
|
||||||
image_encoder = pipe.image_encoder
|
image_encoder = pipe.image_encoder
|
||||||
|
|
||||||
# create feature extractor if it has not been registered to the pipeline yet
|
# create feature extractor if it has not been registered to the pipeline yet
|
||||||
if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None:
|
if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None:
|
||||||
feature_extractor = feature_extractor_or_path
|
feature_extractor = feature_extractor_or_path
|
||||||
pipe.register_modules(feature_extractor=feature_extractor)
|
pipe.register_modules(feature_extractor=feature_extractor)
|
||||||
else:
|
else:
|
||||||
feature_extractor = pipe.feature_extractor
|
feature_extractor = pipe.feature_extractor
|
||||||
|
|
||||||
# load adapter into unet
|
# load adapter into unet
|
||||||
unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet
|
unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet
|
||||||
attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
|
attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
|
||||||
unet.set_attn_processor(attn_procs)
|
unet.set_attn_processor(attn_procs)
|
||||||
image_proj_model = Resampler(
|
image_proj_model = Resampler(
|
||||||
embedding_dim=image_encoder.config.hidden_size,
|
embedding_dim=image_encoder.config.hidden_size,
|
||||||
output_dim=unet.config.cross_attention_dim,
|
output_dim=unet.config.cross_attention_dim,
|
||||||
num_queries=adapter_tokens,
|
num_queries=adapter_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load pretrinaed model if needed.
|
# Load pretrinaed model if needed.
|
||||||
if "ip_adapter" in state_dict.keys():
|
if "ip_adapter" in state_dict.keys():
|
||||||
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
||||||
missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
|
missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
|
||||||
for mk in missing:
|
for mk in missing:
|
||||||
if "ln" not in mk:
|
if "ln" not in mk:
|
||||||
raise ValueError(f"Missing keys in adapter_modules: {missing}")
|
raise ValueError(f"Missing keys in adapter_modules: {missing}")
|
||||||
if "image_proj" in state_dict.keys():
|
if "image_proj" in state_dict.keys():
|
||||||
image_proj_model.load_state_dict(state_dict["image_proj"])
|
image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||||
|
|
||||||
# convert IP-Adapter Image Projection layers to diffusers
|
# convert IP-Adapter Image Projection layers to diffusers
|
||||||
image_projection_layers = []
|
image_projection_layers = []
|
||||||
image_projection_layers.append(image_proj_model)
|
image_projection_layers.append(image_proj_model)
|
||||||
unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||||
|
|
||||||
# Adjust unet config to handle addtional ip hidden states.
|
# Adjust unet config to handle addtional ip hidden states.
|
||||||
unet.config.encoder_hid_dim_type = "ip_image_proj"
|
unet.config.encoder_hid_dim_type = "ip_image_proj"
|
||||||
unet.to(dtype=pipe.dtype, device=pipe.device)
|
unet.to(dtype=pipe.dtype, device=pipe.device)
|
||||||
|
|
||||||
|
|
||||||
def revise_state_dict(old_state_dict_or_path, map_location="cpu"):
|
def revise_state_dict(old_state_dict_or_path, map_location="cpu"):
|
||||||
|
|
@ -198,51 +197,3 @@ def encode_image(image_encoder, feature_extractor, image, device, num_images_per
|
||||||
image_embeds = image_encoder(image).last_hidden_state
|
image_embeds = image_encoder(image).last_hidden_state
|
||||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||||
return image_embeds
|
return image_embeds
|
||||||
|
|
||||||
|
|
||||||
def prepare_training_image_embeds(
|
|
||||||
image_encoder, feature_extractor,
|
|
||||||
ip_adapter_image, ip_adapter_image_embeds,
|
|
||||||
device, drop_rate, output_hidden_state, idx_to_replace=None
|
|
||||||
):
|
|
||||||
if ip_adapter_image_embeds is None:
|
|
||||||
if not isinstance(ip_adapter_image, list):
|
|
||||||
ip_adapter_image = [ip_adapter_image]
|
|
||||||
|
|
||||||
# if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers):
|
|
||||||
# raise ValueError(
|
|
||||||
# f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
|
||||||
# )
|
|
||||||
|
|
||||||
image_embeds = []
|
|
||||||
for single_ip_adapter_image in ip_adapter_image:
|
|
||||||
if idx_to_replace is None:
|
|
||||||
idx_to_replace = torch.rand(len(single_ip_adapter_image)) < drop_rate
|
|
||||||
zero_ip_adapter_image = torch.zeros_like(single_ip_adapter_image)
|
|
||||||
single_ip_adapter_image[idx_to_replace] = zero_ip_adapter_image[idx_to_replace]
|
|
||||||
single_image_embeds = encode_image(
|
|
||||||
image_encoder, feature_extractor, single_ip_adapter_image, device, 1, output_hidden_state
|
|
||||||
)
|
|
||||||
single_image_embeds = torch.stack([single_image_embeds], dim=1) # FIXME
|
|
||||||
|
|
||||||
image_embeds.append(single_image_embeds)
|
|
||||||
else:
|
|
||||||
repeat_dims = [1]
|
|
||||||
image_embeds = []
|
|
||||||
for single_image_embeds in ip_adapter_image_embeds:
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
|
||||||
single_image_embeds = single_image_embeds.repeat(
|
|
||||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
|
||||||
)
|
|
||||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
|
||||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
|
||||||
)
|
|
||||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
|
||||||
else:
|
|
||||||
single_image_embeds = single_image_embeds.repeat(
|
|
||||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
|
||||||
)
|
|
||||||
image_embeds.append(single_image_embeds)
|
|
||||||
|
|
||||||
return image_embeds
|
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ except Exception:
|
||||||
from .modified_resnet import ModifiedResNet
|
from .modified_resnet import ModifiedResNet
|
||||||
from .timm_model import TimmModel
|
from .timm_model import TimmModel
|
||||||
from .eva_vit_model import EVAVisionTransformer
|
from .eva_vit_model import EVAVisionTransformer
|
||||||
from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
from .transformer import LayerNorm, LayerNormFp32, QuickGELU, Attention, VisionTransformer, TextTransformer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex.normalization import FusedLayerNorm
|
from apex.normalization import FusedLayerNorm
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,15 @@ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
|
||||||
from .utils import to_2tuple
|
from .utils import to_2tuple
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import xformers
|
||||||
|
import xformers.ops as xops
|
||||||
|
XFORMERS_IS_AVAILBLE = True
|
||||||
|
except Exception:
|
||||||
|
XFORMERS_IS_AVAILBLE = False
|
||||||
|
xops = None
|
||||||
|
|
||||||
|
|
||||||
class LayerNormFp32(nn.LayerNorm):
|
class LayerNormFp32(nn.LayerNorm):
|
||||||
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
|
@ -417,10 +426,7 @@ class CustomTransformer(nn.Module):
|
||||||
if k is None and v is None:
|
if k is None and v is None:
|
||||||
k = v = q
|
k = v = q
|
||||||
for r in self.resblocks:
|
for r in self.resblocks:
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
q = r(q, k, v, attn_mask=attn_mask)
|
||||||
q = checkpoint(r, q, k, v, attn_mask)
|
|
||||||
else:
|
|
||||||
q = r(q, k, v, attn_mask=attn_mask)
|
|
||||||
return q
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -494,10 +500,7 @@ class Transformer(nn.Module):
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
||||||
for r in self.resblocks:
|
for r in self.resblocks:
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
x = r(x, attn_mask=attn_mask)
|
||||||
x = checkpoint(r, x, attn_mask)
|
|
||||||
else:
|
|
||||||
x = r(x, attn_mask=attn_mask)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import List, Tuple
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
@ -46,7 +47,7 @@ def get_parameter_dtype(parameter: torch.nn.Module):
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
||||||
|
|
||||||
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, torch.Tensor]]:
|
||||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||||
return tuples
|
return tuples
|
||||||
|
|
||||||
|
|
@ -76,7 +77,7 @@ class Downsample(nn.Module):
|
||||||
else:
|
else:
|
||||||
assert self.channels == self.out_channels
|
assert self.channels == self.out_channels
|
||||||
from torch.nn import MaxUnpool2d
|
from torch.nn import MaxUnpool2d
|
||||||
self.op = MaxUnpool2d(dims, kernel_size=stride, stride=stride)
|
self.op = MaxUnpool2d(kernel_size=stride, stride=stride)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
assert x.shape[1] == self.channels
|
assert x.shape[1] == self.channels
|
||||||
|
|
@ -267,10 +268,7 @@ class Adapter_XL(nn.Module):
|
||||||
if t is not None:
|
if t is not None:
|
||||||
if not torch.is_tensor(t):
|
if not torch.is_tensor(t):
|
||||||
is_mps = x[0].device.type == "mps"
|
is_mps = x[0].device.type == "mps"
|
||||||
if isinstance(timestep, float):
|
dtype = torch.int32 if is_mps else torch.int64
|
||||||
dtype = torch.float32 if is_mps else torch.float64
|
|
||||||
else:
|
|
||||||
dtype = torch.int32 if is_mps else torch.int64
|
|
||||||
t = torch.tensor([t], dtype=dtype, device=x[0].device)
|
t = torch.tensor([t], dtype=dtype, device=x[0].device)
|
||||||
elif len(t.shape) == 0:
|
elif len(t.shape) == 0:
|
||||||
t = t[None].to(x[0].device)
|
t = t[None].to(x[0].device)
|
||||||
|
|
|
||||||
|
|
@ -728,7 +728,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||||
down_bridge_residuals: Optional[Tuple[torch.Tensor]] = None,
|
down_bridge_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||||
fusion_guidance_scale: Optional[torch.FloatTensor] = None,
|
fusion_guidance_scale: Optional[torch.FloatTensor] = None,
|
||||||
fusion_type: Optional[str] = 'ADD',
|
fusion_type: Optional[str] = 'ADD',
|
||||||
adapter: Optional = None
|
adapter = None
|
||||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||||
r"""
|
r"""
|
||||||
The [`UNet2DConditionModel`] forward method.
|
The [`UNet2DConditionModel`] forward method.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue