From 4e815714d1cdece008201db891328e8a335c483c Mon Sep 17 00:00:00 2001 From: File_xor Date: Tue, 6 Jun 2023 20:39:33 +0900 Subject: [PATCH] Add Directive mode to Clip Input --- scripts/Clip_IO.py | 210 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 204 insertions(+), 6 deletions(-) diff --git a/scripts/Clip_IO.py b/scripts/Clip_IO.py index 82f92c0..46c58bf 100644 --- a/scripts/Clip_IO.py +++ b/scripts/Clip_IO.py @@ -1,4 +1,7 @@ import os, csv, warnings, datetime +import math as math +from collections import namedtuple +from enum import IntEnum import gradio import torch @@ -37,8 +40,8 @@ class Clip_IO(scripts.Script): with gradio.Accordion("Clip input", open = False): with gradio.Row(): enabled = gradio.Checkbox(label = "Enable") - mode_positive = gradio.Dropdown(["Disabled", "Simple"], value = "Disabled", max_choices = 1, label = "Positive prompt mode") - mode_negative = gradio.Dropdown(["Disabled", "Simple"], value = "Disabled", max_choices = 1, label = "Positive prompt mode") + mode_positive = gradio.Dropdown(["Disabled", "Simple", "Directive"], value = "Disabled", max_choices = 1, label = "Positive prompt mode") + mode_negative = gradio.Dropdown(["Disabled", "Simple", "Directive"], value = "Disabled", max_choices = 1, label = "Positive prompt mode") pass pass if not is_img2img: @@ -132,6 +135,188 @@ class Clip_IO(scripts.Script): pass pass + syntax_directive = r""" + start: (FILE | PROMPT | directive | SPACE)* + FILE: /"(?!"").+?"/ | /'(?!"").+?'/ | /[^?"'\s]+/ + PROMPT: /"{3}.*?"{3}|'{3}.*?'{3}/ + directive: "?" DIRECTIVE ("_" DIRECTIVE_ORDER)? "(" directive_inner ")" + DIRECTIVE: (/[0-9a-zA-Z]+/ | /_(?![0-9]+\()+/)+ + DIRECTIVE_ORDER: /[0-9]+/ + directive_inner: (DIRECTIVE_PLAIN | directive_parentheses)* + !directive_parentheses: "(" (DIRECTIVE_PLAIN | directive_parentheses)* ")" + DIRECTIVE_PLAIN: /[^()]+/ + SPACE: /\s+/ + """ + + class Directive: + class Names(IntEnum): + eval + pass + + def __init__(self, name: str, order: int, inner: str): + self.name = name.lower() + self.order = order + self.inner = inner + pass + + def __lt__(self, other) -> bool: + if type(self) != type(other): + raise TypeError() + pass + if self.order == other.order: + if self.Names[self.name] == self.Names[other.name]: + return True + pass + else: + return self.Names[self.name] < self.Names[other.name] + pass + pass + else: + return self.order < other.order + pass + pass + pass + + def get_cond_directive(model, input: str, is_negative: bool) -> torch.Tensor | None: + conds: list[torch.tensor] = [] + dirs: list[Clip_IO.Directive] = [] + class Process(lark.Transformer): + def FILE(self, token: lark.Token): + cond: torch.Tensor | None = None + filename_original = token.value + if filename_original.startswith('"') and filename_original.endswith('"') or filename_original.startswith("'") and filename_original.endswith("'"): + filename_original = filename_original[1:-1] + if filename_original in Clip_IO.conditioning_cache: + cond = Clip_IO.conditioning_cache[filename_original] + if cond is not None: + conds.append(cond) + pass + return + pass + filename = os.path.join(os.path.dirname(__file__), "../conditioning", filename_original) + filename = os.path.realpath(filename) + if filename.endswith(".csv"): + cond = Clip_IO.load_csv_conditioning(filename) + pass + elif filename.endswith(".pt"): + try: + cond = torch.load(filename) + pass + except Exception: + cond = None + pass + pass + else: + if os.path.exists(filename) and not os.path.isdir(filename): + cond = Clip_IO.load_csv_conditioning(filename) + if cond is None: + try: + cond = torch.load(filename) + pass + except Exception: + cond = None + pass + pass + pass + if cond is None and os.path.exists(filename + ".csv"): + cond = Clip_IO.load_csv_conditioning(filename + ".csv") + pass + if cond is None and not os.path.exists(filename + ".csv") and os.path.exists(filename + ".pt"): + try: + cond = torch.load(filename + ".pt") + pass + except Exception: + cond = None + pass + pass + pass + if cond is not None: + conds.append(cond.to(devices.device)) + pass + Clip_IO.conditioning_cache[filename_original] = cond + pass + def PROMPT(self, token: lark.Token): + string = token.value + if string.startswith('"""') and string.endswith('"""') or string.startswith("'''") and string.endswith("'''"): + string = string[3:-3] + pass + conds.append(model.get_learned_conditioning([string])[0].to(devices.device)) + pass + def directive(self, args: list[lark.Token | lark.Tree]): + def flatten(arg: lark.Token | lark.Tree | list[lark.Token | lark.Tree]) -> str | lark.Token: + if type(arg) == lark.Token: + return arg.value + pass + elif type(arg) == lark.Tree: + array = "" + for child in arg.children: + array += flatten(child) + pass + return array + pass + elif type(arg) == list: + array = "" + for component in arg: + array += flatten(component) + pass + return array + pass + else: + return arg + pass + pass + + dirs.append(Clip_IO.Directive(args[0], args[1] if len(args) == 3 else 0, flatten(args[-1]))) + pass + pass + + Process().transform(lark.Lark(Clip_IO.syntax_directive).parse(input)) + i = torch.vstack(conds) + o = i.clone() + dirs.sort() + for dir in dirs: + if dir.name == "eval": + try: + for t in range(i.shape[0]): + for d in range(i.shape[1]): + local = {"i": i, "o": o, "t": t, "d": d, "torch": torch.__dict__} | math.__dict__ + o[t, d] = eval(dir.inner, None, local) + pass + pass + pass + except Exception as e: + print(repr(e)) + o = i + pass + finally: + i = o.clone() + pass + elif dir.name == "exec": + try: + local = {"i": i, "o": o, "torch": torch.__dict__} | math.__dict__ + exec(dir.inner, None, local) + except Exception as e: + print(repr(e)) + o = i + pass + finally: + i = o.clone() + pass + else: + warnings.warn(f'Directive "{dir.name}" does not exist.') + pass + pass + cond = o + + if cond is not None and cond.shape[0] > 0 and (cond.shape[1] == 768 or cond.shape[1] == 1024): + return cond.to(devices.device) + pass + else: + warnings.warn(f"{'Negative prompt' if is_negative else 'Positive prompt'} is empty. Retrieving conditioning for empty string.") + return model.get_learned_conditioning([""])[0] + pass + pass + def my_get_learned_conditioning(model, prompts, steps, is_negative = True): """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), and the sampling step at which this condition is to be replaced by the next one. @@ -150,11 +335,17 @@ class Clip_IO(scripts.Script): ] ] """ + + if Clip_IO.enabled and (Clip_IO.mode_positive == "Directive" and not is_negative or Clip_IO.mode_negative == "Directive" and is_negative): + # TODO: Implement own parser + prompt_schedules = [[[steps, prompt]] for prompt in prompts] + pass + else: + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompts, steps) + pass + res = [] - - prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompts, steps) cache = {} - for prompt, prompt_schedule in zip(prompts, prompt_schedules): cached = cache.get(prompt, None) @@ -163,10 +354,17 @@ class Clip_IO(scripts.Script): continue texts: list[str] = [x[1] for x in prompt_schedule] - if Clip_IO.mode_positive == "Simple" and not is_negative or Clip_IO.mode_negative == "Simple" and is_negative: + if Clip_IO.enabled and (Clip_IO.mode_positive == "Simple" and not is_negative or Clip_IO.mode_negative == "Simple" and is_negative): conds = [] for text in texts: conds.append(Clip_IO.get_cond_simple(model, text, is_negative)) + pass + pass + elif Clip_IO.enabled and (Clip_IO.mode_positive == "Directive" and not is_negative or Clip_IO.mode_negative == "Directive" and is_negative): + conds = [] + for text in texts: + conds.append(Clip_IO.get_cond_directive(model, text, is_negative)) + pass pass else: conds = model.get_learned_conditioning(texts)