Add prompt directive

pull/13/head
File_xor 2023-06-30 08:30:48 +09:00
parent a16feffde5
commit 034ec00a79
1 changed files with 80 additions and 0 deletions

View File

@ -2,6 +2,7 @@ import os, csv, warnings, datetime
import math as math import math as math
from collections import namedtuple from collections import namedtuple
from enum import IntEnum from enum import IntEnum
from distutils.util import strtobool
import gradio import gradio
import torch import torch
@ -148,6 +149,15 @@ class Clip_IO(scripts.Script):
SPACE: /\s+/ SPACE: /\s+/
""" """
syntax_directive_prompt = r"""
start: PROMPT ("," [ARGUMENT])* ("," keyword_argument)*
PROMPT: /"{3}/ /.*?/ /"{3}/ | /'{3}/ /.*?/ /'{3}/
ARGUMENT: /[^=,]+/
keyword_argument: KEYWORD "=" VALUE
KEYWORD: /[^=,]+/
VALUE: /[^=,]+/
"""
class Directive: class Directive:
class Names(IntEnum): class Names(IntEnum):
eval eval
@ -301,6 +311,51 @@ class Clip_IO(scripts.Script):
finally: finally:
i = local["o"].clone() i = local["o"].clone()
pass pass
elif dir.name == "prompt":
# prompt(prompt: str, clip_skip: int|None=None, padding=True)
prompt: str
keyword_arguments: dict[str, str] = {}
class prompt_visiter(lark.visitors.Visitor):
keyword_position = 0
def PROMPT(self, token: lark.Token):
prompt = token
pass
def ARGUMENT(self, token: lark.Token):
match self.keyword_position:
case 0:
if token.strip(" ").lower() == "none":
keyword_arguments["clip_skip"] = None
pass
else:
try:
keyword_arguments["clip_skip"] = int(token.strip(" "))
pass
except Exception:
print(f'Given argument "{token}" is neither integer nor None.')
pass
self.keyword_position += 1
pass
case 1:
keyword_arguments["padding"] = strtobool(token.strip(" "))
self.keyword_position += 1
pass
case _:
self.keyword_position += 1
pass
pass
def keyword_argument(self, tree: lark.tree.Tree):
keyword_arguments[tree.children[0].strip(" ")] = tree.children[1].strip(" ")
pass
pass
prompt_visiter().visit(lark.Lark(Clip_IO.syntax_directive_prompt).parse(dir.inner))
evacuate_clip_skip = shared.opts.CLIP_stop_at_last_layers
if keyword_arguments["clip_skip"] is not None:
shared.opts.CLIP_stop_at_last_layers = keyword_arguments["clip_skip"]
pass
o = torch.vstack([o, Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, manual_chunk= not keyword_arguments["padding"])])
i = o.clone()
shared.opts.CLIP_stop_at_last_layers = evacuate_clip_skip
pass
else: else:
warnings.warn(f'Directive "{dir.name}" does not exist.') warnings.warn(f'Directive "{dir.name}" does not exist.')
pass pass
@ -939,6 +994,31 @@ class Clip_IO(scripts.Script):
return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}' return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}'
pass pass
def FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, clip = shared.sd_model.cond_stage_model, manual_chunk = False) -> torch.Tensor:
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
chunk_count = max([len(x) for x in batch_chunks])
zs = []
for i in range(chunk_count):
batch_chunk = [chunks[i] if i < len(chunks) else clip.empty_chunk() for chunks in batch_chunks]
remade_batch_tokens = [x.tokens for x in batch_chunk]
tokens = torch.asarray([x.tokens for x in batch_chunk]).to(devices.device)
clip.hijack.fixes = [x.fixes for x in batch_chunk]
if clip.id_end != clip.id_pad:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(clip.id_end)
tokens[batch_pos, index+1:tokens.shape[1]] = clip.id_pad
z = clip.encode_with_transformers(tokens)
if True: # if not no_emphasis:
batch_multipliers = torch.asarray([x.multipliers for x in batch_chunk]).to(devices.device)
original_mean = z.mean()
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z = z * (original_mean / new_mean) # z = z * (original_mean / new_mean) if not no_norm else z
zs.append(z[0])
return torch.vstack(zs)
def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool, manual_chunk: bool): def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool, manual_chunk: bool):
try: try:
with devices.autocast(): with devices.autocast():