From 9775118b0e94623e0b2337bc31682c0c21fe5c8f Mon Sep 17 00:00:00 2001 From: Rhys Edwards <6010457+rhyswynn@users.noreply.github.com> Date: Mon, 22 Apr 2024 07:21:24 -0400 Subject: [PATCH 1/4] Update to split sampler and scheduler selections --- scripts/deforum_helpers/args.py | 8 +++++- scripts/deforum_helpers/defaults.py | 22 ++++++++--------- scripts/deforum_helpers/generate.py | 26 +++++++++++++------- scripts/deforum_helpers/settings.py | 3 +++ scripts/deforum_helpers/ui_elements.py | 1 + scripts/deforum_helpers/webui_sd_pipeline.py | 2 +- 6 files changed, 39 insertions(+), 23 deletions(-) diff --git a/scripts/deforum_helpers/args.py b/scripts/deforum_helpers/args.py index 98257597..5be0b3ea 100644 --- a/scripts/deforum_helpers/args.py +++ b/scripts/deforum_helpers/args.py @@ -22,7 +22,7 @@ from types import SimpleNamespace import modules.paths as ph import modules.shared as sh from modules.processing import get_fixed_seed -from .defaults import get_guided_imgs_default_json, mask_fill_choices, get_samplers_list +from .defaults import get_guided_imgs_default_json, mask_fill_choices, get_samplers_list, get_schedulers_list from .deforum_controlnet import controlnet_component_names from .general_utils import get_os, substitute_placeholders @@ -766,6 +766,12 @@ def DeforumArgs(): "choices": get_samplers_list().values(), "value": "Euler a", }, + "scheduler": { + "label": "Scheduler", + "type": "dropdown", + "choices": get_schedulers_list().values(), + "value": "Automatic", + }, "steps": { "label": "Steps", "type": "slider", diff --git a/scripts/deforum_helpers/defaults.py b/scripts/deforum_helpers/defaults.py index b8aaac58..59678049 100644 --- a/scripts/deforum_helpers/defaults.py +++ b/scripts/deforum_helpers/defaults.py @@ -25,28 +25,26 @@ def get_samplers_list(): 'dpm++ 2s a': 'DPM++ 2S a', 'dpm++ 2m': 'DPM++ 2M', 'dpm++ sde': 'DPM++ SDE', - 'dpm++ 2m sde karras': 'DPM++ 2M SDE Karras', 'dpm fast': 'DPM fast', 'dpm adaptive': 'DPM adaptive', - 'lms karras': 'LMS Karras', - 'dpm2 karras': 'DPM2 Karras', - 'dpm2 a karras': 'DPM2 a Karras', - 'dpm++ 2s a karras': 'DPM++ 2S a Karras', - 'dpm++ 2m karras': 'DPM++ 2M Karras', - 'dpm++ sde karras': 'DPM++ SDE Karras', - 'dpm++ 2m sde exponential': 'DPM++ 2M SDE Exponential', 'dpm++ 2m sde heun': 'DPM++ 2M SDE Heun', - 'dpm++ 2m sde heun karras': 'DPM++ 2M SDE Heun Karras', - 'dpm++ 2m sde Heun Exponential': 'DPM++ 2M SDE Heun Exponential', 'dpm++ 3m sde': 'DPM++ 3M SDE', - 'dpm++ 3m sde karras': 'DPM++ 3M SDE Karras', - 'dpm++ 3m sde exponential': 'DPM++ 3M SDE Exponential', 'ddim': 'DDIM', 'plms': 'PLMS', 'unipc': 'UniPC', 'restart': 'Restart' } +def get_schedulers_list(): + return { + 'automatic': 'Automatic', + 'uniform': 'Uniform', + 'karras': 'Karras', + 'exponential': 'Exponential', + 'polyexponential': 'Polyexponential', + 'sgm uniform': 'SGM Uniform' + } + def DeforumAnimPrompts(): return r"""{ "0": "tiny cute bunny, vibrant diffraction, highly detailed, intricate, ultra hd, sharp photo, crepuscular rays, in focus", diff --git a/scripts/deforum_helpers/generate.py b/scripts/deforum_helpers/generate.py index fe48055e..64fc6135 100644 --- a/scripts/deforum_helpers/generate.py +++ b/scripts/deforum_helpers/generate.py @@ -28,7 +28,7 @@ from .prompt import split_weighted_subprompts from .load_images import load_img, prepare_mask, check_mask_for_errors from .webui_sd_pipeline import get_webui_sd_pipeline from .rich import console -from .defaults import get_samplers_list +from .defaults import get_samplers_list, get_schedulers_list from .prompt import check_is_number from .opts_overrider import A1111OptionsOverrider import cv2 @@ -70,14 +70,14 @@ def pairwise_repl(iterable): next(b, None) return zip(a, b) -def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None): +def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None, scheduler_name=None): if state.interrupted: return None if args.reroll_blank_frames == 'ignore': - return generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name) + return generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name) - image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name) + image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name) if caught_vae_exception or not image.getbbox(): patience = args.reroll_patience @@ -86,7 +86,7 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada while caught_vae_exception or not image.getbbox(): print("Rerolling with +1 seed...") args.seed += 1 - image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name) + image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name=None) patience -= 1 if patience == 0: print("Rerolling with +1 seed failed for 10 iterations! Try setting webui's precision to 'full' and if it fails, please report this to the devs! Interrupting...") @@ -100,12 +100,12 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada return None return image -def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None): +def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None, scheduler_name=None): if cmd_opts.disable_nan_check: - image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name) + image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name) else: try: - image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name) + image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name) except Exception as e: if "A tensor with all NaNs was produced in VAE." in repr(e): print(e) @@ -114,7 +114,7 @@ def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, raise e return image, False -def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None): +def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None, scheduler_name=None): # Setup the pipeline p = get_webui_sd_pipeline(args, root) p.prompt, p.negative_prompt = split_weighted_subprompts(args.prompt, frame, anim_args.max_frames) @@ -176,6 +176,13 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars else: raise RuntimeError(f"Sampler name '{sampler_name}' is invalid. Please check the available sampler list in the 'Run' tab") + available_schedulers = get_schedulers_list() + if scheduler_name is not None: + if scheduler_name in available_schedulers.keys(): + p.scheduler = available_schedulers[scheduler_name] + else: + raise RuntimeError(f"Scheduler name '{scheduler_name}' is invalid. Please check the available scheduler list in the 'Run' tab") + if args.checkpoint is not None: info = sd_models.get_closet_checkpoint_match(args.checkpoint) if info is None: @@ -220,6 +227,7 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, sampler_name=p.sampler_name, + scheduler=p.scheduler, batch_size=p.batch_size, n_iter=p.n_iter, steps=p.steps, diff --git a/scripts/deforum_helpers/settings.py b/scripts/deforum_helpers/settings.py index bd1132d9..bf88f313 100644 --- a/scripts/deforum_helpers/settings.py +++ b/scripts/deforum_helpers/settings.py @@ -132,6 +132,9 @@ def load_all_settings(*args, ui_launch=False, **kwargs): if key == 'sampler' and isinstance(val, int): from modules.sd_samplers import samplers_for_img2img val = samplers_for_img2img[val].name + if key == 'scheduler' and isinstance(val, int): + from modules.sd_schedulers import schedulers_map + val = schedulers_map[val] elif key == 'fill' and isinstance(val, int): val = mask_fill_choices[val] elif key in {'reroll_blank_frames', 'noise_type'} and key not in jdata: diff --git a/scripts/deforum_helpers/ui_elements.py b/scripts/deforum_helpers/ui_elements.py index 277c5d3d..f9467e7c 100644 --- a/scripts/deforum_helpers/ui_elements.py +++ b/scripts/deforum_helpers/ui_elements.py @@ -46,6 +46,7 @@ def get_tab_run(d, da): motion_preview_mode = create_gr_elem(d.motion_preview_mode) with FormRow(): sampler = create_gr_elem(d.sampler) + scheduler = create_gr_elem(d.scheduler) steps = create_gr_elem(d.steps) with FormRow(): W = create_gr_elem(d.W) diff --git a/scripts/deforum_helpers/webui_sd_pipeline.py b/scripts/deforum_helpers/webui_sd_pipeline.py index 81696461..3b3b269c 100644 --- a/scripts/deforum_helpers/webui_sd_pipeline.py +++ b/scripts/deforum_helpers/webui_sd_pipeline.py @@ -40,7 +40,7 @@ def get_webui_sd_pipeline(args, root): p.batch_size = 1 # b.size 1 as this is DEFORUM :) p.seed = args.seed p.do_not_save_samples = True # Setting this to False will trigger webui's saving mechanism - and we will end up with duplicated files, and another folder within our destination folder - big no no. - p.sampler_name = args.sampler + p.scheduler = args.scheduler p.mask_blur = args.mask_overlay_blur p.extra_generation_params["Mask blur"] = args.mask_overlay_blur p.n_iter = 1 From 1ebb9a483e93dbc8004d4439fd187d420dc5804a Mon Sep 17 00:00:00 2001 From: Rhys Edwards <6010457+rhyswynn@users.noreply.github.com> Date: Thu, 25 Apr 2024 16:49:33 -0400 Subject: [PATCH 2/4] Enabled legacy sampler settings to be imported --- scripts/deforum_helpers/defaults.py | 3 ++- scripts/deforum_helpers/settings.py | 31 ++++++++++++++++++++++------- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/scripts/deforum_helpers/defaults.py b/scripts/deforum_helpers/defaults.py index 59678049..0bb7f6f3 100644 --- a/scripts/deforum_helpers/defaults.py +++ b/scripts/deforum_helpers/defaults.py @@ -32,7 +32,8 @@ def get_samplers_list(): 'ddim': 'DDIM', 'plms': 'PLMS', 'unipc': 'UniPC', - 'restart': 'Restart' + 'restart': 'Restart', + 'lcm': 'LCM' } def get_schedulers_list(): diff --git a/scripts/deforum_helpers/settings.py b/scripts/deforum_helpers/settings.py index bf88f313..d2795540 100644 --- a/scripts/deforum_helpers/settings.py +++ b/scripts/deforum_helpers/settings.py @@ -129,12 +129,30 @@ def load_all_settings(*args, ui_launch=False, **kwargs): result = {} for key, default_val in data.items(): val = jdata.get(key, default_val) - if key == 'sampler' and isinstance(val, int): - from modules.sd_samplers import samplers_for_img2img - val = samplers_for_img2img[val].name - if key == 'scheduler' and isinstance(val, int): - from modules.sd_schedulers import schedulers_map - val = schedulers_map[val] + if key == 'sampler' and isinstance(val, str): + samp_val = val.split() + scheduler_val = None + if samp_val[-1] == 'Uniform': + scheduler_val = samp_val[-1] + val = (val.split(" Uniform"))[0] + if samp_val[-1] == 'Karras': + scheduler_val = samp_val[-1] + val = (val.split(" Karras"))[0] + if samp_val[-1] == 'Exponential': + scheduler_val = samp_val[-1] + val = (val.split(" Exponential"))[0] + if samp_val[-1] == 'Polyexponential': + scheduler_val = samp_val[-1] + val = (val.split(" Polyexponential"))[0] + if samp_val[-1] == 'SGM Uniform': + scheduler_val = samp_val[-1] + val = (val.split(" SGM Uniform"))[0] + if key == 'scheduler' and isinstance(val, str): + if scheduler_val is not None: + val = scheduler_val + else: + from modules.sd_schedulers import schedulers_map + val = schedulers_map[val] elif key == 'fill' and isinstance(val, int): val = mask_fill_choices[val] elif key in {'reroll_blank_frames', 'noise_type'} and key not in jdata: @@ -145,7 +163,6 @@ def load_all_settings(*args, ui_launch=False, **kwargs): val = jdata.get(key, default_val) elif key == 'animation_prompts': val = json.dumps(jdata['prompts'], ensure_ascii=False, indent=4) - result[key] = val if ui_launch: From c0042ab2ffe6df688def7c2a9789bb54b7c87fd9 Mon Sep 17 00:00:00 2001 From: Rhys Edwards <6010457+rhyswynn@users.noreply.github.com> Date: Thu, 25 Apr 2024 17:27:41 -0400 Subject: [PATCH 3/4] Enabled legacy sampler settings to be imported --- scripts/deforum_helpers/defaults.py | 1 + scripts/deforum_helpers/settings.py | 16 ++-------------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/scripts/deforum_helpers/defaults.py b/scripts/deforum_helpers/defaults.py index 0bb7f6f3..ac5f1719 100644 --- a/scripts/deforum_helpers/defaults.py +++ b/scripts/deforum_helpers/defaults.py @@ -24,6 +24,7 @@ def get_samplers_list(): 'dpm2 a': 'DPM2 a', 'dpm++ 2s a': 'DPM++ 2S a', 'dpm++ 2m': 'DPM++ 2M', + 'dpm++ 2m sde': 'DPM++ 2M SDE', 'dpm++ sde': 'DPM++ SDE', 'dpm fast': 'DPM fast', 'dpm adaptive': 'DPM adaptive', diff --git a/scripts/deforum_helpers/settings.py b/scripts/deforum_helpers/settings.py index d2795540..c4c79554 100644 --- a/scripts/deforum_helpers/settings.py +++ b/scripts/deforum_helpers/settings.py @@ -132,21 +132,9 @@ def load_all_settings(*args, ui_launch=False, **kwargs): if key == 'sampler' and isinstance(val, str): samp_val = val.split() scheduler_val = None - if samp_val[-1] == 'Uniform': + if samp_val[-1] in ['Uniform','SGM Uniform','Karras','Exponential','Polyexponential']: scheduler_val = samp_val[-1] - val = (val.split(" Uniform"))[0] - if samp_val[-1] == 'Karras': - scheduler_val = samp_val[-1] - val = (val.split(" Karras"))[0] - if samp_val[-1] == 'Exponential': - scheduler_val = samp_val[-1] - val = (val.split(" Exponential"))[0] - if samp_val[-1] == 'Polyexponential': - scheduler_val = samp_val[-1] - val = (val.split(" Polyexponential"))[0] - if samp_val[-1] == 'SGM Uniform': - scheduler_val = samp_val[-1] - val = (val.split(" SGM Uniform"))[0] + val = (val.split(" " + samp_val[-1]))[0] if key == 'scheduler' and isinstance(val, str): if scheduler_val is not None: val = scheduler_val From c37c19ad3c43c639e38e6e50ed35ceeb37c149cf Mon Sep 17 00:00:00 2001 From: Rhys Edwards <6010457+rhyswynn@users.noreply.github.com> Date: Fri, 26 Apr 2024 06:36:17 -0400 Subject: [PATCH 4/4] Enabled legacy sampler settings to be imported --- scripts/deforum_helpers/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/deforum_helpers/settings.py b/scripts/deforum_helpers/settings.py index c4c79554..41b5678a 100644 --- a/scripts/deforum_helpers/settings.py +++ b/scripts/deforum_helpers/settings.py @@ -140,7 +140,7 @@ def load_all_settings(*args, ui_launch=False, **kwargs): val = scheduler_val else: from modules.sd_schedulers import schedulers_map - val = schedulers_map[val] + val = schedulers_map[val].label elif key == 'fill' and isinstance(val, int): val = mask_fill_choices[val] elif key in {'reroll_blank_frames', 'noise_type'} and key not in jdata: