Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/4268/head
Vladimir Mandic 2025-10-12 15:35:08 -04:00
parent 8d36a5aebb
commit 2e4e741d47
30 changed files with 335 additions and 351 deletions

View File

@ -1,6 +1,6 @@
# Change Log for SD.Next
## Update for 2025-10-11
## Update for 2025-10-12
- **Models**
- [WAN 2.2 14B VACE](https://huggingface.co/alibaba-pai/Wan2.2-VACE-Fun-A14B)
@ -17,6 +17,13 @@
*how to use*: enable nunchaku in settings -> quantization and then load either sdxl-base or sdxl-base-turbo reference models
- [HiDream E1.1](https://huggingface.co/HiDream-ai/HiDream-E1-1)
updated version of E1 image editing model
- [SeedVR2](https://iceclear.github.io/projects/seedvr/)
originally designed for video restoration, seedvr works great for image detailing and upscaling!
available in 3B, 7B and 7B-sharp variants
use as any other upscaler!
note: seedvr is a very large model (6.4GB and 16GB respectively) and not designed for lower-end hardware
note: seedvr is highly sensitive to its cfg scale (set in settings -> postprocessing),
lower values will result in smoother output while higher values add details
- [X-Omni SFT](https://x-omni-team.github.io/)
*experimental*: X-omni is a transformer-only discrete autoregressive image generative model trained with reinforcement learning
- **Features**

View File

@ -482,7 +482,7 @@ def sdnq_quantize_model(model, op=None, sd_model=None, do_gc: bool = True, weigh
from modules.sdnq import sdnq_post_load_quant
if weights_dtype is None:
if op is not None and ("text_encoder" in op or op in {"TE", "LLM"}) and shared.opts.sdnq_quantize_weights_mode_te not in {"Same as model", "default"}:
if (op is not None) and ("text_encoder" in op or op in {"TE", "LLM"}) and (shared.opts.sdnq_quantize_weights_mode_te not in {"Same as model", "default"}):
weights_dtype = shared.opts.sdnq_quantize_weights_mode_te
else:
weights_dtype = shared.opts.sdnq_quantize_weights_mode
@ -588,6 +588,8 @@ def sdnq_quantize_weights(sd_model):
log.info(f"Quantization: type=SDNQ time={t1-t0:.2f}")
except Exception as e:
log.warning(f"Quantization: type=SDNQ {e}")
from modules import errors
errors.display(e, 'Quantization')
return sd_model

View File

@ -1,9 +1,9 @@
import time
import random
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import ToPILImage
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn
from modules import devices
from modules.shared import opts, log
from modules.upscaler import Upscaler, UpscalerData
@ -19,7 +19,7 @@ to_pil = ToPILImage()
class UpscalerSeedVR(Upscaler):
def __init__(self, dirname=None):
self.name = "SeedVR"
self.name = "SeedVR2"
super().__init__()
self.scalers = [
UpscalerData(name="SeedVR2 3B", path=None, upscaler=self, model=None, scale=1),
@ -32,45 +32,140 @@ class UpscalerSeedVR(Upscaler):
def load_model(self, path: str):
model_name = MODELS_MAP.get(path, None)
if (self.model is None) or (self.model_loaded != model_name):
log.debug(f'Upscaler load: name="{self.name}" model="{model_name}"')
log.debug(f'Upscaler loading: name="{self.name}" model="{model_name}"')
t0 = time.time()
from modules.seedvr.src.core.model_manager import configure_runner
from modules.seedvr.src.core import generation
self.model = configure_runner(
model_name=model_name,
cache_dir=opts.hfcache_dir,
device=devices.device,
dtype=devices.dtype,
)
self.model_loaded = model_name
self.model.dit.device = devices.device
self.model.dit.dtype = devices.dtype
self.model.vae_encode = self.vae_encode
self.model.vae_decode = self.vae_decode
self.model.model_step = generation.generation_step
generation.generation_step = self.model_step
self.model._internal_dict = {
'dit': self.model.dit,
'vae': self.model.vae,
}
t1 = time.time()
self.model.dit.config = self.model.config.dit
self.model.vae.tile_sample_min_size = 1024
self.model.vae.tile_latent_min_size = 128
# from modules.model_quant import do_post_load_quant
# from modules.sd_offload import set_diffuser_offload
# self.model = do_post_load_quant(self.model, allow=True)
# set_diffuser_offload(self.model)
log.info(f'Upscaler loaded: name="{self.name}" model="{model_name}" time={t1 - t0:.2f}')
def vae_encode(self, samples):
log.debug(f'Upscaler encode: samples={samples[0].shape if len(samples) > 0 else None} tile={self.model.vae.tile_sample_min_size} overlap={self.model.vae.tile_overlap_factor}')
latents = []
if len(samples) == 0:
return latents
self.model.dit = self.model.dit.to(device="cpu")
self.model.vae = self.model.vae.to(device=self.device)
devices.torch_gc()
from einops import rearrange
from modules.seedvr.src.optimization import memory_manager
memory_manager.clear_rope_cache(self.model)
scale = self.model.config.vae.scaling_factor
shift = self.model.config.vae.get("shifting_factor", 0.0)
batches = [sample.unsqueeze(0) for sample in samples]
for sample in batches:
sample = sample.to(self.device, self.model.vae.dtype)
sample = self.model.vae.preprocess(sample)
latent = self.model.vae.encode(sample).latent
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
latent = rearrange(latent, "b c ... -> b ... c")
latent = (latent - shift) * scale
latents.append(latent)
latents = [latent.squeeze(0) for latent in latents]
self.model.vae = self.model.vae.to(device="cpu")
devices.torch_gc()
return latents
def vae_decode(self, latents, target_dtype: torch.dtype = None):
log.debug(f'Upscaler decode: latents={latents[0].shape if len(latents) > 0 else None} tile={self.model.vae.tile_latent_min_size} overlap={self.model.vae.tile_overlap_factor}')
samples = []
if len(latents) == 0:
return samples
from einops import rearrange
from modules.seedvr.src.optimization import memory_manager
memory_manager.clear_rope_cache(self.model)
self.model.dit = self.model.dit.to(device="cpu")
self.model.vae = self.model.vae.to(device=self.device)
devices.torch_gc()
scale = self.model.config.vae.scaling_factor
shift = self.model.config.vae.get("shifting_factor", 0.0)
latents = [latent.unsqueeze(0) for latent in latents]
with devices.inference_context():
for _i, latent in enumerate(latents):
latent = latent.to(self.device, self.model.vae.dtype)
latent = latent / scale + shift
latent = rearrange(latent, "b ... c -> b c ...")
latent = latent.squeeze(2)
sample = self.model.vae.decode(latent).sample
sample = self.model.vae.postprocess(sample)
samples.append(sample)
samples = [sample.squeeze(0) for sample in samples]
self.model.vae = self.model.vae.to(device="cpu")
devices.torch_gc()
return samples
def model_step(self, *args, **kwargs):
from modules.seedvr.src.optimization import memory_manager
self.model.vae = self.model.vae.to(device="cpu")
self.model.dit = self.model.dit.to(device=self.device)
devices.torch_gc()
log.debug(f'Upscaler inference: args={len(args)} kwargs={list(kwargs.keys())}')
memory_manager.preinitialize_rope_cache(self.model)
with devices.inference_context():
result = self.model.model_step(*args, **kwargs)
self.model.dit = self.model.dit.to(device="cpu")
devices.torch_gc()
return result
def do_upscale(self, img: Image.Image, selected_file):
devices.torch_gc()
self.load_model(selected_file)
if self.model is None:
return img
from modules.seedvr.src.core.generation import generation_loop
from modules.seedvr.src.core import generation
width = int(self.scale * img.width) // 8 * 8
image_tensor = np.array(img)
image_tensor = torch.from_numpy(image_tensor).to(device=devices.device, dtype=devices.dtype).unsqueeze(0) / 255.0
random.seed()
seed = int(random.randrange(4294967294))
t0 = time.time()
result_tensor = generation_loop(
runner=self.model,
images=image_tensor,
cfg_scale=1.0,
seed=42,
res_w=width,
batch_size=1,
temporal_overlap=0,
device=devices.device,
)
with devices.inference_context():
result_tensor = generation.generation_loop(
runner=self.model,
images=image_tensor,
cfg_scale=opts.seedvt_cfg_scale,
seed=seed,
res_w=width,
batch_size=1,
temporal_overlap=0,
device=devices.device,
)
t1 = time.time()
log.info(f'Upscaler: type="{self.name}" model="{selected_file}" scale={self.scale} time={t1 - t0:.2f}')
log.info(f'Upscaler: type="{self.name}" model="{selected_file}" scale={self.scale} cfg={opts.seedvt_cfg_scale} seed={seed} time={t1 - t0:.2f}')
img = to_pil(result_tensor.squeeze().permute((2, 0, 1)))
devices.torch_gc()
if opts.upscaler_unload:
self.model.dit = None
self.model.vae = None
self.model.cache = None
self.model = None
log.debug(f'Upscaler unload: type="{self.name}" model="{selected_file}"')
devices.torch_gc(force=True)
devices.torch_gc(force=True)
return img

View File

@ -164,6 +164,8 @@ def apply_function_to_model(sd_model, function, options, op=None):
sd_model.unet = function(sd_model.unet, op="unet", sd_model=sd_model)
if hasattr(sd_model, 'transformer') and hasattr(sd_model.transformer, 'config'):
sd_model.transformer = function(sd_model.transformer, op="transformer", sd_model=sd_model)
if hasattr(sd_model, 'dit') and hasattr(sd_model.dit, 'config'):
sd_model.dit = function(sd_model.dit, op="dit", sd_model=sd_model)
if hasattr(sd_model, 'transformer_2') and hasattr(sd_model.transformer_2, 'config'):
sd_model.transformer_2 = function(sd_model.transformer_2, op="transformer_2", sd_model=sd_model)
if hasattr(sd_model, 'transformer_3') and hasattr(sd_model.transformer_3, 'config'):

View File

@ -274,7 +274,10 @@ class OffloadHook(accelerate.hooks.ModelHook):
def get_pipe_variants(pipe=None):
if pipe is None:
pipe = shared.sd_model
if shared.sd_loaded:
pipe = shared.sd_model
else:
return [pipe]
variants = [pipe]
if hasattr(pipe, "pipe"):
variants.append(pipe.pipe)
@ -287,7 +290,10 @@ def get_pipe_variants(pipe=None):
def get_module_names(pipe=None, exclude=[]):
if pipe is None:
pipe = shared.sd_model
if shared.sd_loaded:
pipe = shared.sd_model
else:
return []
if hasattr(pipe, "_internal_dict"):
modules_names = pipe._internal_dict.keys() # pylint: disable=protected-access
else:

View File

@ -385,7 +385,6 @@ def sdnq_post_load_quant(
modules_dtype_dict=modules_dtype_dict.copy(),
op=op,
)
model.quantization_config = SDNQConfig(
weights_dtype=weights_dtype,
group_size=group_size,
@ -402,8 +401,10 @@ def sdnq_post_load_quant(
modules_dtype_dict=modules_dtype_dict.copy(),
)
if hasattr(model, "config"):
try:
model.config.quantization_config = model.quantization_config
except Exception:
pass
model.quantization_method = QuantizationMethod.SDNQ
return model

View File

@ -6,9 +6,9 @@ dit:
model:
__object__:
path:
- "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit_v2.nadit"
- "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit_v2.nadit"
- "src.models.dit_v2.nadit"
- "SeedVR2_VideoUpscaler.src.models.dit_v2.nadit"
- "SeedVR2_VideoUpscaler.src.models.dit_v2.nadit"
- "modules.seedvr.src.models.dit_v2.nadit"
name: "NaDiT"
args: "as_params"
vid_in_channels: 33
@ -49,9 +49,9 @@ vae:
model:
__object__:
path:
- "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae"
- "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae"
- "src.models.video_vae_v3.modules.attn_video_vae"
- "SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae"
- "SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae"
- "modules.seedvr.src.models.video_vae_v3.modules.attn_video_vae"
name: "VideoAutoencoderKLWrapper"
args: "as_params"
freeze_encoder: False

View File

@ -6,8 +6,8 @@ dit:
model:
__object__:
path:
- "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit.nadit"
- "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit.nadit"
- "SeedVR2_VideoUpscaler.src.models.dit.nadit"
- "SeedVR2_VideoUpscaler.src.models.dit.nadit"
- "src.models.dit.nadit"
name: "NaDiT"
args: "as_params"
@ -46,8 +46,8 @@ vae:
model:
__object__:
path:
- "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae"
- "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae"
- "SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae"
- "SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae"
- "src.models.video_vae_v3.modules.attn_video_vae"
name: "VideoAutoencoderKLWrapper"
args: "as_params"

View File

@ -1,31 +1,4 @@
"""
SeedVR2 Video Upscaler - Modular Architecture
Refactored from monolithic seedvr2.py for better maintainability
Author: Refactored codebase
Version: 2.0.0 - Modular
Available Modules:
- utils: Download and path utilities
- optimization: Memory, performance, and compatibility optimizations
- core: Model management and generation pipeline (NEW)
- processing: Video and tensor processing (coming next)
- interfaces: ComfyUI integration
"""
'''
# Track which modules are available for progressive migration
MODULES_AVAILABLE = {
'downloads': True, # ✅ Module 1 - Downloads and model management
'memory_manager': True, # ✅ Module 2 - Memory optimization
'performance': True, # ✅ Module 3 - Performance optimizations
'compatibility': True, # ✅ Module 4 - FP8/FP16 compatibility
'model_manager': True, # ✅ Module 5 - Model configuration and loading
'generation': True, # ✅ Module 6 - Generation loop and inference
'video_transforms': True, # ✅ Module 7 - Video processing and transforms
'comfyui_node': True, # ✅ Module 8 - ComfyUI node interface (COMPLETE!)
'infer': True, # ✅ Module 9 - Infer
}
'''
# Core imports (always available)
import os
import sys
@ -35,3 +8,4 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
"""

View File

@ -2,6 +2,7 @@ import importlib
from typing import Any, Callable, List, Union
from omegaconf import DictConfig, ListConfig, OmegaConf
try:
OmegaConf.register_new_resolver("eval", eval)
except Exception as e:
@ -9,7 +10,6 @@ except Exception as e:
raise
def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
"""
Load a configuration. Will resolve inheritance.

View File

@ -2,10 +2,10 @@ import functools
import threading
from typing import Callable
import torch
from .distributed import barrier_if_distributed, get_global_rank, get_local_rank
from .logger import get_logger
logger = get_logger(__name__)

View File

@ -26,7 +26,6 @@ from ..cache import Cache
from .advanced import (
get_sequence_parallel_group,
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
)
from .basic import get_device
@ -48,7 +47,7 @@ def single_all_to_all(
"""
A function to do all-to-all on a tensor
"""
seq_world_size = dist.get_world_size(group)
seq_world_size = 1
prev_scatter_dim = scatter_dim
if scatter_dim != 0:
local_input = local_input.transpose(0, scatter_dim)
@ -80,7 +79,7 @@ def _all_to_all(
gather_dim: int,
group: dist.ProcessGroup,
):
seq_world_size = dist.get_world_size(group)
seq_world_size = 1
input_list = [
t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)
]
@ -134,7 +133,7 @@ class Slice(torch.autograd.Function):
def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int) -> Tensor:
ctx.group = group
ctx.rank = dist.get_rank(group)
seq_world_size = dist.get_world_size(group)
seq_world_size = 1
ctx.seq_world_size = seq_world_size
ctx.dim = dim
dim_size = local_input.shape[dim]
@ -163,7 +162,7 @@ class Gather(torch.autograd.Function):
ctx.rank = dist.get_rank(group)
ctx.dim = dim
ctx.grad_scale = grad_scale
seq_world_size = dist.get_world_size(group)
seq_world_size = 1
ctx.seq_world_size = seq_world_size
dim_size = list(local_input.size())
split_size = dim_size[0]
@ -204,7 +203,7 @@ def gather_seq_scatter_heads_qkv(
group = get_sequence_parallel_group()
if not group:
return qkv_tensor
world = get_sequence_parallel_world_size()
world = 1
orig_shape = qkv_tensor.shape
scatter_dim = qkv_tensor.dim()
bef_all2all_shape = list(orig_shape)
@ -237,7 +236,7 @@ def slice_inputs(x: Tensor, dim: int, padding: bool = True):
if group is None:
return x
sp_rank = get_sequence_parallel_rank()
sp_world = get_sequence_parallel_world_size()
sp_world = 1
dim_size = x.shape[dim]
unit = (dim_size + sp_world - 1) // sp_world
if padding and dim_size % sp_world:
@ -255,7 +254,7 @@ def remove_seqeunce_parallel_padding(x: Tensor, dim: int, unpad_dim_size: int):
group = get_sequence_parallel_group()
if group is None:
return x
sp_world = get_sequence_parallel_world_size()
sp_world = 1
if unpad_dim_size % sp_world == 0:
return x
padding_size = sp_world - (unpad_dim_size % sp_world)
@ -271,7 +270,7 @@ def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor:
if not group:
return x
dim_size = x.size(seq_dim)
sp_world = get_sequence_parallel_world_size()
sp_world = 1
if dim_size % sp_world != 0:
padding_size = sp_world - (dim_size % sp_world)
x = _pad_tensor(x, seq_dim, padding_size)
@ -424,7 +423,7 @@ class SPDistForward:
yield inputs
else:
device = self.device
sp_world = get_sequence_parallel_world_size()
sp_world = 1
sp_rank = get_sequence_parallel_rank()
for local_step in range(sp_world):
src_rank = dist.get_global_rank(group, local_step)

View File

@ -19,9 +19,9 @@ Logging utility functions.
import logging
import sys
from typing import Optional
from .distributed import get_global_rank, get_local_rank, get_world_size
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(
logging.Formatter(

View File

@ -16,7 +16,6 @@ import random
from typing import Optional
import numpy as np
import torch
from .distributed import get_global_rank

View File

@ -1,11 +1,10 @@
import torch
from torchvision.transforms import Compose, Lambda, Normalize
from src.optimization.performance import optimized_video_rearrange, optimized_single_video_rearrange, optimized_sample_to_image_format
from src.common.seed import set_seed
from src.data.image.transforms.divisible_crop import DivisibleCrop
from src.data.image.transforms.na_resize import NaResize
from src.utils.color_fix import wavelet_reconstruction
from ..optimization.performance import optimized_video_rearrange, optimized_single_video_rearrange, optimized_sample_to_image_format
from ..common.seed import set_seed
from ..data.image.transforms.divisible_crop import DivisibleCrop
from ..data.image.transforms.na_resize import NaResize
from ..utils.color_fix import wavelet_reconstruction
@ -67,8 +66,6 @@ def generation_step(runner, text_embeds_dict, cond_latents, temporal_overlap, de
x = runner.schedule.forward(x, aug_noise, t)
return x
# Generate conditions with memory optimization
runner.dit.to(device=device)
condition = runner.get_condition(
noises[0],
task="sr",
@ -76,8 +73,8 @@ def generation_step(runner, text_embeds_dict, cond_latents, temporal_overlap, de
)
conditions = [condition]
# Use adaptive autocast for optimal performance
with torch.no_grad():
# Use adaptive autocast for optimal performance
video_tensors = runner.inference(
noises=noises,
conditions=conditions,
@ -92,6 +89,7 @@ def generation_step(runner, text_embeds_dict, cond_latents, temporal_overlap, de
cond_latents = cond_latents[0].to("cpu")
conditions = conditions[0].to("cpu")
condition = condition.to("cpu")
del noises, aug_noises, cond_latents, conditions, condition
return samples #, last_latents
@ -133,7 +131,6 @@ def generation_loop(runner, images, cfg_scale=1.0, seed=666, res_w=720, batch_si
- Intelligent VRAM management throughout process
- Real-time progress reporting
"""
model_dtype = None
model_dtype = next(runner.dit.parameters()).dtype
compute_dtype = model_dtype
@ -230,7 +227,6 @@ def generation_loop(runner, images, cfg_scale=1.0, seed=666, res_w=720, batch_si
# Apply color correction if available
transformed_video = transformed_video.to(device)
input_video = [optimized_single_video_rearrange(transformed_video)]
del transformed_video
sample = wavelet_reconstruction(sample, input_video[0][:sample.size(0)])
@ -244,56 +240,33 @@ def generation_loop(runner, images, cfg_scale=1.0, seed=666, res_w=720, batch_si
batch_samples.append(sample_cpu)
#del sample
# Aggressive cleanup after each batch
# Progress callback - batch start
if progress_callback:
progress_callback(batch_count+1, total_batches, current_frames, "Processing batch...")
runner.vae.to(device="cpu")
runner.dit.to(device="cpu")
# OPTIMISATION ULTIME : Pré-allocation et copie directe (évite les torch.cat multiples)
# 1. Calculer la taille totale finale
total_frames = sum(batch.shape[0] for batch in batch_samples)
if len(batch_samples) > 0:
sample_shape = batch_samples[0].shape
H, W, C = sample_shape[1], sample_shape[2], sample_shape[3]
# 2. Pré-allouer le tensor final directement sur CPU (évite concatenations)
final_video_images = torch.empty((total_frames, H, W, C), dtype=torch.float16)
# 3. Copier par blocs directement dans le tensor final
block_size = 500
current_idx = 0
for block_start in range(0, len(batch_samples), block_size):
block_end = min(block_start + block_size, len(batch_samples))
# Charger le bloc en VRAM
current_block = []
for i in range(block_start, block_end):
current_block.append(batch_samples[i].to(device))
# Concatener en VRAM (rapide)
block_result = torch.cat(current_block, dim=0)
# Convertir en Float16 sur GPU
#if block_result.dtype != torch.float16:
# block_result = block_result.to(torch.float16)
# Copier directement dans le tensor final (pas de concatenation!)
block_frames = block_result.shape[0]
final_video_images[current_idx:current_idx + block_frames] = block_result.to("cpu")
current_idx += block_frames
# Nettoyage immédiat VRAM
del current_block, block_result
else:
print("SeedVR2: No batch_samples to process")
final_video_images = torch.empty((0, 0, 0, 0), dtype=torch.float16)
# Cleanup batch_samples
#del batch_samples
return final_video_images

View File

@ -1,38 +1,12 @@
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
from einops import rearrange
from omegaconf import DictConfig, ListConfig
from torch import Tensor
from src.common.diffusion import (
classifier_free_guidance_dispatcher,
create_sampler_from_config,
create_sampling_timesteps_from_config,
create_schedule_from_config,
)
from src.common.distributed import (
get_device,
)
# from common.fs import download
from src.models.dit_v2 import na
from ..common.diffusion import classifier_free_guidance_dispatcher, create_sampler_from_config, create_sampling_timesteps_from_config, create_schedule_from_config
from ..models.dit_v2 import na
def optimized_channels_to_last(tensor):
def optimized_channels_to_last(tensor: torch.Tensor) -> torch.Tensor:
"""🚀 Optimized replacement for rearrange(tensor, 'b c ... -> b ... c')
Moves channels from position 1 to last position using PyTorch native operations.
"""
@ -74,7 +48,7 @@ class VideoDiffusionInfer():
self.dit = None
self.sampler = None
self.schedule = None
def get_condition(self, latent: Tensor, latent_blur: Tensor, task: str) -> Tensor:
def get_condition(self, latent: torch.Tensor, latent_blur: torch.Tensor, task: str) -> torch.Tensor:
t, h, w, c = latent.shape
cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
if task == "t2v" or t == 1:
@ -118,22 +92,18 @@ class VideoDiffusionInfer():
# -------------------------------- Helper ------------------------------- #
@torch.no_grad()
def vae_encode(self, samples: List[Tensor]) -> List[Tensor]:
self.dit.to(device="cpu")
self.vae.to(device=self.device)
def vae_encode(self, samples: List[torch.Tensor]) -> List[torch.Tensor]:
use_sample = self.config.vae.get("use_sample", True)
latents = []
if len(samples) > 0:
device = get_device()
dtype = self.vae.dtype
scale = self.config.vae.scaling_factor
shift = self.config.vae.get("shifting_factor", 0.0)
if isinstance(scale, ListConfig):
scale = torch.tensor(scale, device=device, dtype=dtype)
scale = torch.tensor(scale, device=self.device, dtype=dtype)
if isinstance(shift, ListConfig):
shift = torch.tensor(shift, device=device, dtype=dtype)
shift = torch.tensor(shift, device=self.device, dtype=dtype)
# Group samples of the same shape to batches if enabled.
if self.config.vae.grouping:
@ -143,7 +113,7 @@ class VideoDiffusionInfer():
# Vae process by each group.
for sample in batches:
sample = sample.to(device, dtype)
sample = sample.to(self.device, dtype)
if hasattr(self.vae, "preprocess"):
sample = self.vae.preprocess(sample)
if use_sample:
@ -162,19 +132,15 @@ class VideoDiffusionInfer():
latents = na.unpack(latents, indices)
else:
latents = [latent.squeeze(0) for latent in latents]
self.vae.to(device="cpu")
return latents
@torch.no_grad()
def vae_decode(self, latents: List[Tensor], target_dtype: torch.dtype = None) -> List[Tensor]:
def vae_decode(self, latents: List[torch.Tensor], target_dtype: torch.dtype = None) -> List[torch.Tensor]:
"""🚀 VAE decode optimisé - décodage direct sans chunking, compatible avec autocast externe"""
self.dit.to(device="cpu")
self.vae.to(device=self.device)
samples = []
if len(latents) > 0:
device = get_device()
device = self.device
dtype = self.vae.dtype
scale = self.config.vae.scaling_factor
shift = self.config.vae.get("shifting_factor", 0.0)
@ -218,10 +184,9 @@ class VideoDiffusionInfer():
samples = na.unpack(samples, indices)
else:
samples = [sample.squeeze(0) for sample in samples]
self.vae.to(device="cpu")
return samples
def timestep_transform(self, timesteps: Tensor, latents_shapes: Tensor):
def timestep_transform(self, timesteps: torch.Tensor, latents_shapes: torch.Tensor):
# Skip if not needed.
if not self.config.diffusion.timesteps.get("transform", False):
return timesteps
@ -256,13 +221,13 @@ class VideoDiffusionInfer():
@torch.no_grad()
def inference(
self,
noises: List[Tensor],
conditions: List[Tensor],
texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]],
texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]],
noises: List[torch.Tensor],
conditions: List[torch.Tensor],
texts_pos: Union[List[str], List[torch.Tensor], List[Tuple[torch.Tensor]]],
texts_neg: Union[List[str], List[torch.Tensor], List[Tuple[torch.Tensor]]],
cfg_scale: Optional[float] = None,
temporal_overlap: int = 0, # pylint: disable=unused-argument
) -> List[Tensor]:
) -> List[torch.Tensor]:
assert len(noises) == len(conditions) == len(texts_pos) == len(texts_neg)
batch_size = len(noises)
@ -316,6 +281,7 @@ class VideoDiffusionInfer():
# Adapter les latents au dtype cible (compatible avec FP8)
latents = latents.to(target_dtype) if latents.dtype != target_dtype else latents
latents_cond = latents_cond.to(target_dtype) if latents_cond.dtype != target_dtype else latents_cond
self.dit = self.dit.to(device=self.device, dtype=target_dtype)
latents = self.sampler.sample(
x=latents,

View File

@ -3,10 +3,9 @@ import torch
from omegaconf import OmegaConf
from safetensors.torch import load_file as load_safetensors_file
from huggingface_hub import hf_hub_download
from src.optimization.memory_manager import preinitialize_rope_cache
from src.common.config import load_config, create_object
from src.core.infer import VideoDiffusionInfer
from ..optimization.memory_manager import preinitialize_rope_cache
from ..common.config import load_config, create_object
from ..core.infer import VideoDiffusionInfer
def configure_runner(model_name, cache_dir, device:str='cpu', dtype:torch.dtype=None):
@ -15,41 +14,39 @@ def configure_runner(model_name, cache_dir, device:str='cpu', dtype:torch.dtype=
config_path = os.path.join(script_directory, './config_7b.yaml') if "7b" in model_name else os.path.join(script_directory, './config_3b.yaml')
config = load_config(config_path)
vae_config_path = os.path.join(script_directory, 'src/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml')
vae_config = OmegaConf.load(vae_config_path)
vae_config.spatial_downsample_factor = vae_config.get('spatial_downsample_factor', 8)
vae_config.temporal_downsample_factor = vae_config.get('temporal_downsample_factor', 4)
config.vae.model = OmegaConf.merge(config.vae.model, vae_config)
runner = VideoDiffusionInfer(config, device=device, dtype=dtype)
OmegaConf.set_readonly(runner.config, False)
# load dit
with torch.device("meta"):
runner.dit = create_object(config.dit.model)
runner.dit.eval().to(dtype)
runner.dit.requires_grad_(False).eval()
runner.dit.to_empty(device="cpu")
model_file = hf_hub_download(repo_id=repo_id, filename=model_name, cache_dir=cache_dir)
state_dict = load_safetensors_file(model_file)
runner.dit.load_state_dict(state_dict, assign=True)
runner.dit = runner.dit.to(device="cpu", dtype=dtype)
del state_dict
runner.dit = runner.dit.to(device=device, dtype=dtype)
# load vae
vae_config_path = os.path.join(script_directory, 'src/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml')
vae_config = OmegaConf.load(vae_config_path)
config.vae.model = OmegaConf.merge(config.vae.model, vae_config)
vae_file = hf_hub_download(repo_id=repo_id, filename=config.vae.checkpoint, cache_dir=cache_dir)
with torch.device("meta"):
runner.vae = create_object(config.vae.model)
runner.vae.requires_grad_(False).eval()
runner.vae.to_empty(device="cpu")
state_dict = load_safetensors_file(vae_file)
runner.vae.load_state_dict(state_dict)
del state_dict
runner.vae = runner.vae.to(device=device, dtype=dtype)
runner.vae = runner.vae.to(device="cpu", dtype=dtype)
runner.config.vae.dtype = str(dtype)
runner.vae.set_causal_slicing(**config.vae.slicing)
runner.config.vae.slicing = {'split_size': 8, 'memory_device': 'same'}
runner.config.vae.memory_limit = {'conv_max_mem': 0.2, 'norm_max_mem': 0.2}
runner.vae.set_causal_slicing(**runner.config.vae.slicing)
runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
del state_dict
# load embeds
pos_embeds_file = hf_hub_download(repo_id=repo_id, filename='pos_emb.pt', cache_dir=cache_dir)
@ -57,5 +54,4 @@ def configure_runner(model_name, cache_dir, device:str='cpu', dtype:torch.dtype=
runner.text_pos_embeds = torch.load(pos_embeds_file).to(device=device, dtype=dtype)
runner.text_neg_embeds = torch.load(neg_embeds_file).to(device=device, dtype=dtype)
preinitialize_rope_cache(runner)
return runner

View File

@ -19,13 +19,7 @@ from torch import nn
from torch.nn import functional as F
from torch.nn.modules.utils import _triple
from ....common.half_precision_fixes import safe_pad_operation
from ....common.distributed.ops import (
gather_heads,
gather_heads_scatter_seq,
gather_seq_scatter_heads_qkv,
scatter_heads,
)
from ....common.distributed.ops import gather_heads, gather_heads_scatter_seq, gather_seq_scatter_heads_qkv, scatter_heads
from ..attention import TorchAttention
from ..mlp import get_mlp
from ..mm import MMArg, MMModule

View File

@ -18,8 +18,7 @@ import torch
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
from torch import nn
from src.common.cache import Cache
from ...common.cache import Cache
class RotaryEmbeddingBase(nn.Module):

View File

@ -15,7 +15,6 @@ from typing import Literal, Optional, Tuple, Union
import diffusers
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention, SpatialNorm
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.downsampling import Downsample2D
@ -28,29 +27,12 @@ from diffusers.utils import is_torch_version
from diffusers.utils.accelerate_utils import apply_forward_hook
from einops import rearrange
from ....common.half_precision_fixes import safe_pad_operation, safe_interpolate_operation
from ....common.distributed.advanced import get_sequence_parallel_world_size
from ....common.logger import get_logger
from .causal_inflation_lib import (
InflatedCausalConv3d,
causal_norm_wrapper,
init_causal_conv3d,
remove_head,
)
from .context_parallel_lib import (
causal_conv_gather_outputs,
causal_conv_slice_inputs,
)
from .causal_inflation_lib import InflatedCausalConv3d, causal_norm_wrapper, init_causal_conv3d, remove_head
from .context_parallel_lib import causal_conv_gather_outputs, causal_conv_slice_inputs
from .global_config import set_norm_limit
from .types import (
CausalAutoencoderOutput,
CausalDecoderOutput,
CausalEncoderOutput,
MemoryState,
_inflation_mode_t,
_memory_device_t,
_receptive_field_t,
)
from .types import CausalAutoencoderOutput, CausalDecoderOutput, CausalEncoderOutput, MemoryState, _inflation_mode_t, _memory_device_t, _receptive_field_t
logger = get_logger(__name__) # pylint: disable=invalid-name
@ -1064,7 +1046,7 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL):
scaling_factor: float = 0.18215,
force_upcast: float = True,
attention: bool = True,
temporal_scale_num: int = 2,
temporal_scale_num: int = 0,
slicing_up_num: int = 0,
gradient_checkpoint: bool = False,
inflation_mode: _inflation_mode_t = "tail",
@ -1164,7 +1146,8 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL):
@apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.slicing_encode(x)
# h = self.slicing_encode(x)
h = self.tiled_encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
@ -1176,7 +1159,8 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL):
def decode(
self, z: torch.Tensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
decoded = self.slicing_decode(z)
# decoded = self.slicing_decode(z)
decoded = self.tiled_decode(z)
if not return_dict:
return (decoded,)
@ -1208,7 +1192,7 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL):
return output.to(z.device)
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
sp_size = get_sequence_parallel_world_size()
sp_size = 1
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size:
x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2)
encoded_slices = [
@ -1226,7 +1210,7 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL):
return self._encode(x)
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
sp_size = get_sequence_parallel_world_size()
sp_size = 1
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size:
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2)
decoded_slices = [
@ -1243,12 +1227,68 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL):
else:
return self._decode(z)
def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
return b
def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
rows = []
for i in range(0, x.shape[3], overlap_size):
row = []
for j in range(0, x.shape[4], overlap_size):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self._encode(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3)
return enc
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
rows = []
for i in range(0, z.shape[3], overlap_size):
row = []
for j in range(0, z.shape[4], overlap_size):
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
return dec
def forward(
self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs
):

View File

@ -21,30 +21,11 @@ from diffusers.models.normalization import RMSNorm
from einops import rearrange
from torch import Tensor, nn
from torch.nn import Conv3d
from .context_parallel_lib import cache_send_recv, get_cache_size
from .global_config import get_norm_limit
from .types import MemoryState, _inflation_mode_t, _memory_device_t
from ....common.half_precision_fixes import safe_pad_operation
# Single GPU inference - no distributed processing needed
# Mock distributed functions for single GPU inference
def get_sequence_parallel_group():
return None
def get_sequence_parallel_rank():
return 0
def get_sequence_parallel_world_size():
return 1
def get_next_sequence_parallel_rank():
return 0
def get_prev_sequence_parallel_rank():
return 0
@contextmanager
def ignore_padding(model):
@ -172,7 +153,6 @@ class InflatedCausalConv3d(Conv3d):
if (
math.isinf(self.memory_limit)
and torch.is_tensor(input)
and get_sequence_parallel_group() is None
):
return self.basic_forward(input, memory_state)
return self.slicing_forward(input, memory_state)

View File

@ -14,11 +14,8 @@
from typing import List
import torch
import torch.nn.functional as F
from torch import Tensor
from .types import MemoryState
# Single GPU inference - no distributed processing needed

View File

@ -16,14 +16,8 @@ from functools import partial
from typing import Literal, Optional
from torch import Tensor
from torch.nn import Conv3d
from .inflated_lib import MemoryState, extend_head, inflate_bias, inflate_weight, modify_state_dict
from .inflated_lib import (
MemoryState,
extend_head,
inflate_bias,
inflate_weight,
modify_state_dict,
)
_inflation_mode_t = Literal["none", "tail", "replicate"]
_memory_device_t = Optional[Literal["cpu", "same"]]

View File

@ -19,9 +19,9 @@ import torch
from diffusers.models.normalization import RMSNorm
from einops import rearrange
from torch import Tensor, nn
from ....common.logger import get_logger
logger = get_logger(__name__)

View File

@ -11,37 +11,17 @@
from contextlib import nullcontext
from typing import Optional, Tuple, Literal, Callable, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from einops import rearrange
from ....common.half_precision_fixes import safe_pad_operation
from ....common.distributed.advanced import get_sequence_parallel_world_size
from ....common.logger import get_logger
from .causal_inflation_lib import (
InflatedCausalConv3d,
causal_norm_wrapper,
init_causal_conv3d,
remove_head,
)
from .context_parallel_lib import (
causal_conv_gather_outputs,
causal_conv_slice_inputs,
)
from .causal_inflation_lib import InflatedCausalConv3d, causal_norm_wrapper, init_causal_conv3d, remove_head
from .context_parallel_lib import causal_conv_gather_outputs, causal_conv_slice_inputs
from .global_config import set_norm_limit
from .types import (
CausalAutoencoderOutput,
CausalDecoderOutput,
CausalEncoderOutput,
MemoryState,
_inflation_mode_t,
_memory_device_t,
_receptive_field_t,
_selective_checkpointing_t,
)
from .types import CausalAutoencoderOutput, CausalDecoderOutput, CausalEncoderOutput, MemoryState, _inflation_mode_t, _memory_device_t, _receptive_field_t, _selective_checkpointing_t
logger = get_logger(__name__) # pylint: disable=invalid-name
@ -717,7 +697,7 @@ class VideoAutoencoderKL(nn.Module):
use_post_quant_conv: bool = True,
enc_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",),
dec_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",),
temporal_scale_num: int = 3,
temporal_scale_num: int = 0,
slicing_up_num: int = 0,
inflation_mode: _inflation_mode_t = "tail",
time_receptive_field: _receptive_field_t = "half",
@ -824,7 +804,7 @@ class VideoAutoencoderKL(nn.Module):
return x
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
sp_size = get_sequence_parallel_world_size()
sp_size = 1
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size:
x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2)
encoded_slices = [
@ -842,7 +822,7 @@ class VideoAutoencoderKL(nn.Module):
return self._encode(x, memory_state=MemoryState.DISABLED)
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
sp_size = get_sequence_parallel_world_size()
sp_size = 1
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size:
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2)
decoded_slices = [

View File

@ -5,12 +5,9 @@ Handles VRAM usage, cache management, and memory optimization
Extracted from: seedvr2.py (lines 373-405, 607-626, 1016-1044)
"""
import os
import torch
import gc
from typing import Tuple, Optional
from src.common.cache import Cache
from src.models.dit_v2.rope import RotaryEmbeddingBase
from ..common.cache import Cache
from ..models.dit_v2.rope import RotaryEmbeddingBase
def preinitialize_rope_cache(runner) -> None:
@ -21,73 +18,63 @@ def preinitialize_rope_cache(runner) -> None:
runner: The model runner containing DiT and VAE models
"""
try:
# Create dummy tensors to simulate common shapes
# Format: [batch, channels, frames, height, width] for vid_shape
# Format: [batch, seq_len] for txt_shape
common_shapes = [
# Common video resolutions
(torch.tensor([[1, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 1 frame, 77 tokens
(torch.tensor([[4, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 4 frames
(torch.tensor([[5, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 5 frames (4n+1 format)
(torch.tensor([[1, 4, 4]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # Higher resolution
]
# Create dummy tensors to simulate common shapes
# Format: [batch, channels, frames, height, width] for vid_shape
# Format: [batch, seq_len] for txt_shape
common_shapes = [
# Common video resolutions
(torch.tensor([[1, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 1 frame, 77 tokens
(torch.tensor([[4, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 4 frames
(torch.tensor([[5, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 5 frames (4n+1 format)
(torch.tensor([[1, 4, 4]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # Higher resolution
]
# Create mock cache for pre-initialization
# Create mock cache for pre-initialization
temp_cache = Cache()
temp_cache = Cache()
# Access RoPE modules in DiT (recursive search)
def find_rope_modules(module):
rope_modules = []
for name, child in module.named_modules():
if hasattr(child, 'get_freqs') and callable(child.get_freqs):
rope_modules.append((name, child))
return rope_modules
# Access RoPE modules in DiT (recursive search)
def find_rope_modules(module):
rope_modules = []
for name, child in module.named_modules():
if hasattr(child, 'get_freqs') and callable(child.get_freqs):
rope_modules.append((name, child))
return rope_modules
rope_modules = find_rope_modules(runner.dit)
rope_modules = find_rope_modules(runner.dit)
# Pre-calculate for each RoPE module found
for name, rope_module in rope_modules:
# Temporarily move module to CPU if necessary
original_device = next(rope_module.parameters()).device if list(rope_module.parameters()) else torch.device('cpu')
rope_module.to('cpu')
# Pre-calculate for each RoPE module found
for _name, rope_module in rope_modules:
# Temporarily move module to CPU if necessary
original_device = next(rope_module.parameters()).device if list(rope_module.parameters()) else torch.device('cpu')
rope_module.to('cpu')
try:
for vid_shape, txt_shape in common_shapes:
cache_key = f"720pswin_by_size_bysize_{tuple(vid_shape[0].tolist())}_sd3.mmrope_freqs_3d"
for vid_shape, txt_shape in common_shapes:
cache_key = f"720pswin_by_size_bysize_{tuple(vid_shape[0].tolist())}_sd3.mmrope_freqs_3d"
def compute_freqs():
# Calculate with reduced dimensions to avoid OOM
with torch.no_grad():
# Detect RoPE module type
module_type = type(rope_module).__name__
def compute_freqs():
# Calculate with reduced dimensions to avoid OOM
with torch.no_grad():
# Detect RoPE module type
module_type = type(rope_module).__name__
if module_type == 'NaRotaryEmbedding3d':
# NaRotaryEmbedding3d: only takes shape (vid_shape)
return rope_module.get_freqs(vid_shape.cpu())
else:
# Standard RoPE: takes vid_shape and txt_shape
return rope_module.get_freqs(vid_shape.cpu(), txt_shape.cpu())
if module_type == 'NaRotaryEmbedding3d':
# NaRotaryEmbedding3d: only takes shape (vid_shape)
return rope_module.get_freqs(vid_shape.cpu())
else:
# Standard RoPE: takes vid_shape and txt_shape
return rope_module.get_freqs(vid_shape.cpu(), txt_shape.cpu())
# Store in cache
temp_cache(cache_key, compute_freqs)
# Store in cache
temp_cache(cache_key, compute_freqs)
except Exception as e:
print(f" ❌ Error in module {name}: {e}")
finally:
# Restore to original device
rope_module.to(original_device)
rope_module.to(original_device)
# Copy temporary cache to runner cache
if hasattr(runner, 'cache'):
runner.cache.cache.update(temp_cache.cache)
else:
runner.cache = temp_cache
except Exception as e:
print(f" ⚠️ Error during RoPE pre-init: {e}")
print(" 🔄 Model will work but could have OOM at first launch")
# Copy temporary cache to runner cache
if hasattr(runner, 'cache'):
runner.cache.cache.update(temp_cache.cache)
else:
runner.cache = temp_cache
def clear_rope_cache(runner) -> None:
@ -97,8 +84,6 @@ def clear_rope_cache(runner) -> None:
Args:
runner: The model runner containing the cache
"""
print("🧹 Cleaning RoPE cache...")
if hasattr(runner, 'cache') and hasattr(runner.cache, 'cache'):
# Count entries before cleanup
cache_size = len(runner.cache.cache)
@ -116,7 +101,6 @@ def clear_rope_cache(runner) -> None:
# Clear the cache
runner.cache.cache.clear()
print(f" ✅ RoPE cache cleared ({cache_size} entries removed)")
if hasattr(runner, 'dit'):
cleared_lru_count = 0
@ -125,7 +109,3 @@ def clear_rope_cache(runner) -> None:
if hasattr(module.get_axial_freqs, 'cache_clear'):
module.get_axial_freqs.cache_clear()
cleared_lru_count += 1
if cleared_lru_count > 0:
print(f" ✅ Cleared {cleared_lru_count} LRU caches from RoPE modules.")
print("🎯 RoPE cache cleanup completed!")

View File

@ -2,7 +2,7 @@ import torch
from PIL import Image
from torch import Tensor
from torch.nn import functional as F
from src.common.half_precision_fixes import safe_pad_operation, safe_interpolate_operation
from ..common.half_precision_fixes import safe_pad_operation, safe_interpolate_operation
from torchvision.transforms import ToTensor, ToPILImage
def adain_color_fix(target: Image, source: Image):

View File

@ -1,12 +1,10 @@
import os
import argparse
import numpy as np
import torch
from PIL import Image
from huggingface_hub import snapshot_download
from torchvision.transforms import ToPILImage
from src.core.generation import generation_loop
from src.core.model_manager import configure_runner
from .src.core.generation import generation_loop
from .src.core.model_manager import configure_runner

View File

@ -620,7 +620,6 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
"detailer_iou": OptionInfo(0.5, "Max overlap", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05, "visible": False}),
"detailer_sigma_adjust": OptionInfo(1.0, "Detailer sigma adjust", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05, "visible": False}),
"detailer_sigma_adjust_max": OptionInfo(1.0, "Detailer sigma end", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05, "visible": False}),
# "detailer_resolution": OptionInfo(1024, "Detailer resolution", gr.Slider, {"minimum": 256, "maximum": 4096, "step": 8, "visible": False}),
"detailer_min_size": OptionInfo(0.0, "Min object size", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05, "visible": False}),
"detailer_max_size": OptionInfo(1.0, "Max object size", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05, "visible": False}),
"detailer_padding": OptionInfo(20, "Item padding", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1, "visible": False}),
@ -631,6 +630,9 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
"detailer_unload": OptionInfo(False, "Move detailer model to CPU when complete"),
"detailer_augment": OptionInfo(True, "Detailer use model augment"),
"postprocessing_sep_seedvt": OptionInfo("<h2>SeedVT</h2>", "", gr.HTML),
"seedvt_cfg_scale": OptionInfo(3.5, "SeedVR CFG Scale", gr.Slider, {"minimum": 1, "maximum": 15, "step": 1}),
"postprocessing_sep_face_restore": OptionInfo("<h2>Face Restore</h2>", "", gr.HTML),
"face_restoration_model": OptionInfo("None", "Face restoration", gr.Radio, lambda: {"choices": ['None'] + [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.2, "CodeFormer weight parameter", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),

View File

@ -227,7 +227,7 @@ def create_ui():
quicksettings_list.append((key, item))
components.append(dummy_component)
else:
with gr.Row(elem_id=f"settings_section_row_{section_id}"): # only so we can add dirty indicator at the start of the row
with gr.Row(elem_id=f"settings_section_row_{section_id}", elem_classes=["settings_section"]): # only so we can add dirty indicator at the start of the row
component = create_setting_component(key)
shared.settings_components[key] = component
current_items.append(key)