Reworked Clip Input

main
File_xor 2023-05-21 21:30:48 +09:00
parent 182d01c026
commit 3343f4a659
1 changed files with 111 additions and 128 deletions

View File

@ -2,7 +2,7 @@ import os, csv, warnings, datetime
import gradio
import torch
import pandas
import lark
from tkinter import filedialog
from modules import scripts, script_callbacks, shared, devices, processing, prompt_parser
@ -12,23 +12,10 @@ from modules.sd_hijack_clip import PromptChunkFix, PromptChunk, FrozenCLIPEmbedd
mode_types = ["replace", "concatenate", "command"]
class Clip_IO(scripts.Script):
class Ui_manager():
def __init__(self):
self.mode = "replace"
self.main_blocks: gradio.Blocks | None = None
pass
pass
ui_txt2img = Ui_manager()
ui_img2img = Ui_manager()
enabled = False
positive_filenames = []
negative_filenames = []
mode_positive = "Disabled"
mode_negative = "Disabled"
conditioning_cache = {}
positive_exist = False
negative_exist = False
evacuate_get_learned_conditioning = None
evacuate_get_multicond_learned_conditioning = None
@ -45,48 +32,105 @@ class Clip_IO(scripts.Script):
return scripts.AlwaysVisible
pass
def show_conditioning_open_dialog() -> str:
results = filedialog.askopenfilenames(filetypes = [("Comma-Separated Values", "*.csv"), ("Pytorch Tensor", "*.pt"), ("Any File", "*")])
output = ""
for result in results:
if ";" in str(result):
warnings.warn(f'In "{result}",\ninvalid character ";" was found when parsing the file name.\nFile name must not contain ";".\nContinue as if the file is not specified.')
continue
pass
output += str(result) + ";"
pass
return output[:-1]
pass
def ui(self, is_img2img):
with gradio.Accordion("Clip input", open = False):
with gradio.Row():
enabled = gradio.Checkbox(label = "Enable")
mode = gradio.Dropdown(choices = mode_types, value = "Replace", label = "Clip input mode")
pass
with gradio.Blocks(visible = True) as main_blocks:
with gradio.Row():
replace_positive = gradio.Textbox(label = "Replacement for Positive prompt")
replace_positive_button = gradio.Button("📂")
replace_negative = gradio.Textbox(label = "Replacement for Negative prompt")
replace_negative_button = gradio.Button("📂")
mode_positive = gradio.Dropdown(["Disabled", "Simple"], value = "Disabled", max_choices = 1, label = "Positive prompt mode")
mode_negative = gradio.Dropdown(["Disabled", "Simple"], value = "Disabled", max_choices = 1, label = "Positive prompt mode")
pass
pass
replace_positive_button.click(Clip_IO.show_conditioning_open_dialog, outputs = replace_positive, show_progress = False)
replace_negative_button.click(Clip_IO.show_conditioning_open_dialog, outputs = replace_negative, show_progress = False)
if not is_img2img:
if Clip_IO.ui_txt2img.mode == "replace":
return [enabled, mode, replace_positive, replace_negative]
pass
return [enabled, mode_positive, mode_negative]
pass
else:
if Clip_IO.ui_img2img.mode == "replace":
return [enabled, mode, replace_positive, replace_negative]
pass
return [enabled, mode_positive, mode_negative]
pass
return []
pass
syntax_simple = r"""
start: (FILE | PROMPT | SPACE)*
FILE: /"(?!"").+?"/ | /'(?!"").+?'/ | /[^"'\s]+/
PROMPT: /"{3}.*?"{3}|'{3}.*?'{3}/
SPACE: /[\s]+/
"""
def get_cond_simple(model, input: str, is_negative: bool) -> torch.Tensor | None:
conds = []
class Process(lark.Transformer):
def FILE(self, token: lark.Token):
cond: torch.Tensor | None = None
filename_original = token.value
if filename_original.startswith('"') and filename_original.endswith('"') or filename_original.startswith("'") and filename_original.endswith("'"):
filename_original = filename_original[1:-1]
if filename_original in Clip_IO.conditioning_cache:
cond = Clip_IO.conditioning_cache[filename_original]
if cond is not None:
conds.append(cond)
pass
return
pass
filename = os.path.join(os.path.dirname(__file__), "../conditioning", filename_original)
filename = os.path.realpath(filename)
if filename.endswith(".csv"):
cond = Clip_IO.load_csv_conditioning(filename)
pass
elif filename.endswith(".pt"):
try:
cond = torch.load(filename)
pass
except Exception:
cond = None
pass
pass
else:
if os.path.exists(filename) and not os.path.isdir(filename):
cond = Clip_IO.load_csv_conditioning(filename)
if cond is None:
try:
cond = torch.load(filename)
pass
except Exception:
cond = None
pass
pass
pass
if cond is None and os.path.exists(filename + ".csv"):
cond = Clip_IO.load_csv_conditioning(filename + ".csv")
pass
if cond is None and not os.path.exists(filename + ".csv") and os.path.exists(filename + ".pt"):
try:
cond = torch.load(filename + ".pt")
pass
except Exception:
cond = None
pass
pass
pass
if cond is not None:
conds.append(cond.to(devices.device))
pass
Clip_IO.conditioning_cache[filename_original] = cond
pass
def PROMPT(self, token: lark.Token):
string = token.value
if string.startswith('"""') and string.endswith('"""') or string.startswith("'''") and string.endswith("'''"):
string = string[3:-3]
pass
conds.append(model.get_learned_conditioning([string])[0].to(devices.device))
pass
pass
Process().transform(lark.Lark(Clip_IO.syntax_simple).parse(input))
if len(conds) != 0:
return torch.vstack(conds)
pass
else:
warnings.warn(f"{'Negative prompt' if is_negative else 'Positive prompt'} is empty. Retrieving conditioning for empty string.")
return model.get_learned_conditioning([""])[0]
pass
pass
def my_get_learned_conditioning(model, prompts, steps, is_negative = True):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
@ -117,29 +161,19 @@ class Clip_IO(scripts.Script):
res.append(cached)
continue
texts = [x[1] for x in prompt_schedule]
conds = model.get_learned_conditioning(texts)
texts: list[str] = [x[1] for x in prompt_schedule]
if Clip_IO.mode_positive == "Simple" and not is_negative or Clip_IO.mode_negative == "Simple" and is_negative:
conds = []
for text in texts:
conds.append(Clip_IO.get_cond_simple(model, text, is_negative))
pass
else:
conds = model.get_learned_conditioning(texts)
pass
cond_schedule = []
for i, (end_at_step, text) in enumerate(prompt_schedule):
if not Clip_IO.enabled:
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, conds[i]))
pass
else:
if not is_negative and Clip_IO.positive_exist:
if Clip_IO.ui_txt2img.mode == "replace":
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, torch.hstack([Clip_IO.conditioning_cache[filename].to(devices.device) for filename in Clip_IO.positive_filenames])))
pass
pass
elif is_negative and Clip_IO.negative_exist:
if Clip_IO.ui_txt2img.mode == "replace":
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, torch.hstack([Clip_IO.conditioning_cache[filename].to(devices.device) for filename in Clip_IO.negative_filenames])))
pass
pass
else:
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, conds[i]))
pass
pass
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, conds[i]))
cache[prompt] = cond_schedule
res.append(cond_schedule)
@ -354,69 +388,7 @@ class Clip_IO(scripts.Script):
def process(self, p: processing.StableDiffusionProcessing, *args):
if args[0]:
Clip_IO.enabled = True
Clip_IO.positive_filenames: list[str | os.PathLike] = str.split(args[2], ";")
Clip_IO.negative_filenames: list[str | os.PathLike] = str.split(args[3], ";")
Clip_IO.negative_exist = False
Clip_IO.positive_exist = False
for i, positive_filename in enumerate(Clip_IO.positive_filenames):
if positive_filename in Clip_IO.conditioning_cache or not os.path.exists(positive_filename) or os.path.isdir(positive_filename):
continue
pass
if positive_filename.endswith(".csv"):
conditioning = Clip_IO.load_csv_conditioning(positive_filename)
if conditioning is not None:
Clip_IO.conditioning_cache[positive_filename] = conditioning
Clip_IO.positive_exist = True
pass
else:
del Clip_IO.positive_filenames[i]
pass
pass
else:
try:
Clip_IO.conditioning_cache[positive_filename] = torch.load(positive_filename)
Clip_IO.positive_exist = True
pass
except Exception as e:
warnings.warn(f'In "{positive_filename}",\nsomething went wrong while loading pytorch Tensor.\nContinue as if the file is not specified.')
warnings.warn(repr(e))
return None
pass
pass
for i, negative_filename in enumerate(Clip_IO.negative_filenames):
if negative_filename in Clip_IO.conditioning_cache or not os.path.exists(negative_filename) or os.path.isdir(negative_filename):
continue
pass
if negative_filename.endswith(".csv"):
conditioning = Clip_IO.load_csv_conditioning(negative_filename)
if conditioning is not None:
Clip_IO.conditioning_cache[negative_filename] = conditioning
Clip_IO.negative_exist = True
pass
else:
del Clip_IO.negative_filenames[i]
pass
pass
else:
try:
Clip_IO.conditioning_cache[negative_filename] = torch.load(negative_filename)
Clip_IO.negative_exist = True
pass
except Exception as e:
warnings.warn(f'In "{negative_filename}",\nsomething went wrong while loading pytorch Tensor.\nContinue as if the file is not specified.')
warnings.warn(repr(e))
return None
pass
pass
if Clip_IO.positive_exist:
p.prompt = ""
pass
if Clip_IO.negative_exist:
p.negative_prompt = ""
pass
Clip_IO.conditioning_cache = {}
Clip_IO.evacuate_get_learned_conditioning = prompt_parser.get_learned_conditioning
Clip_IO.evacuate_get_multicond_learned_conditioning = prompt_parser.get_multicond_learned_conditioning
#Clip_IO.evacuate_get_conds_with_caching = Clip_IO.get_inner_function(processing.process_images_inner, Clip_IO.get_my_get_conds_with_caching()) # Flush cache in my_get_learned_conditioning instead.
@ -432,12 +404,23 @@ class Clip_IO(scripts.Script):
def postprocess(self, p: processing.StableDiffusionProcessing, processed, *args):
if args[0]:
Clip_IO.enabled = False
Clip_IO.conditioning_cache = {}
prompt_parser.get_learned_conditioning = Clip_IO.evacuate_get_learned_conditioning
prompt_parser.get_multicond_learned_conditioning = Clip_IO.evacuate_get_multicond_learned_conditioning
#Clip_IO.replace_inner_function(processing.process_images_inner, Clip_IO.evacuate_get_conds_with_caching)
pass
pass
def process_batch(self, p: processing.StableDiffusionProcessing, *args, **kwargs):
Clip_IO.mode_positive = args[1]
Clip_IO.mode_negative = args[2]
pass
def postprocess_batch(self, p: processing.StableDiffusionProcessing, *args, **kwargs):
Clip_IO.mode_positive = "Disabled"
Clip_IO.mode_negative = "Disabled"
pass
def get_chunks(prompt: str, clip: FrozenCLIPEmbedderWithCustomWordsBase) -> PromptChunk:
if opts.use_old_emphasis_implementation:
raise NotImplementedError