mirror of https://github.com/vladmandic/automatic
530 lines
40 KiB
Python
530 lines
40 KiB
Python
import os
|
|
import re
|
|
import copy
|
|
import inspect
|
|
import diffusers
|
|
from modules import shared, errors
|
|
from modules.logger import log
|
|
from modules.sd_samplers_common import SamplerData, flow_models
|
|
|
|
|
|
debug = os.environ.get('SD_SAMPLER_DEBUG', None) is not None
|
|
debug_log = log.trace if debug else lambda *args, **kwargs: None
|
|
scheduler_overrides = {} # set by sd_samplers.create_sampler() before constructor call
|
|
|
|
# Diffusers schedulers
|
|
try:
|
|
from diffusers import (
|
|
CMStochasticIterativeScheduler,
|
|
CosineDPMSolverMultistepScheduler,
|
|
DDIMScheduler,
|
|
DDPMScheduler,
|
|
DEISMultistepScheduler,
|
|
DPMSolverMultistepInverseScheduler,
|
|
DPMSolverMultistepScheduler,
|
|
DPMSolverSDEScheduler,
|
|
DPMSolverSinglestepScheduler,
|
|
EDMDPMSolverMultistepScheduler,
|
|
EDMEulerScheduler,
|
|
EulerAncestralDiscreteScheduler,
|
|
EulerDiscreteScheduler,
|
|
FlowMatchEulerDiscreteScheduler,
|
|
FlowMatchHeunDiscreteScheduler,
|
|
FlowMatchLCMScheduler,
|
|
HeunDiscreteScheduler,
|
|
IPNDMScheduler,
|
|
KDPM2AncestralDiscreteScheduler,
|
|
KDPM2DiscreteScheduler,
|
|
LCMScheduler,
|
|
LMSDiscreteScheduler,
|
|
PNDMScheduler,
|
|
SASolverScheduler,
|
|
UniPCMultistepScheduler,
|
|
CogVideoXDDIMScheduler,
|
|
DDIMParallelScheduler,
|
|
DDPMParallelScheduler,
|
|
TCDScheduler,
|
|
)
|
|
except Exception as e:
|
|
log.error(f'Sampler import: version={diffusers.__version__} error: {e}')
|
|
if os.environ.get('SD_SAMPLER_DEBUG', None) is not None:
|
|
errors.display(e, 'Samplers')
|
|
|
|
# SD.Next Schedulers
|
|
try:
|
|
# from modules.schedulers.scheduler_tcd import TCDScheduler # pylint: disable=ungrouped-imports
|
|
from modules.schedulers.scheduler_tdd import TDDScheduler # pylint: disable=ungrouped-imports
|
|
from modules.schedulers.scheduler_dc import DCSolverMultistepScheduler # pylint: disable=ungrouped-imports
|
|
from modules.schedulers.scheduler_vdm import VDMScheduler # pylint: disable=ungrouped-imports
|
|
from modules.schedulers.scheduler_dpm_flowmatch import FlowMatchDPMSolverMultistepScheduler # pylint: disable=ungrouped-imports
|
|
from modules.schedulers.scheduler_bdia import BDIA_DDIMScheduler # pylint: disable=ungrouped-imports
|
|
from modules.schedulers.scheduler_ufogen import UFOGenScheduler # pylint: disable=ungrouped-imports
|
|
from modules.schedulers.scheduler_unipc_flowmatch import FlowUniPCMultistepScheduler # pylint: disable=ungrouped-imports
|
|
from modules.schedulers.scheduler_flashflow import FlashFlowMatchEulerDiscreteScheduler # pylint: disable=ungrouped-imports
|
|
from modules.schedulers.perflow import PeRFlowScheduler # pylint: disable=ungrouped-imports
|
|
except Exception as e:
|
|
log.error(f'Sampler import: version={diffusers.__version__} error: {e}')
|
|
if os.environ.get('SD_SAMPLER_DEBUG', None) is not None:
|
|
errors.display(e, 'Samplers')
|
|
|
|
# Res4Lyf Schedulers
|
|
try:
|
|
from modules.res4lyf import (
|
|
ABNorsettScheduler,
|
|
CommonSigmaScheduler,
|
|
ETDRKScheduler,
|
|
LangevinDynamicsScheduler,
|
|
LawsonScheduler,
|
|
PECScheduler,
|
|
RESUnifiedScheduler,
|
|
RESSinglestepScheduler,
|
|
RESMultistepScheduler,
|
|
RESSinglestepSDEScheduler,
|
|
RiemannianFlowScheduler,
|
|
RESDEISMultistepScheduler,
|
|
LinearRKScheduler,
|
|
LobattoScheduler,
|
|
RadauIIAScheduler,
|
|
GaussLegendreScheduler,
|
|
RungeKutta44Scheduler,
|
|
RungeKutta57Scheduler,
|
|
RungeKutta67Scheduler,
|
|
SpecializedRKScheduler,
|
|
# RESMultistepSDEScheduler,
|
|
# BongTangentScheduler,
|
|
# SimpleExponentialScheduler,
|
|
)
|
|
except Exception as e:
|
|
log.error(f'Sampler import: version={diffusers.__version__} error: {e}')
|
|
if os.environ.get('SD_SAMPLER_DEBUG', None) is not None:
|
|
errors.display(e, 'Samplers')
|
|
|
|
config = {
|
|
# beta_start, beta_end are typically per-scheduler, but we don't want them as they should be taken from the model itself as those are values model was trained on
|
|
# prediction_type is ideally set in model as well, but it maybe needed that we do auto-detect of model type in the future
|
|
'All': { 'num_train_timesteps': 1000, 'beta_start': 0.0001, 'beta_end': 0.02, 'beta_schedule': 'linear', 'prediction_type': 'epsilon' },
|
|
'Res4Lyf': { 'timestep_spacing': 'linspace', "steps_offset": 0, "rescale_betas_zero_snr": False, "use_karras_sigmas": False, "use_exponential_sigmas": False, "use_beta_sigmas": False, "use_flow_sigmas": False, "shift": 1, "base_shift": 0.5, "max_shift": 1.15, "use_dynamic_shifting": False },
|
|
}
|
|
|
|
config.update({
|
|
'UniPC': { 'flow_shift': 1, 'predict_x0': True, 'sample_max_value': 1.0, 'solver_order': 2, 'solver_type': 'bh2', 'thresholding': False, 'use_beta_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_karras_sigmas': False, 'lower_order_final': True, 'timestep_spacing': 'linspace', 'final_sigmas_type': 'zero', 'rescale_betas_zero_snr': False },
|
|
'DDIM': { 'clip_sample': False, 'set_alpha_to_one': True, 'steps_offset': 0, 'clip_sample_range': 1.0, 'sample_max_value': 1.0, 'timestep_spacing': 'leading', 'rescale_betas_zero_snr': False, 'thresholding': False },
|
|
|
|
'Euler': { 'steps_offset': 0, 'interpolation_type': "linear", 'rescale_betas_zero_snr': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'use_beta_sigmas': False, 'use_exponential_sigmas': False, 'use_karras_sigmas': False },
|
|
'Euler a': { 'steps_offset': 0, 'rescale_betas_zero_snr': False, 'timestep_spacing': 'linspace' },
|
|
'Euler SGM': { 'steps_offset': 0, 'interpolation_type': "linear", 'rescale_betas_zero_snr': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'trailing', 'use_beta_sigmas': False, 'use_exponential_sigmas': False, 'use_karras_sigmas': False, 'prediction_type': "sample" },
|
|
'Euler EDM': { 'sigma_schedule': "karras" },
|
|
'Euler FlowMatch': { 'timestep_spacing': "linspace", 'shift': 1, 'use_dynamic_shifting': False, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'base_shift': 0.5, 'max_shift': 1.15 },
|
|
|
|
'DPM++': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 1 },
|
|
'DPM++ 2M': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 2 },
|
|
'DPM++ 3M': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 3 },
|
|
'DPM++ 1S': { 'solver_order': 2, 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'final_sigmas_type': 'sigma_min' },
|
|
'DPM++ SDE': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "sde-dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 1 },
|
|
'DPM++ 2M SDE': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "sde-dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 2 },
|
|
'DPM++ 2M EDM': { 'solver_order': 2, 'solver_type': 'midpoint', 'final_sigmas_type': 'zero', 'algorithm_type': 'dpmsolver++' },
|
|
'DPM++ Cosine': { 'solver_order': 2, 'sigma_schedule': "exponential", 'prediction_type': "v-prediction" },
|
|
'DPM SDE': { 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'noise_sampler_seed': None, 'timestep_spacing': 'linspace', 'steps_offset': 0, },
|
|
|
|
'DPM++ Inverse': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 1 },
|
|
'DPM++ 2M Inverse': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 2 },
|
|
'DPM++ 3M Inverse': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 3 },
|
|
|
|
'UniPC FlowMatch': { 'predict_x0': True, 'sample_max_value': 1.0, 'solver_order': 2, 'solver_type': 'bh2', 'thresholding': False, 'use_beta_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_karras_sigmas': False, 'lower_order_final': True, 'timestep_spacing': 'linspace', 'final_sigmas_type': 'zero', 'rescale_betas_zero_snr': False },
|
|
'DPM2 FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver2', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012, 'base_shift': 0.5, 'max_shift': 1.15 },
|
|
'DPM2a FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver2A', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012, 'base_shift': 0.5, 'max_shift': 1.15 },
|
|
'DPM2++ 2M FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2M', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012, 'base_shift': 0.5, 'max_shift': 1.15 },
|
|
'DPM2++ 2S FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2S', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012, 'base_shift': 0.5, 'max_shift': 1.15 },
|
|
'DPM2++ SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++sde', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012, 'base_shift': 0.5, 'max_shift': 1.15 },
|
|
'DPM2++ 2M SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2Msde', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012, 'base_shift': 0.5, 'max_shift': 1.15 },
|
|
'DPM2++ 3M SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 3, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++3Msde', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012, 'base_shift': 0.5, 'max_shift': 1.15 },
|
|
|
|
'Heun': { 'use_beta_sigmas': False, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'timestep_spacing': 'linspace' },
|
|
'Heun FlowMatch': { 'timestep_spacing': "linspace", 'shift': 1 },
|
|
'LCM FlowMatch': { 'beta_start': 0.00085, 'beta_end': 0.012, 'beta_schedule': "scaled_linear", 'set_alpha_to_one': True, 'rescale_betas_zero_snr': False, 'thresholding': False, 'timestep_spacing': 'linspace', 'base_shift': 0.5, 'max_shift': 1.15 },
|
|
|
|
'DEIS': { 'solver_order': 2, 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "deis", 'solver_type': "logrho", 'lower_order_final': True, 'timestep_spacing': 'linspace', 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False },
|
|
'SA Solver': {'predictor_order': 2, 'corrector_order': 2, 'thresholding': False, 'lower_order_final': True, 'use_karras_sigmas': False, 'use_flow_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'timestep_spacing': 'linspace'},
|
|
'DC Solver': { 'beta_start': 0.0001, 'beta_end': 0.02, 'solver_order': 2, 'prediction_type': "epsilon", 'thresholding': False, 'solver_type': 'bh2', 'lower_order_final': True, 'dc_order': 2, 'disable_corrector': [0] },
|
|
'VDM Solver': { 'clip_sample_range': 2.0, },
|
|
'TCD': { 'set_alpha_to_one': True, 'rescale_betas_zero_snr': False, 'beta_schedule': 'scaled_linear' },
|
|
'TDD': { },
|
|
'Flash FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'base_shift': 0.5, 'max_shift': 1.15 },
|
|
'PeRFlow': { 'prediction_type': 'ddim_eps' },
|
|
'UFOGen': { },
|
|
'BDIA DDIM': { 'clip_sample': False, 'set_alpha_to_one': True, 'steps_offset': 0, 'clip_sample_range': 1.0, 'sample_max_value': 1.0, 'timestep_spacing': 'leading', 'rescale_betas_zero_snr': False, 'thresholding': False, 'gamma': 1.0 },
|
|
|
|
'PNDM': { 'skip_prk_steps': False, 'set_alpha_to_one': False, 'steps_offset': 0, 'timestep_spacing': 'linspace' },
|
|
'IPNDM': { },
|
|
'DDPM': { 'variance_type': "fixed_small", 'clip_sample': False, 'thresholding': False, 'clip_sample_range': 1.0, 'sample_max_value': 1.0, 'timestep_spacing': 'linspace', 'rescale_betas_zero_snr': False },
|
|
'LMSD': { 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'timestep_spacing': 'linspace', 'steps_offset': 0 },
|
|
'KDPM2': { 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'steps_offset': 0, 'timestep_spacing': 'linspace' },
|
|
'KDPM2 a': { 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'steps_offset': 0, 'timestep_spacing': 'linspace' },
|
|
'CMSI': { },
|
|
'CogX DDIM': { 'beta_schedule': "scaled_linear", 'beta_start': 0.00085, 'beta_end': 0.012, 'set_alpha_to_one': True, 'rescale_betas_zero_snr': False },
|
|
'DDIM Parallel': {},
|
|
'DDPM Parallel': {},
|
|
|
|
# res4lyf
|
|
'ABNorsett 2M': { 'variant': 'abnorsett_2m', **config['Res4Lyf'] },
|
|
'ABNorsett 3M': { 'variant': 'abnorsett_3m', **config['Res4Lyf'] },
|
|
'ABNorsett 4M': { 'variant': 'abnorsett_4m', **config['Res4Lyf'] },
|
|
'Lawson 2S A': { 'variant': 'lawson2a_2s', **config['Res4Lyf'] },
|
|
'Lawson 2S B': { 'variant': 'lawson2b_2s', **config['Res4Lyf'] },
|
|
'Lawson 4S': { 'variant': 'lawson4_4s', **config['Res4Lyf'] },
|
|
'ETD-RK 2S': { 'variant': 'etdrk2_2s', **config['Res4Lyf'] },
|
|
'ETD-RK 3S A': { 'variant': 'etdrk3_a_3s', **config['Res4Lyf'] },
|
|
'ETD-RK 3S B': { 'variant': 'etdrk3_b_3s', **config['Res4Lyf'] },
|
|
'ETD-RK 4S A': { 'variant': 'etdrk4_4s', **config['Res4Lyf'] },
|
|
'ETD-RK 4S B': { 'variant': 'etdrk4_4s_alt', **config['Res4Lyf'] },
|
|
'RES-Unified 2M': { 'rk_type': 'res_2m', **config['Res4Lyf'] },
|
|
'RES-Unified 3M': { 'rk_type': 'res_3m', **config['Res4Lyf'] },
|
|
'RES-Unified 2S': { 'rk_type': 'res_2s', **config['Res4Lyf'] },
|
|
'RES-Unified 3S': { 'rk_type': 'res_3s', **config['Res4Lyf'] },
|
|
'RES-Singlestep 2S': { 'variant': 'res_2s', **config['Res4Lyf'] },
|
|
'RES-Singlestep 3S': { 'variant': 'res_3s', **config['Res4Lyf'] },
|
|
'RES-Multistep 2M': { 'variant': 'res_2m', **config['Res4Lyf'] },
|
|
'RES-Multistep 3M': { 'variant': 'res_3m', **config['Res4Lyf'] },
|
|
'RES-SDE 2S': { 'variant': 'res_2s', **config['Res4Lyf'] },
|
|
'RES-SDE 3S': { 'variant': 'res_3s', **config['Res4Lyf'] },
|
|
'DEIS-Multistep': { 'order': 2, **config['Res4Lyf'] },
|
|
'DEIS-Unified 1S': { 'rk_type': 'deis_1s', **config['Res4Lyf'] },
|
|
'DEIS-Unified 2M': { 'rk_type': 'deis_2m', **config['Res4Lyf'] },
|
|
'PEC 423': { 'variant': 'pec423_2h2s', **config['Res4Lyf'] },
|
|
'PEC 433': { 'variant': 'pec433_2h3s', **config['Res4Lyf'] },
|
|
'Sigmoid Sigma': { 'profile': 'sigmoid', **config['Res4Lyf'] },
|
|
'Sine Sigma': { 'profile': 'sine', **config['Res4Lyf'] },
|
|
'Easing Sigma': { 'profile': 'easing', **config['Res4Lyf'] },
|
|
'Arcsine Sigma': { 'profile': 'arcsine', **config['Res4Lyf'] },
|
|
'Smoothstep Sigma': { 'profile': 'smoothstep', **config['Res4Lyf'] },
|
|
'Langevin Dynamics': { **config['Res4Lyf'] },
|
|
'Euclidean Flow': { 'metric_type': 'euclidean', **config['Res4Lyf'] },
|
|
'Hyperbolic Flow': { 'metric_type': 'hyperbolic', **config['Res4Lyf'] },
|
|
'Spherical Flow': { 'metric_type': 'spherical', **config['Res4Lyf'] },
|
|
'Lorentzian Flow': { 'metric_type': 'lorentzian', **config['Res4Lyf'] },
|
|
'Linear-RK 2': { 'variant': 'rk2', **config['Res4Lyf'] },
|
|
'Linear-RK 3': { 'variant': 'rk3', **config['Res4Lyf'] },
|
|
'Linear-RK 4': { 'variant': 'rk4', **config['Res4Lyf'] },
|
|
'Linear-RK Euler': { 'variant': 'euler', **config['Res4Lyf'] },
|
|
'Linear-RK Heun': { 'variant': 'heun', **config['Res4Lyf'] },
|
|
'Linear-RK Ralston': { 'variant': 'ralston', **config['Res4Lyf'] },
|
|
'Lobatto 2': { 'variant': 'lobatto_iiia_2s', **config['Res4Lyf'] },
|
|
'Lobatto 3': { 'variant': 'lobatto_iiia_3s', **config['Res4Lyf'] },
|
|
'Lobatto 4': { 'variant': 'lobatto_iiia_4s', **config['Res4Lyf'] },
|
|
'Radau IIA 2': { 'variant': 'radau_iia_2s', **config['Res4Lyf'] },
|
|
'Radau IIA 3': { 'variant': 'radau_iia_3s', **config['Res4Lyf'] },
|
|
'Gauss-Legendre 2S': { 'variant': 'gauss-legendre_2s', **config['Res4Lyf'] },
|
|
'Gauss-Legendre 3S': { 'variant': 'gauss-legendre_3s', **config['Res4Lyf'] },
|
|
'Gauss-Legendre 4S': { 'variant': 'gauss-legendre_4s', **config['Res4Lyf'] },
|
|
'Runge-Kutta 4/4': { **config['Res4Lyf'] },
|
|
'Runge-Kutta 5/7': { **config['Res4Lyf'] },
|
|
'Runge-Kutta 6/7': { **config['Res4Lyf'] },
|
|
'Specialized-RK 3S': { 'variant': 'ssprk3_3s', **config['Res4Lyf'] },
|
|
'Specialized-RK 4S': { 'variant': 'ssprk4_4s', **config['Res4Lyf'] },
|
|
})
|
|
|
|
samplers_data_diffusers = [
|
|
SamplerData('Default', None, [], {}),
|
|
|
|
SamplerData('UniPC', lambda model: DiffusionSampler('UniPC', UniPCMultistepScheduler, model), [], {}),
|
|
SamplerData('DDIM', lambda model: DiffusionSampler('DDIM', DDIMScheduler, model), [], {}),
|
|
SamplerData('Euler', lambda model: DiffusionSampler('Euler', EulerDiscreteScheduler, model), [], {}),
|
|
SamplerData('Euler a', lambda model: DiffusionSampler('Euler a', EulerAncestralDiscreteScheduler, model), [], {}),
|
|
SamplerData('Euler SGM', lambda model: DiffusionSampler('Euler SGM', EulerDiscreteScheduler, model), [], {}),
|
|
SamplerData('Euler EDM', lambda model: DiffusionSampler('Euler EDM', EDMEulerScheduler, model), [], {}),
|
|
SamplerData('Euler FlowMatch', lambda model: DiffusionSampler('Euler FlowMatch', FlowMatchEulerDiscreteScheduler, model), [], {}),
|
|
|
|
SamplerData('DPM++', lambda model: DiffusionSampler('DPM++', DPMSolverMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM++ 2M', lambda model: DiffusionSampler('DPM++ 2M', DPMSolverMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM++ 3M', lambda model: DiffusionSampler('DPM++ 3M', DPMSolverMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM++ 1S', lambda model: DiffusionSampler('DPM++ 1S', DPMSolverSinglestepScheduler, model), [], {}),
|
|
SamplerData('DPM++ SDE', lambda model: DiffusionSampler('DPM++ SDE', DPMSolverMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM++ 2M SDE', lambda model: DiffusionSampler('DPM++ 2M SDE', DPMSolverMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM++ 2M EDM', lambda model: DiffusionSampler('DPM++ 2M EDM', EDMDPMSolverMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM++ Cosine', lambda model: DiffusionSampler('DPM++ 2M EDM', CosineDPMSolverMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM SDE', lambda model: DiffusionSampler('DPM SDE', DPMSolverSDEScheduler, model), [], {}),
|
|
|
|
SamplerData('DPM++ Inverse', lambda model: DiffusionSampler('DPM++ Inverse', DPMSolverMultistepInverseScheduler, model), [], {}),
|
|
SamplerData('DPM++ 2M Inverse', lambda model: DiffusionSampler('DPM++ 2M Inverse', DPMSolverMultistepInverseScheduler, model), [], {}),
|
|
SamplerData('DPM++ 3M Inverse', lambda model: DiffusionSampler('DPM++ 3M Inverse', DPMSolverMultistepInverseScheduler, model), [], {}),
|
|
|
|
SamplerData('UniPC FlowMatch', lambda model: DiffusionSampler('UniPC FlowMatch', FlowUniPCMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM2 FlowMatch', lambda model: DiffusionSampler('DPM2 FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM2a FlowMatch', lambda model: DiffusionSampler('DPM2a FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM2++ 2M FlowMatch', lambda model: DiffusionSampler('DPM2++ 2M FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM2++ 2S FlowMatch', lambda model: DiffusionSampler('DPM2++ 2S FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM2++ SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM2++ 2M SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ 2M SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
|
|
SamplerData('DPM2++ 3M SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ 3M SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
|
|
|
|
SamplerData('Heun', lambda model: DiffusionSampler('Heun', HeunDiscreteScheduler, model), [], {}),
|
|
SamplerData('Heun FlowMatch', lambda model: DiffusionSampler('Heun FlowMatch', FlowMatchHeunDiscreteScheduler, model), [], {}),
|
|
SamplerData('Flash FlowMatch', lambda model: DiffusionSampler('Flash FlowMatch', FlashFlowMatchEulerDiscreteScheduler, model), [], {}),
|
|
|
|
SamplerData('DEIS', lambda model: DiffusionSampler('DEIS', DEISMultistepScheduler, model), [], {}),
|
|
SamplerData('SA Solver', lambda model: DiffusionSampler('SA Solver', SASolverScheduler, model), [], {}),
|
|
SamplerData('DC Solver', lambda model: DiffusionSampler('DC Solver', DCSolverMultistepScheduler, model), [], {}),
|
|
|
|
SamplerData('DDPM', lambda model: DiffusionSampler('DDPM', DDPMScheduler, model), [], {}),
|
|
SamplerData('DDPM Parallel', lambda model: DiffusionSampler('DDPM Parallel', DDPMParallelScheduler, model), [], {}),
|
|
SamplerData('DDIM Parallel', lambda model: DiffusionSampler('DDIM Parallel', DDIMParallelScheduler, model), [], {}),
|
|
SamplerData('PNDM', lambda model: DiffusionSampler('PNDM', PNDMScheduler, model), [], {}),
|
|
SamplerData('IPNDM', lambda model: DiffusionSampler('IPNDM', IPNDMScheduler, model), [], {}),
|
|
SamplerData('LMSD', lambda model: DiffusionSampler('LMSD', LMSDiscreteScheduler, model), [], {}),
|
|
SamplerData('KDPM2', lambda model: DiffusionSampler('KDPM2', KDPM2DiscreteScheduler, model), [], {}),
|
|
SamplerData('KDPM2 a', lambda model: DiffusionSampler('KDPM2 a', KDPM2AncestralDiscreteScheduler, model), [], {}),
|
|
SamplerData('CMSI', lambda model: DiffusionSampler('CMSI', CMStochasticIterativeScheduler, model), [], {}),
|
|
|
|
SamplerData('VDM Solver', lambda model: DiffusionSampler('VDM Solver', VDMScheduler, model), [], {}),
|
|
SamplerData('BDIA DDIM', lambda model: DiffusionSampler('BDIA DDIM g=0', BDIA_DDIMScheduler, model), [], {}),
|
|
SamplerData('LCM', lambda model: DiffusionSampler('LCM', LCMScheduler, model), [], {}),
|
|
SamplerData('LCM FlowMatch', lambda model: DiffusionSampler('LCM FlowMatch', FlowMatchLCMScheduler, model), [], {}),
|
|
SamplerData('TCD', lambda model: DiffusionSampler('TCD', TCDScheduler, model), [], {}),
|
|
SamplerData('TDD', lambda model: DiffusionSampler('TDD', TDDScheduler, model), [], {}),
|
|
SamplerData('PeRFlow', lambda model: DiffusionSampler('PeRFlow', PeRFlowScheduler, model), [], {}),
|
|
SamplerData('UFOGen', lambda model: DiffusionSampler('UFOGen', UFOGenScheduler, model), [], {}),
|
|
SamplerData('CogX DDIM', lambda model: DiffusionSampler('CogX DDIM', CogVideoXDDIMScheduler, model), [], {}),
|
|
|
|
SamplerData('ABNorsett 2M', lambda model: DiffusionSampler('ABNorsett 2M', ABNorsettScheduler, model), [], {}),
|
|
SamplerData('ABNorsett 3M', lambda model: DiffusionSampler('ABNorsett 3M', ABNorsettScheduler, model), [], {}),
|
|
SamplerData('ABNorsett 4M', lambda model: DiffusionSampler('ABNorsett 4M', ABNorsettScheduler, model), [], {}),
|
|
SamplerData('Lawson 2S A', lambda model: DiffusionSampler('Lawson 2S A', LawsonScheduler, model), [], {}),
|
|
SamplerData('Lawson 2S B', lambda model: DiffusionSampler('Lawson 2S B', LawsonScheduler, model), [], {}),
|
|
SamplerData('Lawson 4S', lambda model: DiffusionSampler('Lawson 4S', LawsonScheduler, model), [], {}),
|
|
SamplerData('ETD-RK 2S', lambda model: DiffusionSampler('ETD-RK 2S', ETDRKScheduler, model), [], {}),
|
|
SamplerData('ETD-RK 3S A', lambda model: DiffusionSampler('ETD-RK 3S A', ETDRKScheduler, model), [], {}),
|
|
SamplerData('ETD-RK 3S B', lambda model: DiffusionSampler('ETD-RK 3S B', ETDRKScheduler, model), [], {}),
|
|
SamplerData('ETD-RK 4S A', lambda model: DiffusionSampler('ETD-RK 4S A', ETDRKScheduler, model), [], {}),
|
|
SamplerData('ETD-RK 4S B', lambda model: DiffusionSampler('ETD-RK 4S B', ETDRKScheduler, model), [], {}),
|
|
SamplerData('PEC 423', lambda model: DiffusionSampler('PEC 423', PECScheduler, model), [], {}),
|
|
SamplerData('PEC 433', lambda model: DiffusionSampler('PEC 433', PECScheduler, model), [], {}),
|
|
SamplerData('RES-Unified 2S', lambda model: DiffusionSampler('RES-Unified 2S', RESUnifiedScheduler, model), [], {}),
|
|
SamplerData('RES-Unified 3S', lambda model: DiffusionSampler('RES-Unified 3S', RESUnifiedScheduler, model), [], {}),
|
|
SamplerData('RES-Unified 2M', lambda model: DiffusionSampler('RES-Unified 2M', RESUnifiedScheduler, model), [], {}),
|
|
SamplerData('RES-Unified 3M', lambda model: DiffusionSampler('RES-Unified 3M', RESUnifiedScheduler, model), [], {}),
|
|
SamplerData('RES-Singlestep 2S', lambda model: DiffusionSampler('RES-Singlestep 2S', RESSinglestepScheduler, model), [], {}),
|
|
SamplerData('RES-Singlestep 3S', lambda model: DiffusionSampler('RES-Singlestep 3S', RESSinglestepScheduler, model), [], {}),
|
|
SamplerData('RES-Multistep 2M', lambda model: DiffusionSampler('RES-Multistep 2M', RESMultistepScheduler, model), [], {}),
|
|
SamplerData('RES-Multistep 3M', lambda model: DiffusionSampler('RES-Multistep 3M', RESMultistepScheduler, model), [], {}),
|
|
SamplerData('RES-SDE 2S', lambda model: DiffusionSampler('RES-SDE 2S', RESSinglestepSDEScheduler, model), [], {}),
|
|
SamplerData('RES-SDE 3S', lambda model: DiffusionSampler('RES-SDE 3S', RESSinglestepSDEScheduler, model), [], {}),
|
|
SamplerData('DEIS-Multistep', lambda model: DiffusionSampler('DEIS Multistep', RESDEISMultistepScheduler, model), [], {}),
|
|
SamplerData('DEIS-Unified 1S', lambda model: DiffusionSampler('DEIS-Unified 1S', RESUnifiedScheduler, model), [], {}),
|
|
SamplerData('DEIS-Unified 2M', lambda model: DiffusionSampler('DEIS-Unified 2M', RESUnifiedScheduler, model), [], {}),
|
|
SamplerData('Sigmoid Sigma', lambda model: DiffusionSampler('Sigmoid Sigma', CommonSigmaScheduler, model), [], {}),
|
|
SamplerData('Sine Sigma', lambda model: DiffusionSampler('Sine Sigma', CommonSigmaScheduler, model), [], {}),
|
|
SamplerData('Easing Sigma', lambda model: DiffusionSampler('Easing Sigma', CommonSigmaScheduler, model), [], {}),
|
|
SamplerData('Arcsine Sigma', lambda model: DiffusionSampler('Arcsine Sigma', CommonSigmaScheduler, model), [], {}),
|
|
SamplerData('Smoothstep Sigma', lambda model: DiffusionSampler('Smoothstep Sigma', CommonSigmaScheduler, model), [], {}),
|
|
SamplerData('Langevin Dynamics', lambda model: DiffusionSampler('Langevin Dynamics', LangevinDynamicsScheduler, model), [], {}),
|
|
SamplerData('Euclidean Flow', lambda model: DiffusionSampler('Euclidean Flow', RiemannianFlowScheduler, model), [], {}),
|
|
SamplerData('Hyperbolic Flow', lambda model: DiffusionSampler('Hyperbolic Flow', RiemannianFlowScheduler, model), [], {}),
|
|
SamplerData('Spherical Flow', lambda model: DiffusionSampler('Spherical Flow', RiemannianFlowScheduler, model), [], {}),
|
|
SamplerData('Lorentzian Flow', lambda model: DiffusionSampler('Lorentzian Flow', RiemannianFlowScheduler, model), [], {}),
|
|
SamplerData('Linear-RK 2', lambda model: DiffusionSampler('Linear-RK 2', LinearRKScheduler, model), [], {}),
|
|
SamplerData('Linear-RK 3', lambda model: DiffusionSampler('Linear-RK 3', LinearRKScheduler, model), [], {}),
|
|
SamplerData('Linear-RK 4', lambda model: DiffusionSampler('Linear-RK 4', LinearRKScheduler, model), [], {}),
|
|
SamplerData('Linear-RK Euler', lambda model: DiffusionSampler('Linear-RK Euler', LinearRKScheduler, model), [], {}),
|
|
SamplerData('Linear-RK Heun', lambda model: DiffusionSampler('Linear-RK Heun', LinearRKScheduler, model), [], {}),
|
|
SamplerData('Linear-RK Ralston', lambda model: DiffusionSampler('Linear-RK Ralston', LinearRKScheduler, model), [], {}),
|
|
SamplerData('Lobatto 2', lambda model: DiffusionSampler('Lobatto 2', LobattoScheduler, model), [], {}),
|
|
SamplerData('Lobatto 3', lambda model: DiffusionSampler('Lobatto 3', LobattoScheduler, model), [], {}),
|
|
SamplerData('Lobatto 4', lambda model: DiffusionSampler('Lobatto 4', LobattoScheduler, model), [], {}),
|
|
SamplerData('Radau IIA 2', lambda model: DiffusionSampler('Radau IIA 2', RadauIIAScheduler, model), [], {}),
|
|
SamplerData('Radau IIA 3', lambda model: DiffusionSampler('Radau IIA 2', RadauIIAScheduler, model), [], {}),
|
|
SamplerData('Radau IIA 4', lambda model: DiffusionSampler('Radau IIA 2', RadauIIAScheduler, model), [], {}),
|
|
SamplerData('Gauss-Legendre 2S', lambda model: DiffusionSampler('Gauss-Legendre 2S', GaussLegendreScheduler, model), [], {}),
|
|
SamplerData('Gauss-Legendre 3S', lambda model: DiffusionSampler('Gauss-Legendre 3S', GaussLegendreScheduler, model), [], {}),
|
|
SamplerData('Gauss-Legendre 4S', lambda model: DiffusionSampler('Gauss-Legendre 4S', GaussLegendreScheduler, model), [], {}),
|
|
SamplerData('Specialized-RK 3S', lambda model: DiffusionSampler('Specialized-RK 3S', SpecializedRKScheduler, model), [], {}),
|
|
SamplerData('Specialized-RK 4S', lambda model: DiffusionSampler('Specialized-RK 4S', SpecializedRKScheduler, model), [], {}),
|
|
SamplerData('Runge-Kutta 4/4', lambda model: DiffusionSampler('Runge-Kutta 4/4', RungeKutta44Scheduler, model), [], {}),
|
|
SamplerData('Runge-Kutta 5/7', lambda model: DiffusionSampler('Runge-Kutta 5/7', RungeKutta57Scheduler, model), [], {}),
|
|
SamplerData('Runge-Kutta 6/7', lambda model: DiffusionSampler('Runge-Kutta 6/7', RungeKutta67Scheduler, model), [], {}),
|
|
|
|
SamplerData('Same as primary', None, [], {}),
|
|
]
|
|
|
|
|
|
def get_sampler_compatibility(sd_model): # pylint: disable=unused-argument # TODO enso-required
|
|
return {}
|
|
|
|
|
|
def get_sampler_capability(): # TODO enso-required
|
|
return {}
|
|
|
|
|
|
def get_override(key, default=None):
|
|
if key in scheduler_overrides:
|
|
return scheduler_overrides[key]
|
|
return getattr(shared.opts, key, default)
|
|
|
|
|
|
class DiffusionSampler:
|
|
def __init__(self, name, constructor, model, **kwargs):
|
|
if name == 'Default':
|
|
return
|
|
self.name = name
|
|
self.config = {}
|
|
self.sampler = None
|
|
|
|
if model is not None and getattr(model, "default_scheduler", None) is None and (model is not None): # sanity check
|
|
model.default_scheduler = copy.deepcopy(model.scheduler)
|
|
for key, value in config.get('All', {}).items(): # apply global defaults
|
|
self.config[key] = value
|
|
debug_log(f'Sampler: all="{self.config}"')
|
|
if model is None:
|
|
orig_config = {}
|
|
elif hasattr(model.default_scheduler, 'scheduler_config'): # find model defaults
|
|
orig_config = model.default_scheduler.scheduler_config
|
|
else:
|
|
orig_config = model.default_scheduler.config
|
|
debug_log(f'Sampler: diffusers="{self.config}"')
|
|
debug_log(f'Sampler: original="{orig_config}"')
|
|
for key, value in orig_config.items(): # apply model defaults
|
|
if key in self.config:
|
|
self.config[key] = value
|
|
debug_log(f'Sampler: default="{self.config}"')
|
|
for key, value in config.get(name, {}).items(): # apply diffusers per-scheduler defaults
|
|
self.config[key] = value
|
|
for key, value in kwargs.items(): # apply user args, if any
|
|
if key in self.config:
|
|
self.config[key] = value
|
|
|
|
if get_override('schedulers_prediction_type') != 'default':
|
|
self.config['prediction_type'] = get_override('schedulers_prediction_type')
|
|
sched_beta = get_override('schedulers_beta_schedule')
|
|
if sched_beta != 'default':
|
|
if sched_beta == 'linear':
|
|
self.config['beta_schedule'] = 'linear'
|
|
elif sched_beta == 'scaled':
|
|
self.config['beta_schedule'] = 'scaled_linear'
|
|
elif sched_beta == 'cosine':
|
|
self.config['beta_schedule'] = 'squaredcos_cap_v2'
|
|
elif sched_beta == 'sigmoid':
|
|
self.config['beta_schedule'] = 'sigmoid'
|
|
|
|
timesteps = re.split(',| ', get_override('schedulers_timesteps'))
|
|
timesteps = [int(x) for x in timesteps if x.isdigit()]
|
|
sched_sigma = get_override('schedulers_sigma')
|
|
if len(timesteps) == 0:
|
|
if 'sigma_schedule' in self.config:
|
|
self.config['sigma_schedule'] = sched_sigma if sched_sigma != 'default' else None
|
|
if sched_sigma == 'default' and shared.sd_model_type in flow_models and 'use_flow_sigmas' in self.config:
|
|
self.config['use_flow_sigmas'] = True
|
|
elif sched_sigma == 'betas' and 'use_beta_sigmas' in self.config:
|
|
self.config['use_beta_sigmas'] = True
|
|
elif sched_sigma == 'karras' and 'use_karras_sigmas' in self.config:
|
|
self.config['use_karras_sigmas'] = True
|
|
elif sched_sigma == 'flowmatch' and 'use_flow_sigmas' in self.config:
|
|
self.config['use_flow_sigmas'] = True
|
|
elif sched_sigma == 'exponential' and 'use_exponential_sigmas' in self.config:
|
|
self.config['use_exponential_sigmas'] = True
|
|
elif sched_sigma == 'lambdas' and 'use_lu_lambdas' in self.config:
|
|
self.config['use_lu_lambdas'] = True
|
|
else:
|
|
pass # timesteps are set using set_timesteps in set_pipeline_args
|
|
|
|
if 'thresholding' in self.config:
|
|
self.config['thresholding'] = get_override('schedulers_use_thresholding')
|
|
if 'lower_order_final' in self.config:
|
|
self.config['lower_order_final'] = get_override('schedulers_use_loworder')
|
|
if 'solver_order' in self.config and int(get_override('schedulers_solver_order')) > 0:
|
|
self.config['solver_order'] = int(get_override('schedulers_solver_order'))
|
|
if 'predict_x0' in self.config:
|
|
self.config['solver_type'] = get_override('uni_pc_variant')
|
|
if 'beta_start' in self.config and get_override('schedulers_beta_start') > 0:
|
|
self.config['beta_start'] = get_override('schedulers_beta_start')
|
|
if 'beta_end' in self.config and get_override('schedulers_beta_end') > 0:
|
|
self.config['beta_end'] = get_override('schedulers_beta_end')
|
|
sched_shift = get_override('schedulers_shift')
|
|
if 'shift' in self.config:
|
|
self.config['shift'] = sched_shift if sched_shift > 0 else 3
|
|
if 'flow_shift' in self.config:
|
|
self.config['flow_shift'] = sched_shift if sched_shift > 0 else 3
|
|
if 'use_dynamic_shifting' in self.config:
|
|
self.config['use_dynamic_shifting'] = True if sched_shift == 0 else get_override('schedulers_dynamic_shift')
|
|
if 'base_shift' in self.config:
|
|
self.config['base_shift'] = get_override('schedulers_base_shift')
|
|
if 'max_shift' in self.config:
|
|
self.config['max_shift'] = get_override('schedulers_max_shift')
|
|
if 'use_beta_sigmas' in self.config and 'sigma_schedule' in self.config:
|
|
self.config['use_beta_sigmas'] = 'StableDiffusion3' in model.__class__.__name__
|
|
if 'rescale_betas_zero_snr' in self.config:
|
|
self.config['rescale_betas_zero_snr'] = get_override('schedulers_rescale_betas')
|
|
sched_ts_spacing = get_override('schedulers_timestep_spacing')
|
|
if 'timestep_spacing' in self.config and sched_ts_spacing != 'default' and sched_ts_spacing is not None:
|
|
self.config['timestep_spacing'] = sched_ts_spacing
|
|
if 'num_train_timesteps' in self.config:
|
|
self.config['num_train_timesteps'] = get_override('schedulers_timesteps_range')
|
|
if 'EDM' in name:
|
|
del self.config['beta_start']
|
|
del self.config['beta_end']
|
|
del self.config['beta_schedule']
|
|
if name in {'IPNDM', 'CMSI', 'VDM Solver'}:
|
|
del self.config['beta_start']
|
|
del self.config['beta_end']
|
|
del self.config['beta_schedule']
|
|
del self.config['prediction_type']
|
|
if 'prediction_type' in self.config and 'Flow' in name:
|
|
self.config['prediction_type'] = 'flow_prediction'
|
|
if 'SGM' in name:
|
|
self.config['timestep_spacing'] = 'trailing'
|
|
|
|
# validate all config params
|
|
signature = inspect.signature(constructor, follow_wrapped=True)
|
|
possible = signature.parameters.keys()
|
|
for key in self.config.copy().keys():
|
|
if key not in possible:
|
|
del self.config[key]
|
|
debug_log(f'Sampler: name="{name}"')
|
|
debug_log(f'Sampler: config={self.config}')
|
|
debug_log(f'Sampler: signature={possible}')
|
|
|
|
# finally create the new sampler
|
|
try:
|
|
sampler = constructor(**self.config)
|
|
except Exception as e:
|
|
log.error(f'Sampler: "{name}" {e}')
|
|
if debug:
|
|
errors.display(e, 'Samplers')
|
|
self.sampler = None
|
|
return
|
|
|
|
if self.config.get('prediction_type') == 'flow_prediction' and 'FlowMatch' not in constructor.__name__:
|
|
try:
|
|
cls_source = inspect.getsource(constructor)
|
|
if '"flow_prediction"' not in cls_source and "'flow_prediction'" not in cls_source:
|
|
log.warning(f'Sampler: "{name}" does not support flow_prediction')
|
|
self.sampler = None
|
|
return
|
|
except (TypeError, OSError):
|
|
pass
|
|
|
|
if hasattr(sampler, 'set_timesteps'):
|
|
accept_sigmas = "sigmas" in set(inspect.signature(sampler.set_timesteps).parameters.keys())
|
|
accepts_timesteps = "timesteps" in set(inspect.signature(sampler.set_timesteps).parameters.keys())
|
|
accept_scale_noise = hasattr(sampler, "scale_noise")
|
|
debug_log(f'Sampler: "{name}" sigmas={accept_sigmas} timesteps={accepts_timesteps}')
|
|
default_accept_sigmas = model is not None and hasattr(model.default_scheduler, 'set_timesteps') and "sigmas" in set(inspect.signature(model.default_scheduler.set_timesteps).parameters.keys())
|
|
default_accept_scale_noise = model is not None and hasattr(model.default_scheduler, "scale_noise")
|
|
if default_accept_sigmas and not accept_sigmas:
|
|
log.warning(f'Sampler: "{name}" does not accept sigmas')
|
|
self.sampler = None
|
|
return
|
|
if default_accept_scale_noise and not accept_scale_noise:
|
|
log.warning(f'Sampler: "{name}" does not implement scale noise')
|
|
self.sampler = None
|
|
return
|
|
|
|
# monkey-patch to allow sdxl pipeline to execute flowmatch samplers
|
|
if not hasattr(sampler, 'scale_model_input'):
|
|
sampler.scale_model_input = lambda x, _y: x
|
|
if not hasattr(sampler, 'init_noise_sigma'):
|
|
sampler.init_noise_sigma = 1.0
|
|
|
|
self.sampler = sampler
|
|
|
|
# log.debug_log(f'Sampler: class="{self.sampler.__class__.__name__}" config={self.sampler.config}')
|
|
self.sampler.name = name
|