stable-diffusion-aws-extension/scripts/xyz_grid_utils.py

667 lines
25 KiB
Python

from collections import namedtuple
from itertools import chain
import csv
import os.path
from io import StringIO
import numpy as np
import modules.scripts as scripts
from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion, errors
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, state
import modules.shared as shared
import modules.sd_samplers
import modules.sd_models
import modules.sd_vae
import re
from scripts.global_state import update_cn_models, cn_models_names, cn_preprocessor_modules
from scripts.external_code import ResizeMode, ControlMode
from modules.ui_components import ToolButton
fill_values_symbol = "\U0001f4d2" # 📒
AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])
def apply_field(field):
def fun(p, x, xs):
setattr(p, field, x)
return fun
def apply_prompt(p, x, xs):
if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
p.prompt = p.prompt.replace(xs[0], x)
p.negative_prompt = p.negative_prompt.replace(xs[0], x)
def apply_order(p, x, xs):
token_order = []
for token in x:
token_order.append((p.prompt.find(token), token))
token_order.sort(key=lambda t: t[0])
prompt_parts = []
for _, token in token_order:
n = p.prompt.find(token)
prompt_parts.append(p.prompt[0:n])
p.prompt = p.prompt[n + len(token):]
prompt_tmp = ""
for idx, part in enumerate(prompt_parts):
prompt_tmp += part
prompt_tmp += x[idx]
p.prompt = prompt_tmp + p.prompt
def confirm_samplers(p, xs):
for x in xs:
if x.lower() not in sd_samplers.samplers_map:
raise RuntimeError(f"Unknown sampler: {x}")
def apply_checkpoint(p, x, xs):
# info = modules.sd_models.get_closet_checkpoint_match(x)
# if info is None:
# raise RuntimeError(f"Unknown checkpoint: {x}")
p.override_settings['sd_model_checkpoint'] = x
def apply_controlnet(p, x, xs):
shared.opts.data["control_net_allow_script_control"] = True
setattr(p, 'control_net_model', x)
def apply_refiner(p, x, xs):
setattr(p, 'refiner_checkpoint', x)
def confirm_checkpoints(p, xs):
for x in xs:
if modules.sd_models.get_closet_checkpoint_match(x) is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
def confirm_checkpoints_or_none(p, xs):
for x in xs:
if x in (None, "", "None", "none"):
continue
if modules.sd_models.get_closet_checkpoint_match(x) is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
def apply_clip_skip(p, x, xs):
opts.data["CLIP_stop_at_last_layers"] = x
def apply_upscale_latent_space(p, x, xs):
if x.lower().strip() != '0':
opts.data["use_scale_latent_for_hires_fix"] = True
else:
opts.data["use_scale_latent_for_hires_fix"] = False
def find_vae(name: str):
if name.lower() in ['auto', 'automatic']:
return modules.sd_vae.unspecified
if name.lower() == 'none':
return None
else:
choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
if len(choices) == 0:
print(f"No VAE found for {name}; using automatic")
return modules.sd_vae.unspecified
else:
return modules.sd_vae.vae_dict[choices[0]]
def apply_vae(p, x, xs):
modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
p.styles.extend(x.split(','))
def apply_uni_pc_order(p, x, xs):
opts.data["uni_pc_order"] = min(x, p.steps - 1)
def apply_face_restore(p, opt, x):
opt = opt.lower()
if opt == 'codeformer':
is_active = True
p.face_restoration_model = 'CodeFormer'
elif opt == 'gfpgan':
is_active = True
p.face_restoration_model = 'GFPGAN'
else:
is_active = opt in ('true', 'yes', 'y', '1')
p.restore_faces = is_active
def apply_override(field, boolean: bool = False):
def fun(p, x, xs):
if boolean:
x = True if x.lower() == "true" else False
p.override_settings[field] = x
return fun
def boolean_choice(reverse: bool = False):
def choice():
return ["False", "True"] if reverse else ["True", "False"]
return choice
def format_value_add_label(p, opt, x):
if type(x) == float:
x = round(x, 8)
return f"{opt.label}: {x}"
def format_value(p, opt, x):
if type(x) == float:
x = round(x, 8)
return x
def format_value_join_list(p, opt, x):
return ", ".join(x)
def do_nothing(p, x, xs):
pass
def format_nothing(p, opt, x):
return ""
def format_name(p, opt, x):
return x
def format_remove_path(p, opt, x):
return os.path.basename(x)
def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
return x
def list_to_csv_string(data_list):
with StringIO() as o:
csv.writer(o).writerow(data_list)
return o.getvalue().strip()
def csv_string_to_list_strip(data_str):
return list(map(str.strip, chain.from_iterable(csv.reader(StringIO(data_str)))))
def identity(x):
return x
class ListParser():
"""This class restores a broken list caused by the following process
in the xyz_grid module.
-> valslist = [x.strip() for x in chain.from_iterable(
csv.reader(StringIO(vals)))]
It also performs type conversion,
adjusts the number of elements in the list, and other operations.
This class directly modifies the received list.
"""
numeric_pattern = {
int: {
"range": r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*",
"count": r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*"
},
float: {
"range": r"\s*([+-]?\s*\d+(?:\.\d*)?)\s*-\s*([+-]?\s*\d+(?:\.\d*)?)(?:\s*\(([+-]\d+(?:\.\d*)?)\s*\))?\s*",
"count": r"\s*([+-]?\s*\d+(?:\.\d*)?)\s*-\s*([+-]?\s*\d+(?:\.\d*)?)(?:\s*\[(\d+(?:\.\d*)?)\s*\])?\s*"
}
}
################################################
#
# Initialization method from here.
#
################################################
def __init__(self, my_list, converter=None, allow_blank=True, exclude_list=None, run=True):
self.my_list = my_list
self.converter = converter
self.allow_blank = allow_blank
self.exclude_list = exclude_list
self.re_bracket_start = None
self.re_bracket_start_precheck = None
self.re_bracket_end = None
self.re_bracket_end_precheck = None
self.re_range = None
self.re_count = None
self.compile_regex()
if run:
self.auto_normalize()
def compile_regex(self):
exclude_pattern = "|".join(self.exclude_list) if self.exclude_list else None
if exclude_pattern is None:
self.re_bracket_start = re.compile(r"^\[")
self.re_bracket_end = re.compile(r"\]$")
else:
self.re_bracket_start = re.compile(fr"^\[(?!(?:{exclude_pattern})\])")
self.re_bracket_end = re.compile(fr"(?<!\[(?:{exclude_pattern}))\]$")
if self.converter not in self.numeric_pattern:
return self
# If the converter is either int or float.
self.re_range = re.compile(self.numeric_pattern[self.converter]["range"])
self.re_count = re.compile(self.numeric_pattern[self.converter]["count"])
self.re_bracket_start_precheck = None
self.re_bracket_end_precheck = self.re_count
return self
################################################
#
# Public method from here.
#
################################################
################################################
# This method is executed at the time of initialization.
#
def auto_normalize(self):
if not self.has_list_notation():
self.numeric_range_parser()
self.type_convert()
return self
else:
self.fix_structure()
self.numeric_range_parser()
self.type_convert()
self.fill_to_longest()
return self
def has_list_notation(self):
return any(self._search_bracket(s) for s in self.my_list)
def numeric_range_parser(self, my_list=None, depth=0):
if self.converter not in self.numeric_pattern:
return self
my_list = self.my_list if my_list is None else my_list
result = []
is_matched = False
for s in my_list:
if isinstance(s, list):
result.extend(self.numeric_range_parser(s, depth+1))
continue
match = self._numeric_range_to_list(s)
if s != match:
is_matched = True
result.extend(match if not depth else [match])
continue
else:
result.append(s)
continue
if depth:
return self._transpose(result) if is_matched else [result]
else:
my_list[:] = result
return self
def type_convert(self, my_list=None):
my_list = self.my_list if my_list is None else my_list
for i, s in enumerate(my_list):
if isinstance(s, list):
self.type_convert(s)
elif self.allow_blank and (str(s) in ["None", ""]):
my_list[i] = None
elif self.converter:
my_list[i] = self.converter(s)
else:
my_list[i] = s
return self
def fix_structure(self):
def is_same_length(list1, list2):
return len(list1) == len(list2)
start_indices, end_indices = [], []
for i, s in enumerate(self.my_list):
if is_same_length(start_indices, end_indices):
replace_string = self._search_bracket(s, "[", replace="")
if s != replace_string:
s = replace_string
start_indices.append(i)
if not is_same_length(start_indices, end_indices):
replace_string = self._search_bracket(s, "]", replace="")
if s != replace_string:
s = replace_string
end_indices.append(i + 1)
self.my_list[i] = s
if not is_same_length(start_indices, end_indices):
raise ValueError(f"Lengths of {start_indices} and {end_indices} are different.")
# Restore the structure of a list.
for i, j in zip(reversed(start_indices), reversed(end_indices)):
self.my_list[i:j] = [self.my_list[i:j]]
return self
def fill_to_longest(self, my_list=None, value=None, index=None):
my_list = self.my_list if my_list is None else my_list
if not self.sublist_exists(my_list):
return self
max_length = max(len(sub_list) for sub_list in my_list if isinstance(sub_list, list))
for i, sub_list in enumerate(my_list):
if isinstance(sub_list, list):
fill_value = value if index is None else sub_list[index]
my_list[i] = sub_list + [fill_value] * (max_length-len(sub_list))
return self
def sublist_exists(self, my_list=None):
my_list = self.my_list if my_list is None else my_list
return any(isinstance(item, list) for item in my_list)
def all_sublists(self, my_list=None): # Unused method
my_list = self.my_list if my_list is None else my_list
return all(isinstance(item, list) for item in my_list)
def get_list(self): # Unused method
return self.my_list
################################################
#
# Private method from here.
#
################################################
def _search_bracket(self, string, bracket="[", replace=None):
if bracket == "[":
pattern = self.re_bracket_start
precheck = self.re_bracket_start_precheck # None
elif bracket == "]":
pattern = self.re_bracket_end
precheck = self.re_bracket_end_precheck
else:
raise ValueError(f"Invalid argument provided. (bracket: {bracket})")
if precheck and precheck.fullmatch(string):
return None if replace is None else string
elif replace is None:
return pattern.search(string)
else:
return pattern.sub(replace, string)
def _numeric_range_to_list(self, string):
match = self.re_range.fullmatch(string)
if match is not None:
if self.converter == int:
start = int(match.group(1))
end = int(match.group(2)) + 1
step = int(match.group(3)) if match.group(3) is not None else 1
return list(range(start, end, step))
else: # float
start = float(match.group(1))
end = float(match.group(2))
step = float(match.group(3)) if match.group(3) is not None else 1
return np.arange(start, end + step, step).tolist()
match = self.re_count.fullmatch(string)
if match is not None:
if self.converter == int:
start = int(match.group(1))
end = int(match.group(2))
num = int(match.group(3)) if match.group(3) is not None else 1
return [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
else: # float
start = float(match.group(1))
end = float(match.group(2))
num = int(match.group(3)) if match.group(3) is not None else 1
return np.linspace(start=start, stop=end, num=num).tolist()
return string
def _transpose(self, my_list=None):
my_list = self.my_list if my_list is None else my_list
my_list = [item if isinstance(item, list) else [item] for item in my_list]
self.fill_to_longest(my_list, index=-1)
return np.array(my_list, dtype=object).T.tolist()
################################################
#
# The methods of ListParser class end here.
#
################################################
def find_dict(dict_list, keyword, search_key="name", stop=False):
result = next((d for d in dict_list if d[search_key] == keyword), None)
if result or not stop:
return result
else:
raise ValueError(f"Dictionary with value '{keyword}' in key '{search_key}' not found.")
def choices_bool():
return ["False", "True"]
def choices_model():
update_cn_models()
return list(cn_models_names.values())
def choices_control_mode():
return [e.value for e in ControlMode]
def choices_resize_mode():
return [e.value for e in ResizeMode]
def choices_preprocessor():
return list(cn_preprocessor_modules)
def make_excluded_list():
pattern = re.compile(r"\[(\w+)\]")
return [match.group(1) for s in choices_model()
for match in pattern.finditer(s)]
validation_data = [
{"name": "model", "type": str, "check": choices_model, "exclude": make_excluded_list},
{"name": "control_mode", "type": str, "check": choices_control_mode, "exclude": None},
{"name": "resize_mode", "type": str, "check": choices_resize_mode, "exclude": None},
{"name": "preprocessor", "type": str, "check": choices_preprocessor, "exclude": None},
]
def flatten(lst):
result = []
for element in lst:
if isinstance(element, list):
result.extend(flatten(element))
else:
result.append(element)
return result
def is_all_included(target_list, check_list, allow_blank=False, stop=False):
for element in flatten(target_list):
if allow_blank and str(element) in ["None", ""]:
continue
elif element not in check_list:
if not stop:
return False
else:
raise ValueError(f"'{element}' is not included in check list.")
return True
def bool_(string):
string = str(string)
if string in ["None", ""]:
return None
elif string.lower() in ["true", "1"]:
return True
elif string.lower() in ["false", "0"]:
return False
else:
raise ValueError(f"Could not convert string to boolean: {string}")
def confirm(func_or_str):
def confirm_(p, xs):
if callable(func_or_str):
# func_or_str is converter
ListParser(xs, func_or_str, allow_blank=True)
return
elif isinstance(func_or_str, str): # func_or_str is keyword
valid_data = find_dict(validation_data, func_or_str, stop=True)
converter = valid_data["type"]
exclude_list = valid_data["exclude"]() if valid_data["exclude"] else None
check_list = valid_data["check"]()
ListParser(xs, converter, allow_blank=True, exclude_list=exclude_list)
is_all_included(xs, check_list, allow_blank=True, stop=True)
return
else:
raise TypeError(f"Argument must be callable or str, not {type(func_or_str).__name__}.")
return confirm_
from scripts.xyz_grid import AxisOption, AxisOptionTxt2Img, AxisOptionImg2Img
axis_options_aws = [
AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
AxisOption("Seed", int, apply_field("seed")),
AxisOption("Var. seed", int, apply_field("subseed")),
AxisOption("Var. strength", float, apply_field("subseed_strength")),
AxisOption("Steps", int, apply_field("steps")),
AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
AxisOption("CFG Scale", float, apply_field("cfg_scale")),
AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers if x.name not in opts.hide_samplers]),
AxisOptionTxt2Img("Hires sampler", str, apply_field("hr_sampler_name"), confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img if x.name not in opts.hide_samplers]),
AxisOptionImg2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img if x.name not in opts.hide_samplers]),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
AxisOption("Sigma Churn", float, apply_field("s_churn")),
AxisOption("Sigma min", float, apply_field("s_tmin")),
AxisOption("Sigma max", float, apply_field("s_tmax")),
AxisOption("Sigma noise", float, apply_field("s_noise")),
AxisOption("Schedule type", str, apply_override("k_sched_type"), choices=lambda: list(sd_samplers_kdiffusion.k_diffusion_scheduler)),
AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
AxisOption("Schedule rho", float, apply_override("rho")),
AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_clip_skip),
AxisOption("Denoising", float, apply_field("denoising_strength")),
AxisOption("Initial noise multiplier", float, apply_field("initial_noise_multiplier")),
AxisOption("Extra noise", float, apply_override("img2img_extra_noise")),
AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['None'] + list(sd_vae.vae_dict)),
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')),
AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')),
AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)),
AxisOption("SGM noise multiplier", str, apply_override('sgm_noise_multiplier', boolean=True), choices=boolean_choice(reverse=True)),
AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),
AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')),
AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]),
AxisOption("[ControlNet] Enabled", identity, apply_field("control_net_enabled"), confirm=confirm(bool_), choices=choices_bool),
AxisOption("[ControlNet] Model", identity, apply_field("control_net_model"), choices=choices_model, cost=0.9),
AxisOption("[ControlNet] Weight", identity, apply_field("control_net_weight"), confirm=confirm(float)),
AxisOption("[ControlNet] Guidance Start", identity, apply_field("control_net_guidance_start"), confirm=confirm(float)),
AxisOption("[ControlNet] Guidance End", identity, apply_field("control_net_guidance_end"), confirm=confirm(float)),
AxisOption("[ControlNet] Control Mode", identity, apply_field("control_net_control_mode"), confirm=confirm("control_mode"), choices=choices_control_mode),
AxisOption("[ControlNet] Resize Mode", identity, apply_field("control_net_resize_mode"), confirm=confirm("resize_mode"), choices=choices_resize_mode),
AxisOption("[ControlNet] Preprocessor", identity, apply_field("control_net_module"), confirm=confirm("preprocessor"), choices=choices_preprocessor),
AxisOption("[ControlNet] Pre Resolution", identity, apply_field("control_net_pres"), confirm=confirm(int)),
AxisOption("[ControlNet] Pre Threshold A", identity, apply_field("control_net_pthr_a"), confirm=confirm(float)),
AxisOption("[ControlNet] Pre Threshold B", identity, apply_field("control_net_pthr_b"), confirm=confirm(float)),
]
shared.axis_options_aws = []
for x in axis_options_aws:
if isinstance(x, AxisOption):
shared.axis_options_aws.append(x)
else:
try:
converted_obj = AxisOption(*x.__dict__)
shared.axis_options_aws.append(converted_obj)
except Exception as e:
print(f"Failed to convert object: {x}. Error: {e}")
def find_module(module_names):
if isinstance(module_names, str):
module_names = [s.strip() for s in module_names.split(",")]
for data in scripts.scripts_data:
if data.script_class.__module__ in module_names and hasattr(data, "module"):
return data.module
return None
def apply_checkpoint_detector(origin_fn, cloud_fn):
def wrapper(*args):
try:
return cloud_fn(*args)
except Exception as e:
return origin_fn(*args)
return wrapper
xyz_grid = find_module("xyz_grid.py, xy_grid.py")
if xyz_grid:
sd_models_xyz_option = xyz_grid.axis_options[13]
sd_models_xyz_option.confirm = lambda *args: ""
sd_models_xyz_option.format_value = format_name
sd_models_xyz_option.apply = apply_checkpoint_detector(origin_fn=sd_models_xyz_option.apply, cloud_fn=apply_checkpoint)
vae_models_xyz_option = xyz_grid.axis_options[30]
vae_models_xyz_option.confirm = lambda *args: ""
vae_models_xyz_option.format_value = format_name
vae_models_xyz_option.apply = apply_checkpoint_detector(origin_fn=vae_models_xyz_option.apply, cloud_fn=apply_vae)
refiner_models_xyz_option = xyz_grid.axis_options[38]
refiner_models_xyz_option.confirm = lambda *args: ""
refiner_models_xyz_option.format_value = format_name
refiner_models_xyz_option.apply = apply_checkpoint_detector(origin_fn=refiner_models_xyz_option.apply, cloud_fn=apply_refiner)
controlnet_models_xyz_option = xyz_grid.axis_options[42]
controlnet_models_xyz_option.confirm = lambda *args: ""
controlnet_models_xyz_option.format_value = format_name
controlnet_models_xyz_option.apply = apply_checkpoint_detector(origin_fn=controlnet_models_xyz_option.apply, cloud_fn=apply_controlnet)
xyz_grid_support = find_module("xyz_grid_support.py")
if xyz_grid_support:
def origin_detector(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
return True
return wrapper
xyz_grid_support.is_all_included = origin_detector