Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/4268/head
Vladimir Mandic 2025-10-12 15:35:08 -04:00
parent 8d36a5aebb
commit 2e4e741d47
30 changed files with 335 additions and 351 deletions

View File

@ -1,6 +1,6 @@
# Change Log for SD.Next # Change Log for SD.Next
## Update for 2025-10-11 ## Update for 2025-10-12
- **Models** - **Models**
- [WAN 2.2 14B VACE](https://huggingface.co/alibaba-pai/Wan2.2-VACE-Fun-A14B) - [WAN 2.2 14B VACE](https://huggingface.co/alibaba-pai/Wan2.2-VACE-Fun-A14B)
@ -17,6 +17,13 @@
*how to use*: enable nunchaku in settings -> quantization and then load either sdxl-base or sdxl-base-turbo reference models *how to use*: enable nunchaku in settings -> quantization and then load either sdxl-base or sdxl-base-turbo reference models
- [HiDream E1.1](https://huggingface.co/HiDream-ai/HiDream-E1-1) - [HiDream E1.1](https://huggingface.co/HiDream-ai/HiDream-E1-1)
updated version of E1 image editing model updated version of E1 image editing model
- [SeedVR2](https://iceclear.github.io/projects/seedvr/)
originally designed for video restoration, seedvr works great for image detailing and upscaling!
available in 3B, 7B and 7B-sharp variants
use as any other upscaler!
note: seedvr is a very large model (6.4GB and 16GB respectively) and not designed for lower-end hardware
note: seedvr is highly sensitive to its cfg scale (set in settings -> postprocessing),
lower values will result in smoother output while higher values add details
- [X-Omni SFT](https://x-omni-team.github.io/) - [X-Omni SFT](https://x-omni-team.github.io/)
*experimental*: X-omni is a transformer-only discrete autoregressive image generative model trained with reinforcement learning *experimental*: X-omni is a transformer-only discrete autoregressive image generative model trained with reinforcement learning
- **Features** - **Features**

View File

@ -482,7 +482,7 @@ def sdnq_quantize_model(model, op=None, sd_model=None, do_gc: bool = True, weigh
from modules.sdnq import sdnq_post_load_quant from modules.sdnq import sdnq_post_load_quant
if weights_dtype is None: if weights_dtype is None:
if op is not None and ("text_encoder" in op or op in {"TE", "LLM"}) and shared.opts.sdnq_quantize_weights_mode_te not in {"Same as model", "default"}: if (op is not None) and ("text_encoder" in op or op in {"TE", "LLM"}) and (shared.opts.sdnq_quantize_weights_mode_te not in {"Same as model", "default"}):
weights_dtype = shared.opts.sdnq_quantize_weights_mode_te weights_dtype = shared.opts.sdnq_quantize_weights_mode_te
else: else:
weights_dtype = shared.opts.sdnq_quantize_weights_mode weights_dtype = shared.opts.sdnq_quantize_weights_mode
@ -588,6 +588,8 @@ def sdnq_quantize_weights(sd_model):
log.info(f"Quantization: type=SDNQ time={t1-t0:.2f}") log.info(f"Quantization: type=SDNQ time={t1-t0:.2f}")
except Exception as e: except Exception as e:
log.warning(f"Quantization: type=SDNQ {e}") log.warning(f"Quantization: type=SDNQ {e}")
from modules import errors
errors.display(e, 'Quantization')
return sd_model return sd_model

View File

@ -1,9 +1,9 @@
import time import time
import random
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from torchvision.transforms import ToPILImage from torchvision.transforms import ToPILImage
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn
from modules import devices from modules import devices
from modules.shared import opts, log from modules.shared import opts, log
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
@ -19,7 +19,7 @@ to_pil = ToPILImage()
class UpscalerSeedVR(Upscaler): class UpscalerSeedVR(Upscaler):
def __init__(self, dirname=None): def __init__(self, dirname=None):
self.name = "SeedVR" self.name = "SeedVR2"
super().__init__() super().__init__()
self.scalers = [ self.scalers = [
UpscalerData(name="SeedVR2 3B", path=None, upscaler=self, model=None, scale=1), UpscalerData(name="SeedVR2 3B", path=None, upscaler=self, model=None, scale=1),
@ -32,45 +32,140 @@ class UpscalerSeedVR(Upscaler):
def load_model(self, path: str): def load_model(self, path: str):
model_name = MODELS_MAP.get(path, None) model_name = MODELS_MAP.get(path, None)
if (self.model is None) or (self.model_loaded != model_name): if (self.model is None) or (self.model_loaded != model_name):
log.debug(f'Upscaler load: name="{self.name}" model="{model_name}"') log.debug(f'Upscaler loading: name="{self.name}" model="{model_name}"')
t0 = time.time()
from modules.seedvr.src.core.model_manager import configure_runner from modules.seedvr.src.core.model_manager import configure_runner
from modules.seedvr.src.core import generation
self.model = configure_runner( self.model = configure_runner(
model_name=model_name, model_name=model_name,
cache_dir=opts.hfcache_dir, cache_dir=opts.hfcache_dir,
device=devices.device, device=devices.device,
dtype=devices.dtype, dtype=devices.dtype,
) )
self.model_loaded = model_name
self.model.dit.device = devices.device
self.model.dit.dtype = devices.dtype
self.model.vae_encode = self.vae_encode
self.model.vae_decode = self.vae_decode
self.model.model_step = generation.generation_step
generation.generation_step = self.model_step
self.model._internal_dict = {
'dit': self.model.dit,
'vae': self.model.vae,
}
t1 = time.time()
self.model.dit.config = self.model.config.dit
self.model.vae.tile_sample_min_size = 1024
self.model.vae.tile_latent_min_size = 128
# from modules.model_quant import do_post_load_quant
# from modules.sd_offload import set_diffuser_offload
# self.model = do_post_load_quant(self.model, allow=True)
# set_diffuser_offload(self.model)
log.info(f'Upscaler loaded: name="{self.name}" model="{model_name}" time={t1 - t0:.2f}')
def vae_encode(self, samples):
log.debug(f'Upscaler encode: samples={samples[0].shape if len(samples) > 0 else None} tile={self.model.vae.tile_sample_min_size} overlap={self.model.vae.tile_overlap_factor}')
latents = []
if len(samples) == 0:
return latents
self.model.dit = self.model.dit.to(device="cpu")
self.model.vae = self.model.vae.to(device=self.device)
devices.torch_gc()
from einops import rearrange
from modules.seedvr.src.optimization import memory_manager
memory_manager.clear_rope_cache(self.model)
scale = self.model.config.vae.scaling_factor
shift = self.model.config.vae.get("shifting_factor", 0.0)
batches = [sample.unsqueeze(0) for sample in samples]
for sample in batches:
sample = sample.to(self.device, self.model.vae.dtype)
sample = self.model.vae.preprocess(sample)
latent = self.model.vae.encode(sample).latent
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
latent = rearrange(latent, "b c ... -> b ... c")
latent = (latent - shift) * scale
latents.append(latent)
latents = [latent.squeeze(0) for latent in latents]
self.model.vae = self.model.vae.to(device="cpu")
devices.torch_gc()
return latents
def vae_decode(self, latents, target_dtype: torch.dtype = None):
log.debug(f'Upscaler decode: latents={latents[0].shape if len(latents) > 0 else None} tile={self.model.vae.tile_latent_min_size} overlap={self.model.vae.tile_overlap_factor}')
samples = []
if len(latents) == 0:
return samples
from einops import rearrange
from modules.seedvr.src.optimization import memory_manager
memory_manager.clear_rope_cache(self.model)
self.model.dit = self.model.dit.to(device="cpu")
self.model.vae = self.model.vae.to(device=self.device)
devices.torch_gc()
scale = self.model.config.vae.scaling_factor
shift = self.model.config.vae.get("shifting_factor", 0.0)
latents = [latent.unsqueeze(0) for latent in latents]
with devices.inference_context():
for _i, latent in enumerate(latents):
latent = latent.to(self.device, self.model.vae.dtype)
latent = latent / scale + shift
latent = rearrange(latent, "b ... c -> b c ...")
latent = latent.squeeze(2)
sample = self.model.vae.decode(latent).sample
sample = self.model.vae.postprocess(sample)
samples.append(sample)
samples = [sample.squeeze(0) for sample in samples]
self.model.vae = self.model.vae.to(device="cpu")
devices.torch_gc()
return samples
def model_step(self, *args, **kwargs):
from modules.seedvr.src.optimization import memory_manager
self.model.vae = self.model.vae.to(device="cpu")
self.model.dit = self.model.dit.to(device=self.device)
devices.torch_gc()
log.debug(f'Upscaler inference: args={len(args)} kwargs={list(kwargs.keys())}')
memory_manager.preinitialize_rope_cache(self.model)
with devices.inference_context():
result = self.model.model_step(*args, **kwargs)
self.model.dit = self.model.dit.to(device="cpu")
devices.torch_gc()
return result
def do_upscale(self, img: Image.Image, selected_file): def do_upscale(self, img: Image.Image, selected_file):
devices.torch_gc()
self.load_model(selected_file) self.load_model(selected_file)
if self.model is None: if self.model is None:
return img return img
from modules.seedvr.src.core.generation import generation_loop from modules.seedvr.src.core import generation
width = int(self.scale * img.width) // 8 * 8 width = int(self.scale * img.width) // 8 * 8
image_tensor = np.array(img) image_tensor = np.array(img)
image_tensor = torch.from_numpy(image_tensor).to(device=devices.device, dtype=devices.dtype).unsqueeze(0) / 255.0 image_tensor = torch.from_numpy(image_tensor).to(device=devices.device, dtype=devices.dtype).unsqueeze(0) / 255.0
random.seed()
seed = int(random.randrange(4294967294))
t0 = time.time() t0 = time.time()
result_tensor = generation_loop( with devices.inference_context():
runner=self.model, result_tensor = generation.generation_loop(
images=image_tensor, runner=self.model,
cfg_scale=1.0, images=image_tensor,
seed=42, cfg_scale=opts.seedvt_cfg_scale,
res_w=width, seed=seed,
batch_size=1, res_w=width,
temporal_overlap=0, batch_size=1,
device=devices.device, temporal_overlap=0,
) device=devices.device,
)
t1 = time.time() t1 = time.time()
log.info(f'Upscaler: type="{self.name}" model="{selected_file}" scale={self.scale} time={t1 - t0:.2f}') log.info(f'Upscaler: type="{self.name}" model="{selected_file}" scale={self.scale} cfg={opts.seedvt_cfg_scale} seed={seed} time={t1 - t0:.2f}')
img = to_pil(result_tensor.squeeze().permute((2, 0, 1))) img = to_pil(result_tensor.squeeze().permute((2, 0, 1)))
devices.torch_gc()
if opts.upscaler_unload: if opts.upscaler_unload:
self.model.dit = None
self.model.vae = None
self.model.cache = None
self.model = None self.model = None
log.debug(f'Upscaler unload: type="{self.name}" model="{selected_file}"') log.debug(f'Upscaler unload: type="{self.name}" model="{selected_file}"')
devices.torch_gc(force=True) devices.torch_gc(force=True)
return img return img

View File

@ -164,6 +164,8 @@ def apply_function_to_model(sd_model, function, options, op=None):
sd_model.unet = function(sd_model.unet, op="unet", sd_model=sd_model) sd_model.unet = function(sd_model.unet, op="unet", sd_model=sd_model)
if hasattr(sd_model, 'transformer') and hasattr(sd_model.transformer, 'config'): if hasattr(sd_model, 'transformer') and hasattr(sd_model.transformer, 'config'):
sd_model.transformer = function(sd_model.transformer, op="transformer", sd_model=sd_model) sd_model.transformer = function(sd_model.transformer, op="transformer", sd_model=sd_model)
if hasattr(sd_model, 'dit') and hasattr(sd_model.dit, 'config'):
sd_model.dit = function(sd_model.dit, op="dit", sd_model=sd_model)
if hasattr(sd_model, 'transformer_2') and hasattr(sd_model.transformer_2, 'config'): if hasattr(sd_model, 'transformer_2') and hasattr(sd_model.transformer_2, 'config'):
sd_model.transformer_2 = function(sd_model.transformer_2, op="transformer_2", sd_model=sd_model) sd_model.transformer_2 = function(sd_model.transformer_2, op="transformer_2", sd_model=sd_model)
if hasattr(sd_model, 'transformer_3') and hasattr(sd_model.transformer_3, 'config'): if hasattr(sd_model, 'transformer_3') and hasattr(sd_model.transformer_3, 'config'):

View File

@ -274,7 +274,10 @@ class OffloadHook(accelerate.hooks.ModelHook):
def get_pipe_variants(pipe=None): def get_pipe_variants(pipe=None):
if pipe is None: if pipe is None:
pipe = shared.sd_model if shared.sd_loaded:
pipe = shared.sd_model
else:
return [pipe]
variants = [pipe] variants = [pipe]
if hasattr(pipe, "pipe"): if hasattr(pipe, "pipe"):
variants.append(pipe.pipe) variants.append(pipe.pipe)
@ -287,7 +290,10 @@ def get_pipe_variants(pipe=None):
def get_module_names(pipe=None, exclude=[]): def get_module_names(pipe=None, exclude=[]):
if pipe is None: if pipe is None:
pipe = shared.sd_model if shared.sd_loaded:
pipe = shared.sd_model
else:
return []
if hasattr(pipe, "_internal_dict"): if hasattr(pipe, "_internal_dict"):
modules_names = pipe._internal_dict.keys() # pylint: disable=protected-access modules_names = pipe._internal_dict.keys() # pylint: disable=protected-access
else: else:

View File

@ -385,7 +385,6 @@ def sdnq_post_load_quant(
modules_dtype_dict=modules_dtype_dict.copy(), modules_dtype_dict=modules_dtype_dict.copy(),
op=op, op=op,
) )
model.quantization_config = SDNQConfig( model.quantization_config = SDNQConfig(
weights_dtype=weights_dtype, weights_dtype=weights_dtype,
group_size=group_size, group_size=group_size,
@ -402,8 +401,10 @@ def sdnq_post_load_quant(
modules_dtype_dict=modules_dtype_dict.copy(), modules_dtype_dict=modules_dtype_dict.copy(),
) )
if hasattr(model, "config"): try:
model.config.quantization_config = model.quantization_config model.config.quantization_config = model.quantization_config
except Exception:
pass
model.quantization_method = QuantizationMethod.SDNQ model.quantization_method = QuantizationMethod.SDNQ
return model return model

View File

@ -6,9 +6,9 @@ dit:
model: model:
__object__: __object__:
path: path:
- "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit_v2.nadit" - "SeedVR2_VideoUpscaler.src.models.dit_v2.nadit"
- "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit_v2.nadit" - "SeedVR2_VideoUpscaler.src.models.dit_v2.nadit"
- "src.models.dit_v2.nadit" - "modules.seedvr.src.models.dit_v2.nadit"
name: "NaDiT" name: "NaDiT"
args: "as_params" args: "as_params"
vid_in_channels: 33 vid_in_channels: 33
@ -49,9 +49,9 @@ vae:
model: model:
__object__: __object__:
path: path:
- "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" - "SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae"
- "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" - "SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae"
- "src.models.video_vae_v3.modules.attn_video_vae" - "modules.seedvr.src.models.video_vae_v3.modules.attn_video_vae"
name: "VideoAutoencoderKLWrapper" name: "VideoAutoencoderKLWrapper"
args: "as_params" args: "as_params"
freeze_encoder: False freeze_encoder: False

View File

@ -6,8 +6,8 @@ dit:
model: model:
__object__: __object__:
path: path:
- "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit.nadit" - "SeedVR2_VideoUpscaler.src.models.dit.nadit"
- "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.dit.nadit" - "SeedVR2_VideoUpscaler.src.models.dit.nadit"
- "src.models.dit.nadit" - "src.models.dit.nadit"
name: "NaDiT" name: "NaDiT"
args: "as_params" args: "as_params"
@ -46,8 +46,8 @@ vae:
model: model:
__object__: __object__:
path: path:
- "custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" - "SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae"
- "ComfyUI.custom_nodes.ComfyUI-SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae" - "SeedVR2_VideoUpscaler.src.models.video_vae_v3.modules.attn_video_vae"
- "src.models.video_vae_v3.modules.attn_video_vae" - "src.models.video_vae_v3.modules.attn_video_vae"
name: "VideoAutoencoderKLWrapper" name: "VideoAutoencoderKLWrapper"
args: "as_params" args: "as_params"

View File

@ -1,31 +1,4 @@
""" """
SeedVR2 Video Upscaler - Modular Architecture
Refactored from monolithic seedvr2.py for better maintainability
Author: Refactored codebase
Version: 2.0.0 - Modular
Available Modules:
- utils: Download and path utilities
- optimization: Memory, performance, and compatibility optimizations
- core: Model management and generation pipeline (NEW)
- processing: Video and tensor processing (coming next)
- interfaces: ComfyUI integration
"""
'''
# Track which modules are available for progressive migration
MODULES_AVAILABLE = {
'downloads': True, # ✅ Module 1 - Downloads and model management
'memory_manager': True, # ✅ Module 2 - Memory optimization
'performance': True, # ✅ Module 3 - Performance optimizations
'compatibility': True, # ✅ Module 4 - FP8/FP16 compatibility
'model_manager': True, # ✅ Module 5 - Model configuration and loading
'generation': True, # ✅ Module 6 - Generation loop and inference
'video_transforms': True, # ✅ Module 7 - Video processing and transforms
'comfyui_node': True, # ✅ Module 8 - ComfyUI node interface (COMPLETE!)
'infer': True, # ✅ Module 9 - Infer
}
'''
# Core imports (always available) # Core imports (always available)
import os import os
import sys import sys
@ -35,3 +8,4 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir) parent_dir = os.path.dirname(current_dir)
if parent_dir not in sys.path: if parent_dir not in sys.path:
sys.path.insert(0, parent_dir) sys.path.insert(0, parent_dir)
"""

View File

@ -2,6 +2,7 @@ import importlib
from typing import Any, Callable, List, Union from typing import Any, Callable, List, Union
from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf import DictConfig, ListConfig, OmegaConf
try: try:
OmegaConf.register_new_resolver("eval", eval) OmegaConf.register_new_resolver("eval", eval)
except Exception as e: except Exception as e:
@ -9,7 +10,6 @@ except Exception as e:
raise raise
def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]: def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
""" """
Load a configuration. Will resolve inheritance. Load a configuration. Will resolve inheritance.

View File

@ -2,10 +2,10 @@ import functools
import threading import threading
from typing import Callable from typing import Callable
import torch import torch
from .distributed import barrier_if_distributed, get_global_rank, get_local_rank from .distributed import barrier_if_distributed, get_global_rank, get_local_rank
from .logger import get_logger from .logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -26,7 +26,6 @@ from ..cache import Cache
from .advanced import ( from .advanced import (
get_sequence_parallel_group, get_sequence_parallel_group,
get_sequence_parallel_rank, get_sequence_parallel_rank,
get_sequence_parallel_world_size,
) )
from .basic import get_device from .basic import get_device
@ -48,7 +47,7 @@ def single_all_to_all(
""" """
A function to do all-to-all on a tensor A function to do all-to-all on a tensor
""" """
seq_world_size = dist.get_world_size(group) seq_world_size = 1
prev_scatter_dim = scatter_dim prev_scatter_dim = scatter_dim
if scatter_dim != 0: if scatter_dim != 0:
local_input = local_input.transpose(0, scatter_dim) local_input = local_input.transpose(0, scatter_dim)
@ -80,7 +79,7 @@ def _all_to_all(
gather_dim: int, gather_dim: int,
group: dist.ProcessGroup, group: dist.ProcessGroup,
): ):
seq_world_size = dist.get_world_size(group) seq_world_size = 1
input_list = [ input_list = [
t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim) t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)
] ]
@ -134,7 +133,7 @@ class Slice(torch.autograd.Function):
def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int) -> Tensor: def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int) -> Tensor:
ctx.group = group ctx.group = group
ctx.rank = dist.get_rank(group) ctx.rank = dist.get_rank(group)
seq_world_size = dist.get_world_size(group) seq_world_size = 1
ctx.seq_world_size = seq_world_size ctx.seq_world_size = seq_world_size
ctx.dim = dim ctx.dim = dim
dim_size = local_input.shape[dim] dim_size = local_input.shape[dim]
@ -163,7 +162,7 @@ class Gather(torch.autograd.Function):
ctx.rank = dist.get_rank(group) ctx.rank = dist.get_rank(group)
ctx.dim = dim ctx.dim = dim
ctx.grad_scale = grad_scale ctx.grad_scale = grad_scale
seq_world_size = dist.get_world_size(group) seq_world_size = 1
ctx.seq_world_size = seq_world_size ctx.seq_world_size = seq_world_size
dim_size = list(local_input.size()) dim_size = list(local_input.size())
split_size = dim_size[0] split_size = dim_size[0]
@ -204,7 +203,7 @@ def gather_seq_scatter_heads_qkv(
group = get_sequence_parallel_group() group = get_sequence_parallel_group()
if not group: if not group:
return qkv_tensor return qkv_tensor
world = get_sequence_parallel_world_size() world = 1
orig_shape = qkv_tensor.shape orig_shape = qkv_tensor.shape
scatter_dim = qkv_tensor.dim() scatter_dim = qkv_tensor.dim()
bef_all2all_shape = list(orig_shape) bef_all2all_shape = list(orig_shape)
@ -237,7 +236,7 @@ def slice_inputs(x: Tensor, dim: int, padding: bool = True):
if group is None: if group is None:
return x return x
sp_rank = get_sequence_parallel_rank() sp_rank = get_sequence_parallel_rank()
sp_world = get_sequence_parallel_world_size() sp_world = 1
dim_size = x.shape[dim] dim_size = x.shape[dim]
unit = (dim_size + sp_world - 1) // sp_world unit = (dim_size + sp_world - 1) // sp_world
if padding and dim_size % sp_world: if padding and dim_size % sp_world:
@ -255,7 +254,7 @@ def remove_seqeunce_parallel_padding(x: Tensor, dim: int, unpad_dim_size: int):
group = get_sequence_parallel_group() group = get_sequence_parallel_group()
if group is None: if group is None:
return x return x
sp_world = get_sequence_parallel_world_size() sp_world = 1
if unpad_dim_size % sp_world == 0: if unpad_dim_size % sp_world == 0:
return x return x
padding_size = sp_world - (unpad_dim_size % sp_world) padding_size = sp_world - (unpad_dim_size % sp_world)
@ -271,7 +270,7 @@ def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor:
if not group: if not group:
return x return x
dim_size = x.size(seq_dim) dim_size = x.size(seq_dim)
sp_world = get_sequence_parallel_world_size() sp_world = 1
if dim_size % sp_world != 0: if dim_size % sp_world != 0:
padding_size = sp_world - (dim_size % sp_world) padding_size = sp_world - (dim_size % sp_world)
x = _pad_tensor(x, seq_dim, padding_size) x = _pad_tensor(x, seq_dim, padding_size)
@ -424,7 +423,7 @@ class SPDistForward:
yield inputs yield inputs
else: else:
device = self.device device = self.device
sp_world = get_sequence_parallel_world_size() sp_world = 1
sp_rank = get_sequence_parallel_rank() sp_rank = get_sequence_parallel_rank()
for local_step in range(sp_world): for local_step in range(sp_world):
src_rank = dist.get_global_rank(group, local_step) src_rank = dist.get_global_rank(group, local_step)

View File

@ -19,9 +19,9 @@ Logging utility functions.
import logging import logging
import sys import sys
from typing import Optional from typing import Optional
from .distributed import get_global_rank, get_local_rank, get_world_size from .distributed import get_global_rank, get_local_rank, get_world_size
_default_handler = logging.StreamHandler(sys.stdout) _default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter( _default_handler.setFormatter(
logging.Formatter( logging.Formatter(

View File

@ -16,7 +16,6 @@ import random
from typing import Optional from typing import Optional
import numpy as np import numpy as np
import torch import torch
from .distributed import get_global_rank from .distributed import get_global_rank

View File

@ -1,11 +1,10 @@
import torch import torch
from torchvision.transforms import Compose, Lambda, Normalize from torchvision.transforms import Compose, Lambda, Normalize
from ..optimization.performance import optimized_video_rearrange, optimized_single_video_rearrange, optimized_sample_to_image_format
from src.optimization.performance import optimized_video_rearrange, optimized_single_video_rearrange, optimized_sample_to_image_format from ..common.seed import set_seed
from src.common.seed import set_seed from ..data.image.transforms.divisible_crop import DivisibleCrop
from src.data.image.transforms.divisible_crop import DivisibleCrop from ..data.image.transforms.na_resize import NaResize
from src.data.image.transforms.na_resize import NaResize from ..utils.color_fix import wavelet_reconstruction
from src.utils.color_fix import wavelet_reconstruction
@ -67,8 +66,6 @@ def generation_step(runner, text_embeds_dict, cond_latents, temporal_overlap, de
x = runner.schedule.forward(x, aug_noise, t) x = runner.schedule.forward(x, aug_noise, t)
return x return x
# Generate conditions with memory optimization
runner.dit.to(device=device)
condition = runner.get_condition( condition = runner.get_condition(
noises[0], noises[0],
task="sr", task="sr",
@ -76,8 +73,8 @@ def generation_step(runner, text_embeds_dict, cond_latents, temporal_overlap, de
) )
conditions = [condition] conditions = [condition]
# Use adaptive autocast for optimal performance
with torch.no_grad(): with torch.no_grad():
# Use adaptive autocast for optimal performance
video_tensors = runner.inference( video_tensors = runner.inference(
noises=noises, noises=noises,
conditions=conditions, conditions=conditions,
@ -92,6 +89,7 @@ def generation_step(runner, text_embeds_dict, cond_latents, temporal_overlap, de
cond_latents = cond_latents[0].to("cpu") cond_latents = cond_latents[0].to("cpu")
conditions = conditions[0].to("cpu") conditions = conditions[0].to("cpu")
condition = condition.to("cpu") condition = condition.to("cpu")
del noises, aug_noises, cond_latents, conditions, condition
return samples #, last_latents return samples #, last_latents
@ -133,7 +131,6 @@ def generation_loop(runner, images, cfg_scale=1.0, seed=666, res_w=720, batch_si
- Intelligent VRAM management throughout process - Intelligent VRAM management throughout process
- Real-time progress reporting - Real-time progress reporting
""" """
model_dtype = None model_dtype = None
model_dtype = next(runner.dit.parameters()).dtype model_dtype = next(runner.dit.parameters()).dtype
compute_dtype = model_dtype compute_dtype = model_dtype
@ -230,7 +227,6 @@ def generation_loop(runner, images, cfg_scale=1.0, seed=666, res_w=720, batch_si
# Apply color correction if available # Apply color correction if available
transformed_video = transformed_video.to(device) transformed_video = transformed_video.to(device)
input_video = [optimized_single_video_rearrange(transformed_video)] input_video = [optimized_single_video_rearrange(transformed_video)]
del transformed_video del transformed_video
sample = wavelet_reconstruction(sample, input_video[0][:sample.size(0)]) sample = wavelet_reconstruction(sample, input_video[0][:sample.size(0)])
@ -244,56 +240,33 @@ def generation_loop(runner, images, cfg_scale=1.0, seed=666, res_w=720, batch_si
batch_samples.append(sample_cpu) batch_samples.append(sample_cpu)
#del sample #del sample
# Aggressive cleanup after each batch
# Progress callback - batch start
if progress_callback: if progress_callback:
progress_callback(batch_count+1, total_batches, current_frames, "Processing batch...") progress_callback(batch_count+1, total_batches, current_frames, "Processing batch...")
runner.vae.to(device="cpu")
runner.dit.to(device="cpu")
# OPTIMISATION ULTIME : Pré-allocation et copie directe (évite les torch.cat multiples)
# 1. Calculer la taille totale finale # 1. Calculer la taille totale finale
total_frames = sum(batch.shape[0] for batch in batch_samples) total_frames = sum(batch.shape[0] for batch in batch_samples)
if len(batch_samples) > 0: if len(batch_samples) > 0:
sample_shape = batch_samples[0].shape sample_shape = batch_samples[0].shape
H, W, C = sample_shape[1], sample_shape[2], sample_shape[3] H, W, C = sample_shape[1], sample_shape[2], sample_shape[3]
# 2. Pré-allouer le tensor final directement sur CPU (évite concatenations)
final_video_images = torch.empty((total_frames, H, W, C), dtype=torch.float16) final_video_images = torch.empty((total_frames, H, W, C), dtype=torch.float16)
# 3. Copier par blocs directement dans le tensor final
block_size = 500 block_size = 500
current_idx = 0 current_idx = 0
for block_start in range(0, len(batch_samples), block_size): for block_start in range(0, len(batch_samples), block_size):
block_end = min(block_start + block_size, len(batch_samples)) block_end = min(block_start + block_size, len(batch_samples))
# Charger le bloc en VRAM
current_block = [] current_block = []
for i in range(block_start, block_end): for i in range(block_start, block_end):
current_block.append(batch_samples[i].to(device)) current_block.append(batch_samples[i].to(device))
# Concatener en VRAM (rapide)
block_result = torch.cat(current_block, dim=0) block_result = torch.cat(current_block, dim=0)
# Convertir en Float16 sur GPU
#if block_result.dtype != torch.float16:
# block_result = block_result.to(torch.float16)
# Copier directement dans le tensor final (pas de concatenation!)
block_frames = block_result.shape[0] block_frames = block_result.shape[0]
final_video_images[current_idx:current_idx + block_frames] = block_result.to("cpu") final_video_images[current_idx:current_idx + block_frames] = block_result.to("cpu")
current_idx += block_frames current_idx += block_frames
# Nettoyage immédiat VRAM
del current_block, block_result del current_block, block_result
else: else:
print("SeedVR2: No batch_samples to process") print("SeedVR2: No batch_samples to process")
final_video_images = torch.empty((0, 0, 0, 0), dtype=torch.float16) final_video_images = torch.empty((0, 0, 0, 0), dtype=torch.float16)
# Cleanup batch_samples
#del batch_samples
return final_video_images return final_video_images

View File

@ -1,38 +1,12 @@
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
from einops import rearrange from einops import rearrange
from omegaconf import DictConfig, ListConfig from omegaconf import DictConfig, ListConfig
from torch import Tensor from ..common.diffusion import classifier_free_guidance_dispatcher, create_sampler_from_config, create_sampling_timesteps_from_config, create_schedule_from_config
from src.common.diffusion import ( from ..models.dit_v2 import na
classifier_free_guidance_dispatcher,
create_sampler_from_config,
create_sampling_timesteps_from_config,
create_schedule_from_config,
)
from src.common.distributed import (
get_device,
)
# from common.fs import download
from src.models.dit_v2 import na
def optimized_channels_to_last(tensor): def optimized_channels_to_last(tensor: torch.Tensor) -> torch.Tensor:
"""🚀 Optimized replacement for rearrange(tensor, 'b c ... -> b ... c') """🚀 Optimized replacement for rearrange(tensor, 'b c ... -> b ... c')
Moves channels from position 1 to last position using PyTorch native operations. Moves channels from position 1 to last position using PyTorch native operations.
""" """
@ -74,7 +48,7 @@ class VideoDiffusionInfer():
self.dit = None self.dit = None
self.sampler = None self.sampler = None
self.schedule = None self.schedule = None
def get_condition(self, latent: Tensor, latent_blur: Tensor, task: str) -> Tensor: def get_condition(self, latent: torch.Tensor, latent_blur: torch.Tensor, task: str) -> torch.Tensor:
t, h, w, c = latent.shape t, h, w, c = latent.shape
cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
if task == "t2v" or t == 1: if task == "t2v" or t == 1:
@ -118,22 +92,18 @@ class VideoDiffusionInfer():
# -------------------------------- Helper ------------------------------- # # -------------------------------- Helper ------------------------------- #
@torch.no_grad() @torch.no_grad()
def vae_encode(self, samples: List[Tensor]) -> List[Tensor]: def vae_encode(self, samples: List[torch.Tensor]) -> List[torch.Tensor]:
self.dit.to(device="cpu")
self.vae.to(device=self.device)
use_sample = self.config.vae.get("use_sample", True) use_sample = self.config.vae.get("use_sample", True)
latents = [] latents = []
if len(samples) > 0: if len(samples) > 0:
device = get_device()
dtype = self.vae.dtype dtype = self.vae.dtype
scale = self.config.vae.scaling_factor scale = self.config.vae.scaling_factor
shift = self.config.vae.get("shifting_factor", 0.0) shift = self.config.vae.get("shifting_factor", 0.0)
if isinstance(scale, ListConfig): if isinstance(scale, ListConfig):
scale = torch.tensor(scale, device=device, dtype=dtype) scale = torch.tensor(scale, device=self.device, dtype=dtype)
if isinstance(shift, ListConfig): if isinstance(shift, ListConfig):
shift = torch.tensor(shift, device=device, dtype=dtype) shift = torch.tensor(shift, device=self.device, dtype=dtype)
# Group samples of the same shape to batches if enabled. # Group samples of the same shape to batches if enabled.
if self.config.vae.grouping: if self.config.vae.grouping:
@ -143,7 +113,7 @@ class VideoDiffusionInfer():
# Vae process by each group. # Vae process by each group.
for sample in batches: for sample in batches:
sample = sample.to(device, dtype) sample = sample.to(self.device, dtype)
if hasattr(self.vae, "preprocess"): if hasattr(self.vae, "preprocess"):
sample = self.vae.preprocess(sample) sample = self.vae.preprocess(sample)
if use_sample: if use_sample:
@ -162,19 +132,15 @@ class VideoDiffusionInfer():
latents = na.unpack(latents, indices) latents = na.unpack(latents, indices)
else: else:
latents = [latent.squeeze(0) for latent in latents] latents = [latent.squeeze(0) for latent in latents]
self.vae.to(device="cpu")
return latents return latents
@torch.no_grad() @torch.no_grad()
def vae_decode(self, latents: List[Tensor], target_dtype: torch.dtype = None) -> List[Tensor]: def vae_decode(self, latents: List[torch.Tensor], target_dtype: torch.dtype = None) -> List[torch.Tensor]:
"""🚀 VAE decode optimisé - décodage direct sans chunking, compatible avec autocast externe""" """🚀 VAE decode optimisé - décodage direct sans chunking, compatible avec autocast externe"""
self.dit.to(device="cpu")
self.vae.to(device=self.device)
samples = [] samples = []
if len(latents) > 0: if len(latents) > 0:
device = get_device() device = self.device
dtype = self.vae.dtype dtype = self.vae.dtype
scale = self.config.vae.scaling_factor scale = self.config.vae.scaling_factor
shift = self.config.vae.get("shifting_factor", 0.0) shift = self.config.vae.get("shifting_factor", 0.0)
@ -218,10 +184,9 @@ class VideoDiffusionInfer():
samples = na.unpack(samples, indices) samples = na.unpack(samples, indices)
else: else:
samples = [sample.squeeze(0) for sample in samples] samples = [sample.squeeze(0) for sample in samples]
self.vae.to(device="cpu")
return samples return samples
def timestep_transform(self, timesteps: Tensor, latents_shapes: Tensor): def timestep_transform(self, timesteps: torch.Tensor, latents_shapes: torch.Tensor):
# Skip if not needed. # Skip if not needed.
if not self.config.diffusion.timesteps.get("transform", False): if not self.config.diffusion.timesteps.get("transform", False):
return timesteps return timesteps
@ -256,13 +221,13 @@ class VideoDiffusionInfer():
@torch.no_grad() @torch.no_grad()
def inference( def inference(
self, self,
noises: List[Tensor], noises: List[torch.Tensor],
conditions: List[Tensor], conditions: List[torch.Tensor],
texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]], texts_pos: Union[List[str], List[torch.Tensor], List[Tuple[torch.Tensor]]],
texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]], texts_neg: Union[List[str], List[torch.Tensor], List[Tuple[torch.Tensor]]],
cfg_scale: Optional[float] = None, cfg_scale: Optional[float] = None,
temporal_overlap: int = 0, # pylint: disable=unused-argument temporal_overlap: int = 0, # pylint: disable=unused-argument
) -> List[Tensor]: ) -> List[torch.Tensor]:
assert len(noises) == len(conditions) == len(texts_pos) == len(texts_neg) assert len(noises) == len(conditions) == len(texts_pos) == len(texts_neg)
batch_size = len(noises) batch_size = len(noises)
@ -316,6 +281,7 @@ class VideoDiffusionInfer():
# Adapter les latents au dtype cible (compatible avec FP8) # Adapter les latents au dtype cible (compatible avec FP8)
latents = latents.to(target_dtype) if latents.dtype != target_dtype else latents latents = latents.to(target_dtype) if latents.dtype != target_dtype else latents
latents_cond = latents_cond.to(target_dtype) if latents_cond.dtype != target_dtype else latents_cond latents_cond = latents_cond.to(target_dtype) if latents_cond.dtype != target_dtype else latents_cond
self.dit = self.dit.to(device=self.device, dtype=target_dtype)
latents = self.sampler.sample( latents = self.sampler.sample(
x=latents, x=latents,

View File

@ -3,10 +3,9 @@ import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from safetensors.torch import load_file as load_safetensors_file from safetensors.torch import load_file as load_safetensors_file
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from ..optimization.memory_manager import preinitialize_rope_cache
from src.optimization.memory_manager import preinitialize_rope_cache from ..common.config import load_config, create_object
from src.common.config import load_config, create_object from ..core.infer import VideoDiffusionInfer
from src.core.infer import VideoDiffusionInfer
def configure_runner(model_name, cache_dir, device:str='cpu', dtype:torch.dtype=None): def configure_runner(model_name, cache_dir, device:str='cpu', dtype:torch.dtype=None):
@ -15,41 +14,39 @@ def configure_runner(model_name, cache_dir, device:str='cpu', dtype:torch.dtype=
config_path = os.path.join(script_directory, './config_7b.yaml') if "7b" in model_name else os.path.join(script_directory, './config_3b.yaml') config_path = os.path.join(script_directory, './config_7b.yaml') if "7b" in model_name else os.path.join(script_directory, './config_3b.yaml')
config = load_config(config_path) config = load_config(config_path)
vae_config_path = os.path.join(script_directory, 'src/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml')
vae_config = OmegaConf.load(vae_config_path)
vae_config.spatial_downsample_factor = vae_config.get('spatial_downsample_factor', 8)
vae_config.temporal_downsample_factor = vae_config.get('temporal_downsample_factor', 4)
config.vae.model = OmegaConf.merge(config.vae.model, vae_config)
runner = VideoDiffusionInfer(config, device=device, dtype=dtype) runner = VideoDiffusionInfer(config, device=device, dtype=dtype)
OmegaConf.set_readonly(runner.config, False) OmegaConf.set_readonly(runner.config, False)
# load dit # load dit
with torch.device("meta"): with torch.device("meta"):
runner.dit = create_object(config.dit.model) runner.dit = create_object(config.dit.model)
runner.dit.eval().to(dtype) runner.dit.requires_grad_(False).eval()
runner.dit.to_empty(device="cpu") runner.dit.to_empty(device="cpu")
model_file = hf_hub_download(repo_id=repo_id, filename=model_name, cache_dir=cache_dir) model_file = hf_hub_download(repo_id=repo_id, filename=model_name, cache_dir=cache_dir)
state_dict = load_safetensors_file(model_file) state_dict = load_safetensors_file(model_file)
runner.dit.load_state_dict(state_dict, assign=True) runner.dit.load_state_dict(state_dict, assign=True)
runner.dit = runner.dit.to(device="cpu", dtype=dtype)
del state_dict del state_dict
runner.dit = runner.dit.to(device=device, dtype=dtype)
# load vae # load vae
vae_config_path = os.path.join(script_directory, 'src/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml')
vae_config = OmegaConf.load(vae_config_path)
config.vae.model = OmegaConf.merge(config.vae.model, vae_config)
vae_file = hf_hub_download(repo_id=repo_id, filename=config.vae.checkpoint, cache_dir=cache_dir) vae_file = hf_hub_download(repo_id=repo_id, filename=config.vae.checkpoint, cache_dir=cache_dir)
with torch.device("meta"): with torch.device("meta"):
runner.vae = create_object(config.vae.model) runner.vae = create_object(config.vae.model)
runner.vae.requires_grad_(False).eval() runner.vae.requires_grad_(False).eval()
runner.vae.to_empty(device="cpu") runner.vae.to_empty(device="cpu")
state_dict = load_safetensors_file(vae_file) state_dict = load_safetensors_file(vae_file)
runner.vae.load_state_dict(state_dict) runner.vae.load_state_dict(state_dict)
del state_dict runner.vae = runner.vae.to(device="cpu", dtype=dtype)
runner.vae = runner.vae.to(device=device, dtype=dtype)
runner.config.vae.dtype = str(dtype) runner.config.vae.dtype = str(dtype)
runner.vae.set_causal_slicing(**config.vae.slicing) runner.config.vae.slicing = {'split_size': 8, 'memory_device': 'same'}
runner.config.vae.memory_limit = {'conv_max_mem': 0.2, 'norm_max_mem': 0.2}
runner.vae.set_causal_slicing(**runner.config.vae.slicing)
runner.vae.set_memory_limit(**runner.config.vae.memory_limit) runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
del state_dict
# load embeds # load embeds
pos_embeds_file = hf_hub_download(repo_id=repo_id, filename='pos_emb.pt', cache_dir=cache_dir) pos_embeds_file = hf_hub_download(repo_id=repo_id, filename='pos_emb.pt', cache_dir=cache_dir)
@ -57,5 +54,4 @@ def configure_runner(model_name, cache_dir, device:str='cpu', dtype:torch.dtype=
runner.text_pos_embeds = torch.load(pos_embeds_file).to(device=device, dtype=dtype) runner.text_pos_embeds = torch.load(pos_embeds_file).to(device=device, dtype=dtype)
runner.text_neg_embeds = torch.load(neg_embeds_file).to(device=device, dtype=dtype) runner.text_neg_embeds = torch.load(neg_embeds_file).to(device=device, dtype=dtype)
preinitialize_rope_cache(runner)
return runner return runner

View File

@ -19,13 +19,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.modules.utils import _triple from torch.nn.modules.utils import _triple
from ....common.half_precision_fixes import safe_pad_operation from ....common.half_precision_fixes import safe_pad_operation
from ....common.distributed.ops import ( from ....common.distributed.ops import gather_heads, gather_heads_scatter_seq, gather_seq_scatter_heads_qkv, scatter_heads
gather_heads,
gather_heads_scatter_seq,
gather_seq_scatter_heads_qkv,
scatter_heads,
)
from ..attention import TorchAttention from ..attention import TorchAttention
from ..mlp import get_mlp from ..mlp import get_mlp
from ..mm import MMArg, MMModule from ..mm import MMArg, MMModule

View File

@ -18,8 +18,7 @@ import torch
from einops import rearrange from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
from torch import nn from torch import nn
from ...common.cache import Cache
from src.common.cache import Cache
class RotaryEmbeddingBase(nn.Module): class RotaryEmbeddingBase(nn.Module):

View File

@ -15,7 +15,6 @@ from typing import Literal, Optional, Tuple, Union
import diffusers import diffusers
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention, SpatialNorm from diffusers.models.attention_processor import Attention, SpatialNorm
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.downsampling import Downsample2D from diffusers.models.downsampling import Downsample2D
@ -28,29 +27,12 @@ from diffusers.utils import is_torch_version
from diffusers.utils.accelerate_utils import apply_forward_hook from diffusers.utils.accelerate_utils import apply_forward_hook
from einops import rearrange from einops import rearrange
from ....common.half_precision_fixes import safe_pad_operation, safe_interpolate_operation from ....common.half_precision_fixes import safe_pad_operation, safe_interpolate_operation
from ....common.distributed.advanced import get_sequence_parallel_world_size
from ....common.logger import get_logger from ....common.logger import get_logger
from .causal_inflation_lib import ( from .causal_inflation_lib import InflatedCausalConv3d, causal_norm_wrapper, init_causal_conv3d, remove_head
InflatedCausalConv3d, from .context_parallel_lib import causal_conv_gather_outputs, causal_conv_slice_inputs
causal_norm_wrapper,
init_causal_conv3d,
remove_head,
)
from .context_parallel_lib import (
causal_conv_gather_outputs,
causal_conv_slice_inputs,
)
from .global_config import set_norm_limit from .global_config import set_norm_limit
from .types import ( from .types import CausalAutoencoderOutput, CausalDecoderOutput, CausalEncoderOutput, MemoryState, _inflation_mode_t, _memory_device_t, _receptive_field_t
CausalAutoencoderOutput,
CausalDecoderOutput,
CausalEncoderOutput,
MemoryState,
_inflation_mode_t,
_memory_device_t,
_receptive_field_t,
)
logger = get_logger(__name__) # pylint: disable=invalid-name logger = get_logger(__name__) # pylint: disable=invalid-name
@ -1064,7 +1046,7 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL):
scaling_factor: float = 0.18215, scaling_factor: float = 0.18215,
force_upcast: float = True, force_upcast: float = True,
attention: bool = True, attention: bool = True,
temporal_scale_num: int = 2, temporal_scale_num: int = 0,
slicing_up_num: int = 0, slicing_up_num: int = 0,
gradient_checkpoint: bool = False, gradient_checkpoint: bool = False,
inflation_mode: _inflation_mode_t = "tail", inflation_mode: _inflation_mode_t = "tail",
@ -1164,7 +1146,8 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL):
@apply_forward_hook @apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.slicing_encode(x) # h = self.slicing_encode(x)
h = self.tiled_encode(x)
posterior = DiagonalGaussianDistribution(h) posterior = DiagonalGaussianDistribution(h)
if not return_dict: if not return_dict:
@ -1176,7 +1159,8 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL):
def decode( def decode(
self, z: torch.Tensor, return_dict: bool = True self, z: torch.Tensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]: ) -> Union[DecoderOutput, torch.Tensor]:
decoded = self.slicing_decode(z) # decoded = self.slicing_decode(z)
decoded = self.tiled_decode(z)
if not return_dict: if not return_dict:
return (decoded,) return (decoded,)
@ -1208,7 +1192,7 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL):
return output.to(z.device) return output.to(z.device)
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
sp_size = get_sequence_parallel_world_size() sp_size = 1
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size:
x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2)
encoded_slices = [ encoded_slices = [
@ -1226,7 +1210,7 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL):
return self._encode(x) return self._encode(x)
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
sp_size = get_sequence_parallel_world_size() sp_size = 1
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size:
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2)
decoded_slices = [ decoded_slices = [
@ -1243,11 +1227,67 @@ class VideoAutoencoderKL(diffusers.AutoencoderKL):
else: else:
return self._decode(z) return self._decode(z)
def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
raise NotImplementedError blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
return b
def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
raise NotImplementedError blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
rows = []
for i in range(0, x.shape[3], overlap_size):
row = []
for j in range(0, x.shape[4], overlap_size):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self._encode(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3)
return enc
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
rows = []
for i in range(0, z.shape[3], overlap_size):
row = []
for j in range(0, z.shape[4], overlap_size):
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
return dec
def forward( def forward(
self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs

View File

@ -21,30 +21,11 @@ from diffusers.models.normalization import RMSNorm
from einops import rearrange from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import Conv3d from torch.nn import Conv3d
from .context_parallel_lib import cache_send_recv, get_cache_size from .context_parallel_lib import cache_send_recv, get_cache_size
from .global_config import get_norm_limit from .global_config import get_norm_limit
from .types import MemoryState, _inflation_mode_t, _memory_device_t from .types import MemoryState, _inflation_mode_t, _memory_device_t
from ....common.half_precision_fixes import safe_pad_operation from ....common.half_precision_fixes import safe_pad_operation
# Single GPU inference - no distributed processing needed
# Mock distributed functions for single GPU inference
def get_sequence_parallel_group():
return None
def get_sequence_parallel_rank():
return 0
def get_sequence_parallel_world_size():
return 1
def get_next_sequence_parallel_rank():
return 0
def get_prev_sequence_parallel_rank():
return 0
@contextmanager @contextmanager
def ignore_padding(model): def ignore_padding(model):
@ -172,7 +153,6 @@ class InflatedCausalConv3d(Conv3d):
if ( if (
math.isinf(self.memory_limit) math.isinf(self.memory_limit)
and torch.is_tensor(input) and torch.is_tensor(input)
and get_sequence_parallel_group() is None
): ):
return self.basic_forward(input, memory_state) return self.basic_forward(input, memory_state)
return self.slicing_forward(input, memory_state) return self.slicing_forward(input, memory_state)

View File

@ -14,11 +14,8 @@
from typing import List from typing import List
import torch import torch
import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from .types import MemoryState
# Single GPU inference - no distributed processing needed # Single GPU inference - no distributed processing needed

View File

@ -16,14 +16,8 @@ from functools import partial
from typing import Literal, Optional from typing import Literal, Optional
from torch import Tensor from torch import Tensor
from torch.nn import Conv3d from torch.nn import Conv3d
from .inflated_lib import MemoryState, extend_head, inflate_bias, inflate_weight, modify_state_dict
from .inflated_lib import (
MemoryState,
extend_head,
inflate_bias,
inflate_weight,
modify_state_dict,
)
_inflation_mode_t = Literal["none", "tail", "replicate"] _inflation_mode_t = Literal["none", "tail", "replicate"]
_memory_device_t = Optional[Literal["cpu", "same"]] _memory_device_t = Optional[Literal["cpu", "same"]]

View File

@ -19,9 +19,9 @@ import torch
from diffusers.models.normalization import RMSNorm from diffusers.models.normalization import RMSNorm
from einops import rearrange from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from ....common.logger import get_logger from ....common.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -11,37 +11,17 @@
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional, Tuple, Literal, Callable, Union from typing import Optional, Tuple, Literal, Callable, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from einops import rearrange from einops import rearrange
from ....common.half_precision_fixes import safe_pad_operation from ....common.half_precision_fixes import safe_pad_operation
from ....common.distributed.advanced import get_sequence_parallel_world_size
from ....common.logger import get_logger from ....common.logger import get_logger
from .causal_inflation_lib import ( from .causal_inflation_lib import InflatedCausalConv3d, causal_norm_wrapper, init_causal_conv3d, remove_head
InflatedCausalConv3d, from .context_parallel_lib import causal_conv_gather_outputs, causal_conv_slice_inputs
causal_norm_wrapper,
init_causal_conv3d,
remove_head,
)
from .context_parallel_lib import (
causal_conv_gather_outputs,
causal_conv_slice_inputs,
)
from .global_config import set_norm_limit from .global_config import set_norm_limit
from .types import ( from .types import CausalAutoencoderOutput, CausalDecoderOutput, CausalEncoderOutput, MemoryState, _inflation_mode_t, _memory_device_t, _receptive_field_t, _selective_checkpointing_t
CausalAutoencoderOutput,
CausalDecoderOutput,
CausalEncoderOutput,
MemoryState,
_inflation_mode_t,
_memory_device_t,
_receptive_field_t,
_selective_checkpointing_t,
)
logger = get_logger(__name__) # pylint: disable=invalid-name logger = get_logger(__name__) # pylint: disable=invalid-name
@ -717,7 +697,7 @@ class VideoAutoencoderKL(nn.Module):
use_post_quant_conv: bool = True, use_post_quant_conv: bool = True,
enc_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), enc_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",),
dec_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), dec_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",),
temporal_scale_num: int = 3, temporal_scale_num: int = 0,
slicing_up_num: int = 0, slicing_up_num: int = 0,
inflation_mode: _inflation_mode_t = "tail", inflation_mode: _inflation_mode_t = "tail",
time_receptive_field: _receptive_field_t = "half", time_receptive_field: _receptive_field_t = "half",
@ -824,7 +804,7 @@ class VideoAutoencoderKL(nn.Module):
return x return x
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
sp_size = get_sequence_parallel_world_size() sp_size = 1
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size:
x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2)
encoded_slices = [ encoded_slices = [
@ -842,7 +822,7 @@ class VideoAutoencoderKL(nn.Module):
return self._encode(x, memory_state=MemoryState.DISABLED) return self._encode(x, memory_state=MemoryState.DISABLED)
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
sp_size = get_sequence_parallel_world_size() sp_size = 1
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size:
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2)
decoded_slices = [ decoded_slices = [

View File

@ -5,12 +5,9 @@ Handles VRAM usage, cache management, and memory optimization
Extracted from: seedvr2.py (lines 373-405, 607-626, 1016-1044) Extracted from: seedvr2.py (lines 373-405, 607-626, 1016-1044)
""" """
import os
import torch import torch
import gc from ..common.cache import Cache
from typing import Tuple, Optional from ..models.dit_v2.rope import RotaryEmbeddingBase
from src.common.cache import Cache
from src.models.dit_v2.rope import RotaryEmbeddingBase
def preinitialize_rope_cache(runner) -> None: def preinitialize_rope_cache(runner) -> None:
@ -21,73 +18,63 @@ def preinitialize_rope_cache(runner) -> None:
runner: The model runner containing DiT and VAE models runner: The model runner containing DiT and VAE models
""" """
try: # Create dummy tensors to simulate common shapes
# Create dummy tensors to simulate common shapes # Format: [batch, channels, frames, height, width] for vid_shape
# Format: [batch, channels, frames, height, width] for vid_shape # Format: [batch, seq_len] for txt_shape
# Format: [batch, seq_len] for txt_shape common_shapes = [
common_shapes = [ # Common video resolutions
# Common video resolutions (torch.tensor([[1, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 1 frame, 77 tokens
(torch.tensor([[1, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 1 frame, 77 tokens (torch.tensor([[4, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 4 frames
(torch.tensor([[4, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 4 frames (torch.tensor([[5, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 5 frames (4n+1 format)
(torch.tensor([[5, 3, 3]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # 5 frames (4n+1 format) (torch.tensor([[1, 4, 4]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # Higher resolution
(torch.tensor([[1, 4, 4]], dtype=torch.long), torch.tensor([[77]], dtype=torch.long)), # Higher resolution ]
]
# Create mock cache for pre-initialization # Create mock cache for pre-initialization
temp_cache = Cache() temp_cache = Cache()
# Access RoPE modules in DiT (recursive search) # Access RoPE modules in DiT (recursive search)
def find_rope_modules(module): def find_rope_modules(module):
rope_modules = [] rope_modules = []
for name, child in module.named_modules(): for name, child in module.named_modules():
if hasattr(child, 'get_freqs') and callable(child.get_freqs): if hasattr(child, 'get_freqs') and callable(child.get_freqs):
rope_modules.append((name, child)) rope_modules.append((name, child))
return rope_modules return rope_modules
rope_modules = find_rope_modules(runner.dit) rope_modules = find_rope_modules(runner.dit)
# Pre-calculate for each RoPE module found # Pre-calculate for each RoPE module found
for name, rope_module in rope_modules: for _name, rope_module in rope_modules:
# Temporarily move module to CPU if necessary # Temporarily move module to CPU if necessary
original_device = next(rope_module.parameters()).device if list(rope_module.parameters()) else torch.device('cpu') original_device = next(rope_module.parameters()).device if list(rope_module.parameters()) else torch.device('cpu')
rope_module.to('cpu') rope_module.to('cpu')
try: for vid_shape, txt_shape in common_shapes:
for vid_shape, txt_shape in common_shapes: cache_key = f"720pswin_by_size_bysize_{tuple(vid_shape[0].tolist())}_sd3.mmrope_freqs_3d"
cache_key = f"720pswin_by_size_bysize_{tuple(vid_shape[0].tolist())}_sd3.mmrope_freqs_3d"
def compute_freqs(): def compute_freqs():
# Calculate with reduced dimensions to avoid OOM # Calculate with reduced dimensions to avoid OOM
with torch.no_grad(): with torch.no_grad():
# Detect RoPE module type # Detect RoPE module type
module_type = type(rope_module).__name__ module_type = type(rope_module).__name__
if module_type == 'NaRotaryEmbedding3d': if module_type == 'NaRotaryEmbedding3d':
# NaRotaryEmbedding3d: only takes shape (vid_shape) # NaRotaryEmbedding3d: only takes shape (vid_shape)
return rope_module.get_freqs(vid_shape.cpu()) return rope_module.get_freqs(vid_shape.cpu())
else: else:
# Standard RoPE: takes vid_shape and txt_shape # Standard RoPE: takes vid_shape and txt_shape
return rope_module.get_freqs(vid_shape.cpu(), txt_shape.cpu()) return rope_module.get_freqs(vid_shape.cpu(), txt_shape.cpu())
# Store in cache # Store in cache
temp_cache(cache_key, compute_freqs) temp_cache(cache_key, compute_freqs)
except Exception as e: rope_module.to(original_device)
print(f" ❌ Error in module {name}: {e}")
finally:
# Restore to original device
rope_module.to(original_device)
# Copy temporary cache to runner cache # Copy temporary cache to runner cache
if hasattr(runner, 'cache'): if hasattr(runner, 'cache'):
runner.cache.cache.update(temp_cache.cache) runner.cache.cache.update(temp_cache.cache)
else: else:
runner.cache = temp_cache runner.cache = temp_cache
except Exception as e:
print(f" ⚠️ Error during RoPE pre-init: {e}")
print(" 🔄 Model will work but could have OOM at first launch")
def clear_rope_cache(runner) -> None: def clear_rope_cache(runner) -> None:
@ -97,8 +84,6 @@ def clear_rope_cache(runner) -> None:
Args: Args:
runner: The model runner containing the cache runner: The model runner containing the cache
""" """
print("🧹 Cleaning RoPE cache...")
if hasattr(runner, 'cache') and hasattr(runner.cache, 'cache'): if hasattr(runner, 'cache') and hasattr(runner.cache, 'cache'):
# Count entries before cleanup # Count entries before cleanup
cache_size = len(runner.cache.cache) cache_size = len(runner.cache.cache)
@ -116,7 +101,6 @@ def clear_rope_cache(runner) -> None:
# Clear the cache # Clear the cache
runner.cache.cache.clear() runner.cache.cache.clear()
print(f" ✅ RoPE cache cleared ({cache_size} entries removed)")
if hasattr(runner, 'dit'): if hasattr(runner, 'dit'):
cleared_lru_count = 0 cleared_lru_count = 0
@ -125,7 +109,3 @@ def clear_rope_cache(runner) -> None:
if hasattr(module.get_axial_freqs, 'cache_clear'): if hasattr(module.get_axial_freqs, 'cache_clear'):
module.get_axial_freqs.cache_clear() module.get_axial_freqs.cache_clear()
cleared_lru_count += 1 cleared_lru_count += 1
if cleared_lru_count > 0:
print(f" ✅ Cleared {cleared_lru_count} LRU caches from RoPE modules.")
print("🎯 RoPE cache cleanup completed!")

View File

@ -2,7 +2,7 @@ import torch
from PIL import Image from PIL import Image
from torch import Tensor from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from src.common.half_precision_fixes import safe_pad_operation, safe_interpolate_operation from ..common.half_precision_fixes import safe_pad_operation, safe_interpolate_operation
from torchvision.transforms import ToTensor, ToPILImage from torchvision.transforms import ToTensor, ToPILImage
def adain_color_fix(target: Image, source: Image): def adain_color_fix(target: Image, source: Image):

View File

@ -1,12 +1,10 @@
import os import os
import argparse
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from huggingface_hub import snapshot_download
from torchvision.transforms import ToPILImage from torchvision.transforms import ToPILImage
from src.core.generation import generation_loop from .src.core.generation import generation_loop
from src.core.model_manager import configure_runner from .src.core.model_manager import configure_runner

View File

@ -620,7 +620,6 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
"detailer_iou": OptionInfo(0.5, "Max overlap", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05, "visible": False}), "detailer_iou": OptionInfo(0.5, "Max overlap", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05, "visible": False}),
"detailer_sigma_adjust": OptionInfo(1.0, "Detailer sigma adjust", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05, "visible": False}), "detailer_sigma_adjust": OptionInfo(1.0, "Detailer sigma adjust", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05, "visible": False}),
"detailer_sigma_adjust_max": OptionInfo(1.0, "Detailer sigma end", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05, "visible": False}), "detailer_sigma_adjust_max": OptionInfo(1.0, "Detailer sigma end", gr.Slider, {"minimum": 0, "maximum": 1.0, "step": 0.05, "visible": False}),
# "detailer_resolution": OptionInfo(1024, "Detailer resolution", gr.Slider, {"minimum": 256, "maximum": 4096, "step": 8, "visible": False}),
"detailer_min_size": OptionInfo(0.0, "Min object size", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05, "visible": False}), "detailer_min_size": OptionInfo(0.0, "Min object size", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05, "visible": False}),
"detailer_max_size": OptionInfo(1.0, "Max object size", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05, "visible": False}), "detailer_max_size": OptionInfo(1.0, "Max object size", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05, "visible": False}),
"detailer_padding": OptionInfo(20, "Item padding", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1, "visible": False}), "detailer_padding": OptionInfo(20, "Item padding", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1, "visible": False}),
@ -631,6 +630,9 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
"detailer_unload": OptionInfo(False, "Move detailer model to CPU when complete"), "detailer_unload": OptionInfo(False, "Move detailer model to CPU when complete"),
"detailer_augment": OptionInfo(True, "Detailer use model augment"), "detailer_augment": OptionInfo(True, "Detailer use model augment"),
"postprocessing_sep_seedvt": OptionInfo("<h2>SeedVT</h2>", "", gr.HTML),
"seedvt_cfg_scale": OptionInfo(3.5, "SeedVR CFG Scale", gr.Slider, {"minimum": 1, "maximum": 15, "step": 1}),
"postprocessing_sep_face_restore": OptionInfo("<h2>Face Restore</h2>", "", gr.HTML), "postprocessing_sep_face_restore": OptionInfo("<h2>Face Restore</h2>", "", gr.HTML),
"face_restoration_model": OptionInfo("None", "Face restoration", gr.Radio, lambda: {"choices": ['None'] + [x.name() for x in face_restorers]}), "face_restoration_model": OptionInfo("None", "Face restoration", gr.Radio, lambda: {"choices": ['None'] + [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.2, "CodeFormer weight parameter", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "code_former_weight": OptionInfo(0.2, "CodeFormer weight parameter", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),

View File

@ -227,7 +227,7 @@ def create_ui():
quicksettings_list.append((key, item)) quicksettings_list.append((key, item))
components.append(dummy_component) components.append(dummy_component)
else: else:
with gr.Row(elem_id=f"settings_section_row_{section_id}"): # only so we can add dirty indicator at the start of the row with gr.Row(elem_id=f"settings_section_row_{section_id}", elem_classes=["settings_section"]): # only so we can add dirty indicator at the start of the row
component = create_setting_component(key) component = create_setting_component(key)
shared.settings_components[key] = component shared.settings_components[key] = component
current_items.append(key) current_items.append(key)