feat(ltx): canonical LTX-2.x Stage 2 recipe (LoRA + guidance + connectors)

Implement the Lightricks two-stage recipe (diffusers PR #13217) for the
LTX-2.x Dev family: Stage 1 at half-res with full four-way guidance,
2x latent upsample, Stage 2 with distilled LoRA + scheduler swap + identity
guidance on STAGE_2_DISTILLED_SIGMA_VALUES.

Extends to both LTX-2.0 and LTX-2.3 Dev via per-family distilled-LoRA
repos carried on the caps; Distilled variants take the same flow minus
the LoRA swap. Auto-couples Refine with a fixed 2x upsample on any Dev
variant with a known LoRA when the user enables Refine without Upsample.

- caps: is_ltx_2_3, use_cross_timestep, default_dynamic_shift,
  stage2_dev_lora_repo, supports_canonical_stage2, modality_default_scale,
  guidance_rescale_default; LTX-2.x defaults realigned to canonical
  cfg=3.0 / steps=30; per-variant STG block and four-way guidance wired
  for non-distilled 2.x
- process: canonical Stage 1/Stage 2 helpers, scheduler + opts snapshot
  under try/finally, per-family upsampler repo, audio latents threaded
  from Stage 1 into Stage 2, use_cross_timestep gated per caps
- overrides: skip the redundant unsharded LTX-2.3 connectors blob and
  share LTX2TextConnectors weights across 2.3 variants when te_shared_t5
- load: Gemma3 shared-TE path for LTX-2.3; gate use_dynamic_shifting=False
  override to 0.9.x only so LTX-2.x stays on its canonical token-count
  dynamic shift
pull/4783/head
CalamitousFelicitousness 2026-04-19 03:37:25 +01:00
parent 05abd99285
commit 5cf46d2f81
4 changed files with 489 additions and 267 deletions

View File

@ -10,23 +10,34 @@ class LTXCaps:
repo_cls_name: str
family: str # '0.9' or '2.x'
is_distilled: bool
is_ltx_2_3: bool
is_i2v: bool
supports_input_media: bool # accordion visible for any pipeline that accepts image/video input
supports_multi_condition: bool # uses conditions=[LTX(2)VideoCondition(...)] kwarg; Condition classes only
supports_input_media: bool
supports_multi_condition: bool
supports_image_cond_noise_scale: bool
supports_decode_timestep: bool
supports_stg: bool
supports_audio: bool
supports_frame_rate_kwarg: bool
# 2.3 transformer cross-attn reads the other modality's sigma; unset falls back to 2.0's
# independent-sigma path, which is a joint-distribution mismatch for 2.3 weights.
use_cross_timestep: bool
default_cfg: float
default_steps: int
default_sampler_shift: float
default_dynamic_shift: bool
default_width: int
default_height: int
default_frames: int
default_frame_rate: int
stg_default_scale: float = 0.0
stg_default_blocks: list = field(default_factory=list)
# Dev 2.x trained under cfg + stg + modality + rescale four-way composition;
# distilled bakes these into its sigma schedule and stays at pipeline identity.
modality_default_scale: float = 1.0
guidance_rescale_default: float = 0.0
supports_canonical_stage2: bool = False
stage2_dev_lora_repo: Optional[str] = None
CONDITION_CLASSES = {'LTXConditionPipeline', 'LTX2ConditionPipeline'}
@ -69,12 +80,14 @@ def get_caps(model_name: str) -> Optional[LTXCaps]:
is_i2v = 'I2V' in model_name or cls_name in ('LTXImageToVideoPipeline', 'LTX2ImageToVideoPipeline')
is_condition_cls = cls_name in CONDITION_CLASSES
supports_input_media = is_i2v or is_condition_cls
is_ltx_2_3 = is_ltx2 and '2.3' in model_name
caps = LTXCaps(
name=model_name,
repo_cls_name=cls_name,
family=family,
is_distilled=is_distilled,
is_ltx_2_3=is_ltx_2_3,
is_i2v=is_i2v,
supports_input_media=supports_input_media,
supports_multi_condition=is_condition_cls,
@ -83,9 +96,11 @@ def get_caps(model_name: str) -> Optional[LTXCaps]:
supports_stg=is_ltx2,
supports_audio=is_ltx2,
supports_frame_rate_kwarg=is_ltx2,
default_cfg=4.0 if is_ltx2 else 3.0,
default_steps=40 if is_ltx2 else 50,
use_cross_timestep=is_ltx_2_3,
default_cfg=3.0,
default_steps=30 if is_ltx2 else 50,
default_sampler_shift=-1.0,
default_dynamic_shift=is_ltx2,
default_width=768,
default_height=512,
default_frames=121 if is_ltx2 else 161,
@ -96,6 +111,13 @@ def get_caps(model_name: str) -> Optional[LTXCaps]:
caps.default_cfg = 1.0
caps.default_steps = 8
if is_ltx2 and not is_distilled:
if is_ltx_2_3:
caps.stage2_dev_lora_repo = 'CalamitousFelicitousness/LTX-2.3-distilled-lora-384-Diffusers'
elif '2.0' in model_name:
caps.stage2_dev_lora_repo = 'CalamitousFelicitousness/LTX-2.0-distilled-lora-384-Diffusers'
caps.supports_canonical_stage2 = caps.stage2_dev_lora_repo is not None
if is_ltx2:
if '2.3' in model_name:
caps.stg_default_blocks = [28]
@ -103,6 +125,10 @@ def get_caps(model_name: str) -> Optional[LTXCaps]:
caps.stg_default_blocks = [29]
else:
caps.stg_default_blocks = [28]
caps.stg_default_scale = 0.0
if not is_distilled:
# canonical T2V composition from huggingface/diffusers#13217
caps.stg_default_scale = 1.0
caps.modality_default_scale = 3.0
caps.guidance_rescale_default = 0.7
return caps

View File

@ -1,7 +1,5 @@
import os
import copy
import time
import numpy as np
import torch
from PIL import Image
@ -17,8 +15,51 @@ from modules.video_models.video_utils import check_av
debug = log.trace if os.environ.get('SD_VIDEO_DEBUG', None) is not None else lambda *args, **kwargs: None
upsample_repo_id_09 = 'a-r-r-o-w/LTX-Video-0.9.7-Latent-Spatial-Upsampler-diffusers'
# Upsampler weights are tied to the family VAE; using the wrong one preserves structure
# but drifts per-channel latent statistics (decodes desaturated / crushed contrast).
upsample_repo_id_20 = 'Lightricks/LTX-2'
upsample_repo_id_23 = 'CalamitousFelicitousness/LTX-2.3-Spatial-Upsampler-x2-1.1-Diffusers'
upsample_pipe = None
STAGE2_DEV_LORA_ADAPTER = 'ltx2_stage2_distilled'
def _canonical_ltx2_guidance(caps) -> dict:
# Four-way composition (cfg + stg + modality + rescale) from huggingface/diffusers#13217.
# Distilled bakes these into its sigma schedule; skip or we double-apply.
if caps.family != '2.x' or caps.is_distilled:
return {}
return {
'stg_scale': caps.stg_default_scale,
'modality_scale': caps.modality_default_scale,
'guidance_rescale': caps.guidance_rescale_default,
'spatio_temporal_guidance_blocks': list(caps.stg_default_blocks),
'audio_guidance_scale': 7.0,
'audio_stg_scale': 1.0,
'audio_modality_scale': 3.0,
'audio_guidance_rescale': 0.7,
}
def _canonical_stage2_dev_kwargs() -> dict:
# Stage 2 identity guidance from huggingface/diffusers#13217. The distilled LoRA makes Dev
# behave like Distilled, which was trained at identity; Stage 1's four-way composition on
# top double-dips and produces striping/flicker.
from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES
return {
'sigmas': list(STAGE_2_DISTILLED_SIGMA_VALUES),
'noise_scale': float(STAGE_2_DISTILLED_SIGMA_VALUES[0]),
'guidance_scale': 1.0,
'stg_scale': 0.0,
'modality_scale': 1.0,
'guidance_rescale': 0.0,
'audio_guidance_scale': 1.0,
'audio_stg_scale': 0.0,
'audio_modality_scale': 1.0,
'audio_guidance_rescale': 0.0,
'spatio_temporal_guidance_blocks': None,
}
def _latent_pass(caps, prompt, negative, width, height, frames, steps, guidance_scale, mp4_fps, conditions, image_cond_noise_scale, seed, image=None):
base_args = {
@ -43,17 +84,21 @@ def _latent_pass(caps, prompt, negative, width, height, frames, steps, guidance_
if caps.is_i2v and caps.repo_cls_name in ('LTXImageToVideoPipeline', 'LTX2ImageToVideoPipeline') and image is not None:
base_args['image'] = image
if caps.family == '2.x' and caps.is_distilled:
# distilled 2.x was trained with a fixed sigma schedule; override diffusers' linspace default
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES
base_args['sigmas'] = list(DISTILLED_SIGMA_VALUES)
base_args.pop('num_inference_steps', None)
base_args.update(_canonical_ltx2_guidance(caps))
if caps.use_cross_timestep:
base_args['use_cross_timestep'] = True
log.debug(f'Video: cls={shared.sd_model.__class__.__name__} op=latent_pass args_keys={list(base_args.keys())}')
result = shared.sd_model(**base_args)
# video latents strip the batch dim; audio latents keep it so LTX2Pipeline.prepare_audio_latents
# can rewrap them when re-entered as ndim==4 at Stage 2.
latents = result.frames[0] if hasattr(result, 'frames') else None
audio = None
audio_latents = None
if hasattr(result, 'audio') and result.audio is not None:
audio = result.audio[0].float().cpu()
return latents, audio
audio_latents = result.audio
return latents, audio_latents
def run_ltx(task_id,
@ -128,6 +173,40 @@ def run_ltx(task_id,
yield from abort(f'Video: cls={shared.sd_model.__class__.__name__} selected model is not LTX', ok=True)
return
# Lightricks TI2VidTwoStagesPipeline: Stage 1 at half-res, 2x upsample, Stage 2 refine at target.
# Auto-couple when the user picks Refine but not Upsample. Condition variants still need per-stage
# conditioning rebuild, so keep them on the same-resolution path.
auto_refine_upsample = (
refine_enable
and caps.supports_canonical_stage2
and not upsample_enable
and not caps.supports_multi_condition
)
effective_upsample_enable = upsample_enable or auto_refine_upsample
effective_upsample_ratio = upsample_ratio if upsample_enable else 2.0
target_w = get_bucket(width)
target_h = get_bucket(height)
if auto_refine_upsample:
# Stage 1 at target/2 needs multiple-of-32; 2x upsample then forces final divisible by 64.
# Derive final from base, otherwise Stage 2 silently falls to base*2 != target.
base_w = get_bucket(target_w // 2)
base_h = get_bucket(target_h // 2)
final_w = base_w * 2
final_h = base_h * 2
if (final_w, final_h) != (target_w, target_h):
log.warning(f'LTX: two-stage refine needs resolution divisible by 64; adjusting {target_w}x{target_h} -> {final_w}x{final_h}')
elif effective_upsample_enable:
base_w = target_w
base_h = target_h
final_w = get_bucket(effective_upsample_ratio * target_w)
final_h = get_bucket(effective_upsample_ratio * target_h)
else:
base_w = target_w
base_h = target_h
final_w = target_w
final_h = target_h
log.debug(f'LTX: resolution planning target={target_w}x{target_h} base={base_w}x{base_h} final={final_w}x{final_h} auto_refine_upsample={auto_refine_upsample}')
videojob = shared.state.begin('Video', task_id=task_id)
shared.state.job_count = 1
@ -170,8 +249,8 @@ def run_ltx(task_id,
sampler_name=sampler_name,
sampler_shift=float(sampler_shift),
steps=int(steps),
width=get_bucket(width),
height=get_bucket(height),
width=base_w,
height=base_h,
frames=get_frames(frames),
cfg_scale=float(guidance_scale) if guidance_scale is not None and guidance_scale > 0 else caps.default_cfg,
denoising_strength=float(condition_strength) if condition_strength is not None else 1.0,
@ -205,277 +284,354 @@ def run_ltx(task_id,
p.task_args['image'] = images.resize_image(resize_mode=2, im=effective_init_image, width=p.width, height=p.height, upscaler_name=None, output_type='pil')
if caps.family == '2.x' and caps.is_distilled:
# distilled 2.x was trained with a fixed sigma schedule; override diffusers' linspace default
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES
p.task_args['sigmas'] = list(DISTILLED_SIGMA_VALUES)
p.task_args.pop('num_inference_steps', None)
p.task_args.update(_canonical_ltx2_guidance(caps))
framewise = caps.family == '0.9'
set_vae_params(p, framewise=framewise)
# Snapshot scheduler + shared.opts before mutation so the try/finally restores on every exit
# path (abort, interrupt, Stage 2 scheduler swap). Without this, run-specific sampler settings
# leak into shared.opts.data and across runs/tabs, and the default_scheduler snapshot from
# video_load.py:171 gets clobbered by a deepcopy of the mutated scheduler on every run.
orig_dynamic_shift = shared.opts.schedulers_dynamic_shift
orig_sampler_shift = shared.opts.schedulers_shift
shared.opts.data['schedulers_dynamic_shift'] = dynamic_shift
shared.opts.data['schedulers_shift'] = sampler_shift
if hasattr(shared.sd_model, 'scheduler') and hasattr(shared.sd_model.scheduler, 'config') and hasattr(shared.sd_model.scheduler, 'register_to_config'):
if hasattr(shared.sd_model.scheduler.config, 'use_dynamic_shifting'):
shared.sd_model.scheduler.config.use_dynamic_shifting = dynamic_shift
shared.sd_model.scheduler.register_to_config(use_dynamic_shifting=dynamic_shift)
if hasattr(shared.sd_model.scheduler.config, 'flow_shift') and sampler_shift is not None and sampler_shift >= 0:
shared.sd_model.scheduler.config.flow_shift = sampler_shift
shared.sd_model.scheduler.register_to_config(flow_shift=sampler_shift)
shared.sd_model.default_scheduler = copy.deepcopy(shared.sd_model.scheduler)
if selected is not None:
video_overrides.set_overrides(p, selected)
t0 = time.time()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
t1 = time.time()
samplejob = shared.state.begin('Sample')
yield None, 'LTX: Generate in progress...'
audio = None
pixels = None
frames_out = None
needs_latent_path = upsample_enable or refine_enable
orig_scheduler = shared.sd_model.scheduler
orig_default_scheduler = getattr(shared.sd_model, 'default_scheduler', None)
orig_use_dynamic_shifting = getattr(orig_scheduler.config, 'use_dynamic_shifting', None) if hasattr(orig_scheduler, 'config') else None
orig_flow_shift = getattr(orig_scheduler.config, 'flow_shift', None) if hasattr(orig_scheduler, 'config') else None
try:
if needs_latent_path:
prompt_final, negative_final, networks = get_prompts(prompt, negative, styles)
extra_networks.activate(p, networks)
latents, audio = _latent_pass(
caps=caps,
prompt=prompt_final,
negative=negative_final,
width=width,
height=height,
frames=frames,
steps=steps,
guidance_scale=p.cfg_scale,
mp4_fps=mp4_fps,
conditions=conditions,
image_cond_noise_scale=image_cond_noise_scale if caps.supports_image_cond_noise_scale else None,
seed=int(seed) if seed is not None else -1,
image=p.task_args.get('image'),
)
else:
processed = processing.process_images(p)
if processed is None or processed.images is None or len(processed.images) == 0:
yield from abort('Video: process_images returned no frames', ok=True, p=p)
return
pixels = processed.images
if getattr(processed, 'audio', None) is not None:
audio = processed.audio
latents = None
except AssertionError as e:
yield from abort(e, ok=True, p=p)
return
except Exception as e:
yield from abort(e, ok=False, p=p)
return
shared.opts.data['schedulers_dynamic_shift'] = dynamic_shift
shared.opts.data['schedulers_shift'] = sampler_shift
if hasattr(shared.sd_model, 'scheduler') and hasattr(shared.sd_model.scheduler, 'config') and hasattr(shared.sd_model.scheduler, 'register_to_config'):
if hasattr(shared.sd_model.scheduler.config, 'use_dynamic_shifting'):
shared.sd_model.scheduler.config.use_dynamic_shifting = dynamic_shift
shared.sd_model.scheduler.register_to_config(use_dynamic_shifting=dynamic_shift)
if hasattr(shared.sd_model.scheduler.config, 'flow_shift') and sampler_shift is not None and sampler_shift >= 0:
shared.sd_model.scheduler.config.flow_shift = sampler_shift
shared.sd_model.scheduler.register_to_config(flow_shift=sampler_shift)
# Do NOT re-snapshot default_scheduler; that overwrites video_load.py:171's load-time
# snapshot with the run-mutated config, so reset_scheduler then carries the last run's choice.
t2 = time.time()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
devices.torch_gc(force=True, reason='ltx:base')
t3 = time.time()
timer.process.add('offload', t1 - t0)
timer.process.add('base', t2 - t1)
timer.process.add('offload', t3 - t2)
shared.state.end(samplejob)
if selected is not None:
video_overrides.set_overrides(p, selected)
if upsample_enable and latents is not None:
t4 = time.time()
upsamplejob = shared.state.begin('Upsample')
try:
if caps.family == '0.9':
global upsample_pipe # pylint: disable=global-statement
upsample_pipe = load_upsample(upsample_pipe, upsample_repo_id_09)
upsample_pipe = sd_models.apply_balanced_offload(upsample_pipe)
up_args = {
'width': get_bucket(upsample_ratio * width),
'height': get_bucket(upsample_ratio * height),
'generator': get_generator(int(seed) if seed is not None else -1),
'output_type': 'latent',
}
if latents.ndim == 4:
latents = latents.unsqueeze(0)
log.debug(f'Video: op=upsample family=0.9 latents={latents.shape} {up_args}')
yield None, 'LTX: Upsample in progress...'
latents = upsample_pipe(latents=latents, **up_args).frames[0]
upsample_pipe = sd_models.apply_balanced_offload(upsample_pipe)
else:
from diffusers.pipelines.ltx2.pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
log.info(f'Video load: cls={LTX2LatentUpsamplePipeline.__name__} family=2.x')
up_pipe = LTX2LatentUpsamplePipeline.from_pretrained(
'Lightricks/LTX-2-Latent-Upsampler',
vae=shared.sd_model.vae,
cache_dir=shared.opts.hfcache_dir,
torch_dtype=devices.dtype,
)
up_pipe = sd_models.apply_balanced_offload(up_pipe)
up_args = {
'width': get_bucket(upsample_ratio * width),
'height': get_bucket(upsample_ratio * height),
'num_frames': get_frames(frames),
'latents_normalized': True,
'generator': get_generator(int(seed) if seed is not None else -1),
'output_type': 'latent',
}
if latents.ndim == 4:
latents = latents.unsqueeze(0)
log.debug(f'Video: op=upsample family=2.x latents={latents.shape} {up_args}')
yield None, 'LTX: Upsample in progress...'
latents = up_pipe(latents=latents, **up_args).frames[0]
up_pipe = sd_models.apply_balanced_offload(up_pipe)
except AssertionError as e:
yield from abort(e, ok=True, p=p)
return
except Exception as e:
yield from abort(e, ok=False, p=p)
return
t5 = time.time()
timer.process.add('upsample', t5 - t4)
shared.state.end(upsamplejob)
if refine_enable and latents is not None:
t7 = time.time()
refinejob = shared.state.begin('Refine')
t0 = time.time()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
devices.torch_gc(force=True, reason='ltx:refine')
# refine is the terminal stage when enabled: let the pipeline decode internally so the final vae pass runs
# inside the same offload/cudnn context as a normal generation, matching the Generic Video tab
refine_args = {
'prompt': prompt_final,
'negative_prompt': negative_final,
'width': get_bucket((upsample_ratio if upsample_enable else 1.0) * width),
'height': get_bucket((upsample_ratio if upsample_enable else 1.0) * height),
'num_frames': get_frames(frames),
'num_inference_steps': steps,
'generator': get_generator(int(seed) if seed is not None else -1),
'callback_on_step_end': diffusers_callback,
'output_type': 'pil',
}
if p.cfg_scale is not None and p.cfg_scale > 0:
refine_args['guidance_scale'] = p.cfg_scale
if caps.supports_frame_rate_kwarg:
refine_args['frame_rate'] = float(mp4_fps)
if caps.supports_image_cond_noise_scale and image_cond_noise_scale is not None:
refine_args['image_cond_noise_scale'] = image_cond_noise_scale
if caps.supports_multi_condition and conditions:
refine_args['conditions'] = conditions
if caps.family == '2.x':
if caps.is_distilled:
# distilled variants have a canonical Stage-2 refine schedule they were trained on;
# see diffusers.pipelines.ltx2.utils and Lightricks/LTX-2 ti2vid_two_stages pipeline
from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES
refine_args['sigmas'] = list(STAGE_2_DISTILLED_SIGMA_VALUES)
else:
# non-distilled: truncate the default linspace schedule to match user-controlled refine_strength
default_sigmas = np.linspace(1.0, 1.0 / steps, steps)
num_skip = max(steps - max(int(steps * refine_strength), 1), 0)
refine_args['sigmas'] = default_sigmas[num_skip:].tolist()
refine_args.pop('num_inference_steps', None)
elif caps.repo_cls_name == 'LTXConditionPipeline':
refine_args['denoise_strength'] = refine_strength
if latents.ndim == 4:
latents = latents.unsqueeze(0)
log.debug(f'Video: op=refine cls={caps.repo_cls_name} latents={latents.shape}')
yield None, 'LTX: Refine in progress...'
try:
result = shared.sd_model(latents=latents, **refine_args)
pixels = result.frames[0] if hasattr(result, 'frames') else None
if hasattr(result, 'audio') and result.audio is not None:
audio = result.audio[0].float().cpu()
latents = None
except AssertionError as e:
yield from abort(e, ok=True, p=p)
return
except Exception as e:
yield from abort(e, ok=False, p=p)
return
t8 = time.time()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
t9 = time.time()
timer.process.add('refine', t8 - t7)
timer.process.add('offload', t9 - t8)
shared.state.end(refinejob)
t1 = time.time()
shared.opts.data['schedulers_dynamic_shift'] = orig_dynamic_shift
shared.opts.data['schedulers_shift'] = orig_sampler_shift
samplejob = shared.state.begin('Sample')
yield None, 'LTX: Generate in progress...'
if needs_latent_path:
extra_networks.deactivate(p)
if needs_latent_path and latents is not None:
# only reached when upsample ran without refine; refine decodes through the pipeline and sets latents=None
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model, exclude=['vae'], force=True)
devices.torch_gc(force=True, reason='ltx:vae')
yield None, 'LTX: VAE decode in progress...'
try:
if torch.is_tensor(latents):
# 0.9.x returns raw latents with output_type='latent'; 2.x pre-denormalizes them
frames_out = vae_decode(latents, decode_timestep if caps.supports_decode_timestep else 0.0, int(seed) if seed is not None else -1, denormalize=caps.family == '0.9')
else:
frames_out = latents
except AssertionError as e:
yield from abort(e, ok=True, p=p)
return
except Exception as e:
yield from abort(e, ok=False, p=p)
return
pixels = frames_out
t10 = time.time()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
t11 = time.time()
timer.process.add('offload', t11 - t10)
if not audio_enable:
audio = None
stage1_audio_latents = None
pixels = None
frames_out = None
needs_latent_path = upsample_enable or refine_enable
try:
aac_sample_rate = shared.sd_model.vocoder.config.output_sampling_rate
except Exception:
aac_sample_rate = 24000
try:
if needs_latent_path:
prompt_final, negative_final, networks = get_prompts(prompt, negative, styles)
extra_networks.activate(p, networks)
latents, stage1_audio_latents = _latent_pass(
caps=caps,
prompt=prompt_final,
negative=negative_final,
width=base_w,
height=base_h,
frames=frames,
steps=steps,
guidance_scale=p.cfg_scale,
mp4_fps=mp4_fps,
conditions=conditions,
image_cond_noise_scale=image_cond_noise_scale if caps.supports_image_cond_noise_scale else None,
seed=int(seed) if seed is not None else -1,
image=p.task_args.get('image'),
)
else:
processed = processing.process_images(p)
if processed is None or processed.images is None or len(processed.images) == 0:
yield from abort('Video: process_images returned no frames', ok=True, p=p)
return
pixels = processed.images
if getattr(processed, 'audio', None) is not None:
audio = processed.audio
latents = None
except AssertionError as e:
yield from abort(e, ok=True, p=p)
return
except Exception as e:
yield from abort(e, ok=False, p=p)
return
num_frames, video_file, _thumb = save_video(
p=p,
pixels=pixels,
audio=audio,
mp4_fps=mp4_fps,
mp4_codec=mp4_codec,
mp4_opt=mp4_opt,
mp4_ext=mp4_ext,
mp4_sf=mp4_sf,
mp4_video=mp4_video,
mp4_frames=mp4_frames,
mp4_interpolate=mp4_interpolate,
aac_sample_rate=aac_sample_rate,
metadata={},
)
t2 = time.time()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
devices.torch_gc(force=True, reason='ltx:base')
t3 = time.time()
timer.process.add('offload', t1 - t0)
timer.process.add('base', t2 - t1)
timer.process.add('offload', t3 - t2)
shared.state.end(samplejob)
t_end = time.time()
if isinstance(pixels, list) and len(pixels) > 0 and isinstance(pixels[0], Image.Image):
w, h = pixels[0].size
elif hasattr(pixels, 'ndim') and pixels.ndim == 5:
_n, _c, _t, h, w = pixels.shape
elif hasattr(pixels, 'ndim') and pixels.ndim == 4:
_n, h, w, _c = pixels.shape
elif hasattr(pixels, 'shape'):
h, w = pixels.shape[-2], pixels.shape[-1]
else:
w, h = p.width, p.height
resolution = f'{w}x{h}' if num_frames > 0 else None
summary = timer.process.summary(min_time=0.25, total=False).replace('=', ' ')
memory = shared.mem_mon.summary()
total_time = max(t_end - t0, 1e-6)
fps = f'{num_frames/total_time:.2f}'
its = f'{(steps)/total_time:.2f}'
if effective_upsample_enable and latents is not None:
t4 = time.time()
upsamplejob = shared.state.begin('Upsample')
try:
if caps.family == '0.9':
global upsample_pipe # pylint: disable=global-statement
upsample_pipe = load_upsample(upsample_pipe, upsample_repo_id_09)
upsample_pipe = sd_models.apply_balanced_offload(upsample_pipe)
up_args = {
'width': final_w,
'height': final_h,
'generator': get_generator(int(seed) if seed is not None else -1),
'output_type': 'latent',
}
if latents.ndim == 4:
latents = latents.unsqueeze(0)
log.debug(f'Video: op=upsample family=0.9 latents={latents.shape} {up_args}')
yield None, 'LTX: Upsample in progress...'
latents = upsample_pipe(latents=latents, **up_args).frames[0]
upsample_pipe = sd_models.apply_balanced_offload(upsample_pipe)
else:
from diffusers.pipelines.ltx2.pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
# Skip apply_balanced_offload on the upsampler; checkpoint_name differs from the main
# pipe so the shared OffloadHook (sd_offload.py:488) would rebuild and force a heavy
# re-init on the next refine. At ~2.3GB it fits on device; free after the pass.
upsample_repo = upsample_repo_id_23 if '2.3' in caps.name else upsample_repo_id_20
log.info(f'Video load: cls={LTX2LatentUpsamplePipeline.__name__} family=2.x repo={upsample_repo} auto={auto_refine_upsample}')
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
upsample_repo,
subfolder='latent_upsampler',
cache_dir=shared.opts.hfcache_dir,
torch_dtype=devices.dtype,
).to(devices.device)
up_pipe = LTX2LatentUpsamplePipeline(vae=shared.sd_model.vae, latent_upsampler=latent_upsampler)
# 2.x base pass returns denormalized latents; latents_normalized=False tells the
# upsampler "already raw, do not denormalize again".
up_args = {
'width': final_w,
'height': final_h,
'num_frames': get_frames(frames),
'latents_normalized': False,
'generator': get_generator(int(seed) if seed is not None else -1),
'output_type': 'latent',
}
if latents.ndim == 4:
latents = latents.unsqueeze(0)
log.debug(f'Video: op=upsample family=2.x latents={latents.shape} {up_args}')
yield None, 'LTX: Upsample in progress...'
latents = up_pipe(latents=latents, **up_args).frames[0]
latent_upsampler.to('cpu')
del up_pipe, latent_upsampler
devices.torch_gc(force=True, reason='ltx:upsample')
except AssertionError as e:
yield from abort(e, ok=True, p=p)
return
except Exception as e:
yield from abort(e, ok=False, p=p)
return
t5 = time.time()
timer.process.add('upsample', t5 - t4)
shared.state.end(upsamplejob)
shared.state.end(videojob)
progress.finish_task(task_id)
p.close()
if refine_enable and latents is not None:
t7 = time.time()
refinejob = shared.state.begin('Refine')
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
devices.torch_gc(force=True, reason='ltx:refine')
# Refine is terminal: let the pipe decode internally so the final VAE pass runs inside
# the same offload/cudnn context as a normal generation (matches Generic Video tab).
refine_args = {
'prompt': prompt_final,
'negative_prompt': negative_final,
'width': final_w,
'height': final_h,
'num_frames': get_frames(frames),
'num_inference_steps': steps,
'generator': get_generator(int(seed) if seed is not None else -1),
'callback_on_step_end': diffusers_callback,
'output_type': 'pil',
}
if p.cfg_scale is not None and p.cfg_scale > 0:
refine_args['guidance_scale'] = p.cfg_scale
if caps.supports_frame_rate_kwarg:
refine_args['frame_rate'] = float(mp4_fps)
if caps.supports_image_cond_noise_scale and image_cond_noise_scale is not None:
refine_args['image_cond_noise_scale'] = image_cond_noise_scale
if caps.supports_multi_condition and conditions:
refine_args['conditions'] = conditions
# Thread Stage-1 I2V init image through Stage 2 so first-frame identity survives refine.
if caps.is_i2v and caps.repo_cls_name in ('LTXImageToVideoPipeline', 'LTX2ImageToVideoPipeline') and p.task_args.get('image') is not None:
refine_args['image'] = p.task_args['image']
# Thread Stage-1 audio latents into Stage 2 on 2.x. The video branch cross-attends
# audio every layer; letting prepare_audio_latents fall back to fresh noise biases
# the video branch off-distribution (desaturated output on distilled 2.x).
if caps.family == '2.x':
if stage1_audio_latents is not None:
refine_args['audio_latents'] = stage1_audio_latents.to(device=devices.device)
if caps.use_cross_timestep:
refine_args['use_cross_timestep'] = True
log.info(f'Processed: fn="{video_file}" frames={num_frames} fps={fps} its={its} resolution={resolution} time={t_end-t0:.2f} timers={timer.process.dct()} memory={memstats.memory_stats()}')
yield video_file, f'LTX: Generation completed | File {video_file} | Frames {num_frames} | Resolution {resolution} | f/s {fps} | it/s {its} ' + f"<div class='performance'><p>{summary} {memory}</p></div>"
saved_scheduler_stage2 = None
try:
if caps.supports_canonical_stage2:
# Dev 2.x Stage 2: swap scheduler, fuse distilled LoRA, 3 steps on the distilled
# sigma schedule at identity guidance (huggingface/diffusers#13217).
log.info(f'LTX: canonical Stage 2 via distilled LoRA repo={caps.stage2_dev_lora_repo}')
from diffusers import FlowMatchEulerDiscreteScheduler
offline_args = {'local_files_only': True} if shared.opts.offline_mode else {}
saved_scheduler_stage2 = shared.sd_model.scheduler
shared.sd_model.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
saved_scheduler_stage2.config,
use_dynamic_shifting=False,
shift_terminal=None,
)
shared.sd_model.load_lora_weights(
caps.stage2_dev_lora_repo,
adapter_name=STAGE2_DEV_LORA_ADAPTER,
cache_dir=shared.opts.hfcache_dir,
**offline_args,
)
shared.sd_model.set_adapters([STAGE2_DEV_LORA_ADAPTER], [1.0])
# Do NOT apply _canonical_ltx2_guidance on this path; its audio-branch kwargs
# would clobber the identity set.
refine_args.update(_canonical_stage2_dev_kwargs())
refine_args.pop('num_inference_steps', None)
elif caps.family == '2.x':
# Distilled 2.x. Dev 2.x with a LoRA hit the branch above.
from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES
refine_args['sigmas'] = list(STAGE_2_DISTILLED_SIGMA_VALUES)
refine_args.pop('num_inference_steps', None)
# LTX2Pipeline/LTX2ImageToVideoPipeline default noise_scale=0.0 when not passed;
# sigma=0 user latents mismatched against sigmas[0] scheduler collapses output.
# LTX2ConditionPipeline auto-infers this; do the same explicitly for T2V/I2V.
refine_args['noise_scale'] = float(refine_args['sigmas'][0])
refine_args.update(_canonical_ltx2_guidance(caps))
elif caps.repo_cls_name == 'LTXConditionPipeline':
refine_args['denoise_strength'] = refine_strength
if latents.ndim == 4:
latents = latents.unsqueeze(0)
log.debug(f'Video: op=refine cls={caps.repo_cls_name} latents={latents.shape} canonical_stage2={caps.supports_canonical_stage2}')
yield None, 'LTX: Refine in progress...'
try:
result = shared.sd_model(latents=latents, **refine_args)
pixels = result.frames[0] if hasattr(result, 'frames') else None
if hasattr(result, 'audio') and result.audio is not None:
audio = result.audio[0].float().cpu()
latents = None
except AssertionError as e:
yield from abort(e, ok=True, p=p)
return
except Exception as e:
yield from abort(e, ok=False, p=p)
return
finally:
if saved_scheduler_stage2 is not None:
try:
from modules.lora.extra_networks_lora import unload_diffusers
unload_diffusers()
except Exception as e:
log.warning(f'LTX: canonical Stage 2 LoRA unload failed: {e}')
shared.sd_model.scheduler = saved_scheduler_stage2
log.debug('LTX: canonical Stage 2 cleanup done (LoRA unloaded, scheduler restored)')
t8 = time.time()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
t9 = time.time()
timer.process.add('refine', t8 - t7)
timer.process.add('offload', t9 - t8)
shared.state.end(refinejob)
if needs_latent_path:
extra_networks.deactivate(p)
if needs_latent_path and latents is not None:
# Only reached on upsample-without-refine; refine decodes through the pipe and nulls latents.
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model, exclude=['vae'], force=True)
devices.torch_gc(force=True, reason='ltx:vae')
yield None, 'LTX: VAE decode in progress...'
try:
if torch.is_tensor(latents):
# 0.9.x returns raw latents with output_type='latent'; 2.x pre-denormalizes.
frames_out = vae_decode(latents, decode_timestep if caps.supports_decode_timestep else 0.0, int(seed) if seed is not None else -1, denormalize=caps.family == '0.9')
else:
frames_out = latents
except AssertionError as e:
yield from abort(e, ok=True, p=p)
return
except Exception as e:
yield from abort(e, ok=False, p=p)
return
pixels = frames_out
t10 = time.time()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
t11 = time.time()
timer.process.add('offload', t11 - t10)
if not audio_enable:
audio = None
try:
aac_sample_rate = shared.sd_model.vocoder.config.output_sampling_rate
except Exception:
aac_sample_rate = 24000
num_frames, video_file, _thumb = save_video(
p=p,
pixels=pixels,
audio=audio,
mp4_fps=mp4_fps,
mp4_codec=mp4_codec,
mp4_opt=mp4_opt,
mp4_ext=mp4_ext,
mp4_sf=mp4_sf,
mp4_video=mp4_video,
mp4_frames=mp4_frames,
mp4_interpolate=mp4_interpolate,
aac_sample_rate=aac_sample_rate,
metadata={},
)
t_end = time.time()
if isinstance(pixels, list) and len(pixels) > 0 and isinstance(pixels[0], Image.Image):
w, h = pixels[0].size
elif hasattr(pixels, 'ndim') and pixels.ndim == 5:
_n, _c, _t, h, w = pixels.shape
elif hasattr(pixels, 'ndim') and pixels.ndim == 4:
_n, h, w, _c = pixels.shape
elif hasattr(pixels, 'shape'):
h, w = pixels.shape[-2], pixels.shape[-1]
else:
w, h = p.width, p.height
resolution = f'{w}x{h}' if num_frames > 0 else None
summary = timer.process.summary(min_time=0.25, total=False).replace('=', ' ')
memory = shared.mem_mon.summary()
total_time = max(t_end - t0, 1e-6)
fps = f'{num_frames/total_time:.2f}'
its = f'{(steps)/total_time:.2f}'
shared.state.end(videojob)
progress.finish_task(task_id)
p.close()
log.info(f'Processed: fn="{video_file}" frames={num_frames} fps={fps} its={its} resolution={resolution} time={t_end-t0:.2f} timers={timer.process.dct()} memory={memstats.memory_stats()}')
yield video_file, f'LTX: Generation completed | File {video_file} | Frames {num_frames} | Resolution {resolution} | f/s {fps} | it/s {its} ' + f"<div class='performance'><p>{summary} {memory}</p></div>"
finally:
shared.opts.data['schedulers_dynamic_shift'] = orig_dynamic_shift
shared.opts.data['schedulers_shift'] = orig_sampler_shift
if shared.sd_model.scheduler is not orig_scheduler:
shared.sd_model.scheduler = orig_scheduler
if orig_default_scheduler is not None and shared.sd_model.default_scheduler is not orig_default_scheduler:
shared.sd_model.default_scheduler = orig_default_scheduler
if hasattr(shared.sd_model.scheduler, 'config') and hasattr(shared.sd_model.scheduler, 'register_to_config'):
if orig_use_dynamic_shifting is not None and hasattr(shared.sd_model.scheduler.config, 'use_dynamic_shifting'):
shared.sd_model.scheduler.config.use_dynamic_shifting = orig_use_dynamic_shifting
shared.sd_model.scheduler.register_to_config(use_dynamic_shifting=orig_use_dynamic_shifting)
if orig_flow_shift is not None and hasattr(shared.sd_model.scheduler.config, 'flow_shift'):
shared.sd_model.scheduler.config.flow_shift = orig_flow_shift
shared.sd_model.scheduler.register_to_config(flow_shift=orig_flow_shift)
log.debug(f'LTX: scheduler/opts restored dynamic_shift={orig_dynamic_shift} sampler_shift={orig_sampler_shift}')

View File

@ -88,6 +88,13 @@ def load_model(selected: models_def.Model):
selected.te = 'ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers'
selected.te_folder = 'text_encoder'
selected.te_revision = None
if selected.te_cls.__name__ == 'Gemma3ForConditionalGeneration' and shared.opts.te_shared_t5:
if 'SDNQ' in selected.name:
selected.te = 'OzzyGT/LTX-2.3-sdnq-dynamic-int4'
else:
selected.te = 'OzzyGT/LTX-2.3'
selected.te_folder = 'text_encoder'
selected.te_revision = None
log.debug(f'Video load: module=te repo="{selected.te or selected.repo}" folder="{selected.te_folder}" cls={selected.te_cls.__name__} quant={model_quant.get_quant_type(quant_args)} loader={_loader("transformers")}')
kwargs["text_encoder"] = selected.te_cls.from_pretrained(
@ -158,7 +165,11 @@ def load_model(selected: models_def.Model):
return msg
t1 = time.time()
if shared.sd_model.__class__.__name__.startswith("LTX"):
cls_name = shared.sd_model.__class__.__name__
# LTX 0.9.x is plain linear; pin use_dynamic_shifting=False against upstream config drift.
# LTX-2.x canonical is token-count-based dynamic shift (base_shift=0.95, max_shift=2.05);
# disabling it there would take the model off-distribution.
if cls_name.startswith("LTX") and not cls_name.startswith("LTX2"):
shared.sd_model.scheduler.config.use_dynamic_shifting = False
shared.sd_model.default_scheduler = copy.deepcopy(shared.sd_model.scheduler) if hasattr(shared.sd_model, "scheduler") else None
shared.sd_model.sd_checkpoint_info = sd_checkpoint.CheckpointInfo(selected.repo)

View File

@ -1,7 +1,7 @@
import os
import torch
import diffusers
from modules import shared, processing
from modules import shared, processing, devices
from modules.logger import log
from modules.video_models.models_def import Model
@ -17,6 +17,35 @@ def load_override(selected: Model, **load_args):
# LTX
if 'LTXVideo 0.9.5 I2V' in selected.name:
kwargs['vae'] = diffusers.AutoencoderKLLTXVideo.from_pretrained(selected.repo, subfolder="vae", torch_dtype=torch.float32, cache_dir=shared.opts.hfcache_dir, **load_args)
# OzzyGT LTX-2.3 mirrors ship connectors twice: sharded safetensors + .index.json plus a
# redundant unsharded diffusion_pytorch_model.safetensors of the same weights. Diffusers
# fetches both but loads sharded; skip the ~6.3 GB duplicate.
ltx2_redundant_connector_repos = {
'OzzyGT/LTX-2.3',
'OzzyGT/LTX-2.3-sdnq-dynamic-int4',
}
if selected.repo in ltx2_redundant_connector_repos:
kwargs['ignore_patterns'] = ['connectors/diffusion_pytorch_model.safetensors']
# LTX2TextConnectors weights are byte-identical across all 2.3 variants (verified by blob
# hash). Pre-load from a canonical repo so per-variant fetches skip connectors/ entirely.
# FP16 variants share OzzyGT/LTX-2.3; SDNQ variants share the pre-quantized mirror.
ltx2_connectors_cls = None
try:
from diffusers.pipelines.ltx2 import LTX2TextConnectors
ltx2_connectors_cls = LTX2TextConnectors
except ImportError as e:
log.warning(f'Video load: LTX2TextConnectors unavailable ({e}); dedup of LTX-2.3 connectors disabled')
if ('LTXVideo 2.3' in selected.name and shared.opts.te_shared_t5 and ltx2_connectors_cls is not None):
conn_repo = 'OzzyGT/LTX-2.3-sdnq-dynamic-int4' if 'SDNQ' in selected.name else 'OzzyGT/LTX-2.3'
log.debug(f'Video load: module=connectors repo="{conn_repo}" cls={ltx2_connectors_cls.__name__} shared={shared.opts.te_shared_t5}')
kwargs['connectors'] = ltx2_connectors_cls.from_pretrained(
conn_repo,
subfolder='connectors',
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
ignore_patterns=['connectors/diffusion_pytorch_model.safetensors'],
**load_args,
)
# WAN
if 'WAN 2.1 14B' in selected.name:
kwargs['vae'] = diffusers.AutoencoderKLWan.from_pretrained(selected.repo, subfolder="vae", torch_dtype=torch.float32, cache_dir=shared.opts.hfcache_dir, **load_args)