diff --git a/README.md b/README.md index c054830..00b64f0 100644 --- a/README.md +++ b/README.md @@ -52,16 +52,23 @@ If "DirectiveOrder" is absent, it will be treated as order is 0. Local objects for eval are: i: torch.Tensor : input conditioning o: torch.Tensor : output conditioning -c: dict : dict for carrying over +g: dict : dict for carrying over p: modules.processing.StableDiffusionProcessing : [See source code of Stable diffusion Web UI.](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/baf6946e06249c5af9851c60171692c44ef633e0/modules/processing.py#L105) t: int : 0th dimension (token-wise) of index of input conditioning d: int : 1st dimension (dimension-wise) of index of input conditioning torch module and all objects in math module ##### exec "exec" does component-wise python's exec. -Local objects for exec are: +Global objects for exec are: i: torch.Tensor : input conditioning o: torch.Tensor : output conditioning -c: dict : dict for carrying over +g: dict : dict for carrying over p: modules.processing.StableDiffusionProcessing : [See source code of Stable diffusion Web UI.](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/baf6946e06249c5af9851c60171692c44ef633e0/modules/processing.py#L105) -torch module and all objects in math module \ No newline at end of file +**NOTE: If you want to change seed, change both p.seed: int and p.seeds: list[int] .** +torch module and all objects in math module +##### prompt +"prompt" is prompt with some additional options. +syntax example: `?prompt("""prompt must be triple single/double quoted""", clip_skip=2, no_padding=False)` +arguments can be written as positional or omitted: `?prompt('''this also work''',,True)` +prompt and clip_skip can be with list or tuple syntax: `?prompt(("This part of clip_skip is 2", 'Here is 1', "1 is broadcasted because clip_skip is exhausted"),[2,1])` +`no_padding` does same behavior of `Don't add bos / eos / pad tokens` in Clip Output. \ No newline at end of file diff --git a/scripts/Clip_IO.py b/scripts/Clip_IO.py index fbcedd5..493cbbe 100644 --- a/scripts/Clip_IO.py +++ b/scripts/Clip_IO.py @@ -4,7 +4,7 @@ from collections import namedtuple from enum import IntEnum import gradio -import torch +import torch as torch import lark import open_clip @@ -20,6 +20,7 @@ class Clip_IO(scripts.Script): mode_positive = "Disabled" mode_negative = "Disabled" conditioning_cache = {} + global_carry = {} evacuate_get_learned_conditioning = None evacuate_get_multicond_learned_conditioning = None @@ -43,12 +44,16 @@ class Clip_IO(scripts.Script): mode_positive = gradio.Dropdown(["Disabled", "Simple", "Directive"], value = "Disabled", max_choices = 1, label = "Positive prompt mode") mode_negative = gradio.Dropdown(["Disabled", "Simple", "Directive"], value = "Disabled", max_choices = 1, label = "Negative prompt mode") pass + with gradio.Accordion("Pre/Post-process", open = False): + pre_batch_process = gradio.TextArea(max_lines=1024, label="Pre-batch-process") + post_batch_process = gradio.TextArea(max_lines=1024, label="Post-batch-process-process") + pass pass if not is_img2img: - return [enabled, mode_positive, mode_negative] + return [enabled, mode_positive, mode_negative, post_batch_process, pre_batch_process] pass else: - return [enabled, mode_positive, mode_negative] + return [enabled, mode_positive, mode_negative, post_batch_process, pre_batch_process] pass return [] pass @@ -148,6 +153,21 @@ class Clip_IO(scripts.Script): SPACE: /\s+/ """ + syntax_directive_prompt = r""" + start: (prompt | prompts) ("," (argument | arguments))* ("," (keyword_argument | keyword_arguments))* + prompt: PROMPT + prompts: "(" SPACE? PROMPT SPACE? ( "," SPACE? PROMPT SPACE? )* ")" | "[" SPACE? PROMPT SPACE? ( "," SPACE? PROMPT SPACE? )* "]" + PROMPT: /"{3}/ /.*?/ /"{3}/ | /'{3}/ /.*?/ /'{3}/ | /"(?!"")[^"]*?"/ | /'(?!'')[^']*?'/ + argument: [ARGUMENT] + arguments: "(" [ARGUMENT] ("," [ARGUMENT] )* ")" | "[" [ARGUMENT] ("," [ARGUMENT] )* "]" + ARGUMENT: /[^()\[\]=,]+/ + keyword_argument: KEYWORD "=" VALUE + keyword_arguments: KEYWORD "=" ( "(" [VALUE] ("," [VALUE] )* ")" | "[" [VALUE] ("," [VALUE] )* "]" ) + KEYWORD: /[^=,]+/ + VALUE: /[^()\[\]=,]+/ + SPACE: /\s+/ + """ + class Directive: class Names(IntEnum): eval @@ -266,6 +286,10 @@ class Clip_IO(scripts.Script): pass Process().transform(lark.Lark(Clip_IO.syntax_directive).parse(input)) + if len(conds) == 0: + tmp = Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(("a",), manual_chunk=False) + conds.append(torch.zeros(0, tmp.shape[1]).to(devices.device)) + pass i = torch.vstack(conds) o = i.clone() c = {} @@ -278,28 +302,161 @@ class Clip_IO(scripts.Script): try: for t in range(i.shape[0]): for d in range(i.shape[1]): - local = {"i": i, "o": o, "c": c, "p": p, "t": t, "d": d, "torch": torch.__dict__} | math.__dict__ + local = {"i": i, "o": o, "g": Clip_IO.global_carry, "c": c, "p": p, "t": t, "d": d, "sd_model": shared.sd_model, "torch": torch.__dict__} | math.__dict__ o[t, d] = eval(dir.inner, None, local) pass pass pass except Exception as e: - print(repr(e)) o = i + raise e pass finally: i = o.clone() pass elif dir.name == "exec": try: - local = {"i": i, "o": o, "c": c, "p": p, "torch": torch.__dict__} | math.__dict__ - exec(dir.inner, None, local) + globals = {"i": i, "o": o, "g": Clip_IO.global_carry, "c": c, "p": p, "sd_model": shared.sd_model, "torch": torch, "math": math} + exec(dir.inner, globals, None) except Exception as e: - print(repr(e)) o = i + raise e pass finally: - i = local["o"].clone() + i = globals["o"].clone() + pass + elif dir.name == "prompt": + # prompt(prompt: str, clip_skip: int|None=None, no_padding=True) + prompt: tuple[str] + keyword_arguments: dict = {"clip_skip": [None], "no_padding": False} + class prompt_transformer(lark.visitors.Transformer): + keyword_position = 0 + + def prompt(self, token: list[lark.Token]): + nonlocal prompt + prompt = (token[0],) + pass + + def prompts(self, tokens: list[lark.Token]): + nonlocal prompt + prompt = tuple(tokens) + pass + + def PROMPT(self, token: lark.Token): + if token.startswith('"""') and token.endswith('"""') or token.startswith("'''") and token.endswith("'''"): + token = token[3:-3] + pass + elif token.startswith('"') and token.endswith('"') or token.startswith("'") and token.endswith("'"): + token = token[1:-1] + pass + return token + pass + + def argument(self, token: list[lark.Token]): + token = token[0] + if token is None: + self.keyword_position += 1 + return + match self.keyword_position: + case 0: + if token.strip(" ").lower() == "none": + keyword_arguments["clip_skip"] = [[None]] + pass + else: + try: + keyword_arguments["clip_skip"] = [int(token.strip(" "))] + pass + except Exception: + print(f'Given argument "{token}" is neither integer nor None.') + pass + self.keyword_position += 1 + pass + case 1: + if token.strip(" ").lower() == "true": + value = True + pass + elif token.strip(" ").lower() == "false": + value = False + pass + else: + raise RuntimeError(f'Given argument "{token}" is neither True nor False.') + pass + keyword_arguments["no_padding"] = [value] + self.keyword_position += 1 + pass + case _: + self.keyword_position += 1 + pass + pass + + def arguments(self, token: list[lark.Token]): + match self.keyword_position: + case 0: + keyword_arguments["clip_skip"] = [] + for token in token: + if token.strip(" ").lower() == "none" or token.strip(" ") == "": + keyword_arguments["clip_skip"].append(None) + pass + else: + try: + keyword_arguments["clip_skip"].append(int(token.strip(" "))) + pass + except Exception: + print(f'Given argument "{token}" is neither integer nor None.') + pass + self.keyword_position += 1 + pass + pass + case 1: + keyword_arguments["no_padding"] = [] + for token in token: + if token.strip(" ").lower() == "true": + value = True + pass + elif token.strip(" ").lower() == "false": + value = False + pass + else: + raise RuntimeError(f'Given argument "{token}" is neither True nor False.') + pass + keyword_arguments["no_padding"].append(value) + self.keyword_position += 1 + pass + pass + case _: + self.keyword_position += 1 + pass + pass + pass + + def keyword_argument(self, tree: lark.tree.Tree): + keyword_arguments[tree.children[0].strip(" ")] = tree.children[1].strip(" ") + pass + pass + prompt_transformer().transform(lark.Lark(Clip_IO.syntax_directive_prompt).parse(dir.inner)) + o = torch.vstack([o, Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, manual_chunk= keyword_arguments["no_padding"], clip_skips=[keyword_arguments["clip_skip"]])]) + i = o.clone() + pass + elif dir.name == "execfile": + try: + if dir.inner.startswith('"') and dir.inner.endswith('"') or dir.inner.startswith("'") and dir.inner.endswith("'"): + dir.inner = dir.inner[1:-1] + pass + if not os.path.exists(dir.inner): + dir.inner = os.path.join(os.path.dirname(__file__), "../program", dir.inner) + if not os.path.exists(dir.inner): + dir.inner = dir.inner + ".py" + pass + pass + with open(dir.inner) as program: + globals = {"i": i, "o": o, "g": Clip_IO.global_carry, "c": c, "p": p, "sd_model": shared.sd_model, "torch": torch, "math": math} + exec(program, globals, None) + pass + pass + except Exception as e: + o = i + raise e + pass pass else: warnings.warn(f'Directive "{dir.name}" does not exist.') @@ -316,7 +473,7 @@ class Clip_IO(scripts.Script): pass pass - def my_get_learned_conditioning(model, prompts, steps, p: processing.StableDiffusionProcessing = None, is_negative = True): + def my_get_learned_conditioning(model, prompts: prompt_parser.SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False, p: processing.StableDiffusionProcessing = None, is_negative = True): """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), and the sampling step at which this condition is to be replaced by the next one. @@ -340,7 +497,7 @@ class Clip_IO(scripts.Script): prompt_schedules = [[[steps, prompt]] for prompt in prompts] pass else: - prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompts, steps) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling) pass res = [] @@ -352,7 +509,7 @@ class Clip_IO(scripts.Script): res.append(cached) continue - texts: list[str] = [x[1] for x in prompt_schedule] + texts = prompt_parser.SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts) if Clip_IO.enabled and (Clip_IO.mode_positive == "Simple" and not is_negative or Clip_IO.mode_negative == "Simple" and is_negative): conds = [] for text in texts: @@ -371,7 +528,15 @@ class Clip_IO(scripts.Script): cond_schedule = [] for i, (end_at_step, text) in enumerate(prompt_schedule): - cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, conds[i].to(devices.device))) + if isinstance(conds, dict): + cond = {k: v[i] for k, v in conds.items()} + pass + else: + cond = conds[i] + pass + + cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, cond)) + pass cache[prompt] = cond_schedule res.append(cond_schedule) @@ -379,7 +544,7 @@ class Clip_IO(scripts.Script): return res pass - def my_get_multicond_learned_conditioning(model, prompts, steps, p: processing.StableDiffusionProcessing = None) -> prompt_parser.MulticondLearnedConditioning: + def my_get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False, p: processing.StableDiffusionProcessing = None) -> prompt_parser.MulticondLearnedConditioning: """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. For each prompt, the list is obtained by splitting the prompt using the AND separator. @@ -388,11 +553,12 @@ class Clip_IO(scripts.Script): res_indexes, prompt_flat_list, prompt_indexes = prompt_parser.get_multicond_prompt_list(prompts) - learned_conditioning = prompt_parser.get_learned_conditioning(model, prompt_flat_list, steps, p, is_negative = False) + learned_conditioning = prompt_parser.get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling, p, is_negative = False) res = [] for indexes in res_indexes: res.append([prompt_parser.ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes]) + pass return prompt_parser.MulticondLearnedConditioning(shape=(len(prompts),), batch=res) pass @@ -474,7 +640,7 @@ class Clip_IO(scripts.Script): pass def get_my_get_conds_with_caching(p: processing.StableDiffusionProcessing): - def my_get_conds_with_caching(function, required_prompts, steps, caches, extra_network_data): + def my_get_conds_with_caching(function, required_prompts, steps, caches, extra_network_data, hires_steps=None): """ Returns the result of calling function(shared.sd_model, required_prompts, steps) using a cache to store the result if the same arguments have been used before. @@ -486,18 +652,29 @@ class Clip_IO(scripts.Script): caches is a list with items described above. """ + if shared.opts.use_old_scheduling: + old_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, False) + new_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, True) + if old_schedules != new_schedules: + p.extra_generation_params["Old prompt editing timelines"] = True + pass + pass + + cached_params = p.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling) + for cache in caches: - if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]: + if cache[0] is not None and cached_params == cache[0] and not Clip_IO.enabled: return cache[1] cache = caches[0] with devices.autocast(): - cache[1] = function(shared.sd_model, required_prompts, steps, p) + cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) - cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) + cache[0] = cached_params return cache[1] pass + return my_get_conds_with_caching def get_inner_function(outer, new_inner): @@ -596,7 +773,6 @@ class Clip_IO(scripts.Script): prompt_parser.get_learned_conditioning = Clip_IO.my_get_learned_conditioning prompt_parser.get_multicond_learned_conditioning = Clip_IO.my_get_multicond_learned_conditioning #Clip_IO.replace_inner_function(processing.process_images_inner, Clip_IO.get_my_get_conds_with_caching()) - pass else: Clip_IO.enabled = False pass @@ -620,38 +796,75 @@ class Clip_IO(scripts.Script): Clip_IO.evacuate_get_conds_with_caching = p.get_conds_with_caching p.get_conds_with_caching = Clip_IO.get_my_get_conds_with_caching(p) pass + try: + globals = {"g": Clip_IO.global_carry, "p": p, "sd_model": shared.sd_model, "torch": torch, "math": math} + exec(args[4], globals, None) + except Exception as e: + raise e + pass + pass pass pass def postprocess_batch(self, p: processing.StableDiffusionProcessing, *args, **kwargs): if Clip_IO.enabled: - Clip_IO.mode_positive = "Disabled" - Clip_IO.mode_negative = "Disabled" - if getattr(p, "get_conds_with_caching", None) is not None: - p.get_conds_with_caching = Clip_IO.evacuate_get_conds_with_caching + try: + globals = {"g": Clip_IO.global_carry, "p": p, "sd_model": shared.sd_model, "torch": torch, "math": math} + exec(args[3], globals, None) + except Exception as e: + raise e + pass + finally: + Clip_IO.mode_positive = "Disabled" + Clip_IO.mode_negative = "Disabled" + if getattr(p, "get_conds_with_caching", None) is not None: + p.get_conds_with_caching = Clip_IO.evacuate_get_conds_with_caching + pass pass pass pass - def tokenize_line_manual_chunk(prompt: str, clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords) -> list[PromptChunk]: - if opts.enable_emphasis: - parsed = prompt_parser.parse_prompt_attention(prompt) - pass - else: - parsed = [[prompt, 1.0]] + def tokenize_line_manual_chunk(prompts: tuple[str, ...], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[list[PromptChunk],list[tuple[int,int]]]: + parsed = [] + for prompt in prompts: + parsed.append(["SEPARATION", -1]) + if opts.enable_emphasis: + to_appends = prompt_parser.parse_prompt_attention(prompt) + for to_append in to_appends: + parsed.append(to_append) + pass + else: + parsed.append([prompt, 1.0]) + pass pass tokenized = clip.tokenize([text for text, _ in parsed]) chunks: list[PromptChunk] = [] chunk = PromptChunk() + token_count = 0 last_comma = -1 + separation_starts: list[tuple[int,int]] = [] def next_chunk(is_last=False): + nonlocal token_count nonlocal last_comma nonlocal chunk - # We don't have to fill the chunk. + if not manual_chunk: + if is_last: + token_count += len(chunk.tokens) + else: + token_count += clip.chunk_length + + to_add = clip.chunk_length - len(chunk.tokens) + if to_add > 0: + chunk.tokens += [clip.id_end] * to_add + chunk.multipliers += [1.0] * to_add + + chunk.tokens = [clip.id_start] + chunk.tokens + [clip.id_end] + chunk.multipliers = [1.0] + chunk.multipliers + [1.0] + pass last_comma = -1 chunks.append(chunk) @@ -659,6 +872,11 @@ class Clip_IO(scripts.Script): pass for tokens, (text, weight) in zip(tokenized, parsed): + if text == 'SEPARATION' and weight == -1: + separation_starts.append((len(chunks), len(chunk.tokens))) + continue + pass + if text == 'BREAK' and weight == -1: next_chunk() continue @@ -674,7 +892,7 @@ class Clip_IO(scripts.Script): # this is when we are at the end of alloted 77 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. - elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == clip.chunk_length + 2 and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: + elif manual_chunk and opts.comma_padding_backtrack != 0 and len(chunk.tokens) == clip.chunk_length + 2 and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: break_location = last_comma + 1 reloc_tokens = chunk.tokens[break_location:] @@ -688,7 +906,25 @@ class Clip_IO(scripts.Script): chunk.multipliers = reloc_mults pass - if len(chunk.tokens) == clip.chunk_length + 2: + elif not manual_chunk and opts.comma_padding_backtrack != 0 and len(chunk.tokens) == clip.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: + break_location = last_comma + 1 + + reloc_tokens = chunk.tokens[break_location:] + reloc_mults = chunk.multipliers[break_location:] + + chunk.tokens = chunk.tokens[:break_location] + chunk.multipliers = chunk.multipliers[:break_location] + + next_chunk() + chunk.tokens = reloc_tokens + chunk.multipliers = reloc_mults + pass + + if manual_chunk and len(chunk.tokens) == clip.chunk_length + 2: + next_chunk() + pass + + elif not manual_chunk and len(chunk.tokens) == clip.chunk_length: next_chunk() pass @@ -718,19 +954,21 @@ class Clip_IO(scripts.Script): next_chunk(is_last=True) pass - return chunks + return chunks, separation_starts - def process_texts_manual_chunk(prompts: list[str], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords) -> list[list[PromptChunk]]: + def process_texts_manual_chunk(prompts: list[tuple[str, ...]], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[list[list[PromptChunk]],list[list[tuple[int,int]]]]: cache = {} batch_chunks: list[list[PromptChunk]] = [] + separation_starts_list: list[list[tuple[int,int]]] = [] for prompt in prompts: if prompt in cache: chunks = cache[prompt] else: - chunks = Clip_IO.tokenize_line_manual_chunk(prompt, clip) + chunks, separation_starts = Clip_IO.tokenize_line_manual_chunk(prompt, clip, manual_chunk) cache[prompt] = chunks batch_chunks.append(chunks) + separation_starts_list.append(separation_starts) if False: # We have to ensure all chunk in batch_chunks have same length. @@ -755,14 +993,18 @@ class Clip_IO(scripts.Script): pass pass - return batch_chunks + return batch_chunks, separation_starts_list - def get_chunks(prompt: str, clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> list[list[PromptChunk]]: + def get_chunks(prompt: str | tuple[str, ...], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[list[list[PromptChunk]],list[list[tuple[int,int]]]]: + "Return: PromptChunks, token separation position" + if isinstance(prompt, str): + prompt = (prompt,) + pass if opts.use_old_emphasis_implementation: raise NotImplementedError pass - if manual_chunk: - return Clip_IO.process_texts_manual_chunk([prompt], clip) + if True: # manual_chunk: + return Clip_IO.process_texts_manual_chunk([prompt], clip, manual_chunk) pass else: batch_chunks, _ = clip.process_texts([prompt]) @@ -820,7 +1062,7 @@ class Clip_IO(scripts.Script): def on_save_embeddings_as_pt(prompt: str, filename: str, overwrite: bool, manual_chunk: bool): try: clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model - batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) + batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk) embeddings: torch.Tensor = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk) filename = os.path.join(os.path.dirname(__file__), "../conditioning", filename) @@ -847,7 +1089,7 @@ class Clip_IO(scripts.Script): def on_save_embeddings_as_csv(prompt: str, filename: str, transpose: bool, add_token: bool, overwrite: bool, manual_chunk: bool): try: clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model - batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) + batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk) embeddings, tokens = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk) embeddings: list[list[str]] = embeddings[0].tolist() @@ -894,7 +1136,7 @@ class Clip_IO(scripts.Script): try: with devices.autocast(): clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model - batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) + batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk) chunk_count = max([len(x) for x in batch_chunks]) zs = [] for i in range(chunk_count): @@ -939,11 +1181,92 @@ class Clip_IO(scripts.Script): return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}' pass + def encode_with_transformers(clip, tokens, chunk_count: int, clip_skips: list[list[int | None]], separation_starts_list:list[list[tuple[int,int]]]): + clip.wrapped.transformer.to(shared.device) + outputs = clip.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + + zs = [] + for z in outputs.hidden_states: + zs.append(clip.wrapped.transformer.text_model.final_layer_norm(z)) + pass + zs[-1] = outputs.last_hidden_state + + if opts.CLIP_stop_at_last_layers > 1: + zo = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] + zo = clip.wrapped.transformer.text_model.final_layer_norm(zo) + else: + zo = outputs.last_hidden_state + + for i in range(len(tokens)): + chunk_start = separation_starts_list[i][0][0] + chunk_end = separation_starts_list[i][1][0] if len(separation_starts_list[i]) >= 2 else chunk_count + 1 + while True: + while chunk_count < chunk_start: + separation_starts_list[i].pop(0) + clip_skips[i].pop(0) + pass + + if chunk_count == chunk_start: + token_start = separation_starts_list[i][0][1] + pass + elif chunk_count > chunk_start: + token_start = 0 + pass + + if chunk_count < chunk_end: + token_end = zo.shape[1] + pass + elif chunk_count == chunk_end: + token_end = separation_starts_list[i][1][1] if len(separation_starts_list[i]) >= 2 else zo.shape[1] + pass + + zo[i,token_start:token_end,:] = zs[-clip_skips[i][0]][i,token_start:token_end,:] if clip_skips[i][0] is not None else zo[i,token_start:token_end,:] + if token_end == zo.shape[1]: + break + pass + if chunk_count == chunk_end: + separation_starts_list[i].pop(0) + if len(clip_skips[i]) >=2: + clip_skips[i].pop(0) + pass + else: + clip_skips[i][0] = None + pass + pass + pass + + return zo + + def FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt: str | tuple[str, ...], clip = shared.sd_model.cond_stage_model, manual_chunk = False, clip_skips: list[list[int | None]] = [[None]]) -> torch.Tensor: + batch_chunks, separation_starts_list = Clip_IO.get_chunks(prompt, clip, manual_chunk) + chunk_count = max([len(x) for x in batch_chunks]) + zs = [] + for i in range(chunk_count): + batch_chunk = [chunks[i] if i < len(chunks) else clip.empty_chunk() for chunks in batch_chunks] + remade_batch_tokens = [x.tokens for x in batch_chunk] + tokens = torch.asarray([x.tokens for x in batch_chunk]).to(devices.device) + clip.hijack.fixes = [x.fixes for x in batch_chunk] + + if clip.id_end != clip.id_pad and not manual_chunk: + for batch_pos in range(len(remade_batch_tokens)): + index = remade_batch_tokens[batch_pos].index(clip.id_end) + tokens[batch_pos, index+1:tokens.shape[1]] = clip.id_pad + + z = Clip_IO.encode_with_transformers(clip, tokens, i, clip_skips, separation_starts_list) if isinstance(clip, FrozenCLIPEmbedderWithCustomWords) else clip.encode_with_transformers(tokens) + if True: # if not no_emphasis: + batch_multipliers = torch.asarray([x.multipliers for x in batch_chunk]).to(devices.device) + original_mean = z.mean() + z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + new_mean = z.mean() + z = z * (original_mean / new_mean) # z = z * (original_mean / new_mean) if not no_norm else z + zs.append(z[0]) + return torch.vstack(zs) + def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool, manual_chunk: bool): try: with devices.autocast(): clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model - batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) + batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk) _, token_list = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk) chunk_count = max([len(x) for x in batch_chunks]) zs = []