from collections import namedtuple from itertools import chain import csv import os.path from io import StringIO import numpy as np import modules.scripts as scripts from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion, errors from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, state import modules.shared as shared import modules.sd_samplers import modules.sd_models import modules.sd_vae import re from scripts.global_state import update_cn_models, cn_models_names, cn_preprocessor_modules from scripts.external_code import ResizeMode, ControlMode from modules.ui_components import ToolButton fill_values_symbol = "\U0001f4d2" # 📒 AxisInfo = namedtuple('AxisInfo', ['axis', 'values']) def apply_field(field): def fun(p, x, xs): setattr(p, field, x) return fun def apply_prompt(p, x, xs): if xs[0] not in p.prompt and xs[0] not in p.negative_prompt: raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.") p.prompt = p.prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x) def apply_order(p, x, xs): token_order = [] for token in x: token_order.append((p.prompt.find(token), token)) token_order.sort(key=lambda t: t[0]) prompt_parts = [] for _, token in token_order: n = p.prompt.find(token) prompt_parts.append(p.prompt[0:n]) p.prompt = p.prompt[n + len(token):] prompt_tmp = "" for idx, part in enumerate(prompt_parts): prompt_tmp += part prompt_tmp += x[idx] p.prompt = prompt_tmp + p.prompt def confirm_samplers(p, xs): for x in xs: if x.lower() not in sd_samplers.samplers_map: raise RuntimeError(f"Unknown sampler: {x}") def apply_checkpoint(p, x, xs): # info = modules.sd_models.get_closet_checkpoint_match(x) # if info is None: # raise RuntimeError(f"Unknown checkpoint: {x}") p.override_settings['sd_model_checkpoint'] = x def apply_controlnet(p, x, xs): shared.opts.data["control_net_allow_script_control"] = True setattr(p, 'control_net_model', x) def apply_refiner(p, x, xs): setattr(p, 'refiner_checkpoint', x) def confirm_checkpoints(p, xs): for x in xs: if modules.sd_models.get_closet_checkpoint_match(x) is None: raise RuntimeError(f"Unknown checkpoint: {x}") def confirm_checkpoints_or_none(p, xs): for x in xs: if x in (None, "", "None", "none"): continue if modules.sd_models.get_closet_checkpoint_match(x) is None: raise RuntimeError(f"Unknown checkpoint: {x}") def apply_clip_skip(p, x, xs): opts.data["CLIP_stop_at_last_layers"] = x def apply_upscale_latent_space(p, x, xs): if x.lower().strip() != '0': opts.data["use_scale_latent_for_hires_fix"] = True else: opts.data["use_scale_latent_for_hires_fix"] = False def find_vae(name: str): if name.lower() in ['auto', 'automatic']: return modules.sd_vae.unspecified if name.lower() == 'none': return None else: choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] if len(choices) == 0: print(f"No VAE found for {name}; using automatic") return modules.sd_vae.unspecified else: return modules.sd_vae.vae_dict[choices[0]] def apply_vae(p, x, xs): modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): p.styles.extend(x.split(',')) def apply_uni_pc_order(p, x, xs): opts.data["uni_pc_order"] = min(x, p.steps - 1) def apply_face_restore(p, opt, x): opt = opt.lower() if opt == 'codeformer': is_active = True p.face_restoration_model = 'CodeFormer' elif opt == 'gfpgan': is_active = True p.face_restoration_model = 'GFPGAN' else: is_active = opt in ('true', 'yes', 'y', '1') p.restore_faces = is_active def apply_override(field, boolean: bool = False): def fun(p, x, xs): if boolean: x = True if x.lower() == "true" else False p.override_settings[field] = x return fun def boolean_choice(reverse: bool = False): def choice(): return ["False", "True"] if reverse else ["True", "False"] return choice def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) return f"{opt.label}: {x}" def format_value(p, opt, x): if type(x) == float: x = round(x, 8) return x def format_value_join_list(p, opt, x): return ", ".join(x) def do_nothing(p, x, xs): pass def format_nothing(p, opt, x): return "" def format_name(p, opt, x): return x def format_remove_path(p, opt, x): return os.path.basename(x) def str_permutations(x): """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" return x def list_to_csv_string(data_list): with StringIO() as o: csv.writer(o).writerow(data_list) return o.getvalue().strip() def csv_string_to_list_strip(data_str): return list(map(str.strip, chain.from_iterable(csv.reader(StringIO(data_str))))) def identity(x): return x class ListParser(): """This class restores a broken list caused by the following process in the xyz_grid module. -> valslist = [x.strip() for x in chain.from_iterable( csv.reader(StringIO(vals)))] It also performs type conversion, adjusts the number of elements in the list, and other operations. This class directly modifies the received list. """ numeric_pattern = { int: { "range": r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*", "count": r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*" }, float: { "range": r"\s*([+-]?\s*\d+(?:\.\d*)?)\s*-\s*([+-]?\s*\d+(?:\.\d*)?)(?:\s*\(([+-]\d+(?:\.\d*)?)\s*\))?\s*", "count": r"\s*([+-]?\s*\d+(?:\.\d*)?)\s*-\s*([+-]?\s*\d+(?:\.\d*)?)(?:\s*\[(\d+(?:\.\d*)?)\s*\])?\s*" } } ################################################ # # Initialization method from here. # ################################################ def __init__(self, my_list, converter=None, allow_blank=True, exclude_list=None, run=True): self.my_list = my_list self.converter = converter self.allow_blank = allow_blank self.exclude_list = exclude_list self.re_bracket_start = None self.re_bracket_start_precheck = None self.re_bracket_end = None self.re_bracket_end_precheck = None self.re_range = None self.re_count = None self.compile_regex() if run: self.auto_normalize() def compile_regex(self): exclude_pattern = "|".join(self.exclude_list) if self.exclude_list else None if exclude_pattern is None: self.re_bracket_start = re.compile(r"^\[") self.re_bracket_end = re.compile(r"\]$") else: self.re_bracket_start = re.compile(fr"^\[(?!(?:{exclude_pattern})\])") self.re_bracket_end = re.compile(fr"(?