parent
f4bcaa078e
commit
35a8a8c5f7
|
|
@ -0,0 +1,42 @@
|
|||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Optional
|
||||
from modules import scripts
|
||||
from lib_neutral_prompt import global_state
|
||||
|
||||
|
||||
def patch():
|
||||
xyz_module = find_xyz_module()
|
||||
if xyz_module is None:
|
||||
print("[sd-webui-neutral-prompt]", "xyz_grid.py not found.", file=sys.stderr)
|
||||
return
|
||||
|
||||
xyz_module.axis_options.extend([
|
||||
xyz_module.AxisOption("[Neutral Prompt] CFG Rescale", int_or_float, apply_cfg_rescale()),
|
||||
])
|
||||
|
||||
|
||||
class XyzFloat(float):
|
||||
is_xyz: bool = True
|
||||
|
||||
|
||||
def apply_cfg_rescale():
|
||||
def callback(_p, v, _vs):
|
||||
global_state.cfg_rescale = XyzFloat(v)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
def int_or_float(string):
|
||||
try:
|
||||
return int(string)
|
||||
except ValueError:
|
||||
return float(string)
|
||||
|
||||
|
||||
def find_xyz_module() -> Optional[ModuleType]:
|
||||
for data in scripts.scripts_data:
|
||||
if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"):
|
||||
return data.module
|
||||
|
||||
return None
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui
|
||||
from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui, xyz_grid
|
||||
from modules import scripts, processing, shared
|
||||
from typing import Dict
|
||||
import functools
|
||||
|
|
@ -43,16 +43,24 @@ class NeutralPromptScript(scripts.Script):
|
|||
p.extra_generation_params.update(self.accordion_interface.get_extra_generation_params(args))
|
||||
|
||||
def update_global_state(self, args: Dict):
|
||||
if shared.state.job_no > 0:
|
||||
return
|
||||
if shared.state.job_no == 0:
|
||||
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
|
||||
|
||||
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
|
||||
for k, v in args.items():
|
||||
try:
|
||||
getattr(global_state, k)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if getattr(getattr(global_state, k), 'is_xyz', False):
|
||||
xyz_attr = getattr(global_state, k)
|
||||
xyz_attr.is_xyz = False
|
||||
args[k] = xyz_attr
|
||||
continue
|
||||
|
||||
if shared.state.job_no > 0:
|
||||
continue
|
||||
|
||||
setattr(global_state, k, v)
|
||||
|
||||
def hijack_composable_lora(self, is_img2img):
|
||||
|
|
@ -81,3 +89,6 @@ def composable_lora_process_hijack(p: processing.StableDiffusionProcessing, *arg
|
|||
# restore original prompts
|
||||
p.all_prompts = all_prompts
|
||||
return res
|
||||
|
||||
|
||||
xyz_grid.patch()
|
||||
|
|
|
|||
Loading…
Reference in New Issue