automatic/modules/res4lyf/variants.py

398 lines
11 KiB
Python

from .abnorsett_scheduler import ABNorsettScheduler
from .common_sigma_scheduler import CommonSigmaScheduler
from .deis_scheduler_alt import RESDEISMultistepScheduler
from .etdrk_scheduler import ETDRKScheduler
from .gauss_legendre_scheduler import GaussLegendreScheduler
from .lawson_scheduler import LawsonScheduler
from .linear_rk_scheduler import LinearRKScheduler
from .lobatto_scheduler import LobattoScheduler
from .pec_scheduler import PECScheduler
from .radau_iia_scheduler import RadauIIAScheduler
from .res_multistep_scheduler import RESMultistepScheduler
from .res_multistep_sde_scheduler import RESMultistepSDEScheduler
from .res_singlestep_scheduler import RESSinglestepScheduler
from .res_singlestep_sde_scheduler import RESSinglestepSDEScheduler
from .res_unified_scheduler import RESUnifiedScheduler
from .riemannian_flow_scheduler import RiemannianFlowScheduler
# RES Unified Variants
"""
Supports RES 2M, 3M, 2S, 3S, 5S, 6S
Supports DEIS 1S, 2M, 3M
"""
class RESUnified2MScheduler(RESUnifiedScheduler):
def __init__(self, **kwargs):
kwargs["rk_type"] = "res_2m"
super().__init__(**kwargs)
class RESUnified3MScheduler(RESUnifiedScheduler):
def __init__(self, **kwargs):
kwargs["rk_type"] = "res_3m"
super().__init__(**kwargs)
class RESUnified2SScheduler(RESUnifiedScheduler):
def __init__(self, **kwargs):
kwargs["rk_type"] = "res_2s"
super().__init__(**kwargs)
class RESUnified3SScheduler(RESUnifiedScheduler):
def __init__(self, **kwargs):
kwargs["rk_type"] = "res_3s"
super().__init__(**kwargs)
class RESUnified5SScheduler(RESUnifiedScheduler):
def __init__(self, **kwargs):
kwargs["rk_type"] = "res_5s"
super().__init__(**kwargs)
class RESUnified6SScheduler(RESUnifiedScheduler):
def __init__(self, **kwargs):
kwargs["rk_type"] = "res_6s"
super().__init__(**kwargs)
class DEISUnified1SScheduler(RESUnifiedScheduler):
def __init__(self, **kwargs):
kwargs["rk_type"] = "deis_1s"
super().__init__(**kwargs)
class DEISUnified2MScheduler(RESUnifiedScheduler):
def __init__(self, **kwargs):
kwargs["rk_type"] = "deis_2m"
super().__init__(**kwargs)
class DEISUnified3MScheduler(RESUnifiedScheduler):
def __init__(self, **kwargs):
kwargs["rk_type"] = "deis_3m"
super().__init__(**kwargs)
# RES Multistep Variants
class RES2MScheduler(RESMultistepScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "res_2m"
super().__init__(**kwargs)
class RES3MScheduler(RESMultistepScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "res_3m"
super().__init__(**kwargs)
class DEIS2MScheduler(RESMultistepScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "deis_2m"
super().__init__(**kwargs)
class DEIS3MScheduler(RESMultistepScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "deis_3m"
super().__init__(**kwargs)
# RES Multistep SDE Variants
class RES2MSDEScheduler(RESMultistepSDEScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "res_2m"
super().__init__(**kwargs)
class RES3MSDEScheduler(RESMultistepSDEScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "res_3m"
super().__init__(**kwargs)
# RES Singlestep (Multistage) Variants
class RES2SScheduler(RESSinglestepScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "res_2s"
super().__init__(**kwargs)
class RES3SScheduler(RESSinglestepScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "res_3s"
super().__init__(**kwargs)
class RES5SScheduler(RESSinglestepScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "res_5s"
super().__init__(**kwargs)
class RES6SScheduler(RESSinglestepScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "res_6s"
super().__init__(**kwargs)
# RES Singlestep SDE Variants
class RES2SSDEScheduler(RESSinglestepSDEScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "res_2s"
super().__init__(**kwargs)
class RES3SSDEScheduler(RESSinglestepSDEScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "res_3s"
super().__init__(**kwargs)
class RES5SSDEScheduler(RESSinglestepSDEScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "res_5s"
super().__init__(**kwargs)
class RES6SSDEScheduler(RESSinglestepSDEScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "res_6s"
super().__init__(**kwargs)
# ETDRK Variants
class ETDRK2Scheduler(ETDRKScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "etdrk2_2s"
super().__init__(**kwargs)
class ETDRK3AScheduler(ETDRKScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "etdrk3_a_3s"
super().__init__(**kwargs)
class ETDRK3BScheduler(ETDRKScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "etdrk3_b_3s"
super().__init__(**kwargs)
class ETDRK4Scheduler(ETDRKScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "etdrk4_4s"
super().__init__(**kwargs)
class ETDRK4AltScheduler(ETDRKScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "etdrk4_4s_alt"
super().__init__(**kwargs)
# Lawson Variants
class Lawson2AScheduler(LawsonScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "lawson2a_2s"
super().__init__(**kwargs)
class Lawson2BScheduler(LawsonScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "lawson2b_2s"
super().__init__(**kwargs)
class Lawson4Scheduler(LawsonScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "lawson4_4s"
super().__init__(**kwargs)
# ABNorsett Variants
class ABNorsett2MScheduler(ABNorsettScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "abnorsett_2m"
super().__init__(**kwargs)
class ABNorsett3MScheduler(ABNorsettScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "abnorsett_3m"
super().__init__(**kwargs)
class ABNorsett4MScheduler(ABNorsettScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "abnorsett_4m"
super().__init__(**kwargs)
# PEC Variants
class PEC2H2SScheduler(PECScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "pec423_2h2s"
super().__init__(**kwargs)
class PEC2H3SScheduler(PECScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "pec433_2h3s"
super().__init__(**kwargs)
# Riemannian Flow Variants
class FlowEuclideanScheduler(RiemannianFlowScheduler):
def __init__(self, **kwargs):
kwargs["metric_type"] = "euclidean"
super().__init__(**kwargs)
class FlowHyperbolicScheduler(RiemannianFlowScheduler):
def __init__(self, **kwargs):
kwargs["metric_type"] = "hyperbolic"
super().__init__(**kwargs)
class FlowSphericalScheduler(RiemannianFlowScheduler):
def __init__(self, **kwargs):
kwargs["metric_type"] = "spherical"
super().__init__(**kwargs)
class FlowLorentzianScheduler(RiemannianFlowScheduler):
def __init__(self, **kwargs):
kwargs["metric_type"] = "lorentzian"
super().__init__(**kwargs)
# Common Sigma Variants
class SigmaSigmoidScheduler(CommonSigmaScheduler):
def __init__(self, **kwargs):
kwargs["profile"] = "sigmoid"
super().__init__(**kwargs)
class SigmaSineScheduler(CommonSigmaScheduler):
def __init__(self, **kwargs):
kwargs["profile"] = "sine"
super().__init__(**kwargs)
class SigmaEasingScheduler(CommonSigmaScheduler):
def __init__(self, **kwargs):
kwargs["profile"] = "easing"
super().__init__(**kwargs)
class SigmaArcsineScheduler(CommonSigmaScheduler):
def __init__(self, **kwargs):
kwargs["profile"] = "arcsine"
super().__init__(**kwargs)
class SigmaSmoothScheduler(CommonSigmaScheduler):
def __init__(self, **kwargs):
kwargs["profile"] = "smoothstep"
super().__init__(**kwargs)
## DEIS Multistep Variants
class DEIS1MultistepScheduler(RESDEISMultistepScheduler):
def __init__(self, **kwargs):
kwargs["solver_order"] = 1
super().__init__(**kwargs)
class DEIS2MultistepScheduler(RESDEISMultistepScheduler):
def __init__(self, **kwargs):
kwargs["solver_order"] = 2
super().__init__(**kwargs)
class DEIS3MultistepScheduler(RESDEISMultistepScheduler):
def __init__(self, **kwargs):
kwargs["solver_order"] = 3
super().__init__(**kwargs)
## Linear RK Variants
class LinearRKEulerScheduler(LinearRKScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "euler"
super().__init__(**kwargs)
class LinearRKHeunScheduler(LinearRKScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "heun"
super().__init__(**kwargs)
class LinearRK2Scheduler(LinearRKScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "rk2"
super().__init__(**kwargs)
class LinearRK3Scheduler(LinearRKScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "rk3"
super().__init__(**kwargs)
class LinearRK4Scheduler(LinearRKScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "rk4"
super().__init__(**kwargs)
class LinearRKRalsstonScheduler(LinearRKScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "ralston"
super().__init__(**kwargs)
class LinearRKMidpointScheduler(LinearRKScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "midpoint"
super().__init__(**kwargs)
## Lobatto Variants
class Lobatto2Scheduler(LobattoScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "lobatto_iiia_2s"
super().__init__(**kwargs)
class Lobatto3Scheduler(LobattoScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "lobatto_iiia_3s"
super().__init__(**kwargs)
class Lobatto4Scheduler(LobattoScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "lobatto_iiia_4s"
super().__init__(**kwargs)
## Radau IIA Variants
class RadauIIA2Scheduler(RadauIIAScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "radau_iia_2s"
super().__init__(**kwargs)
class RadauIIA3Scheduler(RadauIIAScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "radau_iia_3s"
super().__init__(**kwargs)
## Gauss Legendre Variants
class GaussLegendre2SScheduler(GaussLegendreScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "gauss-legendre_2s"
super().__init__(**kwargs)
class GaussLegendre3SScheduler(GaussLegendreScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "gauss-legendre_3s"
super().__init__(**kwargs)
class GaussLegendre4SScheduler(GaussLegendreScheduler):
def __init__(self, **kwargs):
kwargs["variant"] = "gauss-legendre_4s"
super().__init__(**kwargs)