diff --git a/scripts/Clip_IO.py b/scripts/Clip_IO.py index d1a8380..6f9f9a4 100644 --- a/scripts/Clip_IO.py +++ b/scripts/Clip_IO.py @@ -149,13 +149,18 @@ class Clip_IO(scripts.Script): """ syntax_directive_prompt = r""" - start: PROMPT ("," argument)* ("," keyword_argument)* - PROMPT: /"{3}/ /.*?/ /"{3}/ | /'{3}/ /.*?/ /'{3}/ + start: (prompt | prompts) ("," (argument | arguments))* ("," (keyword_argument | keyword_arguments))* + prompt: PROMPT + prompts: "(" SPACE? PROMPT SPACE? ( "," SPACE? PROMPT SPACE? )* ")" | "[" SPACE? PROMPT SPACE? ( "," SPACE? PROMPT SPACE? )* "]" + PROMPT: /"{3}/ /.*?/ /"{3}/ | /'{3}/ /.*?/ /'{3}/ | /"(?!"")[^"]*?"/ | /'(?!'')[^']*?'/ argument: [ARGUMENT] - ARGUMENT: /[^=,]+/ + arguments: "(" [ARGUMENT] ("," [ARGUMENT] )* ")" | "[" [ARGUMENT] ("," [ARGUMENT] )* "]" + ARGUMENT: /[^()\[\]=,]+/ keyword_argument: KEYWORD "=" VALUE + keyword_arguments: KEYWORD "=" ( "(" [VALUE] ("," [VALUE] )* ")" | "[" [VALUE] ("," [VALUE] )* "]" ) KEYWORD: /[^=,]+/ - VALUE: /[^=,]+/ + VALUE: /[^()\[\]=,]+/ + SPACE: /\s+/ """ class Directive: @@ -277,7 +282,7 @@ class Clip_IO(scripts.Script): Process().transform(lark.Lark(Clip_IO.syntax_directive).parse(input)) if len(conds) == 0: - tmp = Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword("a", manual_chunk=True) + tmp = Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(("a",), manual_chunk=True) conds.append(torch.zeros(0, tmp.shape[1]).to(devices.device)) pass i = torch.vstack(conds) @@ -317,18 +322,32 @@ class Clip_IO(scripts.Script): pass elif dir.name == "prompt": # prompt(prompt: str, clip_skip: int|None=None, padding=True) - prompt: str + prompt: tuple[str] keyword_arguments: dict = {"clip_skip": None, "padding": True} class prompt_transformer(lark.visitors.Transformer): keyword_position = 0 - def PROMPT(self, token: lark.Token): + + def prompt(self, token: list[lark.Token]): nonlocal prompt + prompt = (token[0],) + pass + + def prompts(self, tokens: list[lark.Token]): + nonlocal prompt + prompt = tuple(tokens) + pass + + def PROMPT(self, token: lark.Token): if token.startswith('"""') and token.endswith('"""') or token.startswith("'''") and token.endswith("'''"): token = token[3:-3] pass - prompt = token + elif token.startswith('"') and token.endswith('"') or token.startswith("'") and token.endswith("'"): + token = token[1:-1] + pass + return token pass - def argument(self, token: lark.Token): + + def argument(self, token: list[lark.Token]): token = token[0] if token is None: self.keyword_position += 1 @@ -336,11 +355,11 @@ class Clip_IO(scripts.Script): match self.keyword_position: case 0: if token.strip(" ").lower() == "none": - keyword_arguments["clip_skip"] = None + keyword_arguments["clip_skip"] = [None] pass else: try: - keyword_arguments["clip_skip"] = int(token.strip(" ")) + keyword_arguments["clip_skip"] = [int(token.strip(" "))] pass except Exception: print(f'Given argument "{token}" is neither integer nor None.') @@ -357,25 +376,61 @@ class Clip_IO(scripts.Script): else: raise RuntimeError(f'Given argument "{token}" is neither True nor False.') pass - keyword_arguments["padding"] = value + keyword_arguments["padding"] = [value] self.keyword_position += 1 pass case _: self.keyword_position += 1 pass pass + + def arguments(self, token: list[lark.Token]): + match self.keyword_position: + case 0: + keyword_arguments["clip_skip"] = [] + for token in token: + if token.strip(" ").lower() == "none" or token.strip(" ") == "": + keyword_arguments["clip_skip"].append(None) + pass + else: + try: + keyword_arguments["clip_skip"].append(int(token.strip(" "))) + pass + except Exception: + print(f'Given argument "{token}" is neither integer nor None.') + pass + self.keyword_position += 1 + pass + pass + case 1: + keyword_arguments["padding"] = [] + for token in token: + if token.strip(" ").lower() == "true": + value = True + pass + elif token.strip(" ").lower() == "false": + value = False + pass + else: + raise RuntimeError(f'Given argument "{token}" is neither True nor False.') + pass + keyword_arguments["padding"].append(value) + self.keyword_position += 1 + pass + pass + case _: + self.keyword_position += 1 + pass + pass + pass + def keyword_argument(self, tree: lark.tree.Tree): keyword_arguments[tree.children[0].strip(" ")] = tree.children[1].strip(" ") pass pass prompt_transformer().transform(lark.Lark(Clip_IO.syntax_directive_prompt).parse(dir.inner)) - evacuate_clip_skip = shared.opts.CLIP_stop_at_last_layers - if keyword_arguments["clip_skip"] is not None: - shared.opts.CLIP_stop_at_last_layers = keyword_arguments["clip_skip"] - pass - o = torch.vstack([o, Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, manual_chunk= not keyword_arguments["padding"])]) + o = torch.vstack([o, Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, manual_chunk= not keyword_arguments["padding"], clip_skips=[keyword_arguments["clip_skip"]])]) i = o.clone() - shared.opts.CLIP_stop_at_last_layers = evacuate_clip_skip pass else: warnings.warn(f'Directive "{dir.name}" does not exist.') @@ -709,25 +764,47 @@ class Clip_IO(scripts.Script): pass pass - def tokenize_line_manual_chunk(prompt: str, clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords) -> list[PromptChunk]: - if opts.enable_emphasis: - parsed = prompt_parser.parse_prompt_attention(prompt) - pass - else: - parsed = [[prompt, 1.0]] + def tokenize_line_manual_chunk(prompts: tuple[str, ...], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[list[PromptChunk],list[tuple[int,int]]]: + parsed = [] + for prompt in prompts: + parsed.append(["SEPARATION", -1]) + if opts.enable_emphasis: + to_appends = prompt_parser.parse_prompt_attention(prompt) + for to_append in to_appends: + parsed.append(to_append) + pass + else: + parsed.append([prompt, 1.0]) + pass pass tokenized = clip.tokenize([text for text, _ in parsed]) chunks: list[PromptChunk] = [] chunk = PromptChunk() + token_count = 0 last_comma = -1 + separation_starts: list[tuple[int,int]] = [] def next_chunk(is_last=False): + nonlocal token_count nonlocal last_comma nonlocal chunk - # We don't have to fill the chunk. + if not manual_chunk: + if is_last: + token_count += len(chunk.tokens) + else: + token_count += clip.chunk_length + + to_add = clip.chunk_length - len(chunk.tokens) + if to_add > 0: + chunk.tokens += [clip.id_end] * to_add + chunk.multipliers += [1.0] * to_add + + chunk.tokens = [clip.id_start] + chunk.tokens + [clip.id_end] + chunk.multipliers = [1.0] + chunk.multipliers + [1.0] + pass last_comma = -1 chunks.append(chunk) @@ -735,6 +812,9 @@ class Clip_IO(scripts.Script): pass for tokens, (text, weight) in zip(tokenized, parsed): + if text == 'SEPARATION' and weight == -1: + separation_starts.append((len(chunks), len(chunk.tokens))) + if text == 'BREAK' and weight == -1: next_chunk() continue @@ -750,7 +830,7 @@ class Clip_IO(scripts.Script): # this is when we are at the end of alloted 77 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. - elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == clip.chunk_length + 2 and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: + elif manual_chunk and opts.comma_padding_backtrack != 0 and len(chunk.tokens) == clip.chunk_length + 2 and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: break_location = last_comma + 1 reloc_tokens = chunk.tokens[break_location:] @@ -764,7 +844,25 @@ class Clip_IO(scripts.Script): chunk.multipliers = reloc_mults pass - if len(chunk.tokens) == clip.chunk_length + 2: + elif not manual_chunk and opts.comma_padding_backtrack != 0 and len(chunk.tokens) == clip.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: + break_location = last_comma + 1 + + reloc_tokens = chunk.tokens[break_location:] + reloc_mults = chunk.multipliers[break_location:] + + chunk.tokens = chunk.tokens[:break_location] + chunk.multipliers = chunk.multipliers[:break_location] + + next_chunk() + chunk.tokens = reloc_tokens + chunk.multipliers = reloc_mults + pass + + if manual_chunk and len(chunk.tokens) == clip.chunk_length + 2: + next_chunk() + pass + + elif not manual_chunk and len(chunk.tokens) == clip.chunk_length: next_chunk() pass @@ -794,19 +892,21 @@ class Clip_IO(scripts.Script): next_chunk(is_last=True) pass - return chunks + return chunks, separation_starts - def process_texts_manual_chunk(prompts: list[str], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords) -> list[list[PromptChunk]]: + def process_texts_manual_chunk(prompts: list[tuple[str, ...]], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[list[list[PromptChunk]],list[list[tuple[int,int]]]]: cache = {} batch_chunks: list[list[PromptChunk]] = [] + separation_starts_list: list[list[tuple[int,int]]] = [] for prompt in prompts: if prompt in cache: chunks = cache[prompt] else: - chunks = Clip_IO.tokenize_line_manual_chunk(prompt, clip) + chunks, separation_starts = Clip_IO.tokenize_line_manual_chunk(prompt, clip, manual_chunk) cache[prompt] = chunks batch_chunks.append(chunks) + separation_starts_list.append(separation_starts) if False: # We have to ensure all chunk in batch_chunks have same length. @@ -831,14 +931,18 @@ class Clip_IO(scripts.Script): pass pass - return batch_chunks + return batch_chunks, separation_starts_list - def get_chunks(prompt: str, clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> list[list[PromptChunk]]: + def get_chunks(prompt: str | tuple[str, ...], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[list[list[PromptChunk]],list[list[tuple[int,int]]]]: + "Return: PromptChunks, token separation position" + if isinstance(prompt, str): + prompt = [prompt] + pass if opts.use_old_emphasis_implementation: raise NotImplementedError pass - if manual_chunk: - return Clip_IO.process_texts_manual_chunk([prompt], clip) + if True: # manual_chunk: + return Clip_IO.process_texts_manual_chunk([prompt], clip, manual_chunk) pass else: batch_chunks, _ = clip.process_texts([prompt]) @@ -896,7 +1000,7 @@ class Clip_IO(scripts.Script): def on_save_embeddings_as_pt(prompt: str, filename: str, overwrite: bool, manual_chunk: bool): try: clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model - batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) + batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk) embeddings: torch.Tensor = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk) filename = os.path.join(os.path.dirname(__file__), "../conditioning", filename) @@ -923,7 +1027,7 @@ class Clip_IO(scripts.Script): def on_save_embeddings_as_csv(prompt: str, filename: str, transpose: bool, add_token: bool, overwrite: bool, manual_chunk: bool): try: clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model - batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) + batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk) embeddings, tokens = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk) embeddings: list[list[str]] = embeddings[0].tolist() @@ -970,7 +1074,7 @@ class Clip_IO(scripts.Script): try: with devices.autocast(): clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model - batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) + batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk) chunk_count = max([len(x) for x in batch_chunks]) zs = [] for i in range(chunk_count): @@ -1015,8 +1119,63 @@ class Clip_IO(scripts.Script): return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}' pass - def FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, clip = shared.sd_model.cond_stage_model, manual_chunk = False) -> torch.Tensor: - batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) + def encode_with_transformers(clip, tokens, chunk_count: int, clip_skips: list[list[int | None]], separation_starts_list:list[list[tuple[int,int]]]): + outputs = clip.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + + zs = [] + for z in outputs.hidden_states: + zs.append(clip.wrapped.transformer.text_model.final_layer_norm(z)) + pass + zs[-1] = outputs.last_hidden_state + + if opts.CLIP_stop_at_last_layers > 1: + zo = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] + zo = clip.wrapped.transformer.text_model.final_layer_norm(z) + else: + zo = outputs.last_hidden_state + + for i in range(len(tokens)): + chunk_start = separation_starts_list[i][0][0] + chunk_end = separation_starts_list[i][1][0] if len(separation_starts_list[i]) >= 2 else chunk_count + 1 + while True: + while chunk_count < chunk_start: + separation_starts_list[i].pop(0) + clip_skips[i].pop(0) + pass + + if chunk_count == chunk_start: + token_start = separation_starts_list[i][0][1] + pass + elif chunk_count > chunk_start: + token_start = 0 + pass + + if chunk_count < chunk_end: + token_end = zo.shape[1] + pass + elif chunk_count == chunk_end: + token_end = separation_starts_list[i][1][1] if len(separation_starts_list[i]) >= 2 else zo.shape[1] + pass + + zo[i,token_start:token_end,:] = zs[-clip_skips[i][0]][i,token_start:token_end,:] if clip_skips[i][0] is not None else zo[i,token_start:token_end,:] + if token_end == zo.shape[1]: + break + pass + if chunk_count == chunk_end: + separation_starts_list[i].pop(0) + if len(clip_skips[i]) >=2: + clip_skips[i].pop(0) + pass + else: + clip_skips[i][0] = None + pass + pass + pass + + return zo + + def FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt: str | tuple[str, ...], clip = shared.sd_model.cond_stage_model, manual_chunk = False, clip_skips: list[list[int | None]] = [[None]]) -> torch.Tensor: + batch_chunks, separation_starts_list = Clip_IO.get_chunks(prompt, clip, manual_chunk) chunk_count = max([len(x) for x in batch_chunks]) zs = [] for i in range(chunk_count): @@ -1030,7 +1189,7 @@ class Clip_IO(scripts.Script): index = remade_batch_tokens[batch_pos].index(clip.id_end) tokens[batch_pos, index+1:tokens.shape[1]] = clip.id_pad - z = clip.encode_with_transformers(tokens) + z = Clip_IO.encode_with_transformers(clip, tokens, i, clip_skips, separation_starts_list) if isinstance(clip, FrozenCLIPEmbedderWithCustomWords) else clip.encode_with_transformers(tokens) if True: # if not no_emphasis: batch_multipliers = torch.asarray([x.multipliers for x in batch_chunk]).to(devices.device) original_mean = z.mean() @@ -1044,7 +1203,7 @@ class Clip_IO(scripts.Script): try: with devices.autocast(): clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model - batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) + batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk) _, token_list = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk) chunk_count = max([len(x) for x in batch_chunks]) zs = []