Merge pull request #968 from rhyswynn/automatic1111-webui

Update to split sampler and scheduler selections
automatic1111-webui
Robin Fernandes 2024-05-15 10:05:36 +10:00 committed by GitHub
commit 91df0fa852
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 51 additions and 28 deletions

View File

@ -22,7 +22,7 @@ from types import SimpleNamespace
import modules.paths as ph
import modules.shared as sh
from modules.processing import get_fixed_seed
from .defaults import get_guided_imgs_default_json, mask_fill_choices, get_samplers_list
from .defaults import get_guided_imgs_default_json, mask_fill_choices, get_samplers_list, get_schedulers_list
from .deforum_controlnet import controlnet_component_names
from .general_utils import get_os, substitute_placeholders
@ -766,6 +766,12 @@ def DeforumArgs():
"choices": get_samplers_list().values(),
"value": "Euler a",
},
"scheduler": {
"label": "Scheduler",
"type": "dropdown",
"choices": get_schedulers_list().values(),
"value": "Automatic",
},
"steps": {
"label": "Steps",
"type": "slider",

View File

@ -24,27 +24,27 @@ def get_samplers_list():
'dpm2 a': 'DPM2 a',
'dpm++ 2s a': 'DPM++ 2S a',
'dpm++ 2m': 'DPM++ 2M',
'dpm++ 2m sde': 'DPM++ 2M SDE',
'dpm++ sde': 'DPM++ SDE',
'dpm++ 2m sde karras': 'DPM++ 2M SDE Karras',
'dpm fast': 'DPM fast',
'dpm adaptive': 'DPM adaptive',
'lms karras': 'LMS Karras',
'dpm2 karras': 'DPM2 Karras',
'dpm2 a karras': 'DPM2 a Karras',
'dpm++ 2s a karras': 'DPM++ 2S a Karras',
'dpm++ 2m karras': 'DPM++ 2M Karras',
'dpm++ sde karras': 'DPM++ SDE Karras',
'dpm++ 2m sde exponential': 'DPM++ 2M SDE Exponential',
'dpm++ 2m sde heun': 'DPM++ 2M SDE Heun',
'dpm++ 2m sde heun karras': 'DPM++ 2M SDE Heun Karras',
'dpm++ 2m sde Heun Exponential': 'DPM++ 2M SDE Heun Exponential',
'dpm++ 3m sde': 'DPM++ 3M SDE',
'dpm++ 3m sde karras': 'DPM++ 3M SDE Karras',
'dpm++ 3m sde exponential': 'DPM++ 3M SDE Exponential',
'ddim': 'DDIM',
'plms': 'PLMS',
'unipc': 'UniPC',
'restart': 'Restart'
'restart': 'Restart',
'lcm': 'LCM'
}
def get_schedulers_list():
return {
'automatic': 'Automatic',
'uniform': 'Uniform',
'karras': 'Karras',
'exponential': 'Exponential',
'polyexponential': 'Polyexponential',
'sgm uniform': 'SGM Uniform'
}
def DeforumAnimPrompts():

View File

@ -28,7 +28,7 @@ from .prompt import split_weighted_subprompts
from .load_images import load_img, prepare_mask, check_mask_for_errors
from .webui_sd_pipeline import get_webui_sd_pipeline
from .rich import console
from .defaults import get_samplers_list
from .defaults import get_samplers_list, get_schedulers_list
from .prompt import check_is_number
from .opts_overrider import A1111OptionsOverrider
import cv2
@ -70,14 +70,14 @@ def pairwise_repl(iterable):
next(b, None)
return zip(a, b)
def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None, scheduler_name=None):
if state.interrupted:
return None
if args.reroll_blank_frames == 'ignore':
return generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
return generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name)
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name)
if caught_vae_exception or not image.getbbox():
patience = args.reroll_patience
@ -86,7 +86,7 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada
while caught_vae_exception or not image.getbbox():
print("Rerolling with +1 seed...")
args.seed += 1
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name=None)
patience -= 1
if patience == 0:
print("Rerolling with +1 seed failed for 10 iterations! Try setting webui's precision to 'full' and if it fails, please report this to the devs! Interrupting...")
@ -100,12 +100,12 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada
return None
return image
def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None, scheduler_name=None):
if cmd_opts.disable_nan_check:
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name)
else:
try:
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name)
except Exception as e:
if "A tensor with all NaNs was produced in VAE." in repr(e):
print(e)
@ -114,7 +114,7 @@ def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args,
raise e
return image, False
def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None, scheduler_name=None):
# Setup the pipeline
p = get_webui_sd_pipeline(args, root)
p.prompt, p.negative_prompt = split_weighted_subprompts(args.prompt, frame, anim_args.max_frames)
@ -176,6 +176,13 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
else:
raise RuntimeError(f"Sampler name '{sampler_name}' is invalid. Please check the available sampler list in the 'Run' tab")
available_schedulers = get_schedulers_list()
if scheduler_name is not None:
if scheduler_name in available_schedulers.keys():
p.scheduler = available_schedulers[scheduler_name]
else:
raise RuntimeError(f"Scheduler name '{scheduler_name}' is invalid. Please check the available scheduler list in the 'Run' tab")
if args.checkpoint is not None:
info = sd_models.get_closet_checkpoint_match(args.checkpoint)
if info is None:
@ -220,6 +227,7 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
seed_resize_from_h=p.seed_resize_from_h,
seed_resize_from_w=p.seed_resize_from_w,
sampler_name=p.sampler_name,
scheduler=p.scheduler,
batch_size=p.batch_size,
n_iter=p.n_iter,
steps=p.steps,

View File

@ -129,9 +129,18 @@ def load_all_settings(*args, ui_launch=False, **kwargs):
result = {}
for key, default_val in data.items():
val = jdata.get(key, default_val)
if key == 'sampler' and isinstance(val, int):
from modules.sd_samplers import samplers_for_img2img
val = samplers_for_img2img[val].name
if key == 'sampler' and isinstance(val, str):
samp_val = val.split()
scheduler_val = None
if samp_val[-1] in ['Uniform','SGM Uniform','Karras','Exponential','Polyexponential']:
scheduler_val = samp_val[-1]
val = (val.split(" " + samp_val[-1]))[0]
if key == 'scheduler' and isinstance(val, str):
if scheduler_val is not None:
val = scheduler_val
else:
from modules.sd_schedulers import schedulers_map
val = schedulers_map[val].label
elif key == 'fill' and isinstance(val, int):
val = mask_fill_choices[val]
elif key in {'reroll_blank_frames', 'noise_type'} and key not in jdata:
@ -142,7 +151,6 @@ def load_all_settings(*args, ui_launch=False, **kwargs):
val = jdata.get(key, default_val)
elif key == 'animation_prompts':
val = json.dumps(jdata['prompts'], ensure_ascii=False, indent=4)
result[key] = val
if ui_launch:

View File

@ -46,6 +46,7 @@ def get_tab_run(d, da):
motion_preview_mode = create_gr_elem(d.motion_preview_mode)
with FormRow():
sampler = create_gr_elem(d.sampler)
scheduler = create_gr_elem(d.scheduler)
steps = create_gr_elem(d.steps)
with FormRow():
W = create_gr_elem(d.W)

View File

@ -40,7 +40,7 @@ def get_webui_sd_pipeline(args, root):
p.batch_size = 1 # b.size 1 as this is DEFORUM :)
p.seed = args.seed
p.do_not_save_samples = True # Setting this to False will trigger webui's saving mechanism - and we will end up with duplicated files, and another folder within our destination folder - big no no.
p.sampler_name = args.sampler
p.scheduler = args.scheduler
p.mask_blur = args.mask_overlay_blur
p.extra_generation_params["Mask blur"] = args.mask_overlay_blur
p.n_iter = 1