diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f06fff89..860ba0038 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Change Log for SD.Next -## Update for 2025-10-11 +## Update for 2025-10-12 - **Models** - [WAN 2.2 14B VACE](https://huggingface.co/alibaba-pai/Wan2.2-VACE-Fun-A14B) @@ -17,6 +17,13 @@ *how to use*: enable nunchaku in settings -> quantization and then load either sdxl-base or sdxl-base-turbo reference models - [HiDream E1.1](https://huggingface.co/HiDream-ai/HiDream-E1-1) updated version of E1 image editing model + - [SeedVR2](https://iceclear.github.io/projects/seedvr/) + originally designed for video restoration, seedvr works great for image detailing and upscaling! + available in 3B, 7B and 7B-sharp variants + use as any other upscaler! + note: seedvr is a very large model (6.4GB and 16GB respectively) and not designed for lower-end hardware + note: seedvr is highly sensitive to its cfg scale (set in settings -> postprocessing), + lower values will result in smoother output while higher values add details - [X-Omni SFT](https://x-omni-team.github.io/) *experimental*: X-omni is a transformer-only discrete autoregressive image generative model trained with reinforcement learning - **Features** diff --git a/modules/model_quant.py b/modules/model_quant.py index c3d2439ec..5f88d3174 100644 --- a/modules/model_quant.py +++ b/modules/model_quant.py @@ -482,7 +482,7 @@ def sdnq_quantize_model(model, op=None, sd_model=None, do_gc: bool = True, weigh from modules.sdnq import sdnq_post_load_quant if weights_dtype is None: - if op is not None and ("text_encoder" in op or op in {"TE", "LLM"}) and shared.opts.sdnq_quantize_weights_mode_te not in {"Same as model", "default"}: + if (op is not None) and ("text_encoder" in op or op in {"TE", "LLM"}) and (shared.opts.sdnq_quantize_weights_mode_te not in {"Same as model", "default"}): weights_dtype = shared.opts.sdnq_quantize_weights_mode_te else: weights_dtype = shared.opts.sdnq_quantize_weights_mode @@ -588,6 +588,8 @@ def sdnq_quantize_weights(sd_model): log.info(f"Quantization: type=SDNQ time={t1-t0:.2f}") except Exception as e: log.warning(f"Quantization: type=SDNQ {e}") + from modules import errors + errors.display(e, 'Quantization') return sd_model diff --git a/modules/postprocess/seedvr_model.py b/modules/postprocess/seedvr_model.py index 1f833cbd5..fddeff557 100644 --- a/modules/postprocess/seedvr_model.py +++ b/modules/postprocess/seedvr_model.py @@ -1,9 +1,9 @@ import time +import random import numpy as np import torch from PIL import Image from torchvision.transforms import ToPILImage -from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn from modules import devices from modules.shared import opts, log from modules.upscaler import Upscaler, UpscalerData @@ -19,7 +19,7 @@ to_pil = ToPILImage() class UpscalerSeedVR(Upscaler): def __init__(self, dirname=None): - self.name = "SeedVR" + self.name = "SeedVR2" super().__init__() self.scalers = [ UpscalerData(name="SeedVR2 3B", path=None, upscaler=self, model=None, scale=1), @@ -32,45 +32,140 @@ class UpscalerSeedVR(Upscaler): def load_model(self, path: str): model_name = MODELS_MAP.get(path, None) if (self.model is None) or (self.model_loaded != model_name): - log.debug(f'Upscaler load: name="{self.name}" model="{model_name}"') + log.debug(f'Upscaler loading: name="{self.name}" model="{model_name}"') + t0 = time.time() from modules.seedvr.src.core.model_manager import configure_runner + from modules.seedvr.src.core import generation self.model = configure_runner( model_name=model_name, cache_dir=opts.hfcache_dir, device=devices.device, dtype=devices.dtype, ) + self.model_loaded = model_name + self.model.dit.device = devices.device + self.model.dit.dtype = devices.dtype + self.model.vae_encode = self.vae_encode + self.model.vae_decode = self.vae_decode + self.model.model_step = generation.generation_step + generation.generation_step = self.model_step + self.model._internal_dict = { + 'dit': self.model.dit, + 'vae': self.model.vae, + } + t1 = time.time() + self.model.dit.config = self.model.config.dit + self.model.vae.tile_sample_min_size = 1024 + self.model.vae.tile_latent_min_size = 128 + # from modules.model_quant import do_post_load_quant + # from modules.sd_offload import set_diffuser_offload + # self.model = do_post_load_quant(self.model, allow=True) + # set_diffuser_offload(self.model) + log.info(f'Upscaler loaded: name="{self.name}" model="{model_name}" time={t1 - t0:.2f}') + + def vae_encode(self, samples): + log.debug(f'Upscaler encode: samples={samples[0].shape if len(samples) > 0 else None} tile={self.model.vae.tile_sample_min_size} overlap={self.model.vae.tile_overlap_factor}') + latents = [] + if len(samples) == 0: + return latents + self.model.dit = self.model.dit.to(device="cpu") + self.model.vae = self.model.vae.to(device=self.device) + devices.torch_gc() + from einops import rearrange + from modules.seedvr.src.optimization import memory_manager + memory_manager.clear_rope_cache(self.model) + scale = self.model.config.vae.scaling_factor + shift = self.model.config.vae.get("shifting_factor", 0.0) + batches = [sample.unsqueeze(0) for sample in samples] + for sample in batches: + sample = sample.to(self.device, self.model.vae.dtype) + sample = self.model.vae.preprocess(sample) + latent = self.model.vae.encode(sample).latent + latent = latent.unsqueeze(2) if latent.ndim == 4 else latent + latent = rearrange(latent, "b c ... -> b ... c") + latent = (latent - shift) * scale + latents.append(latent) + latents = [latent.squeeze(0) for latent in latents] + self.model.vae = self.model.vae.to(device="cpu") + devices.torch_gc() + return latents + + def vae_decode(self, latents, target_dtype: torch.dtype = None): + log.debug(f'Upscaler decode: latents={latents[0].shape if len(latents) > 0 else None} tile={self.model.vae.tile_latent_min_size} overlap={self.model.vae.tile_overlap_factor}') + samples = [] + if len(latents) == 0: + return samples + from einops import rearrange + from modules.seedvr.src.optimization import memory_manager + memory_manager.clear_rope_cache(self.model) + self.model.dit = self.model.dit.to(device="cpu") + self.model.vae = self.model.vae.to(device=self.device) + devices.torch_gc() + scale = self.model.config.vae.scaling_factor + shift = self.model.config.vae.get("shifting_factor", 0.0) + latents = [latent.unsqueeze(0) for latent in latents] + with devices.inference_context(): + for _i, latent in enumerate(latents): + latent = latent.to(self.device, self.model.vae.dtype) + latent = latent / scale + shift + latent = rearrange(latent, "b ... c -> b c ...") + latent = latent.squeeze(2) + sample = self.model.vae.decode(latent).sample + sample = self.model.vae.postprocess(sample) + samples.append(sample) + samples = [sample.squeeze(0) for sample in samples] + self.model.vae = self.model.vae.to(device="cpu") + devices.torch_gc() + return samples + + def model_step(self, *args, **kwargs): + from modules.seedvr.src.optimization import memory_manager + self.model.vae = self.model.vae.to(device="cpu") + self.model.dit = self.model.dit.to(device=self.device) + devices.torch_gc() + log.debug(f'Upscaler inference: args={len(args)} kwargs={list(kwargs.keys())}') + memory_manager.preinitialize_rope_cache(self.model) + with devices.inference_context(): + result = self.model.model_step(*args, **kwargs) + self.model.dit = self.model.dit.to(device="cpu") + devices.torch_gc() + return result def do_upscale(self, img: Image.Image, selected_file): - devices.torch_gc() self.load_model(selected_file) if self.model is None: return img - from modules.seedvr.src.core.generation import generation_loop + from modules.seedvr.src.core import generation width = int(self.scale * img.width) // 8 * 8 image_tensor = np.array(img) image_tensor = torch.from_numpy(image_tensor).to(device=devices.device, dtype=devices.dtype).unsqueeze(0) / 255.0 + random.seed() + seed = int(random.randrange(4294967294)) + t0 = time.time() - result_tensor = generation_loop( - runner=self.model, - images=image_tensor, - cfg_scale=1.0, - seed=42, - res_w=width, - batch_size=1, - temporal_overlap=0, - device=devices.device, - ) + with devices.inference_context(): + result_tensor = generation.generation_loop( + runner=self.model, + images=image_tensor, + cfg_scale=opts.seedvt_cfg_scale, + seed=seed, + res_w=width, + batch_size=1, + temporal_overlap=0, + device=devices.device, + ) t1 = time.time() - log.info(f'Upscaler: type="{self.name}" model="{selected_file}" scale={self.scale} time={t1 - t0:.2f}') + log.info(f'Upscaler: type="{self.name}" model="{selected_file}" scale={self.scale} cfg={opts.seedvt_cfg_scale} seed={seed} time={t1 - t0:.2f}') img = to_pil(result_tensor.squeeze().permute((2, 0, 1))) - devices.torch_gc() if opts.upscaler_unload: + self.model.dit = None + self.model.vae = None + self.model.cache = None self.model = None log.debug(f'Upscaler unload: type="{self.name}" model="{selected_file}"') - devices.torch_gc(force=True) + devices.torch_gc(force=True) return img diff --git a/modules/sd_models_utils.py b/modules/sd_models_utils.py index 637f8c585..05f44ef6c 100644 --- a/modules/sd_models_utils.py +++ b/modules/sd_models_utils.py @@ -164,6 +164,8 @@ def apply_function_to_model(sd_model, function, options, op=None): sd_model.unet = function(sd_model.unet, op="unet", sd_model=sd_model) if hasattr(sd_model, 'transformer') and hasattr(sd_model.transformer, 'config'): sd_model.transformer = function(sd_model.transformer, op="transformer", sd_model=sd_model) + if hasattr(sd_model, 'dit') and hasattr(sd_model.dit, 'config'): + sd_model.dit = function(sd_model.dit, op="dit", sd_model=sd_model) if hasattr(sd_model, 'transformer_2') and hasattr(sd_model.transformer_2, 'config'): sd_model.transformer_2 = function(sd_model.transformer_2, op="transformer_2", sd_model=sd_model) if hasattr(sd_model, 'transformer_3') and hasattr(sd_model.transformer_3, 'config'): diff --git a/modules/sd_offload.py b/modules/sd_offload.py index d20fb0371..0d9defd5d 100644 --- a/modules/sd_offload.py +++ b/modules/sd_offload.py @@ -274,7 +274,10 @@ class OffloadHook(accelerate.hooks.ModelHook): def get_pipe_variants(pipe=None): if pipe is None: - pipe = shared.sd_model + if shared.sd_loaded: + pipe = shared.sd_model + else: + return [pipe] variants = [pipe] if hasattr(pipe, "pipe"): variants.append(pipe.pipe) @@ -287,7 +290,10 @@ def get_pipe_variants(pipe=None): def get_module_names(pipe=None, exclude=[]): if pipe is None: - pipe = shared.sd_model + if shared.sd_loaded: + pipe = shared.sd_model + else: + return [] if hasattr(pipe, "_internal_dict"): modules_names = pipe._internal_dict.keys() # pylint: disable=protected-access else: diff --git a/modules/sdnq/quantizer.py b/modules/sdnq/quantizer.py index 56c6c6d7f..499bc02b8 100644 --- a/modules/sdnq/quantizer.py +++ b/modules/sdnq/quantizer.py @@ -385,7 +385,6 @@ def sdnq_post_load_quant( modules_dtype_dict=modules_dtype_dict.copy(), op=op, ) - model.quantization_config = SDNQConfig( weights_dtype=weights_dtype, group_size=group_size, @@ -402,8 +401,10 @@ def sdnq_post_load_quant( modules_dtype_dict=modules_dtype_dict.copy(), ) - if hasattr(model, "config"): + try: model.config.quantization_config = model.quantization_config + except Exception: + pass model.quantization_method = QuantizationMethod.SDNQ return model diff --git a/modules/seedvr/config_3b.yaml b/modules/seedvr/config_3b.yaml index d2ffda469..711d4ec71 100644 --- a/modules/seedvr/config_3b.yaml +++ b/modules/seedvr/config_3b.yaml @@ -6,9 +6,9 @@ dit: model: __object__: path: - - "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit_v2.nadit" - - "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit_v2.nadit" - - "src.models.dit_v2.nadit" + - "SeedVR2_VideoUpscaler.src.models.dit_v2.nadit" + - "SeedVR2_VideoUpscaler.src.models.dit_v2.nadit" + - "modules.seedvr.src.models.dit_v2.nadit" name: "NaDiT" args: "as_params" vid_in_channels: 33 @@ -49,9 +49,9 @@ vae: model: __object__: path: - - "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" - - "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" - - "src.models.video_vae_v3.modules.attn_video_vae" + - "SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" + - "SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" + - "modules.seedvr.src.models.video_vae_v3.modules.attn_video_vae" name: "VideoAutoencoderKLWrapper" args: "as_params" freeze_encoder: False diff --git a/modules/seedvr/config_7b.yaml b/modules/seedvr/config_7b.yaml index 3f80813d9..d3097c43a 100644 --- a/modules/seedvr/config_7b.yaml +++ b/modules/seedvr/config_7b.yaml @@ -6,8 +6,8 @@ dit: model: __object__: path: - - "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit.nadit" - - "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit.nadit" + - "SeedVR2_VideoUpscaler.src.models.dit.nadit" + - "SeedVR2_VideoUpscaler.src.models.dit.nadit" - "src.models.dit.nadit" name: "NaDiT" args: "as_params" @@ -46,8 +46,8 @@ vae: model: __object__: path: - - "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" - - "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" + - "SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" + - "SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" - "src.models.video_vae_v3.modules.attn_video_vae" name: "VideoAutoencoderKLWrapper" args: "as_params" diff --git a/modules/seedvr/src/__init__.py b/modules/seedvr/src/__init__.py index a374bbd24..d9e24b028 100644 --- a/modules/seedvr/src/__init__.py +++ b/modules/seedvr/src/__init__.py @@ -1,31 +1,4 @@ """ -SeedVR2 Video Upscaler - Modular Architecture -Refactored from monolithic seedvr2.py for better maintainability - -Author: Refactored codebase -Version: 2.0.0 - Modular - -Available Modules: -- utils: Download and path utilities -- optimization: Memory, performance, and compatibility optimizations -- core: Model management and generation pipeline (NEW) -- processing: Video and tensor processing (coming next) -- interfaces: ComfyUI integration -""" -''' -# Track which modules are available for progressive migration -MODULES_AVAILABLE = { - 'downloads': True, # ✅ Module 1 - Downloads and model management - 'memory_manager': True, # ✅ Module 2 - Memory optimization - 'performance': True, # ✅ Module 3 - Performance optimizations - 'compatibility': True, # ✅ Module 4 - FP8/FP16 compatibility - 'model_manager': True, # ✅ Module 5 - Model configuration and loading - 'generation': True, # ✅ Module 6 - Generation loop and inference - 'video_transforms': True, # ✅ Module 7 - Video processing and transforms - 'comfyui_node': True, # ✅ Module 8 - ComfyUI node interface (COMPLETE!) - 'infer': True, # ✅ Module 9 - Infer -} -''' # Core imports (always available) import os import sys @@ -35,3 +8,4 @@ current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) if parent_dir not in sys.path: sys.path.insert(0, parent_dir) +""" \ No newline at end of file diff --git a/modules/seedvr/src/common/config.py b/modules/seedvr/src/common/config.py index 653685ccd..58c4e71f8 100644 --- a/modules/seedvr/src/common/config.py +++ b/modules/seedvr/src/common/config.py @@ -2,6 +2,7 @@ import importlib from typing import Any, Callable, List, Union from omegaconf import DictConfig, ListConfig, OmegaConf + try: OmegaConf.register_new_resolver("eval", eval) except Exception as e: @@ -9,7 +10,6 @@ except Exception as e: raise - def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]: """ Load a configuration. Will resolve inheritance. diff --git a/modules/seedvr/src/common/decorators.py b/modules/seedvr/src/common/decorators.py index 6eacf8bd5..cf504b3bf 100644 --- a/modules/seedvr/src/common/decorators.py +++ b/modules/seedvr/src/common/decorators.py @@ -2,10 +2,10 @@ import functools import threading from typing import Callable import torch - from .distributed import barrier_if_distributed, get_global_rank, get_local_rank from .logger import get_logger + logger = get_logger(__name__) diff --git a/modules/seedvr/src/common/distributed/ops.py b/modules/seedvr/src/common/distributed/ops.py index 7c9256706..bba121b85 100644 --- a/modules/seedvr/src/common/distributed/ops.py +++ b/modules/seedvr/src/common/distributed/ops.py @@ -26,7 +26,6 @@ from ..cache import Cache from .advanced import ( get_sequence_parallel_group, get_sequence_parallel_rank, - get_sequence_parallel_world_size, ) from .basic import get_device @@ -48,7 +47,7 @@ def single_all_to_all( """ A function to do all-to-all on a tensor """ - seq_world_size = dist.get_world_size(group) + seq_world_size = 1 prev_scatter_dim = scatter_dim if scatter_dim != 0: local_input = local_input.transpose(0, scatter_dim) @@ -80,7 +79,7 @@ def _all_to_all( gather_dim: int, group: dist.ProcessGroup, ): - seq_world_size = dist.get_world_size(group) + seq_world_size = 1 input_list = [ t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim) ] @@ -134,7 +133,7 @@ class Slice(torch.autograd.Function): def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int) -> Tensor: ctx.group = group ctx.rank = dist.get_rank(group) - seq_world_size = dist.get_world_size(group) + seq_world_size = 1 ctx.seq_world_size = seq_world_size ctx.dim = dim dim_size = local_input.shape[dim] @@ -163,7 +162,7 @@ class Gather(torch.autograd.Function): ctx.rank = dist.get_rank(group) ctx.dim = dim ctx.grad_scale = grad_scale - seq_world_size = dist.get_world_size(group) + seq_world_size = 1 ctx.seq_world_size = seq_world_size dim_size = list(local_input.size()) split_size = dim_size[0] @@ -204,7 +203,7 @@ def gather_seq_scatter_heads_qkv( group = get_sequence_parallel_group() if not group: return qkv_tensor - world = get_sequence_parallel_world_size() + world = 1 orig_shape = qkv_tensor.shape scatter_dim = qkv_tensor.dim() bef_all2all_shape = list(orig_shape) @@ -237,7 +236,7 @@ def slice_inputs(x: Tensor, dim: int, padding: bool = True): if group is None: return x sp_rank = get_sequence_parallel_rank() - sp_world = get_sequence_parallel_world_size() + sp_world = 1 dim_size = x.shape[dim] unit = (dim_size + sp_world - 1) // sp_world if padding and dim_size % sp_world: @@ -255,7 +254,7 @@ def remove_seqeunce_parallel_padding(x: Tensor, dim: int, unpad_dim_size: int): group = get_sequence_parallel_group() if group is None: return x - sp_world = get_sequence_parallel_world_size() + sp_world = 1 if unpad_dim_size % sp_world == 0: return x padding_size = sp_world - (unpad_dim_size % sp_world) @@ -271,7 +270,7 @@ def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor: if not group: return x dim_size = x.size(seq_dim) - sp_world = get_sequence_parallel_world_size() + sp_world = 1 if dim_size % sp_world != 0: padding_size = sp_world - (dim_size % sp_world) x = _pad_tensor(x, seq_dim, padding_size) @@ -424,7 +423,7 @@ class SPDistForward: yield inputs else: device = self.device - sp_world = get_sequence_parallel_world_size() + sp_world = 1 sp_rank = get_sequence_parallel_rank() for local_step in range(sp_world): src_rank = dist.get_global_rank(group, local_step) diff --git a/modules/seedvr/src/common/logger.py b/modules/seedvr/src/common/logger.py index ada368e79..48b6eab67 100644 --- a/modules/seedvr/src/common/logger.py +++ b/modules/seedvr/src/common/logger.py @@ -19,9 +19,9 @@ Logging utility functions. import logging import sys from typing import Optional - from .distributed import get_global_rank, get_local_rank, get_world_size + _default_handler = logging.StreamHandler(sys.stdout) _default_handler.setFormatter( logging.Formatter( diff --git a/modules/seedvr/src/common/seed.py b/modules/seedvr/src/common/seed.py index f3e77f2bf..2469ad944 100644 --- a/modules/seedvr/src/common/seed.py +++ b/modules/seedvr/src/common/seed.py @@ -16,7 +16,6 @@ import random from typing import Optional import numpy as np import torch - from .distributed import get_global_rank diff --git a/modules/seedvr/src/core/generation.py b/modules/seedvr/src/core/generation.py index 5f871cfe7..401d8b874 100644 --- a/modules/seedvr/src/core/generation.py +++ b/modules/seedvr/src/core/generation.py @@ -1,11 +1,10 @@ import torch from torchvision.transforms import Compose, Lambda, Normalize - -from src.optimization.performance import optimized_video_rearrange, optimized_single_video_rearrange, optimized_sample_to_image_format -from src.common.seed import set_seed -from src.data.image.transforms.divisible_crop import DivisibleCrop -from src.data.image.transforms.na_resize import NaResize -from src.utils.color_fix import wavelet_reconstruction +from ..optimization.performance import optimized_video_rearrange, optimized_single_video_rearrange, optimized_sample_to_image_format +from ..common.seed import set_seed +from ..data.image.transforms.divisible_crop import DivisibleCrop +from ..data.image.transforms.na_resize import NaResize +from ..utils.color_fix import wavelet_reconstruction @@ -67,8 +66,6 @@ def generation_step(runner, text_embeds_dict, cond_latents, temporal_overlap, de x = runner.schedule.forward(x, aug_noise, t) return x - # Generate conditions with memory optimization - runner.dit.to(device=device) condition = runner.get_condition( noises[0], task="sr", @@ -76,8 +73,8 @@ def generation_step(runner, text_embeds_dict, cond_latents, temporal_overlap, de ) conditions = [condition] - # Use adaptive autocast for optimal performance with torch.no_grad(): + # Use adaptive autocast for optimal performance video_tensors = runner.inference( noises=noises, conditions=conditions, @@ -92,6 +89,7 @@ def generation_step(runner, text_embeds_dict, cond_latents, temporal_overlap, de cond_latents = cond_latents[0].to("cpu") conditions = conditions[0].to("cpu") condition = condition.to("cpu") + del noises, aug_noises, cond_latents, conditions, condition return samples #, last_latents @@ -133,7 +131,6 @@ def generation_loop(runner, images, cfg_scale=1.0, seed=666, res_w=720, batch_si - Intelligent VRAM management throughout process - Real-time progress reporting """ - model_dtype = None model_dtype = next(runner.dit.parameters()).dtype compute_dtype = model_dtype @@ -230,7 +227,6 @@ def generation_loop(runner, images, cfg_scale=1.0, seed=666, res_w=720, batch_si # Apply color correction if available transformed_video = transformed_video.to(device) - input_video = [optimized_single_video_rearrange(transformed_video)] del transformed_video sample = wavelet_reconstruction(sample, input_video[0][:sample.size(0)]) @@ -244,56 +240,33 @@ def generation_loop(runner, images, cfg_scale=1.0, seed=666, res_w=720, batch_si batch_samples.append(sample_cpu) #del sample - # Aggressive cleanup after each batch - # Progress callback - batch start if progress_callback: progress_callback(batch_count+1, total_batches, current_frames, "Processing batch...") - runner.vae.to(device="cpu") - runner.dit.to(device="cpu") - # OPTIMISATION ULTIME : Pré-allocation et copie directe (évite les torch.cat multiples) # 1. Calculer la taille totale finale total_frames = sum(batch.shape[0] for batch in batch_samples) if len(batch_samples) > 0: sample_shape = batch_samples[0].shape H, W, C = sample_shape[1], sample_shape[2], sample_shape[3] - - # 2. Pré-allouer le tensor final directement sur CPU (évite concatenations) final_video_images = torch.empty((total_frames, H, W, C), dtype=torch.float16) - - # 3. Copier par blocs directement dans le tensor final block_size = 500 current_idx = 0 for block_start in range(0, len(batch_samples), block_size): block_end = min(block_start + block_size, len(batch_samples)) - - # Charger le bloc en VRAM current_block = [] for i in range(block_start, block_end): current_block.append(batch_samples[i].to(device)) - - # Concatener en VRAM (rapide) block_result = torch.cat(current_block, dim=0) - - # Convertir en Float16 sur GPU - #if block_result.dtype != torch.float16: - # block_result = block_result.to(torch.float16) - - # Copier directement dans le tensor final (pas de concatenation!) block_frames = block_result.shape[0] final_video_images[current_idx:current_idx + block_frames] = block_result.to("cpu") current_idx += block_frames - - # Nettoyage immédiat VRAM del current_block, block_result else: print("SeedVR2: No batch_samples to process") final_video_images = torch.empty((0, 0, 0, 0), dtype=torch.float16) - # Cleanup batch_samples - #del batch_samples return final_video_images diff --git a/modules/seedvr/src/core/infer.py b/modules/seedvr/src/core/infer.py index fd8a595a2..7882d3a61 100644 --- a/modules/seedvr/src/core/infer.py +++ b/modules/seedvr/src/core/infer.py @@ -1,38 +1,12 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - from typing import List, Optional, Tuple, Union import torch from einops import rearrange from omegaconf import DictConfig, ListConfig -from torch import Tensor -from src.common.diffusion import ( - classifier_free_guidance_dispatcher, - create_sampler_from_config, - create_sampling_timesteps_from_config, - create_schedule_from_config, -) -from src.common.distributed import ( - get_device, -) - -# from common.fs import download - -from src.models.dit_v2 import na +from ..common.diffusion import classifier_free_guidance_dispatcher, create_sampler_from_config, create_sampling_timesteps_from_config, create_schedule_from_config +from ..models.dit_v2 import na -def optimized_channels_to_last(tensor): +def optimized_channels_to_last(tensor: torch.Tensor) -> torch.Tensor: """🚀 Optimized replacement for rearrange(tensor, 'b c ... -> b ... c') Moves channels from position 1 to last position using PyTorch native operations. """ @@ -74,7 +48,7 @@ class VideoDiffusionInfer(): self.dit = None self.sampler = None self.schedule = None - def get_condition(self, latent: Tensor, latent_blur: Tensor, task: str) -> Tensor: + def get_condition(self, latent: torch.Tensor, latent_blur: torch.Tensor, task: str) -> torch.Tensor: t, h, w, c = latent.shape cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) if task == "t2v" or t == 1: @@ -118,22 +92,18 @@ class VideoDiffusionInfer(): # -------------------------------- Helper ------------------------------- # @torch.no_grad() - def vae_encode(self, samples: List[Tensor]) -> List[Tensor]: - self.dit.to(device="cpu") - self.vae.to(device=self.device) - + def vae_encode(self, samples: List[torch.Tensor]) -> List[torch.Tensor]: use_sample = self.config.vae.get("use_sample", True) latents = [] if len(samples) > 0: - device = get_device() dtype = self.vae.dtype scale = self.config.vae.scaling_factor shift = self.config.vae.get("shifting_factor", 0.0) if isinstance(scale, ListConfig): - scale = torch.tensor(scale, device=device, dtype=dtype) + scale = torch.tensor(scale, device=self.device, dtype=dtype) if isinstance(shift, ListConfig): - shift = torch.tensor(shift, device=device, dtype=dtype) + shift = torch.tensor(shift, device=self.device, dtype=dtype) # Group samples of the same shape to batches if enabled. if self.config.vae.grouping: @@ -143,7 +113,7 @@ class VideoDiffusionInfer(): # Vae process by each group. for sample in batches: - sample = sample.to(device, dtype) + sample = sample.to(self.device, dtype) if hasattr(self.vae, "preprocess"): sample = self.vae.preprocess(sample) if use_sample: @@ -162,19 +132,15 @@ class VideoDiffusionInfer(): latents = na.unpack(latents, indices) else: latents = [latent.squeeze(0) for latent in latents] - - self.vae.to(device="cpu") return latents @torch.no_grad() - def vae_decode(self, latents: List[Tensor], target_dtype: torch.dtype = None) -> List[Tensor]: + def vae_decode(self, latents: List[torch.Tensor], target_dtype: torch.dtype = None) -> List[torch.Tensor]: """🚀 VAE decode optimisé - décodage direct sans chunking, compatible avec autocast externe""" - self.dit.to(device="cpu") - self.vae.to(device=self.device) samples = [] if len(latents) > 0: - device = get_device() + device = self.device dtype = self.vae.dtype scale = self.config.vae.scaling_factor shift = self.config.vae.get("shifting_factor", 0.0) @@ -218,10 +184,9 @@ class VideoDiffusionInfer(): samples = na.unpack(samples, indices) else: samples = [sample.squeeze(0) for sample in samples] - self.vae.to(device="cpu") return samples - def timestep_transform(self, timesteps: Tensor, latents_shapes: Tensor): + def timestep_transform(self, timesteps: torch.Tensor, latents_shapes: torch.Tensor): # Skip if not needed. if not self.config.diffusion.timesteps.get("transform", False): return timesteps @@ -256,13 +221,13 @@ class VideoDiffusionInfer(): @torch.no_grad() def inference( self, - noises: List[Tensor], - conditions: List[Tensor], - texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]], - texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]], + noises: List[torch.Tensor], + conditions: List[torch.Tensor], + texts_pos: Union[List[str], List[torch.Tensor], List[Tuple[torch.Tensor]]], + texts_neg: Union[List[str], List[torch.Tensor], List[Tuple[torch.Tensor]]], cfg_scale: Optional[float] = None, temporal_overlap: int = 0, # pylint: disable=unused-argument - ) -> List[Tensor]: + ) -> List[torch.Tensor]: assert len(noises) == len(conditions) == len(texts_pos) == len(texts_neg) batch_size = len(noises) @@ -316,6 +281,7 @@ class VideoDiffusionInfer(): # Adapter les latents au dtype cible (compatible avec FP8) latents = latents.to(target_dtype) if latents.dtype != target_dtype else latents latents_cond = latents_cond.to(target_dtype) if latents_cond.dtype != target_dtype else latents_cond + self.dit = self.dit.to(device=self.device, dtype=target_dtype) latents = self.sampler.sample( x=latents, diff --git a/modules/seedvr/src/core/model_manager.py b/modules/seedvr/src/core/model_manager.py index 043e5ac61..3abade938 100644 --- a/modules/seedvr/src/core/model_manager.py +++ b/modules/seedvr/src/core/model_manager.py @@ -3,10 +3,9 @@ import torch from omegaconf import OmegaConf from safetensors.torch import load_file as load_safetensors_file from huggingface_hub import hf_hub_download - -from src.optimization.memory_manager import preinitialize_rope_cache -from src.common.config import load_config, create_object -from src.core.infer import VideoDiffusionInfer +from ..optimization.memory_manager import preinitialize_rope_cache +from ..common.config import load_config, create_object +from ..core.infer import VideoDiffusionInfer def configure_runner(model_name, cache_dir, device:str='cpu', dtype:torch.dtype=None): @@ -15,41 +14,39 @@ def configure_runner(model_name, cache_dir, device:str='cpu', dtype:torch.dtype= config_path = os.path.join(script_directory, './config_7b.yaml') if "7b" in model_name else os.path.join(script_directory, './config_3b.yaml') config = load_config(config_path) - vae_config_path = os.path.join(script_directory, 'src/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml') - vae_config = OmegaConf.load(vae_config_path) - vae_config.spatial_downsample_factor = vae_config.get('spatial_downsample_factor', 8) - vae_config.temporal_downsample_factor = vae_config.get('temporal_downsample_factor', 4) - config.vae.model = OmegaConf.merge(config.vae.model, vae_config) - runner = VideoDiffusionInfer(config, device=device, dtype=dtype) OmegaConf.set_readonly(runner.config, False) # load dit with torch.device("meta"): runner.dit = create_object(config.dit.model) - runner.dit.eval().to(dtype) + runner.dit.requires_grad_(False).eval() runner.dit.to_empty(device="cpu") - model_file = hf_hub_download(repo_id=repo_id, filename=model_name, cache_dir=cache_dir) state_dict = load_safetensors_file(model_file) runner.dit.load_state_dict(state_dict, assign=True) + runner.dit = runner.dit.to(device="cpu", dtype=dtype) del state_dict - runner.dit = runner.dit.to(device=device, dtype=dtype) # load vae + vae_config_path = os.path.join(script_directory, 'src/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml') + vae_config = OmegaConf.load(vae_config_path) + config.vae.model = OmegaConf.merge(config.vae.model, vae_config) + vae_file = hf_hub_download(repo_id=repo_id, filename=config.vae.checkpoint, cache_dir=cache_dir) with torch.device("meta"): runner.vae = create_object(config.vae.model) runner.vae.requires_grad_(False).eval() runner.vae.to_empty(device="cpu") - state_dict = load_safetensors_file(vae_file) runner.vae.load_state_dict(state_dict) - del state_dict - runner.vae = runner.vae.to(device=device, dtype=dtype) + runner.vae = runner.vae.to(device="cpu", dtype=dtype) runner.config.vae.dtype = str(dtype) - runner.vae.set_causal_slicing(**config.vae.slicing) + runner.config.vae.slicing = {'split_size': 8, 'memory_device': 'same'} + runner.config.vae.memory_limit = {'conv_max_mem': 0.2, 'norm_max_mem': 0.2} + runner.vae.set_causal_slicing(**runner.config.vae.slicing) runner.vae.set_memory_limit(**runner.config.vae.memory_limit) + del state_dict # load embeds pos_embeds_file = hf_hub_download(repo_id=repo_id, filename='pos_emb.pt', cache_dir=cache_dir) @@ -57,5 +54,4 @@ def configure_runner(model_name, cache_dir, device:str='cpu', dtype:torch.dtype= runner.text_pos_embeds = torch.load(pos_embeds_file).to(device=device, dtype=dtype) runner.text_neg_embeds = torch.load(neg_embeds_file).to(device=device, dtype=dtype) - preinitialize_rope_cache(runner) return runner diff --git a/modules/seedvr/src/models/dit/blocks/mmdit_window_block.py b/modules/seedvr/src/models/dit/blocks/mmdit_window_block.py index ee9ae9ae8..eb551aacc 100644 --- a/modules/seedvr/src/models/dit/blocks/mmdit_window_block.py +++ b/modules/seedvr/src/models/dit/blocks/mmdit_window_block.py @@ -19,13 +19,7 @@ from torch import nn from torch.nn import functional as F from torch.nn.modules.utils import _triple from ....common.half_precision_fixes import safe_pad_operation -from ....common.distributed.ops import ( - gather_heads, - gather_heads_scatter_seq, - gather_seq_scatter_heads_qkv, - scatter_heads, -) - +from ....common.distributed.ops import gather_heads, gather_heads_scatter_seq, gather_seq_scatter_heads_qkv, scatter_heads from ..attention import TorchAttention from ..mlp import get_mlp from ..mm import MMArg, MMModule diff --git a/modules/seedvr/src/models/dit_v2/rope.py b/modules/seedvr/src/models/dit_v2/rope.py index dde5b4756..89d851792 100644 --- a/modules/seedvr/src/models/dit_v2/rope.py +++ b/modules/seedvr/src/models/dit_v2/rope.py @@ -18,8 +18,7 @@ import torch from einops import rearrange from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb from torch import nn - -from src.common.cache import Cache +from ...common.cache import Cache class RotaryEmbeddingBase(nn.Module): diff --git a/modules/seedvr/src/models/video_vae_v3/modules/attn_video_vae.py b/modules/seedvr/src/models/video_vae_v3/modules/attn_video_vae.py index 1f282dfa1..3aba37b66 100644 --- a/modules/seedvr/src/models/video_vae_v3/modules/attn_video_vae.py +++ b/modules/seedvr/src/models/video_vae_v3/modules/attn_video_vae.py @@ -15,7 +15,6 @@ from typing import Literal, Optional, Tuple, Union import diffusers import torch import torch.nn as nn -import torch.nn.functional as F from diffusers.models.attention_processor import Attention, SpatialNorm from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.downsampling import Downsample2D @@ -28,29 +27,12 @@ from diffusers.utils import is_torch_version from diffusers.utils.accelerate_utils import apply_forward_hook from einops import rearrange from ....common.half_precision_fixes import safe_pad_operation, safe_interpolate_operation - -from ....common.distributed.advanced import get_sequence_parallel_world_size from ....common.logger import get_logger -from .causal_inflation_lib import ( - InflatedCausalConv3d, - causal_norm_wrapper, - init_causal_conv3d, - remove_head, -) -from .context_parallel_lib import ( - causal_conv_gather_outputs, - causal_conv_slice_inputs, -) +from .causal_inflation_lib import InflatedCausalConv3d, causal_norm_wrapper, init_causal_conv3d, remove_head +from .context_parallel_lib import causal_conv_gather_outputs, causal_conv_slice_inputs from .global_config import set_norm_limit -from .types import ( - CausalAutoencoderOutput, - CausalDecoderOutput, - CausalEncoderOutput, - MemoryState, - _inflation_mode_t, - _memory_device_t, - _receptive_field_t, -) +from .types import CausalAutoencoderOutput, CausalDecoderOutput, CausalEncoderOutput, MemoryState, _inflation_mode_t, _memory_device_t, _receptive_field_t + logger = get_logger(__name__) # pylint: disable=invalid-name @@ -1064,7 +1046,7 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL): scaling_factor: float = 0.18215, force_upcast: float = True, attention: bool = True, - temporal_scale_num: int = 2, + temporal_scale_num: int = 0, slicing_up_num: int = 0, gradient_checkpoint: bool = False, inflation_mode: _inflation_mode_t = "tail", @@ -1164,7 +1146,8 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL): @apply_forward_hook def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: - h = self.slicing_encode(x) + # h = self.slicing_encode(x) + h = self.tiled_encode(x) posterior = DiagonalGaussianDistribution(h) if not return_dict: @@ -1176,7 +1159,8 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL): def decode( self, z: torch.Tensor, return_dict: bool = True ) -> Union[DecoderOutput, torch.Tensor]: - decoded = self.slicing_decode(z) + # decoded = self.slicing_decode(z) + decoded = self.tiled_decode(z) if not return_dict: return (decoded,) @@ -1208,7 +1192,7 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL): return output.to(z.device) def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: - sp_size = get_sequence_parallel_world_size() + sp_size = 1 if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) encoded_slices = [ @@ -1226,7 +1210,7 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL): return self._encode(x) def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: - sp_size = get_sequence_parallel_world_size() + sp_size = 1 if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) decoded_slices = [ @@ -1243,12 +1227,68 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL): else: return self._decode(z) - def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - raise NotImplementedError + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b - def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: - raise NotImplementedError + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + rows = [] + for i in range(0, x.shape[3], overlap_size): + row = [] + for j in range(0, x.shape[4], overlap_size): + tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self._encode(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + enc = torch.cat(result_rows, dim=3) + return enc + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + dec = torch.cat(result_rows, dim=3) + return dec + def forward( self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs ): diff --git a/modules/seedvr/src/models/video_vae_v3/modules/causal_inflation_lib.py b/modules/seedvr/src/models/video_vae_v3/modules/causal_inflation_lib.py index 26f14739e..c6d35f0cb 100644 --- a/modules/seedvr/src/models/video_vae_v3/modules/causal_inflation_lib.py +++ b/modules/seedvr/src/models/video_vae_v3/modules/causal_inflation_lib.py @@ -21,30 +21,11 @@ from diffusers.models.normalization import RMSNorm from einops import rearrange from torch import Tensor, nn from torch.nn import Conv3d - from .context_parallel_lib import cache_send_recv, get_cache_size from .global_config import get_norm_limit from .types import MemoryState, _inflation_mode_t, _memory_device_t from ....common.half_precision_fixes import safe_pad_operation -# Single GPU inference - no distributed processing needed - -# Mock distributed functions for single GPU inference -def get_sequence_parallel_group(): - return None - -def get_sequence_parallel_rank(): - return 0 - -def get_sequence_parallel_world_size(): - return 1 - -def get_next_sequence_parallel_rank(): - return 0 - -def get_prev_sequence_parallel_rank(): - return 0 - @contextmanager def ignore_padding(model): @@ -172,7 +153,6 @@ class InflatedCausalConv3d(Conv3d): if ( math.isinf(self.memory_limit) and torch.is_tensor(input) - and get_sequence_parallel_group() is None ): return self.basic_forward(input, memory_state) return self.slicing_forward(input, memory_state) diff --git a/modules/seedvr/src/models/video_vae_v3/modules/context_parallel_lib.py b/modules/seedvr/src/models/video_vae_v3/modules/context_parallel_lib.py index ffc16938a..830d18b87 100644 --- a/modules/seedvr/src/models/video_vae_v3/modules/context_parallel_lib.py +++ b/modules/seedvr/src/models/video_vae_v3/modules/context_parallel_lib.py @@ -14,11 +14,8 @@ from typing import List import torch -import torch.nn.functional as F from torch import Tensor -from .types import MemoryState - # Single GPU inference - no distributed processing needed diff --git a/modules/seedvr/src/models/video_vae_v3/modules/inflated_layers.py b/modules/seedvr/src/models/video_vae_v3/modules/inflated_layers.py index 4b8e6dfb6..da8c1304f 100644 --- a/modules/seedvr/src/models/video_vae_v3/modules/inflated_layers.py +++ b/modules/seedvr/src/models/video_vae_v3/modules/inflated_layers.py @@ -16,14 +16,8 @@ from functools import partial from typing import Literal, Optional from torch import Tensor from torch.nn import Conv3d +from .inflated_lib import MemoryState, extend_head, inflate_bias, inflate_weight, modify_state_dict -from .inflated_lib import ( - MemoryState, - extend_head, - inflate_bias, - inflate_weight, - modify_state_dict, -) _inflation_mode_t = Literal["none", "tail", "replicate"] _memory_device_t = Optional[Literal["cpu", "same"]] diff --git a/modules/seedvr/src/models/video_vae_v3/modules/inflated_lib.py b/modules/seedvr/src/models/video_vae_v3/modules/inflated_lib.py index 486c63ca1..6feda1751 100644 --- a/modules/seedvr/src/models/video_vae_v3/modules/inflated_lib.py +++ b/modules/seedvr/src/models/video_vae_v3/modules/inflated_lib.py @@ -19,9 +19,9 @@ import torch from diffusers.models.normalization import RMSNorm from einops import rearrange from torch import Tensor, nn - from ....common.logger import get_logger + logger = get_logger(__name__) diff --git a/modules/seedvr/src/models/video_vae_v3/modules/video_vae.py b/modules/seedvr/src/models/video_vae_v3/modules/video_vae.py.old similarity index 97% rename from modules/seedvr/src/models/video_vae_v3/modules/video_vae.py rename to modules/seedvr/src/models/video_vae_v3/modules/video_vae.py.old index 98147e82f..2696c32ca 100644 --- a/modules/seedvr/src/models/video_vae_v3/modules/video_vae.py +++ b/modules/seedvr/src/models/video_vae_v3/modules/video_vae.py.old @@ -11,37 +11,17 @@ from contextlib import nullcontext from typing import Optional, Tuple, Literal, Callable, Union - import torch import torch.nn as nn -import torch.nn.functional as F from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution from einops import rearrange from ....common.half_precision_fixes import safe_pad_operation - -from ....common.distributed.advanced import get_sequence_parallel_world_size from ....common.logger import get_logger -from .causal_inflation_lib import ( - InflatedCausalConv3d, - causal_norm_wrapper, - init_causal_conv3d, - remove_head, -) -from .context_parallel_lib import ( - causal_conv_gather_outputs, - causal_conv_slice_inputs, -) +from .causal_inflation_lib import InflatedCausalConv3d, causal_norm_wrapper, init_causal_conv3d, remove_head +from .context_parallel_lib import causal_conv_gather_outputs, causal_conv_slice_inputs from .global_config import set_norm_limit -from .types import ( - CausalAutoencoderOutput, - CausalDecoderOutput, - CausalEncoderOutput, - MemoryState, - _inflation_mode_t, - _memory_device_t, - _receptive_field_t, - _selective_checkpointing_t, -) +from .types import CausalAutoencoderOutput, CausalDecoderOutput, CausalEncoderOutput, MemoryState, _inflation_mode_t, _memory_device_t, _receptive_field_t, _selective_checkpointing_t + logger = get_logger(__name__) # pylint: disable=invalid-name @@ -717,7 +697,7 @@ class VideoAutoencoderKL(nn.Module): use_post_quant_conv: bool = True, enc_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), dec_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), - temporal_scale_num: int = 3, + temporal_scale_num: int = 0, slicing_up_num: int = 0, inflation_mode: _inflation_mode_t = "tail", time_receptive_field: _receptive_field_t = "half", @@ -824,7 +804,7 @@ class VideoAutoencoderKL(nn.Module): return x def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: - sp_size = get_sequence_parallel_world_size() + sp_size = 1 if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) encoded_slices = [ @@ -842,7 +822,7 @@ class VideoAutoencoderKL(nn.Module): return self._encode(x, memory_state=MemoryState.DISABLED) def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: - sp_size = get_sequence_parallel_world_size() + sp_size = 1 if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) decoded_slices = [ diff --git a/modules/seedvr/src/optimization/memory_manager.py b/modules/seedvr/src/optimization/memory_manager.py index 3be323f67..39feada97 100644 --- a/modules/seedvr/src/optimization/memory_manager.py +++ b/modules/seedvr/src/optimization/memory_manager.py @@ -5,12 +5,9 @@ Handles VRAM usage, cache management, and memory optimization Extracted from: seedvr2.py (lines 373-405, 607-626, 1016-1044) """ -import os import torch -import gc -from typing import Tuple, Optional -from src.common.cache import Cache -from src.models.dit_v2.rope import RotaryEmbeddingBase +from ..common.cache import Cache +from ..models.dit_v2.rope import RotaryEmbeddingBase def preinitialize_rope_cache(runner) -> None: @@ -21,73 +18,63 @@ def preinitialize_rope_cache(runner) -> None: runner: The model runner containing DiT and VAE models """ - try: - # Create dummy tensors to simulate common shapes - # Format: [batch, channels, frames, height, width] for vid_shape - # Format: [batch, seq_len] for txt_shape - common_shapes = [ - # Common video resolutions - (torch.tensor([[1, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 1 frame, 77 tokens - (torch.tensor([[4, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 4 frames - (torch.tensor([[5, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 5 frames (4n+1 format) - (torch.tensor([[1, 4, 4]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # Higher resolution - ] + # Create dummy tensors to simulate common shapes + # Format: [batch, channels, frames, height, width] for vid_shape + # Format: [batch, seq_len] for txt_shape + common_shapes = [ + # Common video resolutions + (torch.tensor([[1, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 1 frame, 77 tokens + (torch.tensor([[4, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 4 frames + (torch.tensor([[5, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 5 frames (4n+1 format) + (torch.tensor([[1, 4, 4]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # Higher resolution + ] - # Create mock cache for pre-initialization + # Create mock cache for pre-initialization - temp_cache = Cache() + temp_cache = Cache() - # Access RoPE modules in DiT (recursive search) - def find_rope_modules(module): - rope_modules = [] - for name, child in module.named_modules(): - if hasattr(child, 'get_freqs') and callable(child.get_freqs): - rope_modules.append((name, child)) - return rope_modules + # Access RoPE modules in DiT (recursive search) + def find_rope_modules(module): + rope_modules = [] + for name, child in module.named_modules(): + if hasattr(child, 'get_freqs') and callable(child.get_freqs): + rope_modules.append((name, child)) + return rope_modules - rope_modules = find_rope_modules(runner.dit) + rope_modules = find_rope_modules(runner.dit) - # Pre-calculate for each RoPE module found - for name, rope_module in rope_modules: - # Temporarily move module to CPU if necessary - original_device = next(rope_module.parameters()).device if list(rope_module.parameters()) else torch.device('cpu') - rope_module.to('cpu') + # Pre-calculate for each RoPE module found + for _name, rope_module in rope_modules: + # Temporarily move module to CPU if necessary + original_device = next(rope_module.parameters()).device if list(rope_module.parameters()) else torch.device('cpu') + rope_module.to('cpu') - try: - for vid_shape, txt_shape in common_shapes: - cache_key = f"720pswin_by_size_bysize_{tuple(vid_shape[0].tolist())}_sd3.mmrope_freqs_3d" + for vid_shape, txt_shape in common_shapes: + cache_key = f"720pswin_by_size_bysize_{tuple(vid_shape[0].tolist())}_sd3.mmrope_freqs_3d" - def compute_freqs(): - # Calculate with reduced dimensions to avoid OOM - with torch.no_grad(): - # Detect RoPE module type - module_type = type(rope_module).__name__ + def compute_freqs(): + # Calculate with reduced dimensions to avoid OOM + with torch.no_grad(): + # Detect RoPE module type + module_type = type(rope_module).__name__ - if module_type == 'NaRotaryEmbedding3d': - # NaRotaryEmbedding3d: only takes shape (vid_shape) - return rope_module.get_freqs(vid_shape.cpu()) - else: - # Standard RoPE: takes vid_shape and txt_shape - return rope_module.get_freqs(vid_shape.cpu(), txt_shape.cpu()) + if module_type == 'NaRotaryEmbedding3d': + # NaRotaryEmbedding3d: only takes shape (vid_shape) + return rope_module.get_freqs(vid_shape.cpu()) + else: + # Standard RoPE: takes vid_shape and txt_shape + return rope_module.get_freqs(vid_shape.cpu(), txt_shape.cpu()) - # Store in cache - temp_cache(cache_key, compute_freqs) + # Store in cache + temp_cache(cache_key, compute_freqs) - except Exception as e: - print(f" ❌ Error in module {name}: {e}") - finally: - # Restore to original device - rope_module.to(original_device) + rope_module.to(original_device) - # Copy temporary cache to runner cache - if hasattr(runner, 'cache'): - runner.cache.cache.update(temp_cache.cache) - else: - runner.cache = temp_cache - - except Exception as e: - print(f" ⚠️ Error during RoPE pre-init: {e}") - print(" 🔄 Model will work but could have OOM at first launch") + # Copy temporary cache to runner cache + if hasattr(runner, 'cache'): + runner.cache.cache.update(temp_cache.cache) + else: + runner.cache = temp_cache def clear_rope_cache(runner) -> None: @@ -97,8 +84,6 @@ def clear_rope_cache(runner) -> None: Args: runner: The model runner containing the cache """ - print("🧹 Cleaning RoPE cache...") - if hasattr(runner, 'cache') and hasattr(runner.cache, 'cache'): # Count entries before cleanup cache_size = len(runner.cache.cache) @@ -116,7 +101,6 @@ def clear_rope_cache(runner) -> None: # Clear the cache runner.cache.cache.clear() - print(f" ✅ RoPE cache cleared ({cache_size} entries removed)") if hasattr(runner, 'dit'): cleared_lru_count = 0 @@ -125,7 +109,3 @@ def clear_rope_cache(runner) -> None: if hasattr(module.get_axial_freqs, 'cache_clear'): module.get_axial_freqs.cache_clear() cleared_lru_count += 1 - if cleared_lru_count > 0: - print(f" ✅ Cleared {cleared_lru_count} LRU caches from RoPE modules.") - - print("🎯 RoPE cache cleanup completed!") diff --git a/modules/seedvr/src/utils/color_fix.py b/modules/seedvr/src/utils/color_fix.py index 8a042892b..a8b0da509 100644 --- a/modules/seedvr/src/utils/color_fix.py +++ b/modules/seedvr/src/utils/color_fix.py @@ -2,7 +2,7 @@ import torch from PIL import Image from torch import Tensor from torch.nn import functional as F -from src.common.half_precision_fixes import safe_pad_operation, safe_interpolate_operation +from ..common.half_precision_fixes import safe_pad_operation, safe_interpolate_operation from torchvision.transforms import ToTensor, ToPILImage def adain_color_fix(target: Image, source: Image): diff --git a/modules/seedvr/test.py b/modules/seedvr/test.py index 6ecf022ab..b8c53f874 100644 --- a/modules/seedvr/test.py +++ b/modules/seedvr/test.py @@ -1,12 +1,10 @@ import os -import argparse import numpy as np import torch from PIL import Image -from huggingface_hub import snapshot_download from torchvision.transforms import ToPILImage -from src.core.generation import generation_loop -from src.core.model_manager import configure_runner +from .src.core.generation import generation_loop +from .src.core.model_manager import configure_runner diff --git a/modules/shared.py b/modules/shared.py index 16009b1e0..f987cdeb8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -620,7 +620,6 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), { "detailer_iou": OptionInfo(0.5, "Max overlap", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05, "visible": False}), "detailer_sigma_adjust": OptionInfo(1.0, "Detailer sigma adjust", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05, "visible": False}), "detailer_sigma_adjust_max": OptionInfo(1.0, "Detailer sigma end", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05, "visible": False}), - # "detailer_resolution": OptionInfo(1024, "Detailer resolution", gr.Slider, {"minimum": 256, "maximum": 4096, "step": 8, "visible": False}), "detailer_min_size": OptionInfo(0.0, "Min object size", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05, "visible": False}), "detailer_max_size": OptionInfo(1.0, "Max object size", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05, "visible": False}), "detailer_padding": OptionInfo(20, "Item padding", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1, "visible": False}), @@ -631,6 +630,9 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), { "detailer_unload": OptionInfo(False, "Move detailer model to CPU when complete"), "detailer_augment": OptionInfo(True, "Detailer use model augment"), + "postprocessing_sep_seedvt": OptionInfo("

SeedVT

", "", gr.HTML), + "seedvt_cfg_scale": OptionInfo(3.5, "SeedVR CFG Scale", gr.Slider, {"minimum": 1, "maximum": 15, "step": 1}), + "postprocessing_sep_face_restore": OptionInfo("

Face Restore

", "", gr.HTML), "face_restoration_model": OptionInfo("None", "Face restoration", gr.Radio, lambda: {"choices": ['None'] + [x.name() for x in face_restorers]}), "code_former_weight": OptionInfo(0.2, "CodeFormer weight parameter", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 926b31762..673e45941 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -227,7 +227,7 @@ def create_ui(): quicksettings_list.append((key, item)) components.append(dummy_component) else: - with gr.Row(elem_id=f"settings_section_row_{section_id}"): # only so we can add dirty indicator at the start of the row + with gr.Row(elem_id=f"settings_section_row_{section_id}", elem_classes=["settings_section"]): # only so we can add dirty indicator at the start of the row component = create_setting_component(key) shared.settings_components[key] = component current_items.append(key)