From f3b4ef2551b3bb3cb21a604c20da5a49b478f4a3 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 13 Oct 2025 10:20:42 -0400 Subject: [PATCH] simplify seedvr depedencies Signed-off-by: Vladimir Mandic --- modules/postprocess/seedvr_model.py | 2 - modules/seedvr/rotary_embedding.py | 346 ++++++++++++++++++ .../src/common/distributed/meta_init_utils.py | 3 +- modules/seedvr/src/models/dit/rope.py | 3 +- modules/seedvr/src/models/dit_v2/rope.py | 2 +- 5 files changed, 349 insertions(+), 7 deletions(-) create mode 100644 modules/seedvr/rotary_embedding.py diff --git a/modules/postprocess/seedvr_model.py b/modules/postprocess/seedvr_model.py index 84a49c22b..fddeff557 100644 --- a/modules/postprocess/seedvr_model.py +++ b/modules/postprocess/seedvr_model.py @@ -4,7 +4,6 @@ import numpy as np import torch from PIL import Image from torchvision.transforms import ToPILImage -from installer import install from modules import devices from modules.shared import opts, log from modules.upscaler import Upscaler, UpscalerData @@ -33,7 +32,6 @@ class UpscalerSeedVR(Upscaler): def load_model(self, path: str): model_name = MODELS_MAP.get(path, None) if (self.model is None) or (self.model_loaded != model_name): - install('rotary_embedding_torch') log.debug(f'Upscaler loading: name="{self.name}" model="{model_name}"') t0 = time.time() from modules.seedvr.src.core.model_manager import configure_runner diff --git a/modules/seedvr/rotary_embedding.py b/modules/seedvr/rotary_embedding.py new file mode 100644 index 000000000..021ae9e52 --- /dev/null +++ b/modules/seedvr/rotary_embedding.py @@ -0,0 +1,346 @@ +from __future__ import annotations +from typing import Literal +from math import pi +import torch +from torch.amp import autocast +from torch.nn import Module +from torch import nn, einsum, broadcast_tensors, is_tensor, Tensor +from einops import rearrange, repeat + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# broadcat, as tortoise-tts was using it + +def broadcat(tensors, dim = -1): + broadcasted_tensors = broadcast_tensors(*tensors) + return torch.cat(broadcasted_tensors, dim = dim) + +def slice_at_dim(t, dim_slice: slice, *, dim): + dim += (t.ndim if dim < 0 else 0) + colons = [slice(None)] * t.ndim + colons[dim] = dim_slice + return t[tuple(colons)] + +# rotary embedding helper functions + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') + +@autocast('cuda', enabled = False) +def apply_rotary_emb( + freqs, + t, + start_index = 0, + scale = 1., + seq_dim = -2, + freqs_seq_dim = None +): + dtype = t.dtype + + if not exists(freqs_seq_dim): + if freqs.ndim == 2 or t.ndim == 3: + freqs_seq_dim = 0 + + if t.ndim == 3 or exists(freqs_seq_dim): + seq_len = t.shape[seq_dim] + freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + + # Split t into three parts: left, middle (to be transformed), and right + t_left = t[..., :start_index] + t_middle = t[..., start_index:end_index] + t_right = t[..., end_index:] + + # Apply rotary embeddings without modifying t in place + t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) + out = torch.cat((t_left, t_transformed, t_right), dim=-1) + + return out.type(dtype) + +# learned rotation helpers + +def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None): + if exists(freq_ranges): + rotations = einsum('..., f -> ... f', rotations, freq_ranges) + rotations = rearrange(rotations, '... r f -> ... (r f)') + + rotations = repeat(rotations, '... n -> ... (n r)', r = 2) + return apply_rotary_emb(rotations, t, start_index = start_index) + +# classes + +class RotaryEmbedding(Module): + def __init__( + self, + dim, + custom_freqs: Tensor | None = None, + freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + learned_freq = False, + use_xpos = False, + xpos_scale_base = 512, + interpolate_factor = 1., + theta_rescale_factor = 1., + seq_before_head_dim = False, + cache_if_possible = True, + cache_max_seq_len = 8192 + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + + self.cache_if_possible = cache_if_possible + self.cache_max_seq_len = cache_max_seq_len + + self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) + self.cached_freqs_seq_len = 0 + + self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) # pylint: disable=possibly-used-before-assignment + + self.learned_freq = learned_freq + + # dummy for device + + self.register_buffer('dummy', torch.tensor(0), persistent = False) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1. + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + + if not use_xpos: + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + + self.register_buffer('scale', scale, persistent = False) + self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False) + self.cached_scales_seq_len = 0 + + # add apply_rotary_emb as static method + + self.apply_rotary_emb = staticmethod(apply_rotary_emb) + + @property + def device(self): + return self.dummy.device + + def get_seq_pos(self, seq_len, device = None, dtype = None, offset = 0): + device = default(device, self.device) + dtype = default(dtype, self.cached_freqs.dtype) + + return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor + + def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, scale = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert not self.use_xpos or exists(scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset) + + freqs = self.forward(seq, seq_len = seq_len, offset = offset) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + + return apply_rotary_emb(freqs, t, scale = default(scale, 1.), seq_dim = seq_dim) + + def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0): + dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim) + + q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] + assert q_len <= k_len + + q_scale = k_scale = 1. + + if self.use_xpos: + seq = self.get_seq_pos(k_len, dtype = dtype, device = device) + + q_scale = self.get_scale(seq[-q_len:]).type(dtype) + k_scale = self.get_scale(seq).type(dtype) + + rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, scale = q_scale, offset = k_len - q_len + offset) + rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, scale = k_scale ** -1) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def rotate_queries_and_keys(self, q, k, seq_dim = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, dtype = dtype, device = device) + + freqs = self.forward(seq, seq_len = seq_len) + scale = self.get_scale(seq, seq_len = seq_len).to(dtype) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + scale = rearrange(scale, 'n d -> n 1 d') + + rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def get_scale( + self, + t: Tensor, + seq_len: int | None = None, + offset = 0 + ): + assert self.use_xpos + + should_cache = ( + self.cache_if_possible and + exists(seq_len) and + (offset + seq_len) <= self.cache_max_seq_len + ) + + if ( + should_cache and \ + exists(self.cached_scales) and \ + (seq_len + offset) <= self.cached_scales_seq_len + ): + return self.cached_scales[offset:(offset + seq_len)] + + scale = 1. + if self.use_xpos: + power = (t - len(t) // 2) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = repeat(scale, 'n d -> n (d r)', r = 2) + + if should_cache and offset == 0: + self.cached_scales[:seq_len] = scale.detach() + self.cached_scales_seq_len = seq_len + + return scale + + def get_axial_freqs( + self, + *dims, + offsets: ( + tuple[int | float, ...] | + Tensor | + None + ) = None + ): + Colon = slice(None) + all_freqs = [] + + # handle offset + + if exists(offsets): + if not is_tensor(offsets): + offsets = torch.tensor(offsets) + + assert len(offsets) == len(dims) + + # get frequencies for each axis + + for ind, dim in enumerate(dims): + + offset = 0 + if exists(offsets): + offset = offsets[ind] + + if self.freqs_for == 'pixel': + pos = torch.linspace(-1, 1, steps = dim, device = self.device) + else: + pos = torch.arange(dim, device = self.device) + + pos = pos + offset + + freqs = self.forward(pos, seq_len = dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + # concat all freqs + + all_freqs = broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim = -1) + + @autocast('cuda', enabled = False) + def forward( + self, + t: Tensor, + seq_len: int | None = None, + offset = 0 + ): + should_cache = ( + self.cache_if_possible and + not self.learned_freq and + exists(seq_len) and + self.freqs_for != 'pixel' and + (offset + seq_len) <= self.cache_max_seq_len + ) + + if ( + should_cache and \ + exists(self.cached_freqs) and \ + (offset + seq_len) <= self.cached_freqs_seq_len + ): + return self.cached_freqs[offset:(offset + seq_len)].detach() + + freqs = self.freqs + + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + + if should_cache and offset == 0: + self.cached_freqs[:seq_len] = freqs.detach() + self.cached_freqs_seq_len = seq_len + + return freqs diff --git a/modules/seedvr/src/common/distributed/meta_init_utils.py b/modules/seedvr/src/common/distributed/meta_init_utils.py index 794cd0b81..9a6570646 100644 --- a/modules/seedvr/src/common/distributed/meta_init_utils.py +++ b/modules/seedvr/src/common/distributed/meta_init_utils.py @@ -13,9 +13,8 @@ # // limitations under the License. import torch -from rotary_embedding_torch import RotaryEmbedding from torch import nn -from torch.distributed.fsdp._common_utils import _is_fsdp_flattened +from ....rotary_embedding import RotaryEmbedding __all__ = ["meta_non_persistent_buffer_init_fn"] diff --git a/modules/seedvr/src/models/dit/rope.py b/modules/seedvr/src/models/dit/rope.py index 3aded8fed..35b91ea8b 100644 --- a/modules/seedvr/src/models/dit/rope.py +++ b/modules/seedvr/src/models/dit/rope.py @@ -16,10 +16,9 @@ from functools import lru_cache from typing import Tuple import torch from einops import rearrange -from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb from torch import nn - from ...common.cache import Cache +from ....rotary_embedding import RotaryEmbedding class RotaryEmbeddingBase(nn.Module): diff --git a/modules/seedvr/src/models/dit_v2/rope.py b/modules/seedvr/src/models/dit_v2/rope.py index 89d851792..3d294621c 100644 --- a/modules/seedvr/src/models/dit_v2/rope.py +++ b/modules/seedvr/src/models/dit_v2/rope.py @@ -16,9 +16,9 @@ from functools import lru_cache from typing import Optional, Tuple import torch from einops import rearrange -from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb from torch import nn from ...common.cache import Cache +from ....rotary_embedding import RotaryEmbedding, apply_rotary_emb class RotaryEmbeddingBase(nn.Module):