Add Directive mode to Clip Input

main
File_xor 2023-06-06 20:39:33 +09:00
parent 1dcc9a116a
commit 4e815714d1
1 changed files with 204 additions and 6 deletions

View File

@ -1,4 +1,7 @@
import os, csv, warnings, datetime
import math as math
from collections import namedtuple
from enum import IntEnum
import gradio
import torch
@ -37,8 +40,8 @@ class Clip_IO(scripts.Script):
with gradio.Accordion("Clip input", open = False):
with gradio.Row():
enabled = gradio.Checkbox(label = "Enable")
mode_positive = gradio.Dropdown(["Disabled", "Simple"], value = "Disabled", max_choices = 1, label = "Positive prompt mode")
mode_negative = gradio.Dropdown(["Disabled", "Simple"], value = "Disabled", max_choices = 1, label = "Positive prompt mode")
mode_positive = gradio.Dropdown(["Disabled", "Simple", "Directive"], value = "Disabled", max_choices = 1, label = "Positive prompt mode")
mode_negative = gradio.Dropdown(["Disabled", "Simple", "Directive"], value = "Disabled", max_choices = 1, label = "Positive prompt mode")
pass
pass
if not is_img2img:
@ -132,6 +135,188 @@ class Clip_IO(scripts.Script):
pass
pass
syntax_directive = r"""
start: (FILE | PROMPT | directive | SPACE)*
FILE: /"(?!"").+?"/ | /'(?!"").+?'/ | /[^?"'\s]+/
PROMPT: /"{3}.*?"{3}|'{3}.*?'{3}/
directive: "?" DIRECTIVE ("_" DIRECTIVE_ORDER)? "(" directive_inner ")"
DIRECTIVE: (/[0-9a-zA-Z]+/ | /_(?![0-9]+\()+/)+
DIRECTIVE_ORDER: /[0-9]+/
directive_inner: (DIRECTIVE_PLAIN | directive_parentheses)*
!directive_parentheses: "(" (DIRECTIVE_PLAIN | directive_parentheses)* ")"
DIRECTIVE_PLAIN: /[^()]+/
SPACE: /\s+/
"""
class Directive:
class Names(IntEnum):
eval
pass
def __init__(self, name: str, order: int, inner: str):
self.name = name.lower()
self.order = order
self.inner = inner
pass
def __lt__(self, other) -> bool:
if type(self) != type(other):
raise TypeError()
pass
if self.order == other.order:
if self.Names[self.name] == self.Names[other.name]:
return True
pass
else:
return self.Names[self.name] < self.Names[other.name]
pass
pass
else:
return self.order < other.order
pass
pass
pass
def get_cond_directive(model, input: str, is_negative: bool) -> torch.Tensor | None:
conds: list[torch.tensor] = []
dirs: list[Clip_IO.Directive] = []
class Process(lark.Transformer):
def FILE(self, token: lark.Token):
cond: torch.Tensor | None = None
filename_original = token.value
if filename_original.startswith('"') and filename_original.endswith('"') or filename_original.startswith("'") and filename_original.endswith("'"):
filename_original = filename_original[1:-1]
if filename_original in Clip_IO.conditioning_cache:
cond = Clip_IO.conditioning_cache[filename_original]
if cond is not None:
conds.append(cond)
pass
return
pass
filename = os.path.join(os.path.dirname(__file__), "../conditioning", filename_original)
filename = os.path.realpath(filename)
if filename.endswith(".csv"):
cond = Clip_IO.load_csv_conditioning(filename)
pass
elif filename.endswith(".pt"):
try:
cond = torch.load(filename)
pass
except Exception:
cond = None
pass
pass
else:
if os.path.exists(filename) and not os.path.isdir(filename):
cond = Clip_IO.load_csv_conditioning(filename)
if cond is None:
try:
cond = torch.load(filename)
pass
except Exception:
cond = None
pass
pass
pass
if cond is None and os.path.exists(filename + ".csv"):
cond = Clip_IO.load_csv_conditioning(filename + ".csv")
pass
if cond is None and not os.path.exists(filename + ".csv") and os.path.exists(filename + ".pt"):
try:
cond = torch.load(filename + ".pt")
pass
except Exception:
cond = None
pass
pass
pass
if cond is not None:
conds.append(cond.to(devices.device))
pass
Clip_IO.conditioning_cache[filename_original] = cond
pass
def PROMPT(self, token: lark.Token):
string = token.value
if string.startswith('"""') and string.endswith('"""') or string.startswith("'''") and string.endswith("'''"):
string = string[3:-3]
pass
conds.append(model.get_learned_conditioning([string])[0].to(devices.device))
pass
def directive(self, args: list[lark.Token | lark.Tree]):
def flatten(arg: lark.Token | lark.Tree | list[lark.Token | lark.Tree]) -> str | lark.Token:
if type(arg) == lark.Token:
return arg.value
pass
elif type(arg) == lark.Tree:
array = ""
for child in arg.children:
array += flatten(child)
pass
return array
pass
elif type(arg) == list:
array = ""
for component in arg:
array += flatten(component)
pass
return array
pass
else:
return arg
pass
pass
dirs.append(Clip_IO.Directive(args[0], args[1] if len(args) == 3 else 0, flatten(args[-1])))
pass
pass
Process().transform(lark.Lark(Clip_IO.syntax_directive).parse(input))
i = torch.vstack(conds)
o = i.clone()
dirs.sort()
for dir in dirs:
if dir.name == "eval":
try:
for t in range(i.shape[0]):
for d in range(i.shape[1]):
local = {"i": i, "o": o, "t": t, "d": d, "torch": torch.__dict__} | math.__dict__
o[t, d] = eval(dir.inner, None, local)
pass
pass
pass
except Exception as e:
print(repr(e))
o = i
pass
finally:
i = o.clone()
pass
elif dir.name == "exec":
try:
local = {"i": i, "o": o, "torch": torch.__dict__} | math.__dict__
exec(dir.inner, None, local)
except Exception as e:
print(repr(e))
o = i
pass
finally:
i = o.clone()
pass
else:
warnings.warn(f'Directive "{dir.name}" does not exist.')
pass
pass
cond = o
if cond is not None and cond.shape[0] > 0 and (cond.shape[1] == 768 or cond.shape[1] == 1024):
return cond.to(devices.device)
pass
else:
warnings.warn(f"{'Negative prompt' if is_negative else 'Positive prompt'} is empty. Retrieving conditioning for empty string.")
return model.get_learned_conditioning([""])[0]
pass
pass
def my_get_learned_conditioning(model, prompts, steps, is_negative = True):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
@ -150,11 +335,17 @@ class Clip_IO(scripts.Script):
]
]
"""
if Clip_IO.enabled and (Clip_IO.mode_positive == "Directive" and not is_negative or Clip_IO.mode_negative == "Directive" and is_negative):
# TODO: Implement own parser
prompt_schedules = [[[steps, prompt]] for prompt in prompts]
pass
else:
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompts, steps)
pass
res = []
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompts, steps)
cache = {}
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
cached = cache.get(prompt, None)
@ -163,10 +354,17 @@ class Clip_IO(scripts.Script):
continue
texts: list[str] = [x[1] for x in prompt_schedule]
if Clip_IO.mode_positive == "Simple" and not is_negative or Clip_IO.mode_negative == "Simple" and is_negative:
if Clip_IO.enabled and (Clip_IO.mode_positive == "Simple" and not is_negative or Clip_IO.mode_negative == "Simple" and is_negative):
conds = []
for text in texts:
conds.append(Clip_IO.get_cond_simple(model, text, is_negative))
pass
pass
elif Clip_IO.enabled and (Clip_IO.mode_positive == "Directive" and not is_negative or Clip_IO.mode_negative == "Directive" and is_negative):
conds = []
for text in texts:
conds.append(Clip_IO.get_cond_directive(model, text, is_negative))
pass
pass
else:
conds = model.get_learned_conditioning(texts)