Add prompt directive

pull/13/head
File_xor 2023-06-30 08:30:48 +09:00
parent a16feffde5
commit 034ec00a79
1 changed files with 80 additions and 0 deletions

View File

@ -2,6 +2,7 @@ import os, csv, warnings, datetime
import math as math
from collections import namedtuple
from enum import IntEnum
from distutils.util import strtobool
import gradio
import torch
@ -148,6 +149,15 @@ class Clip_IO(scripts.Script):
SPACE: /\s+/
"""
syntax_directive_prompt = r"""
start: PROMPT ("," [ARGUMENT])* ("," keyword_argument)*
PROMPT: /"{3}/ /.*?/ /"{3}/ | /'{3}/ /.*?/ /'{3}/
ARGUMENT: /[^=,]+/
keyword_argument: KEYWORD "=" VALUE
KEYWORD: /[^=,]+/
VALUE: /[^=,]+/
"""
class Directive:
class Names(IntEnum):
eval
@ -301,6 +311,51 @@ class Clip_IO(scripts.Script):
finally:
i = local["o"].clone()
pass
elif dir.name == "prompt":
# prompt(prompt: str, clip_skip: int|None=None, padding=True)
prompt: str
keyword_arguments: dict[str, str] = {}
class prompt_visiter(lark.visitors.Visitor):
keyword_position = 0
def PROMPT(self, token: lark.Token):
prompt = token
pass
def ARGUMENT(self, token: lark.Token):
match self.keyword_position:
case 0:
if token.strip(" ").lower() == "none":
keyword_arguments["clip_skip"] = None
pass
else:
try:
keyword_arguments["clip_skip"] = int(token.strip(" "))
pass
except Exception:
print(f'Given argument "{token}" is neither integer nor None.')
pass
self.keyword_position += 1
pass
case 1:
keyword_arguments["padding"] = strtobool(token.strip(" "))
self.keyword_position += 1
pass
case _:
self.keyword_position += 1
pass
pass
def keyword_argument(self, tree: lark.tree.Tree):
keyword_arguments[tree.children[0].strip(" ")] = tree.children[1].strip(" ")
pass
pass
prompt_visiter().visit(lark.Lark(Clip_IO.syntax_directive_prompt).parse(dir.inner))
evacuate_clip_skip = shared.opts.CLIP_stop_at_last_layers
if keyword_arguments["clip_skip"] is not None:
shared.opts.CLIP_stop_at_last_layers = keyword_arguments["clip_skip"]
pass
o = torch.vstack([o, Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, manual_chunk= not keyword_arguments["padding"])])
i = o.clone()
shared.opts.CLIP_stop_at_last_layers = evacuate_clip_skip
pass
else:
warnings.warn(f'Directive "{dir.name}" does not exist.')
pass
@ -939,6 +994,31 @@ class Clip_IO(scripts.Script):
return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}'
pass
def FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, clip = shared.sd_model.cond_stage_model, manual_chunk = False) -> torch.Tensor:
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
chunk_count = max([len(x) for x in batch_chunks])
zs = []
for i in range(chunk_count):
batch_chunk = [chunks[i] if i < len(chunks) else clip.empty_chunk() for chunks in batch_chunks]
remade_batch_tokens = [x.tokens for x in batch_chunk]
tokens = torch.asarray([x.tokens for x in batch_chunk]).to(devices.device)
clip.hijack.fixes = [x.fixes for x in batch_chunk]
if clip.id_end != clip.id_pad:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(clip.id_end)
tokens[batch_pos, index+1:tokens.shape[1]] = clip.id_pad
z = clip.encode_with_transformers(tokens)
if True: # if not no_emphasis:
batch_multipliers = torch.asarray([x.multipliers for x in batch_chunk]).to(devices.device)
original_mean = z.mean()
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z = z * (original_mean / new_mean) # z = z * (original_mean / new_mean) if not no_norm else z
zs.append(z[0])
return torch.vstack(zs)
def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool, manual_chunk: bool):
try:
with devices.autocast():