xyz grid (#54)

* external_code

* xyz grid

---------

Co-authored-by: ljleb <set>
pull/56/head
ljleb 2023-10-09 13:48:37 -04:00 committed by GitHub
parent f4bcaa078e
commit 35a8a8c5f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 4 deletions

View File

@ -0,0 +1,42 @@
import sys
from types import ModuleType
from typing import Optional
from modules import scripts
from lib_neutral_prompt import global_state
def patch():
xyz_module = find_xyz_module()
if xyz_module is None:
print("[sd-webui-neutral-prompt]", "xyz_grid.py not found.", file=sys.stderr)
return
xyz_module.axis_options.extend([
xyz_module.AxisOption("[Neutral Prompt] CFG Rescale", int_or_float, apply_cfg_rescale()),
])
class XyzFloat(float):
is_xyz: bool = True
def apply_cfg_rescale():
def callback(_p, v, _vs):
global_state.cfg_rescale = XyzFloat(v)
return callback
def int_or_float(string):
try:
return int(string)
except ValueError:
return float(string)
def find_xyz_module() -> Optional[ModuleType]:
for data in scripts.scripts_data:
if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"):
return data.module
return None

View File

@ -1,4 +1,4 @@
from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui
from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui, xyz_grid
from modules import scripts, processing, shared
from typing import Dict
import functools
@ -43,16 +43,24 @@ class NeutralPromptScript(scripts.Script):
p.extra_generation_params.update(self.accordion_interface.get_extra_generation_params(args))
def update_global_state(self, args: Dict):
if shared.state.job_no > 0:
return
if shared.state.job_no == 0:
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
for k, v in args.items():
try:
getattr(global_state, k)
except AttributeError:
continue
if getattr(getattr(global_state, k), 'is_xyz', False):
xyz_attr = getattr(global_state, k)
xyz_attr.is_xyz = False
args[k] = xyz_attr
continue
if shared.state.job_no > 0:
continue
setattr(global_state, k, v)
def hijack_composable_lora(self, is_img2img):
@ -81,3 +89,6 @@ def composable_lora_process_hijack(p: processing.StableDiffusionProcessing, *arg
# restore original prompts
p.all_prompts = all_prompts
return res
xyz_grid.patch()