mirror of https://github.com/Filexor/Clip_IO.git
Add prompt directive
parent
a16feffde5
commit
034ec00a79
|
|
@ -2,6 +2,7 @@ import os, csv, warnings, datetime
|
|||
import math as math
|
||||
from collections import namedtuple
|
||||
from enum import IntEnum
|
||||
from distutils.util import strtobool
|
||||
|
||||
import gradio
|
||||
import torch
|
||||
|
|
@ -148,6 +149,15 @@ class Clip_IO(scripts.Script):
|
|||
SPACE: /\s+/
|
||||
"""
|
||||
|
||||
syntax_directive_prompt = r"""
|
||||
start: PROMPT ("," [ARGUMENT])* ("," keyword_argument)*
|
||||
PROMPT: /"{3}/ /.*?/ /"{3}/ | /'{3}/ /.*?/ /'{3}/
|
||||
ARGUMENT: /[^=,]+/
|
||||
keyword_argument: KEYWORD "=" VALUE
|
||||
KEYWORD: /[^=,]+/
|
||||
VALUE: /[^=,]+/
|
||||
"""
|
||||
|
||||
class Directive:
|
||||
class Names(IntEnum):
|
||||
eval
|
||||
|
|
@ -301,6 +311,51 @@ class Clip_IO(scripts.Script):
|
|||
finally:
|
||||
i = local["o"].clone()
|
||||
pass
|
||||
elif dir.name == "prompt":
|
||||
# prompt(prompt: str, clip_skip: int|None=None, padding=True)
|
||||
prompt: str
|
||||
keyword_arguments: dict[str, str] = {}
|
||||
class prompt_visiter(lark.visitors.Visitor):
|
||||
keyword_position = 0
|
||||
def PROMPT(self, token: lark.Token):
|
||||
prompt = token
|
||||
pass
|
||||
def ARGUMENT(self, token: lark.Token):
|
||||
match self.keyword_position:
|
||||
case 0:
|
||||
if token.strip(" ").lower() == "none":
|
||||
keyword_arguments["clip_skip"] = None
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
keyword_arguments["clip_skip"] = int(token.strip(" "))
|
||||
pass
|
||||
except Exception:
|
||||
print(f'Given argument "{token}" is neither integer nor None.')
|
||||
pass
|
||||
self.keyword_position += 1
|
||||
pass
|
||||
case 1:
|
||||
keyword_arguments["padding"] = strtobool(token.strip(" "))
|
||||
self.keyword_position += 1
|
||||
pass
|
||||
case _:
|
||||
self.keyword_position += 1
|
||||
pass
|
||||
pass
|
||||
def keyword_argument(self, tree: lark.tree.Tree):
|
||||
keyword_arguments[tree.children[0].strip(" ")] = tree.children[1].strip(" ")
|
||||
pass
|
||||
pass
|
||||
prompt_visiter().visit(lark.Lark(Clip_IO.syntax_directive_prompt).parse(dir.inner))
|
||||
evacuate_clip_skip = shared.opts.CLIP_stop_at_last_layers
|
||||
if keyword_arguments["clip_skip"] is not None:
|
||||
shared.opts.CLIP_stop_at_last_layers = keyword_arguments["clip_skip"]
|
||||
pass
|
||||
o = torch.vstack([o, Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, manual_chunk= not keyword_arguments["padding"])])
|
||||
i = o.clone()
|
||||
shared.opts.CLIP_stop_at_last_layers = evacuate_clip_skip
|
||||
pass
|
||||
else:
|
||||
warnings.warn(f'Directive "{dir.name}" does not exist.')
|
||||
pass
|
||||
|
|
@ -939,6 +994,31 @@ class Clip_IO(scripts.Script):
|
|||
return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}'
|
||||
pass
|
||||
|
||||
def FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, clip = shared.sd_model.cond_stage_model, manual_chunk = False) -> torch.Tensor:
|
||||
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
|
||||
chunk_count = max([len(x) for x in batch_chunks])
|
||||
zs = []
|
||||
for i in range(chunk_count):
|
||||
batch_chunk = [chunks[i] if i < len(chunks) else clip.empty_chunk() for chunks in batch_chunks]
|
||||
remade_batch_tokens = [x.tokens for x in batch_chunk]
|
||||
tokens = torch.asarray([x.tokens for x in batch_chunk]).to(devices.device)
|
||||
clip.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||
|
||||
if clip.id_end != clip.id_pad:
|
||||
for batch_pos in range(len(remade_batch_tokens)):
|
||||
index = remade_batch_tokens[batch_pos].index(clip.id_end)
|
||||
tokens[batch_pos, index+1:tokens.shape[1]] = clip.id_pad
|
||||
|
||||
z = clip.encode_with_transformers(tokens)
|
||||
if True: # if not no_emphasis:
|
||||
batch_multipliers = torch.asarray([x.multipliers for x in batch_chunk]).to(devices.device)
|
||||
original_mean = z.mean()
|
||||
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z = z * (original_mean / new_mean) # z = z * (original_mean / new_mean) if not no_norm else z
|
||||
zs.append(z[0])
|
||||
return torch.vstack(zs)
|
||||
|
||||
def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool, manual_chunk: bool):
|
||||
try:
|
||||
with devices.autocast():
|
||||
|
|
|
|||
Loading…
Reference in New Issue