Merge pull request #968 from rhyswynn/automatic1111-webui
Update to split sampler and scheduler selectionsautomatic1111-webui
commit
91df0fa852
|
|
@ -22,7 +22,7 @@ from types import SimpleNamespace
|
|||
import modules.paths as ph
|
||||
import modules.shared as sh
|
||||
from modules.processing import get_fixed_seed
|
||||
from .defaults import get_guided_imgs_default_json, mask_fill_choices, get_samplers_list
|
||||
from .defaults import get_guided_imgs_default_json, mask_fill_choices, get_samplers_list, get_schedulers_list
|
||||
from .deforum_controlnet import controlnet_component_names
|
||||
from .general_utils import get_os, substitute_placeholders
|
||||
|
||||
|
|
@ -766,6 +766,12 @@ def DeforumArgs():
|
|||
"choices": get_samplers_list().values(),
|
||||
"value": "Euler a",
|
||||
},
|
||||
"scheduler": {
|
||||
"label": "Scheduler",
|
||||
"type": "dropdown",
|
||||
"choices": get_schedulers_list().values(),
|
||||
"value": "Automatic",
|
||||
},
|
||||
"steps": {
|
||||
"label": "Steps",
|
||||
"type": "slider",
|
||||
|
|
|
|||
|
|
@ -24,27 +24,27 @@ def get_samplers_list():
|
|||
'dpm2 a': 'DPM2 a',
|
||||
'dpm++ 2s a': 'DPM++ 2S a',
|
||||
'dpm++ 2m': 'DPM++ 2M',
|
||||
'dpm++ 2m sde': 'DPM++ 2M SDE',
|
||||
'dpm++ sde': 'DPM++ SDE',
|
||||
'dpm++ 2m sde karras': 'DPM++ 2M SDE Karras',
|
||||
'dpm fast': 'DPM fast',
|
||||
'dpm adaptive': 'DPM adaptive',
|
||||
'lms karras': 'LMS Karras',
|
||||
'dpm2 karras': 'DPM2 Karras',
|
||||
'dpm2 a karras': 'DPM2 a Karras',
|
||||
'dpm++ 2s a karras': 'DPM++ 2S a Karras',
|
||||
'dpm++ 2m karras': 'DPM++ 2M Karras',
|
||||
'dpm++ sde karras': 'DPM++ SDE Karras',
|
||||
'dpm++ 2m sde exponential': 'DPM++ 2M SDE Exponential',
|
||||
'dpm++ 2m sde heun': 'DPM++ 2M SDE Heun',
|
||||
'dpm++ 2m sde heun karras': 'DPM++ 2M SDE Heun Karras',
|
||||
'dpm++ 2m sde Heun Exponential': 'DPM++ 2M SDE Heun Exponential',
|
||||
'dpm++ 3m sde': 'DPM++ 3M SDE',
|
||||
'dpm++ 3m sde karras': 'DPM++ 3M SDE Karras',
|
||||
'dpm++ 3m sde exponential': 'DPM++ 3M SDE Exponential',
|
||||
'ddim': 'DDIM',
|
||||
'plms': 'PLMS',
|
||||
'unipc': 'UniPC',
|
||||
'restart': 'Restart'
|
||||
'restart': 'Restart',
|
||||
'lcm': 'LCM'
|
||||
}
|
||||
|
||||
def get_schedulers_list():
|
||||
return {
|
||||
'automatic': 'Automatic',
|
||||
'uniform': 'Uniform',
|
||||
'karras': 'Karras',
|
||||
'exponential': 'Exponential',
|
||||
'polyexponential': 'Polyexponential',
|
||||
'sgm uniform': 'SGM Uniform'
|
||||
}
|
||||
|
||||
def DeforumAnimPrompts():
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from .prompt import split_weighted_subprompts
|
|||
from .load_images import load_img, prepare_mask, check_mask_for_errors
|
||||
from .webui_sd_pipeline import get_webui_sd_pipeline
|
||||
from .rich import console
|
||||
from .defaults import get_samplers_list
|
||||
from .defaults import get_samplers_list, get_schedulers_list
|
||||
from .prompt import check_is_number
|
||||
from .opts_overrider import A1111OptionsOverrider
|
||||
import cv2
|
||||
|
|
@ -70,14 +70,14 @@ def pairwise_repl(iterable):
|
|||
next(b, None)
|
||||
return zip(a, b)
|
||||
|
||||
def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
|
||||
def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None, scheduler_name=None):
|
||||
if state.interrupted:
|
||||
return None
|
||||
|
||||
if args.reroll_blank_frames == 'ignore':
|
||||
return generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
|
||||
return generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name)
|
||||
|
||||
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
|
||||
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name)
|
||||
|
||||
if caught_vae_exception or not image.getbbox():
|
||||
patience = args.reroll_patience
|
||||
|
|
@ -86,7 +86,7 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada
|
|||
while caught_vae_exception or not image.getbbox():
|
||||
print("Rerolling with +1 seed...")
|
||||
args.seed += 1
|
||||
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
|
||||
image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name=None)
|
||||
patience -= 1
|
||||
if patience == 0:
|
||||
print("Rerolling with +1 seed failed for 10 iterations! Try setting webui's precision to 'full' and if it fails, please report this to the devs! Interrupting...")
|
||||
|
|
@ -100,12 +100,12 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada
|
|||
return None
|
||||
return image
|
||||
|
||||
def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
|
||||
def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None, scheduler_name=None):
|
||||
if cmd_opts.disable_nan_check:
|
||||
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
|
||||
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name)
|
||||
else:
|
||||
try:
|
||||
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
|
||||
image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name, scheduler_name)
|
||||
except Exception as e:
|
||||
if "A tensor with all NaNs was produced in VAE." in repr(e):
|
||||
print(e)
|
||||
|
|
@ -114,7 +114,7 @@ def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args,
|
|||
raise e
|
||||
return image, False
|
||||
|
||||
def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
|
||||
def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None, scheduler_name=None):
|
||||
# Setup the pipeline
|
||||
p = get_webui_sd_pipeline(args, root)
|
||||
p.prompt, p.negative_prompt = split_weighted_subprompts(args.prompt, frame, anim_args.max_frames)
|
||||
|
|
@ -176,6 +176,13 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
|
|||
else:
|
||||
raise RuntimeError(f"Sampler name '{sampler_name}' is invalid. Please check the available sampler list in the 'Run' tab")
|
||||
|
||||
available_schedulers = get_schedulers_list()
|
||||
if scheduler_name is not None:
|
||||
if scheduler_name in available_schedulers.keys():
|
||||
p.scheduler = available_schedulers[scheduler_name]
|
||||
else:
|
||||
raise RuntimeError(f"Scheduler name '{scheduler_name}' is invalid. Please check the available scheduler list in the 'Run' tab")
|
||||
|
||||
if args.checkpoint is not None:
|
||||
info = sd_models.get_closet_checkpoint_match(args.checkpoint)
|
||||
if info is None:
|
||||
|
|
@ -220,6 +227,7 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
|
|||
seed_resize_from_h=p.seed_resize_from_h,
|
||||
seed_resize_from_w=p.seed_resize_from_w,
|
||||
sampler_name=p.sampler_name,
|
||||
scheduler=p.scheduler,
|
||||
batch_size=p.batch_size,
|
||||
n_iter=p.n_iter,
|
||||
steps=p.steps,
|
||||
|
|
|
|||
|
|
@ -129,9 +129,18 @@ def load_all_settings(*args, ui_launch=False, **kwargs):
|
|||
result = {}
|
||||
for key, default_val in data.items():
|
||||
val = jdata.get(key, default_val)
|
||||
if key == 'sampler' and isinstance(val, int):
|
||||
from modules.sd_samplers import samplers_for_img2img
|
||||
val = samplers_for_img2img[val].name
|
||||
if key == 'sampler' and isinstance(val, str):
|
||||
samp_val = val.split()
|
||||
scheduler_val = None
|
||||
if samp_val[-1] in ['Uniform','SGM Uniform','Karras','Exponential','Polyexponential']:
|
||||
scheduler_val = samp_val[-1]
|
||||
val = (val.split(" " + samp_val[-1]))[0]
|
||||
if key == 'scheduler' and isinstance(val, str):
|
||||
if scheduler_val is not None:
|
||||
val = scheduler_val
|
||||
else:
|
||||
from modules.sd_schedulers import schedulers_map
|
||||
val = schedulers_map[val].label
|
||||
elif key == 'fill' and isinstance(val, int):
|
||||
val = mask_fill_choices[val]
|
||||
elif key in {'reroll_blank_frames', 'noise_type'} and key not in jdata:
|
||||
|
|
@ -142,7 +151,6 @@ def load_all_settings(*args, ui_launch=False, **kwargs):
|
|||
val = jdata.get(key, default_val)
|
||||
elif key == 'animation_prompts':
|
||||
val = json.dumps(jdata['prompts'], ensure_ascii=False, indent=4)
|
||||
|
||||
result[key] = val
|
||||
|
||||
if ui_launch:
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ def get_tab_run(d, da):
|
|||
motion_preview_mode = create_gr_elem(d.motion_preview_mode)
|
||||
with FormRow():
|
||||
sampler = create_gr_elem(d.sampler)
|
||||
scheduler = create_gr_elem(d.scheduler)
|
||||
steps = create_gr_elem(d.steps)
|
||||
with FormRow():
|
||||
W = create_gr_elem(d.W)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def get_webui_sd_pipeline(args, root):
|
|||
p.batch_size = 1 # b.size 1 as this is DEFORUM :)
|
||||
p.seed = args.seed
|
||||
p.do_not_save_samples = True # Setting this to False will trigger webui's saving mechanism - and we will end up with duplicated files, and another folder within our destination folder - big no no.
|
||||
p.sampler_name = args.sampler
|
||||
p.scheduler = args.scheduler
|
||||
p.mask_blur = args.mask_overlay_blur
|
||||
p.extra_generation_params["Mask blur"] = args.mask_overlay_blur
|
||||
p.n_iter = 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue