diff --git a/scripts/Clip_IO.py b/scripts/Clip_IO.py index fbcedd5..8f6f65d 100644 --- a/scripts/Clip_IO.py +++ b/scripts/Clip_IO.py @@ -2,6 +2,7 @@ import os, csv, warnings, datetime import math as math from collections import namedtuple from enum import IntEnum +from distutils.util import strtobool import gradio import torch @@ -148,6 +149,15 @@ class Clip_IO(scripts.Script): SPACE: /\s+/ """ + syntax_directive_prompt = r""" + start: PROMPT ("," [ARGUMENT])* ("," keyword_argument)* + PROMPT: /"{3}/ /.*?/ /"{3}/ | /'{3}/ /.*?/ /'{3}/ + ARGUMENT: /[^=,]+/ + keyword_argument: KEYWORD "=" VALUE + KEYWORD: /[^=,]+/ + VALUE: /[^=,]+/ + """ + class Directive: class Names(IntEnum): eval @@ -301,6 +311,51 @@ class Clip_IO(scripts.Script): finally: i = local["o"].clone() pass + elif dir.name == "prompt": + # prompt(prompt: str, clip_skip: int|None=None, padding=True) + prompt: str + keyword_arguments: dict[str, str] = {} + class prompt_visiter(lark.visitors.Visitor): + keyword_position = 0 + def PROMPT(self, token: lark.Token): + prompt = token + pass + def ARGUMENT(self, token: lark.Token): + match self.keyword_position: + case 0: + if token.strip(" ").lower() == "none": + keyword_arguments["clip_skip"] = None + pass + else: + try: + keyword_arguments["clip_skip"] = int(token.strip(" ")) + pass + except Exception: + print(f'Given argument "{token}" is neither integer nor None.') + pass + self.keyword_position += 1 + pass + case 1: + keyword_arguments["padding"] = strtobool(token.strip(" ")) + self.keyword_position += 1 + pass + case _: + self.keyword_position += 1 + pass + pass + def keyword_argument(self, tree: lark.tree.Tree): + keyword_arguments[tree.children[0].strip(" ")] = tree.children[1].strip(" ") + pass + pass + prompt_visiter().visit(lark.Lark(Clip_IO.syntax_directive_prompt).parse(dir.inner)) + evacuate_clip_skip = shared.opts.CLIP_stop_at_last_layers + if keyword_arguments["clip_skip"] is not None: + shared.opts.CLIP_stop_at_last_layers = keyword_arguments["clip_skip"] + pass + o = torch.vstack([o, Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, manual_chunk= not keyword_arguments["padding"])]) + i = o.clone() + shared.opts.CLIP_stop_at_last_layers = evacuate_clip_skip + pass else: warnings.warn(f'Directive "{dir.name}" does not exist.') pass @@ -939,6 +994,31 @@ class Clip_IO(scripts.Script): return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}' pass + def FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, clip = shared.sd_model.cond_stage_model, manual_chunk = False) -> torch.Tensor: + batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) + chunk_count = max([len(x) for x in batch_chunks]) + zs = [] + for i in range(chunk_count): + batch_chunk = [chunks[i] if i < len(chunks) else clip.empty_chunk() for chunks in batch_chunks] + remade_batch_tokens = [x.tokens for x in batch_chunk] + tokens = torch.asarray([x.tokens for x in batch_chunk]).to(devices.device) + clip.hijack.fixes = [x.fixes for x in batch_chunk] + + if clip.id_end != clip.id_pad: + for batch_pos in range(len(remade_batch_tokens)): + index = remade_batch_tokens[batch_pos].index(clip.id_end) + tokens[batch_pos, index+1:tokens.shape[1]] = clip.id_pad + + z = clip.encode_with_transformers(tokens) + if True: # if not no_emphasis: + batch_multipliers = torch.asarray([x.multipliers for x in batch_chunk]).to(devices.device) + original_mean = z.mean() + z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + new_mean = z.mean() + z = z * (original_mean / new_mean) # z = z * (original_mean / new_mean) if not no_norm else z + zs.append(z[0]) + return torch.vstack(zs) + def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool, manual_chunk: bool): try: with devices.autocast():