mirror of https://github.com/Filexor/Clip_IO.git
Reworked Clip Input
parent
182d01c026
commit
3343f4a659
|
|
@ -2,7 +2,7 @@ import os, csv, warnings, datetime
|
|||
|
||||
import gradio
|
||||
import torch
|
||||
import pandas
|
||||
import lark
|
||||
from tkinter import filedialog
|
||||
|
||||
from modules import scripts, script_callbacks, shared, devices, processing, prompt_parser
|
||||
|
|
@ -12,23 +12,10 @@ from modules.sd_hijack_clip import PromptChunkFix, PromptChunk, FrozenCLIPEmbedd
|
|||
mode_types = ["replace", "concatenate", "command"]
|
||||
|
||||
class Clip_IO(scripts.Script):
|
||||
class Ui_manager():
|
||||
def __init__(self):
|
||||
self.mode = "replace"
|
||||
self.main_blocks: gradio.Blocks | None = None
|
||||
pass
|
||||
pass
|
||||
|
||||
ui_txt2img = Ui_manager()
|
||||
ui_img2img = Ui_manager()
|
||||
|
||||
enabled = False
|
||||
|
||||
positive_filenames = []
|
||||
negative_filenames = []
|
||||
mode_positive = "Disabled"
|
||||
mode_negative = "Disabled"
|
||||
conditioning_cache = {}
|
||||
positive_exist = False
|
||||
negative_exist = False
|
||||
|
||||
evacuate_get_learned_conditioning = None
|
||||
evacuate_get_multicond_learned_conditioning = None
|
||||
|
|
@ -45,48 +32,105 @@ class Clip_IO(scripts.Script):
|
|||
return scripts.AlwaysVisible
|
||||
pass
|
||||
|
||||
def show_conditioning_open_dialog() -> str:
|
||||
results = filedialog.askopenfilenames(filetypes = [("Comma-Separated Values", "*.csv"), ("Pytorch Tensor", "*.pt"), ("Any File", "*")])
|
||||
output = ""
|
||||
for result in results:
|
||||
if ";" in str(result):
|
||||
warnings.warn(f'In "{result}",\ninvalid character ";" was found when parsing the file name.\nFile name must not contain ";".\nContinue as if the file is not specified.')
|
||||
continue
|
||||
pass
|
||||
output += str(result) + ";"
|
||||
pass
|
||||
return output[:-1]
|
||||
pass
|
||||
|
||||
def ui(self, is_img2img):
|
||||
with gradio.Accordion("Clip input", open = False):
|
||||
with gradio.Row():
|
||||
enabled = gradio.Checkbox(label = "Enable")
|
||||
mode = gradio.Dropdown(choices = mode_types, value = "Replace", label = "Clip input mode")
|
||||
pass
|
||||
with gradio.Blocks(visible = True) as main_blocks:
|
||||
with gradio.Row():
|
||||
replace_positive = gradio.Textbox(label = "Replacement for Positive prompt")
|
||||
replace_positive_button = gradio.Button("📂")
|
||||
replace_negative = gradio.Textbox(label = "Replacement for Negative prompt")
|
||||
replace_negative_button = gradio.Button("📂")
|
||||
mode_positive = gradio.Dropdown(["Disabled", "Simple"], value = "Disabled", max_choices = 1, label = "Positive prompt mode")
|
||||
mode_negative = gradio.Dropdown(["Disabled", "Simple"], value = "Disabled", max_choices = 1, label = "Positive prompt mode")
|
||||
pass
|
||||
pass
|
||||
replace_positive_button.click(Clip_IO.show_conditioning_open_dialog, outputs = replace_positive, show_progress = False)
|
||||
replace_negative_button.click(Clip_IO.show_conditioning_open_dialog, outputs = replace_negative, show_progress = False)
|
||||
if not is_img2img:
|
||||
if Clip_IO.ui_txt2img.mode == "replace":
|
||||
return [enabled, mode, replace_positive, replace_negative]
|
||||
pass
|
||||
return [enabled, mode_positive, mode_negative]
|
||||
pass
|
||||
else:
|
||||
if Clip_IO.ui_img2img.mode == "replace":
|
||||
return [enabled, mode, replace_positive, replace_negative]
|
||||
pass
|
||||
return [enabled, mode_positive, mode_negative]
|
||||
pass
|
||||
return []
|
||||
pass
|
||||
|
||||
syntax_simple = r"""
|
||||
start: (FILE | PROMPT | SPACE)*
|
||||
FILE: /"(?!"").+?"/ | /'(?!"").+?'/ | /[^"'\s]+/
|
||||
PROMPT: /"{3}.*?"{3}|'{3}.*?'{3}/
|
||||
SPACE: /[\s]+/
|
||||
"""
|
||||
|
||||
def get_cond_simple(model, input: str, is_negative: bool) -> torch.Tensor | None:
|
||||
conds = []
|
||||
class Process(lark.Transformer):
|
||||
def FILE(self, token: lark.Token):
|
||||
cond: torch.Tensor | None = None
|
||||
filename_original = token.value
|
||||
if filename_original.startswith('"') and filename_original.endswith('"') or filename_original.startswith("'") and filename_original.endswith("'"):
|
||||
filename_original = filename_original[1:-1]
|
||||
if filename_original in Clip_IO.conditioning_cache:
|
||||
cond = Clip_IO.conditioning_cache[filename_original]
|
||||
if cond is not None:
|
||||
conds.append(cond)
|
||||
pass
|
||||
return
|
||||
pass
|
||||
filename = os.path.join(os.path.dirname(__file__), "../conditioning", filename_original)
|
||||
filename = os.path.realpath(filename)
|
||||
if filename.endswith(".csv"):
|
||||
cond = Clip_IO.load_csv_conditioning(filename)
|
||||
pass
|
||||
elif filename.endswith(".pt"):
|
||||
try:
|
||||
cond = torch.load(filename)
|
||||
pass
|
||||
except Exception:
|
||||
cond = None
|
||||
pass
|
||||
pass
|
||||
else:
|
||||
if os.path.exists(filename) and not os.path.isdir(filename):
|
||||
cond = Clip_IO.load_csv_conditioning(filename)
|
||||
if cond is None:
|
||||
try:
|
||||
cond = torch.load(filename)
|
||||
pass
|
||||
except Exception:
|
||||
cond = None
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
if cond is None and os.path.exists(filename + ".csv"):
|
||||
cond = Clip_IO.load_csv_conditioning(filename + ".csv")
|
||||
pass
|
||||
if cond is None and not os.path.exists(filename + ".csv") and os.path.exists(filename + ".pt"):
|
||||
try:
|
||||
cond = torch.load(filename + ".pt")
|
||||
pass
|
||||
except Exception:
|
||||
cond = None
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
if cond is not None:
|
||||
conds.append(cond.to(devices.device))
|
||||
pass
|
||||
Clip_IO.conditioning_cache[filename_original] = cond
|
||||
pass
|
||||
def PROMPT(self, token: lark.Token):
|
||||
string = token.value
|
||||
if string.startswith('"""') and string.endswith('"""') or string.startswith("'''") and string.endswith("'''"):
|
||||
string = string[3:-3]
|
||||
pass
|
||||
conds.append(model.get_learned_conditioning([string])[0].to(devices.device))
|
||||
pass
|
||||
pass
|
||||
Process().transform(lark.Lark(Clip_IO.syntax_simple).parse(input))
|
||||
if len(conds) != 0:
|
||||
return torch.vstack(conds)
|
||||
pass
|
||||
else:
|
||||
warnings.warn(f"{'Negative prompt' if is_negative else 'Positive prompt'} is empty. Retrieving conditioning for empty string.")
|
||||
return model.get_learned_conditioning([""])[0]
|
||||
pass
|
||||
pass
|
||||
|
||||
def my_get_learned_conditioning(model, prompts, steps, is_negative = True):
|
||||
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||
and the sampling step at which this condition is to be replaced by the next one.
|
||||
|
|
@ -117,29 +161,19 @@ class Clip_IO(scripts.Script):
|
|||
res.append(cached)
|
||||
continue
|
||||
|
||||
texts = [x[1] for x in prompt_schedule]
|
||||
conds = model.get_learned_conditioning(texts)
|
||||
texts: list[str] = [x[1] for x in prompt_schedule]
|
||||
if Clip_IO.mode_positive == "Simple" and not is_negative or Clip_IO.mode_negative == "Simple" and is_negative:
|
||||
conds = []
|
||||
for text in texts:
|
||||
conds.append(Clip_IO.get_cond_simple(model, text, is_negative))
|
||||
pass
|
||||
else:
|
||||
conds = model.get_learned_conditioning(texts)
|
||||
pass
|
||||
|
||||
cond_schedule = []
|
||||
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
||||
if not Clip_IO.enabled:
|
||||
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||
pass
|
||||
else:
|
||||
if not is_negative and Clip_IO.positive_exist:
|
||||
if Clip_IO.ui_txt2img.mode == "replace":
|
||||
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, torch.hstack([Clip_IO.conditioning_cache[filename].to(devices.device) for filename in Clip_IO.positive_filenames])))
|
||||
pass
|
||||
pass
|
||||
elif is_negative and Clip_IO.negative_exist:
|
||||
if Clip_IO.ui_txt2img.mode == "replace":
|
||||
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, torch.hstack([Clip_IO.conditioning_cache[filename].to(devices.device) for filename in Clip_IO.negative_filenames])))
|
||||
pass
|
||||
pass
|
||||
else:
|
||||
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||
pass
|
||||
pass
|
||||
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||
|
||||
cache[prompt] = cond_schedule
|
||||
res.append(cond_schedule)
|
||||
|
|
@ -354,69 +388,7 @@ class Clip_IO(scripts.Script):
|
|||
def process(self, p: processing.StableDiffusionProcessing, *args):
|
||||
if args[0]:
|
||||
Clip_IO.enabled = True
|
||||
Clip_IO.positive_filenames: list[str | os.PathLike] = str.split(args[2], ";")
|
||||
Clip_IO.negative_filenames: list[str | os.PathLike] = str.split(args[3], ";")
|
||||
Clip_IO.negative_exist = False
|
||||
Clip_IO.positive_exist = False
|
||||
|
||||
for i, positive_filename in enumerate(Clip_IO.positive_filenames):
|
||||
if positive_filename in Clip_IO.conditioning_cache or not os.path.exists(positive_filename) or os.path.isdir(positive_filename):
|
||||
continue
|
||||
pass
|
||||
if positive_filename.endswith(".csv"):
|
||||
conditioning = Clip_IO.load_csv_conditioning(positive_filename)
|
||||
if conditioning is not None:
|
||||
Clip_IO.conditioning_cache[positive_filename] = conditioning
|
||||
Clip_IO.positive_exist = True
|
||||
pass
|
||||
else:
|
||||
del Clip_IO.positive_filenames[i]
|
||||
pass
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
Clip_IO.conditioning_cache[positive_filename] = torch.load(positive_filename)
|
||||
Clip_IO.positive_exist = True
|
||||
pass
|
||||
except Exception as e:
|
||||
warnings.warn(f'In "{positive_filename}",\nsomething went wrong while loading pytorch Tensor.\nContinue as if the file is not specified.')
|
||||
warnings.warn(repr(e))
|
||||
return None
|
||||
pass
|
||||
pass
|
||||
for i, negative_filename in enumerate(Clip_IO.negative_filenames):
|
||||
if negative_filename in Clip_IO.conditioning_cache or not os.path.exists(negative_filename) or os.path.isdir(negative_filename):
|
||||
continue
|
||||
pass
|
||||
if negative_filename.endswith(".csv"):
|
||||
conditioning = Clip_IO.load_csv_conditioning(negative_filename)
|
||||
if conditioning is not None:
|
||||
Clip_IO.conditioning_cache[negative_filename] = conditioning
|
||||
Clip_IO.negative_exist = True
|
||||
pass
|
||||
else:
|
||||
del Clip_IO.negative_filenames[i]
|
||||
pass
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
Clip_IO.conditioning_cache[negative_filename] = torch.load(negative_filename)
|
||||
Clip_IO.negative_exist = True
|
||||
pass
|
||||
except Exception as e:
|
||||
warnings.warn(f'In "{negative_filename}",\nsomething went wrong while loading pytorch Tensor.\nContinue as if the file is not specified.')
|
||||
warnings.warn(repr(e))
|
||||
return None
|
||||
pass
|
||||
pass
|
||||
|
||||
if Clip_IO.positive_exist:
|
||||
p.prompt = ""
|
||||
pass
|
||||
if Clip_IO.negative_exist:
|
||||
p.negative_prompt = ""
|
||||
pass
|
||||
|
||||
Clip_IO.conditioning_cache = {}
|
||||
Clip_IO.evacuate_get_learned_conditioning = prompt_parser.get_learned_conditioning
|
||||
Clip_IO.evacuate_get_multicond_learned_conditioning = prompt_parser.get_multicond_learned_conditioning
|
||||
#Clip_IO.evacuate_get_conds_with_caching = Clip_IO.get_inner_function(processing.process_images_inner, Clip_IO.get_my_get_conds_with_caching()) # Flush cache in my_get_learned_conditioning instead.
|
||||
|
|
@ -432,12 +404,23 @@ class Clip_IO(scripts.Script):
|
|||
def postprocess(self, p: processing.StableDiffusionProcessing, processed, *args):
|
||||
if args[0]:
|
||||
Clip_IO.enabled = False
|
||||
Clip_IO.conditioning_cache = {}
|
||||
prompt_parser.get_learned_conditioning = Clip_IO.evacuate_get_learned_conditioning
|
||||
prompt_parser.get_multicond_learned_conditioning = Clip_IO.evacuate_get_multicond_learned_conditioning
|
||||
#Clip_IO.replace_inner_function(processing.process_images_inner, Clip_IO.evacuate_get_conds_with_caching)
|
||||
pass
|
||||
pass
|
||||
|
||||
def process_batch(self, p: processing.StableDiffusionProcessing, *args, **kwargs):
|
||||
Clip_IO.mode_positive = args[1]
|
||||
Clip_IO.mode_negative = args[2]
|
||||
pass
|
||||
|
||||
def postprocess_batch(self, p: processing.StableDiffusionProcessing, *args, **kwargs):
|
||||
Clip_IO.mode_positive = "Disabled"
|
||||
Clip_IO.mode_negative = "Disabled"
|
||||
pass
|
||||
|
||||
def get_chunks(prompt: str, clip: FrozenCLIPEmbedderWithCustomWordsBase) -> PromptChunk:
|
||||
if opts.use_old_emphasis_implementation:
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
Loading…
Reference in New Issue