automatic/cli/test-schedulers.py

261 lines
12 KiB
Python

import os
import sys
import time
import numpy as np
import torch
# Ensure we can import modules
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from modules.errors import log
from modules.res4lyf import (
BASE, SIMPLE, VARIANTS,
RESUnifiedScheduler, RESMultistepScheduler, RESDEISMultistepScheduler,
ETDRKScheduler, LawsonScheduler, ABNorsettScheduler, PECScheduler,
RiemannianFlowScheduler, RESSinglestepScheduler, RESSinglestepSDEScheduler,
RESMultistepSDEScheduler, SimpleExponentialScheduler, LinearRKScheduler,
LobattoScheduler, GaussLegendreScheduler, RungeKutta44Scheduler,
RungeKutta57Scheduler, RungeKutta67Scheduler, SpecializedRKScheduler,
BongTangentScheduler, CommonSigmaScheduler, RadauIIAScheduler,
LangevinDynamicsScheduler
)
from modules.schedulers.scheduler_vdm import VDMScheduler
from modules.schedulers.scheduler_unipc_flowmatch import FlowUniPCMultistepScheduler
from modules.schedulers.scheduler_ufogen import UFOGenScheduler
from modules.schedulers.scheduler_tdd import TDDScheduler
from modules.schedulers.scheduler_tcd import TCDScheduler
from modules.schedulers.scheduler_flashflow import FlashFlowMatchEulerDiscreteScheduler
from modules.schedulers.scheduler_dpm_flowmatch import FlowMatchDPMSolverMultistepScheduler
from modules.schedulers.scheduler_dc import DCSolverMultistepScheduler
from modules.schedulers.scheduler_bdia import BDIA_DDIMScheduler
def test_scheduler(name, scheduler_class, config):
try:
scheduler = scheduler_class(**config)
except Exception as e:
log.error(f'scheduler="{name}" cls={scheduler_class} config={config} error="Init failed: {e}"')
return False
num_steps = 20
scheduler.set_timesteps(num_steps)
sample = torch.randn((1, 4, 64, 64))
has_changed = False
t0 = time.time()
messages = []
try:
for i, t in enumerate(scheduler.timesteps):
# Simulate model output (noise or x0 or v), Using random noise for stability check
model_output = torch.randn_like(sample)
# Scaling Check
step_idx = scheduler.step_index if hasattr(scheduler, "step_index") and scheduler.step_index is not None else i
# Clamp index
if hasattr(scheduler, 'sigmas'):
step_idx = min(step_idx, len(scheduler.sigmas) - 1)
sigma = scheduler.sigmas[step_idx]
else:
sigma = torch.tensor(1.0) # Dummy for non-sigma schedulers
# Re-introduce scaling calculation first
scaled_sample = scheduler.scale_model_input(sample, t)
if config.get("prediction_type") == "flow_prediction" or name in ["UFOGenScheduler", "TDDScheduler", "TCDScheduler", "BDIA_DDIMScheduler", "DCSolverMultistepScheduler"]:
# Some new schedulers don't use K-diffusion scaling
expected_scale = 1.0
else:
expected_scale = 1.0 / ((sigma**2 + 1) ** 0.5)
# Simple check with loose tolerance due to float precision
expected_scaled_sample = sample * expected_scale
if not torch.allclose(scaled_sample, expected_scaled_sample, atol=1e-4):
# If failed, double check if it's just 'sample' (no scaling)
if torch.allclose(scaled_sample, sample, atol=1e-4):
messages.append('warning="scaling is identity"')
else:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} expected={expected_scale} error="scaling mismatch"')
return False
if torch.isnan(scaled_sample).any():
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="NaN in scaled_sample"')
return False
if torch.isinf(scaled_sample).any():
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="Inf in scaled_sample"')
return False
output = scheduler.step(model_output, t, sample)
# Shape and Dtype check
if output.prev_sample.shape != sample.shape:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="Shape mismatch: {output.prev_sample.shape} vs {sample.shape}"')
return False
if output.prev_sample.dtype != sample.dtype:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="Dtype mismatch: {output.prev_sample.dtype} vs {sample.dtype}"')
return False
# Update check: Did the sample change?
if not torch.equal(sample, output.prev_sample):
has_changed = True
# Sample Evolution Check
step_diff = (sample - output.prev_sample).abs().mean().item()
if step_diff < 1e-6:
messages.append(f'warning="minimal sample change: {step_diff}"')
sample = output.prev_sample
if torch.isnan(sample).any():
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="NaN in sample"')
return False
if torch.isinf(sample).any():
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="Inf in sample"')
return False
# Divergence check
if sample.abs().max() > 1e10:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="divergence detected"')
return False
# External check for Sigma Monotonicity
if hasattr(scheduler, 'sigmas'):
sigmas = scheduler.sigmas.cpu().numpy()
if len(sigmas) > 1:
diffs = np.diff(sigmas) # Check if potentially monotonic decreasing (standard) OR increasing (some flow/inverse setups). We allow flat sections (diff=0) hence 1e-6 slack
is_monotonic_decreasing = np.all(diffs <= 1e-6)
is_monotonic_increasing = np.all(diffs >= -1e-6)
if not (is_monotonic_decreasing or is_monotonic_increasing):
messages.append('warning="sigmas are not monotonic"')
except Exception as e:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} exception: {e}')
import traceback
traceback.print_exc()
return False
if not has_changed:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} error="sample never changed"')
return False
final_std = sample.std().item()
if final_std > 50.0 or final_std < 0.1:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} std={final_std} error="variance drift"')
t1 = time.time()
messages = list(set(messages))
log.info(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} time={t1-t0} messages={messages}')
return True
def run_tests():
prediction_types = ["epsilon", "v_prediction", "sample"] # flow_prediction is special, usually requires flow sigmas or specific setup, checking standard ones first
# Test BASE schedulers with their specific parameters
log.warning('type="base"')
for name, cls in BASE:
configs = []
# prediction_types
for pt in prediction_types:
configs.append({"prediction_type": pt})
# Specific params for specific classes
if cls == RESUnifiedScheduler:
rk_types = ["res_2m", "res_3m", "res_2s", "res_3s", "res_5s", "res_6s", "deis_1s", "deis_2m", "deis_3m"]
for rk in rk_types:
for pt in prediction_types:
configs.append({"rk_type": rk, "prediction_type": pt})
elif cls == RESMultistepScheduler:
variants = ["res_2m", "res_3m", "deis_2m", "deis_3m"]
for v in variants:
for pt in prediction_types:
configs.append({"variant": v, "prediction_type": pt})
elif cls == RESDEISMultistepScheduler:
for order in range(1, 6):
for pt in prediction_types:
configs.append({"solver_order": order, "prediction_type": pt})
elif cls == ETDRKScheduler:
variants = ["etdrk2_2s", "etdrk3_a_3s", "etdrk3_b_3s", "etdrk4_4s", "etdrk4_4s_alt"]
for v in variants:
for pt in prediction_types:
configs.append({"variant": v, "prediction_type": pt})
elif cls == LawsonScheduler:
variants = ["lawson2a_2s", "lawson2b_2s", "lawson4_4s"]
for v in variants:
for pt in prediction_types:
configs.append({"variant": v, "prediction_type": pt})
elif cls == ABNorsettScheduler:
variants = ["abnorsett_2m", "abnorsett_3m", "abnorsett_4m"]
for v in variants:
for pt in prediction_types:
configs.append({"variant": v, "prediction_type": pt})
elif cls == PECScheduler:
variants = ["pec423_2h2s", "pec433_2h3s"]
for v in variants:
for pt in prediction_types:
configs.append({"variant": v, "prediction_type": pt})
elif cls == RiemannianFlowScheduler:
metrics = ["euclidean", "hyperbolic", "spherical", "lorentzian"]
for m in metrics:
configs.append({"metric_type": m, "prediction_type": "epsilon"}) # Flow usually uses v or raw, but epsilon check matches others
if not configs:
for pt in prediction_types:
configs.append({"prediction_type": pt})
for conf in configs:
test_scheduler(name, cls, conf)
log.warning('type="simple"')
for name, cls in SIMPLE:
for pt in prediction_types:
test_scheduler(name, cls, {"prediction_type": pt})
log.warning('type="variants"')
for name, cls in VARIANTS:
# these classes preset their variants/rk_types in __init__ so we just test prediction types
for pt in prediction_types:
test_scheduler(name, cls, {"prediction_type": pt})
# Extra robustness check: Flow Prediction Type
log.warning('type="flow"')
flow_schedulers = [
# res4lyf schedulers
RESUnifiedScheduler, RESMultistepScheduler, ABNorsettScheduler,
RESSinglestepScheduler, RESSinglestepSDEScheduler, RESDEISMultistepScheduler,
RESMultistepSDEScheduler, ETDRKScheduler, LawsonScheduler, PECScheduler,
SimpleExponentialScheduler, LinearRKScheduler, LobattoScheduler,
GaussLegendreScheduler, RungeKutta44Scheduler, RungeKutta57Scheduler,
RungeKutta67Scheduler, SpecializedRKScheduler, BongTangentScheduler,
CommonSigmaScheduler, RadauIIAScheduler, LangevinDynamicsScheduler,
RiemannianFlowScheduler,
# sdnext schedulers
FlowUniPCMultistepScheduler, FlashFlowMatchEulerDiscreteScheduler, FlowMatchDPMSolverMultistepScheduler,
]
for cls in flow_schedulers:
test_scheduler(cls.__name__, cls, {"prediction_type": "flow_prediction", "use_flow_sigmas": True})
log.warning('type="sdnext"')
extended_schedulers = [
VDMScheduler,
UFOGenScheduler,
TDDScheduler,
TCDScheduler,
DCSolverMultistepScheduler,
BDIA_DDIMScheduler
]
for prediction_type in ["epsilon", "v_prediction", "sample"]:
for cls in extended_schedulers:
test_scheduler(cls.__name__, cls, {"prediction_type": prediction_type})
if __name__ == "__main__":
run_tests()