sd-webui-neutral-prompt/lib_neutral_prompt/prompt_parser_hijack.py

50 lines
1.6 KiB
Python

from typing import List
from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser
from modules import script_callbacks, prompt_parser
prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
module=prompt_parser,
hijacker_attribute='__neutral_prompt_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)
@prompt_parser_hijacker.hijack('get_multicond_prompt_list')
def get_multicond_prompt_list_hijack(prompts, original_function):
if not global_state.is_enabled:
return original_function(prompts)
global_state.prompt_exprs = parse_prompts(prompts)
webui_prompts = transpile_exprs(global_state.prompt_exprs)
if isinstance(prompts, getattr(prompt_parser, 'SdConditioning', type(None))):
webui_prompts = prompt_parser.SdConditioning(webui_prompts, copy_from=prompts)
return original_function(webui_prompts)
def parse_prompts(prompts: List[str]) -> List[neutral_prompt_parser.PromptExpr]:
exprs = []
for prompt in prompts:
expr = neutral_prompt_parser.parse_root(prompt)
exprs.append(expr)
return exprs
def transpile_exprs(exprs: neutral_prompt_parser.PromptExpr):
webui_prompts = []
for expr in exprs:
webui_prompts.append(expr.accept(WebuiPromptVisitor()))
return webui_prompts
class WebuiPromptVisitor:
def visit_leaf_prompt(self, that: neutral_prompt_parser.LeafPrompt) -> str:
return f'{that.prompt} :{that.weight}'
def visit_composite_prompt(self, that: neutral_prompt_parser.CompositePrompt) -> str:
return ' AND '.join(child.accept(self) for child in that.children)