mirror of https://github.com/Filexor/Clip_IO.git
Add prompt directive
parent
a16feffde5
commit
034ec00a79
|
|
@ -2,6 +2,7 @@ import os, csv, warnings, datetime
|
||||||
import math as math
|
import math as math
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
from distutils.util import strtobool
|
||||||
|
|
||||||
import gradio
|
import gradio
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -148,6 +149,15 @@ class Clip_IO(scripts.Script):
|
||||||
SPACE: /\s+/
|
SPACE: /\s+/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
syntax_directive_prompt = r"""
|
||||||
|
start: PROMPT ("," [ARGUMENT])* ("," keyword_argument)*
|
||||||
|
PROMPT: /"{3}/ /.*?/ /"{3}/ | /'{3}/ /.*?/ /'{3}/
|
||||||
|
ARGUMENT: /[^=,]+/
|
||||||
|
keyword_argument: KEYWORD "=" VALUE
|
||||||
|
KEYWORD: /[^=,]+/
|
||||||
|
VALUE: /[^=,]+/
|
||||||
|
"""
|
||||||
|
|
||||||
class Directive:
|
class Directive:
|
||||||
class Names(IntEnum):
|
class Names(IntEnum):
|
||||||
eval
|
eval
|
||||||
|
|
@ -301,6 +311,51 @@ class Clip_IO(scripts.Script):
|
||||||
finally:
|
finally:
|
||||||
i = local["o"].clone()
|
i = local["o"].clone()
|
||||||
pass
|
pass
|
||||||
|
elif dir.name == "prompt":
|
||||||
|
# prompt(prompt: str, clip_skip: int|None=None, padding=True)
|
||||||
|
prompt: str
|
||||||
|
keyword_arguments: dict[str, str] = {}
|
||||||
|
class prompt_visiter(lark.visitors.Visitor):
|
||||||
|
keyword_position = 0
|
||||||
|
def PROMPT(self, token: lark.Token):
|
||||||
|
prompt = token
|
||||||
|
pass
|
||||||
|
def ARGUMENT(self, token: lark.Token):
|
||||||
|
match self.keyword_position:
|
||||||
|
case 0:
|
||||||
|
if token.strip(" ").lower() == "none":
|
||||||
|
keyword_arguments["clip_skip"] = None
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
keyword_arguments["clip_skip"] = int(token.strip(" "))
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
print(f'Given argument "{token}" is neither integer nor None.')
|
||||||
|
pass
|
||||||
|
self.keyword_position += 1
|
||||||
|
pass
|
||||||
|
case 1:
|
||||||
|
keyword_arguments["padding"] = strtobool(token.strip(" "))
|
||||||
|
self.keyword_position += 1
|
||||||
|
pass
|
||||||
|
case _:
|
||||||
|
self.keyword_position += 1
|
||||||
|
pass
|
||||||
|
pass
|
||||||
|
def keyword_argument(self, tree: lark.tree.Tree):
|
||||||
|
keyword_arguments[tree.children[0].strip(" ")] = tree.children[1].strip(" ")
|
||||||
|
pass
|
||||||
|
pass
|
||||||
|
prompt_visiter().visit(lark.Lark(Clip_IO.syntax_directive_prompt).parse(dir.inner))
|
||||||
|
evacuate_clip_skip = shared.opts.CLIP_stop_at_last_layers
|
||||||
|
if keyword_arguments["clip_skip"] is not None:
|
||||||
|
shared.opts.CLIP_stop_at_last_layers = keyword_arguments["clip_skip"]
|
||||||
|
pass
|
||||||
|
o = torch.vstack([o, Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, manual_chunk= not keyword_arguments["padding"])])
|
||||||
|
i = o.clone()
|
||||||
|
shared.opts.CLIP_stop_at_last_layers = evacuate_clip_skip
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
warnings.warn(f'Directive "{dir.name}" does not exist.')
|
warnings.warn(f'Directive "{dir.name}" does not exist.')
|
||||||
pass
|
pass
|
||||||
|
|
@ -939,6 +994,31 @@ class Clip_IO(scripts.Script):
|
||||||
return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}'
|
return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}'
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, clip = shared.sd_model.cond_stage_model, manual_chunk = False) -> torch.Tensor:
|
||||||
|
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
|
||||||
|
chunk_count = max([len(x) for x in batch_chunks])
|
||||||
|
zs = []
|
||||||
|
for i in range(chunk_count):
|
||||||
|
batch_chunk = [chunks[i] if i < len(chunks) else clip.empty_chunk() for chunks in batch_chunks]
|
||||||
|
remade_batch_tokens = [x.tokens for x in batch_chunk]
|
||||||
|
tokens = torch.asarray([x.tokens for x in batch_chunk]).to(devices.device)
|
||||||
|
clip.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||||
|
|
||||||
|
if clip.id_end != clip.id_pad:
|
||||||
|
for batch_pos in range(len(remade_batch_tokens)):
|
||||||
|
index = remade_batch_tokens[batch_pos].index(clip.id_end)
|
||||||
|
tokens[batch_pos, index+1:tokens.shape[1]] = clip.id_pad
|
||||||
|
|
||||||
|
z = clip.encode_with_transformers(tokens)
|
||||||
|
if True: # if not no_emphasis:
|
||||||
|
batch_multipliers = torch.asarray([x.multipliers for x in batch_chunk]).to(devices.device)
|
||||||
|
original_mean = z.mean()
|
||||||
|
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||||
|
new_mean = z.mean()
|
||||||
|
z = z * (original_mean / new_mean) # z = z * (original_mean / new_mean) if not no_norm else z
|
||||||
|
zs.append(z[0])
|
||||||
|
return torch.vstack(zs)
|
||||||
|
|
||||||
def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool, manual_chunk: bool):
|
def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool, manual_chunk: bool):
|
||||||
try:
|
try:
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue