diff --git a/lib_neutral_prompt/xyz_grid.py b/lib_neutral_prompt/xyz_grid.py new file mode 100644 index 0000000..94d456c --- /dev/null +++ b/lib_neutral_prompt/xyz_grid.py @@ -0,0 +1,42 @@ +import sys +from types import ModuleType +from typing import Optional +from modules import scripts +from lib_neutral_prompt import global_state + + +def patch(): + xyz_module = find_xyz_module() + if xyz_module is None: + print("[sd-webui-neutral-prompt]", "xyz_grid.py not found.", file=sys.stderr) + return + + xyz_module.axis_options.extend([ + xyz_module.AxisOption("[Neutral Prompt] CFG Rescale", int_or_float, apply_cfg_rescale()), + ]) + + +class XyzFloat(float): + is_xyz: bool = True + + +def apply_cfg_rescale(): + def callback(_p, v, _vs): + global_state.cfg_rescale = XyzFloat(v) + + return callback + + +def int_or_float(string): + try: + return int(string) + except ValueError: + return float(string) + + +def find_xyz_module() -> Optional[ModuleType]: + for data in scripts.scripts_data: + if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"): + return data.module + + return None diff --git a/scripts/neutral_prompt.py b/scripts/neutral_prompt.py index f962f75..b48613d 100644 --- a/scripts/neutral_prompt.py +++ b/scripts/neutral_prompt.py @@ -1,4 +1,4 @@ -from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui +from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui, xyz_grid from modules import scripts, processing, shared from typing import Dict import functools @@ -43,16 +43,24 @@ class NeutralPromptScript(scripts.Script): p.extra_generation_params.update(self.accordion_interface.get_extra_generation_params(args)) def update_global_state(self, args: Dict): - if shared.state.job_no > 0: - return + if shared.state.job_no == 0: + global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True) - global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True) for k, v in args.items(): try: getattr(global_state, k) except AttributeError: continue + if getattr(getattr(global_state, k), 'is_xyz', False): + xyz_attr = getattr(global_state, k) + xyz_attr.is_xyz = False + args[k] = xyz_attr + continue + + if shared.state.job_no > 0: + continue + setattr(global_state, k, v) def hijack_composable_lora(self, is_img2img): @@ -81,3 +89,6 @@ def composable_lora_process_hijack(p: processing.StableDiffusionProcessing, *arg # restore original prompts p.all_prompts = all_prompts return res + + +xyz_grid.patch()