Simplified and added samplers (#125)

master
Gabriel 2024-06-06 10:40:36 +02:00 committed by GitHub
parent ff9af8ade4
commit e206c36452
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 17 additions and 60 deletions

View File

@ -222,67 +222,24 @@ class StableHorde:
samplers = [
SamplerData(
"Euler a Karras",
lambda model, funcname="sample_euler_ancestral": KDiffusionSampler(
funcname, model
),
["k_euler_a_ka"],
name,
lambda model, fn=func: KDiffusionSampler(fn, model),
[alias],
{"scheduler": "karras"},
),
SamplerData(
"Euler Karras",
lambda model, funcname="sample_euler": KDiffusionSampler(
funcname, model
),
["k_euler_ka"],
{"scheduler": "karras"},
),
SamplerData(
"Heun Karras",
lambda model, funcname="sample_heun": KDiffusionSampler(
funcname, model
),
["k_heun_ka"],
{"scheduler": "karras"},
),
SamplerData(
"DPM adaptive Karras",
lambda model, funcname="sample_dpm_adaptive": KDiffusionSampler(
funcname, model
),
["k_dpm_ad_ka"],
{"scheduler": "karras"},
),
SamplerData(
"DPM fast Karras",
lambda model, funcname="sample_dpm_fast": KDiffusionSampler(
funcname, model
),
["k_dpm_fast_ka"],
{"scheduler": "karras"},
),
SamplerData(
"LMS Karras",
lambda model, funcname="sample_lms": KDiffusionSampler(funcname, model),
["k_lms_ka"],
{"scheduler": "karras"},
),
SamplerData(
"DPM++ SDE Karras",
lambda model, funcname="sample_dpmpp_sde": KDiffusionSampler(
funcname, model
),
["k_dpmpp_sde_ka"],
{"scheduler": "karras"},
),
SamplerData(
"DPM++ 2S a Karras",
lambda model, funcname="sample_dpmpp_2s_ancestral": KDiffusionSampler(
funcname, model
),
["k_dpmpp_2s_a_ka"],
{"scheduler": "karras"},
),
)
for name, func, alias in [
("Euler a Karras", "sample_euler_ancestral", "k_euler_a_ka"),
("Euler Karras", "sample_euler", "k_euler_ka"),
("LMS Karras", "sample_lms", "k_lms_ka"),
("Heun Karras", "sample_heun", "k_heun_ka"),
("DPM2 Karras", "sample_dpm_2", "k_dpm_2_ka"),
("DPM2 a Karras", "sample_dpm_2_ancestral", "k_dpm_2_a_ka"),
("DPM++ 2S a Karras", "sample_dpmpp_2s_ancestral", "k_dpmpp_2s_a_ka"),
("DPM++ 2M Karras", "sample_dpmpp_2m", "k_dpmpp_2m_ka"),
("DPM++ SDE Karras", "sample_dpmpp_sde", "k_dpmpp_sde_ka"),
("DPM fast Karras", "sample_dpm_fast", "k_dpm_fast_ka"),
("DPM adaptive Karras", "sample_dpm_adaptive", "k_dpm_ad_ka"),
]
]
sd_samplers.samplers.extend(samplers)
sd_samplers.samplers_for_img2img.extend(samplers)