Add Directive mode to Clip Input

main
File_xor 2023-06-06 20:39:33 +09:00
parent 1dcc9a116a
commit 4e815714d1
1 changed files with 204 additions and 6 deletions

View File

@ -1,4 +1,7 @@
import os, csv, warnings, datetime import os, csv, warnings, datetime
import math as math
from collections import namedtuple
from enum import IntEnum
import gradio import gradio
import torch import torch
@ -37,8 +40,8 @@ class Clip_IO(scripts.Script):
with gradio.Accordion("Clip input", open = False): with gradio.Accordion("Clip input", open = False):
with gradio.Row(): with gradio.Row():
enabled = gradio.Checkbox(label = "Enable") enabled = gradio.Checkbox(label = "Enable")
mode_positive = gradio.Dropdown(["Disabled", "Simple"], value = "Disabled", max_choices = 1, label = "Positive prompt mode") mode_positive = gradio.Dropdown(["Disabled", "Simple", "Directive"], value = "Disabled", max_choices = 1, label = "Positive prompt mode")
mode_negative = gradio.Dropdown(["Disabled", "Simple"], value = "Disabled", max_choices = 1, label = "Positive prompt mode") mode_negative = gradio.Dropdown(["Disabled", "Simple", "Directive"], value = "Disabled", max_choices = 1, label = "Positive prompt mode")
pass pass
pass pass
if not is_img2img: if not is_img2img:
@ -132,6 +135,188 @@ class Clip_IO(scripts.Script):
pass pass
pass pass
syntax_directive = r"""
start: (FILE | PROMPT | directive | SPACE)*
FILE: /"(?!"").+?"/ | /'(?!"").+?'/ | /[^?"'\s]+/
PROMPT: /"{3}.*?"{3}|'{3}.*?'{3}/
directive: "?" DIRECTIVE ("_" DIRECTIVE_ORDER)? "(" directive_inner ")"
DIRECTIVE: (/[0-9a-zA-Z]+/ | /_(?![0-9]+\()+/)+
DIRECTIVE_ORDER: /[0-9]+/
directive_inner: (DIRECTIVE_PLAIN | directive_parentheses)*
!directive_parentheses: "(" (DIRECTIVE_PLAIN | directive_parentheses)* ")"
DIRECTIVE_PLAIN: /[^()]+/
SPACE: /\s+/
"""
class Directive:
class Names(IntEnum):
eval
pass
def __init__(self, name: str, order: int, inner: str):
self.name = name.lower()
self.order = order
self.inner = inner
pass
def __lt__(self, other) -> bool:
if type(self) != type(other):
raise TypeError()
pass
if self.order == other.order:
if self.Names[self.name] == self.Names[other.name]:
return True
pass
else:
return self.Names[self.name] < self.Names[other.name]
pass
pass
else:
return self.order < other.order
pass
pass
pass
def get_cond_directive(model, input: str, is_negative: bool) -> torch.Tensor | None:
conds: list[torch.tensor] = []
dirs: list[Clip_IO.Directive] = []
class Process(lark.Transformer):
def FILE(self, token: lark.Token):
cond: torch.Tensor | None = None
filename_original = token.value
if filename_original.startswith('"') and filename_original.endswith('"') or filename_original.startswith("'") and filename_original.endswith("'"):
filename_original = filename_original[1:-1]
if filename_original in Clip_IO.conditioning_cache:
cond = Clip_IO.conditioning_cache[filename_original]
if cond is not None:
conds.append(cond)
pass
return
pass
filename = os.path.join(os.path.dirname(__file__), "../conditioning", filename_original)
filename = os.path.realpath(filename)
if filename.endswith(".csv"):
cond = Clip_IO.load_csv_conditioning(filename)
pass
elif filename.endswith(".pt"):
try:
cond = torch.load(filename)
pass
except Exception:
cond = None
pass
pass
else:
if os.path.exists(filename) and not os.path.isdir(filename):
cond = Clip_IO.load_csv_conditioning(filename)
if cond is None:
try:
cond = torch.load(filename)
pass
except Exception:
cond = None
pass
pass
pass
if cond is None and os.path.exists(filename + ".csv"):
cond = Clip_IO.load_csv_conditioning(filename + ".csv")
pass
if cond is None and not os.path.exists(filename + ".csv") and os.path.exists(filename + ".pt"):
try:
cond = torch.load(filename + ".pt")
pass
except Exception:
cond = None
pass
pass
pass
if cond is not None:
conds.append(cond.to(devices.device))
pass
Clip_IO.conditioning_cache[filename_original] = cond
pass
def PROMPT(self, token: lark.Token):
string = token.value
if string.startswith('"""') and string.endswith('"""') or string.startswith("'''") and string.endswith("'''"):
string = string[3:-3]
pass
conds.append(model.get_learned_conditioning([string])[0].to(devices.device))
pass
def directive(self, args: list[lark.Token | lark.Tree]):
def flatten(arg: lark.Token | lark.Tree | list[lark.Token | lark.Tree]) -> str | lark.Token:
if type(arg) == lark.Token:
return arg.value
pass
elif type(arg) == lark.Tree:
array = ""
for child in arg.children:
array += flatten(child)
pass
return array
pass
elif type(arg) == list:
array = ""
for component in arg:
array += flatten(component)
pass
return array
pass
else:
return arg
pass
pass
dirs.append(Clip_IO.Directive(args[0], args[1] if len(args) == 3 else 0, flatten(args[-1])))
pass
pass
Process().transform(lark.Lark(Clip_IO.syntax_directive).parse(input))
i = torch.vstack(conds)
o = i.clone()
dirs.sort()
for dir in dirs:
if dir.name == "eval":
try:
for t in range(i.shape[0]):
for d in range(i.shape[1]):
local = {"i": i, "o": o, "t": t, "d": d, "torch": torch.__dict__} | math.__dict__
o[t, d] = eval(dir.inner, None, local)
pass
pass
pass
except Exception as e:
print(repr(e))
o = i
pass
finally:
i = o.clone()
pass
elif dir.name == "exec":
try:
local = {"i": i, "o": o, "torch": torch.__dict__} | math.__dict__
exec(dir.inner, None, local)
except Exception as e:
print(repr(e))
o = i
pass
finally:
i = o.clone()
pass
else:
warnings.warn(f'Directive "{dir.name}" does not exist.')
pass
pass
cond = o
if cond is not None and cond.shape[0] > 0 and (cond.shape[1] == 768 or cond.shape[1] == 1024):
return cond.to(devices.device)
pass
else:
warnings.warn(f"{'Negative prompt' if is_negative else 'Positive prompt'} is empty. Retrieving conditioning for empty string.")
return model.get_learned_conditioning([""])[0]
pass
pass
def my_get_learned_conditioning(model, prompts, steps, is_negative = True): def my_get_learned_conditioning(model, prompts, steps, is_negative = True):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one. and the sampling step at which this condition is to be replaced by the next one.
@ -150,11 +335,17 @@ class Clip_IO(scripts.Script):
] ]
] ]
""" """
if Clip_IO.enabled and (Clip_IO.mode_positive == "Directive" and not is_negative or Clip_IO.mode_negative == "Directive" and is_negative):
# TODO: Implement own parser
prompt_schedules = [[[steps, prompt]] for prompt in prompts]
pass
else:
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompts, steps)
pass
res = [] res = []
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompts, steps)
cache = {} cache = {}
for prompt, prompt_schedule in zip(prompts, prompt_schedules): for prompt, prompt_schedule in zip(prompts, prompt_schedules):
cached = cache.get(prompt, None) cached = cache.get(prompt, None)
@ -163,10 +354,17 @@ class Clip_IO(scripts.Script):
continue continue
texts: list[str] = [x[1] for x in prompt_schedule] texts: list[str] = [x[1] for x in prompt_schedule]
if Clip_IO.mode_positive == "Simple" and not is_negative or Clip_IO.mode_negative == "Simple" and is_negative: if Clip_IO.enabled and (Clip_IO.mode_positive == "Simple" and not is_negative or Clip_IO.mode_negative == "Simple" and is_negative):
conds = [] conds = []
for text in texts: for text in texts:
conds.append(Clip_IO.get_cond_simple(model, text, is_negative)) conds.append(Clip_IO.get_cond_simple(model, text, is_negative))
pass
pass
elif Clip_IO.enabled and (Clip_IO.mode_positive == "Directive" and not is_negative or Clip_IO.mode_negative == "Directive" and is_negative):
conds = []
for text in texts:
conds.append(Clip_IO.get_cond_directive(model, text, is_negative))
pass
pass pass
else: else:
conds = model.get_learned_conditioning(texts) conds = model.get_learned_conditioning(texts)