Merge pull request #13 from Filexor/dev

Update fot WebUI 1.6.0
main
Filexor 2023-09-23 19:42:20 +09:00 committed by GitHub
commit b6c2e9627d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 378 additions and 48 deletions

View File

@ -52,16 +52,23 @@ If "DirectiveOrder" is absent, it will be treated as order is 0.
Local objects for eval are:
i: torch.Tensor : input conditioning
o: torch.Tensor : output conditioning
c: dict : dict for carrying over
g: dict : dict for carrying over
p: modules.processing.StableDiffusionProcessing : [See source code of Stable diffusion Web UI.](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/baf6946e06249c5af9851c60171692c44ef633e0/modules/processing.py#L105)
t: int : 0th dimension (token-wise) of index of input conditioning
d: int : 1st dimension (dimension-wise) of index of input conditioning
torch module and all objects in math module
##### exec
"exec" does component-wise python's exec.
Local objects for exec are:
Global objects for exec are:
i: torch.Tensor : input conditioning
o: torch.Tensor : output conditioning
c: dict : dict for carrying over
g: dict : dict for carrying over
p: modules.processing.StableDiffusionProcessing : [See source code of Stable diffusion Web UI.](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/baf6946e06249c5af9851c60171692c44ef633e0/modules/processing.py#L105)
torch module and all objects in math module
**NOTE: If you want to change seed, change both p.seed: int and p.seeds: list[int] .**
torch module and all objects in math module
##### prompt
"prompt" is prompt with some additional options.
syntax example: `?prompt("""prompt must be triple single/double quoted""", clip_skip=2, no_padding=False)`
arguments can be written as positional or omitted: `?prompt('''this also work''',,True)`
prompt and clip_skip can be with list or tuple syntax: `?prompt(("This part of clip_skip is 2", 'Here is 1', "1 is broadcasted because clip_skip is exhausted"),[2,1])`
`no_padding` does same behavior of `Don't add bos / eos / pad tokens` in Clip Output.

View File

@ -4,7 +4,7 @@ from collections import namedtuple
from enum import IntEnum
import gradio
import torch
import torch as torch
import lark
import open_clip
@ -20,6 +20,7 @@ class Clip_IO(scripts.Script):
mode_positive = "Disabled"
mode_negative = "Disabled"
conditioning_cache = {}
global_carry = {}
evacuate_get_learned_conditioning = None
evacuate_get_multicond_learned_conditioning = None
@ -43,12 +44,16 @@ class Clip_IO(scripts.Script):
mode_positive = gradio.Dropdown(["Disabled", "Simple", "Directive"], value = "Disabled", max_choices = 1, label = "Positive prompt mode")
mode_negative = gradio.Dropdown(["Disabled", "Simple", "Directive"], value = "Disabled", max_choices = 1, label = "Negative prompt mode")
pass
with gradio.Accordion("Pre/Post-process", open = False):
pre_batch_process = gradio.TextArea(max_lines=1024, label="Pre-batch-process")
post_batch_process = gradio.TextArea(max_lines=1024, label="Post-batch-process-process")
pass
pass
if not is_img2img:
return [enabled, mode_positive, mode_negative]
return [enabled, mode_positive, mode_negative, post_batch_process, pre_batch_process]
pass
else:
return [enabled, mode_positive, mode_negative]
return [enabled, mode_positive, mode_negative, post_batch_process, pre_batch_process]
pass
return []
pass
@ -148,6 +153,21 @@ class Clip_IO(scripts.Script):
SPACE: /\s+/
"""
syntax_directive_prompt = r"""
start: (prompt | prompts) ("," (argument | arguments))* ("," (keyword_argument | keyword_arguments))*
prompt: PROMPT
prompts: "(" SPACE? PROMPT SPACE? ( "," SPACE? PROMPT SPACE? )* ")" | "[" SPACE? PROMPT SPACE? ( "," SPACE? PROMPT SPACE? )* "]"
PROMPT: /"{3}/ /.*?/ /"{3}/ | /'{3}/ /.*?/ /'{3}/ | /"(?!"")[^"]*?"/ | /'(?!'')[^']*?'/
argument: [ARGUMENT]
arguments: "(" [ARGUMENT] ("," [ARGUMENT] )* ")" | "[" [ARGUMENT] ("," [ARGUMENT] )* "]"
ARGUMENT: /[^()\[\]=,]+/
keyword_argument: KEYWORD "=" VALUE
keyword_arguments: KEYWORD "=" ( "(" [VALUE] ("," [VALUE] )* ")" | "[" [VALUE] ("," [VALUE] )* "]" )
KEYWORD: /[^=,]+/
VALUE: /[^()\[\]=,]+/
SPACE: /\s+/
"""
class Directive:
class Names(IntEnum):
eval
@ -266,6 +286,10 @@ class Clip_IO(scripts.Script):
pass
Process().transform(lark.Lark(Clip_IO.syntax_directive).parse(input))
if len(conds) == 0:
tmp = Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(("a",), manual_chunk=False)
conds.append(torch.zeros(0, tmp.shape[1]).to(devices.device))
pass
i = torch.vstack(conds)
o = i.clone()
c = {}
@ -278,28 +302,161 @@ class Clip_IO(scripts.Script):
try:
for t in range(i.shape[0]):
for d in range(i.shape[1]):
local = {"i": i, "o": o, "c": c, "p": p, "t": t, "d": d, "torch": torch.__dict__} | math.__dict__
local = {"i": i, "o": o, "g": Clip_IO.global_carry, "c": c, "p": p, "t": t, "d": d, "sd_model": shared.sd_model, "torch": torch.__dict__} | math.__dict__
o[t, d] = eval(dir.inner, None, local)
pass
pass
pass
except Exception as e:
print(repr(e))
o = i
raise e
pass
finally:
i = o.clone()
pass
elif dir.name == "exec":
try:
local = {"i": i, "o": o, "c": c, "p": p, "torch": torch.__dict__} | math.__dict__
exec(dir.inner, None, local)
globals = {"i": i, "o": o, "g": Clip_IO.global_carry, "c": c, "p": p, "sd_model": shared.sd_model, "torch": torch, "math": math}
exec(dir.inner, globals, None)
except Exception as e:
print(repr(e))
o = i
raise e
pass
finally:
i = local["o"].clone()
i = globals["o"].clone()
pass
elif dir.name == "prompt":
# prompt(prompt: str, clip_skip: int|None=None, no_padding=True)
prompt: tuple[str]
keyword_arguments: dict = {"clip_skip": [None], "no_padding": False}
class prompt_transformer(lark.visitors.Transformer):
keyword_position = 0
def prompt(self, token: list[lark.Token]):
nonlocal prompt
prompt = (token[0],)
pass
def prompts(self, tokens: list[lark.Token]):
nonlocal prompt
prompt = tuple(tokens)
pass
def PROMPT(self, token: lark.Token):
if token.startswith('"""') and token.endswith('"""') or token.startswith("'''") and token.endswith("'''"):
token = token[3:-3]
pass
elif token.startswith('"') and token.endswith('"') or token.startswith("'") and token.endswith("'"):
token = token[1:-1]
pass
return token
pass
def argument(self, token: list[lark.Token]):
token = token[0]
if token is None:
self.keyword_position += 1
return
match self.keyword_position:
case 0:
if token.strip(" ").lower() == "none":
keyword_arguments["clip_skip"] = [[None]]
pass
else:
try:
keyword_arguments["clip_skip"] = [int(token.strip(" "))]
pass
except Exception:
print(f'Given argument "{token}" is neither integer nor None.')
pass
self.keyword_position += 1
pass
case 1:
if token.strip(" ").lower() == "true":
value = True
pass
elif token.strip(" ").lower() == "false":
value = False
pass
else:
raise RuntimeError(f'Given argument "{token}" is neither True nor False.')
pass
keyword_arguments["no_padding"] = [value]
self.keyword_position += 1
pass
case _:
self.keyword_position += 1
pass
pass
def arguments(self, token: list[lark.Token]):
match self.keyword_position:
case 0:
keyword_arguments["clip_skip"] = []
for token in token:
if token.strip(" ").lower() == "none" or token.strip(" ") == "":
keyword_arguments["clip_skip"].append(None)
pass
else:
try:
keyword_arguments["clip_skip"].append(int(token.strip(" ")))
pass
except Exception:
print(f'Given argument "{token}" is neither integer nor None.')
pass
self.keyword_position += 1
pass
pass
case 1:
keyword_arguments["no_padding"] = []
for token in token:
if token.strip(" ").lower() == "true":
value = True
pass
elif token.strip(" ").lower() == "false":
value = False
pass
else:
raise RuntimeError(f'Given argument "{token}" is neither True nor False.')
pass
keyword_arguments["no_padding"].append(value)
self.keyword_position += 1
pass
pass
case _:
self.keyword_position += 1
pass
pass
pass
def keyword_argument(self, tree: lark.tree.Tree):
keyword_arguments[tree.children[0].strip(" ")] = tree.children[1].strip(" ")
pass
pass
prompt_transformer().transform(lark.Lark(Clip_IO.syntax_directive_prompt).parse(dir.inner))
o = torch.vstack([o, Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, manual_chunk= keyword_arguments["no_padding"], clip_skips=[keyword_arguments["clip_skip"]])])
i = o.clone()
pass
elif dir.name == "execfile":
try:
if dir.inner.startswith('"') and dir.inner.endswith('"') or dir.inner.startswith("'") and dir.inner.endswith("'"):
dir.inner = dir.inner[1:-1]
pass
if not os.path.exists(dir.inner):
dir.inner = os.path.join(os.path.dirname(__file__), "../program", dir.inner)
if not os.path.exists(dir.inner):
dir.inner = dir.inner + ".py"
pass
pass
with open(dir.inner) as program:
globals = {"i": i, "o": o, "g": Clip_IO.global_carry, "c": c, "p": p, "sd_model": shared.sd_model, "torch": torch, "math": math}
exec(program, globals, None)
pass
pass
except Exception as e:
o = i
raise e
pass
pass
else:
warnings.warn(f'Directive "{dir.name}" does not exist.')
@ -316,7 +473,7 @@ class Clip_IO(scripts.Script):
pass
pass
def my_get_learned_conditioning(model, prompts, steps, p: processing.StableDiffusionProcessing = None, is_negative = True):
def my_get_learned_conditioning(model, prompts: prompt_parser.SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False, p: processing.StableDiffusionProcessing = None, is_negative = True):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
@ -340,7 +497,7 @@ class Clip_IO(scripts.Script):
prompt_schedules = [[[steps, prompt]] for prompt in prompts]
pass
else:
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompts, steps)
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
pass
res = []
@ -352,7 +509,7 @@ class Clip_IO(scripts.Script):
res.append(cached)
continue
texts: list[str] = [x[1] for x in prompt_schedule]
texts = prompt_parser.SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
if Clip_IO.enabled and (Clip_IO.mode_positive == "Simple" and not is_negative or Clip_IO.mode_negative == "Simple" and is_negative):
conds = []
for text in texts:
@ -371,7 +528,15 @@ class Clip_IO(scripts.Script):
cond_schedule = []
for i, (end_at_step, text) in enumerate(prompt_schedule):
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, conds[i].to(devices.device)))
if isinstance(conds, dict):
cond = {k: v[i] for k, v in conds.items()}
pass
else:
cond = conds[i]
pass
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, cond))
pass
cache[prompt] = cond_schedule
res.append(cond_schedule)
@ -379,7 +544,7 @@ class Clip_IO(scripts.Script):
return res
pass
def my_get_multicond_learned_conditioning(model, prompts, steps, p: processing.StableDiffusionProcessing = None) -> prompt_parser.MulticondLearnedConditioning:
def my_get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False, p: processing.StableDiffusionProcessing = None) -> prompt_parser.MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.
@ -388,11 +553,12 @@ class Clip_IO(scripts.Script):
res_indexes, prompt_flat_list, prompt_indexes = prompt_parser.get_multicond_prompt_list(prompts)
learned_conditioning = prompt_parser.get_learned_conditioning(model, prompt_flat_list, steps, p, is_negative = False)
learned_conditioning = prompt_parser.get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling, p, is_negative = False)
res = []
for indexes in res_indexes:
res.append([prompt_parser.ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
pass
return prompt_parser.MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
pass
@ -474,7 +640,7 @@ class Clip_IO(scripts.Script):
pass
def get_my_get_conds_with_caching(p: processing.StableDiffusionProcessing):
def my_get_conds_with_caching(function, required_prompts, steps, caches, extra_network_data):
def my_get_conds_with_caching(function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before.
@ -486,18 +652,29 @@ class Clip_IO(scripts.Script):
caches is a list with items described above.
"""
if shared.opts.use_old_scheduling:
old_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, False)
new_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, True)
if old_schedules != new_schedules:
p.extra_generation_params["Old prompt editing timelines"] = True
pass
pass
cached_params = p.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)
for cache in caches:
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
if cache[0] is not None and cached_params == cache[0] and not Clip_IO.enabled:
return cache[1]
cache = caches[0]
with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps, p)
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
cache[0] = cached_params
return cache[1]
pass
return my_get_conds_with_caching
def get_inner_function(outer, new_inner):
@ -596,7 +773,6 @@ class Clip_IO(scripts.Script):
prompt_parser.get_learned_conditioning = Clip_IO.my_get_learned_conditioning
prompt_parser.get_multicond_learned_conditioning = Clip_IO.my_get_multicond_learned_conditioning
#Clip_IO.replace_inner_function(processing.process_images_inner, Clip_IO.get_my_get_conds_with_caching())
pass
else:
Clip_IO.enabled = False
pass
@ -620,38 +796,75 @@ class Clip_IO(scripts.Script):
Clip_IO.evacuate_get_conds_with_caching = p.get_conds_with_caching
p.get_conds_with_caching = Clip_IO.get_my_get_conds_with_caching(p)
pass
try:
globals = {"g": Clip_IO.global_carry, "p": p, "sd_model": shared.sd_model, "torch": torch, "math": math}
exec(args[4], globals, None)
except Exception as e:
raise e
pass
pass
pass
pass
def postprocess_batch(self, p: processing.StableDiffusionProcessing, *args, **kwargs):
if Clip_IO.enabled:
Clip_IO.mode_positive = "Disabled"
Clip_IO.mode_negative = "Disabled"
if getattr(p, "get_conds_with_caching", None) is not None:
p.get_conds_with_caching = Clip_IO.evacuate_get_conds_with_caching
try:
globals = {"g": Clip_IO.global_carry, "p": p, "sd_model": shared.sd_model, "torch": torch, "math": math}
exec(args[3], globals, None)
except Exception as e:
raise e
pass
finally:
Clip_IO.mode_positive = "Disabled"
Clip_IO.mode_negative = "Disabled"
if getattr(p, "get_conds_with_caching", None) is not None:
p.get_conds_with_caching = Clip_IO.evacuate_get_conds_with_caching
pass
pass
pass
pass
def tokenize_line_manual_chunk(prompt: str, clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords) -> list[PromptChunk]:
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(prompt)
pass
else:
parsed = [[prompt, 1.0]]
def tokenize_line_manual_chunk(prompts: tuple[str, ...], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[list[PromptChunk],list[tuple[int,int]]]:
parsed = []
for prompt in prompts:
parsed.append(["SEPARATION", -1])
if opts.enable_emphasis:
to_appends = prompt_parser.parse_prompt_attention(prompt)
for to_append in to_appends:
parsed.append(to_append)
pass
else:
parsed.append([prompt, 1.0])
pass
pass
tokenized = clip.tokenize([text for text, _ in parsed])
chunks: list[PromptChunk] = []
chunk = PromptChunk()
token_count = 0
last_comma = -1
separation_starts: list[tuple[int,int]] = []
def next_chunk(is_last=False):
nonlocal token_count
nonlocal last_comma
nonlocal chunk
# We don't have to fill the chunk.
if not manual_chunk:
if is_last:
token_count += len(chunk.tokens)
else:
token_count += clip.chunk_length
to_add = clip.chunk_length - len(chunk.tokens)
if to_add > 0:
chunk.tokens += [clip.id_end] * to_add
chunk.multipliers += [1.0] * to_add
chunk.tokens = [clip.id_start] + chunk.tokens + [clip.id_end]
chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
pass
last_comma = -1
chunks.append(chunk)
@ -659,6 +872,11 @@ class Clip_IO(scripts.Script):
pass
for tokens, (text, weight) in zip(tokenized, parsed):
if text == 'SEPARATION' and weight == -1:
separation_starts.append((len(chunks), len(chunk.tokens)))
continue
pass
if text == 'BREAK' and weight == -1:
next_chunk()
continue
@ -674,7 +892,7 @@ class Clip_IO(scripts.Script):
# this is when we are at the end of alloted 77 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == clip.chunk_length + 2 and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
elif manual_chunk and opts.comma_padding_backtrack != 0 and len(chunk.tokens) == clip.chunk_length + 2 and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
break_location = last_comma + 1
reloc_tokens = chunk.tokens[break_location:]
@ -688,7 +906,25 @@ class Clip_IO(scripts.Script):
chunk.multipliers = reloc_mults
pass
if len(chunk.tokens) == clip.chunk_length + 2:
elif not manual_chunk and opts.comma_padding_backtrack != 0 and len(chunk.tokens) == clip.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
break_location = last_comma + 1
reloc_tokens = chunk.tokens[break_location:]
reloc_mults = chunk.multipliers[break_location:]
chunk.tokens = chunk.tokens[:break_location]
chunk.multipliers = chunk.multipliers[:break_location]
next_chunk()
chunk.tokens = reloc_tokens
chunk.multipliers = reloc_mults
pass
if manual_chunk and len(chunk.tokens) == clip.chunk_length + 2:
next_chunk()
pass
elif not manual_chunk and len(chunk.tokens) == clip.chunk_length:
next_chunk()
pass
@ -718,19 +954,21 @@ class Clip_IO(scripts.Script):
next_chunk(is_last=True)
pass
return chunks
return chunks, separation_starts
def process_texts_manual_chunk(prompts: list[str], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords) -> list[list[PromptChunk]]:
def process_texts_manual_chunk(prompts: list[tuple[str, ...]], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[list[list[PromptChunk]],list[list[tuple[int,int]]]]:
cache = {}
batch_chunks: list[list[PromptChunk]] = []
separation_starts_list: list[list[tuple[int,int]]] = []
for prompt in prompts:
if prompt in cache:
chunks = cache[prompt]
else:
chunks = Clip_IO.tokenize_line_manual_chunk(prompt, clip)
chunks, separation_starts = Clip_IO.tokenize_line_manual_chunk(prompt, clip, manual_chunk)
cache[prompt] = chunks
batch_chunks.append(chunks)
separation_starts_list.append(separation_starts)
if False:
# We have to ensure all chunk in batch_chunks have same length.
@ -755,14 +993,18 @@ class Clip_IO(scripts.Script):
pass
pass
return batch_chunks
return batch_chunks, separation_starts_list
def get_chunks(prompt: str, clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> list[list[PromptChunk]]:
def get_chunks(prompt: str | tuple[str, ...], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[list[list[PromptChunk]],list[list[tuple[int,int]]]]:
"Return: PromptChunks, token separation position"
if isinstance(prompt, str):
prompt = (prompt,)
pass
if opts.use_old_emphasis_implementation:
raise NotImplementedError
pass
if manual_chunk:
return Clip_IO.process_texts_manual_chunk([prompt], clip)
if True: # manual_chunk:
return Clip_IO.process_texts_manual_chunk([prompt], clip, manual_chunk)
pass
else:
batch_chunks, _ = clip.process_texts([prompt])
@ -820,7 +1062,7 @@ class Clip_IO(scripts.Script):
def on_save_embeddings_as_pt(prompt: str, filename: str, overwrite: bool, manual_chunk: bool):
try:
clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk)
embeddings: torch.Tensor = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk)
filename = os.path.join(os.path.dirname(__file__), "../conditioning", filename)
@ -847,7 +1089,7 @@ class Clip_IO(scripts.Script):
def on_save_embeddings_as_csv(prompt: str, filename: str, transpose: bool, add_token: bool, overwrite: bool, manual_chunk: bool):
try:
clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk)
embeddings, tokens = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk)
embeddings: list[list[str]] = embeddings[0].tolist()
@ -894,7 +1136,7 @@ class Clip_IO(scripts.Script):
try:
with devices.autocast():
clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk)
chunk_count = max([len(x) for x in batch_chunks])
zs = []
for i in range(chunk_count):
@ -939,11 +1181,92 @@ class Clip_IO(scripts.Script):
return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}'
pass
def encode_with_transformers(clip, tokens, chunk_count: int, clip_skips: list[list[int | None]], separation_starts_list:list[list[tuple[int,int]]]):
clip.wrapped.transformer.to(shared.device)
outputs = clip.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
zs = []
for z in outputs.hidden_states:
zs.append(clip.wrapped.transformer.text_model.final_layer_norm(z))
pass
zs[-1] = outputs.last_hidden_state
if opts.CLIP_stop_at_last_layers > 1:
zo = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
zo = clip.wrapped.transformer.text_model.final_layer_norm(zo)
else:
zo = outputs.last_hidden_state
for i in range(len(tokens)):
chunk_start = separation_starts_list[i][0][0]
chunk_end = separation_starts_list[i][1][0] if len(separation_starts_list[i]) >= 2 else chunk_count + 1
while True:
while chunk_count < chunk_start:
separation_starts_list[i].pop(0)
clip_skips[i].pop(0)
pass
if chunk_count == chunk_start:
token_start = separation_starts_list[i][0][1]
pass
elif chunk_count > chunk_start:
token_start = 0
pass
if chunk_count < chunk_end:
token_end = zo.shape[1]
pass
elif chunk_count == chunk_end:
token_end = separation_starts_list[i][1][1] if len(separation_starts_list[i]) >= 2 else zo.shape[1]
pass
zo[i,token_start:token_end,:] = zs[-clip_skips[i][0]][i,token_start:token_end,:] if clip_skips[i][0] is not None else zo[i,token_start:token_end,:]
if token_end == zo.shape[1]:
break
pass
if chunk_count == chunk_end:
separation_starts_list[i].pop(0)
if len(clip_skips[i]) >=2:
clip_skips[i].pop(0)
pass
else:
clip_skips[i][0] = None
pass
pass
pass
return zo
def FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt: str | tuple[str, ...], clip = shared.sd_model.cond_stage_model, manual_chunk = False, clip_skips: list[list[int | None]] = [[None]]) -> torch.Tensor:
batch_chunks, separation_starts_list = Clip_IO.get_chunks(prompt, clip, manual_chunk)
chunk_count = max([len(x) for x in batch_chunks])
zs = []
for i in range(chunk_count):
batch_chunk = [chunks[i] if i < len(chunks) else clip.empty_chunk() for chunks in batch_chunks]
remade_batch_tokens = [x.tokens for x in batch_chunk]
tokens = torch.asarray([x.tokens for x in batch_chunk]).to(devices.device)
clip.hijack.fixes = [x.fixes for x in batch_chunk]
if clip.id_end != clip.id_pad and not manual_chunk:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(clip.id_end)
tokens[batch_pos, index+1:tokens.shape[1]] = clip.id_pad
z = Clip_IO.encode_with_transformers(clip, tokens, i, clip_skips, separation_starts_list) if isinstance(clip, FrozenCLIPEmbedderWithCustomWords) else clip.encode_with_transformers(tokens)
if True: # if not no_emphasis:
batch_multipliers = torch.asarray([x.multipliers for x in batch_chunk]).to(devices.device)
original_mean = z.mean()
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z = z * (original_mean / new_mean) # z = z * (original_mean / new_mean) if not no_norm else z
zs.append(z[0])
return torch.vstack(zs)
def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool, manual_chunk: bool):
try:
with devices.autocast():
clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk)
_, token_list = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk)
chunk_count = max([len(x) for x in batch_chunks])
zs = []