diff --git a/scripts/Clip_IO.py b/scripts/Clip_IO.py index 36aec7b..e1af555 100644 --- a/scripts/Clip_IO.py +++ b/scripts/Clip_IO.py @@ -7,7 +7,8 @@ import open_clip from modules import scripts, script_callbacks, shared, devices, processing, prompt_parser from modules.shared import opts -from modules.sd_hijack_clip import PromptChunkFix, PromptChunk, FrozenCLIPEmbedderWithCustomWordsBase +from modules.sd_hijack_clip import PromptChunkFix, PromptChunk, FrozenCLIPEmbedderWithCustomWords +from modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedderWithCustomWords mode_types = ["replace", "concatenate", "command"] @@ -421,7 +422,7 @@ class Clip_IO(scripts.Script): Clip_IO.mode_negative = "Disabled" pass - def tokenize_line_manual_chunk(prompt: str, clip: FrozenCLIPEmbedderWithCustomWordsBase) -> list[PromptChunk]: + def tokenize_line_manual_chunk(prompt: str, clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords) -> list[PromptChunk]: if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(prompt) pass @@ -508,9 +509,9 @@ class Clip_IO(scripts.Script): return chunks - def process_texts_manual_chunk(prompts: list[str], clip: FrozenCLIPEmbedderWithCustomWordsBase) -> list[list[PromptChunk]]: + def process_texts_manual_chunk(prompts: list[str], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords) -> list[list[PromptChunk]]: cache = {} - batch_chunks = [] + batch_chunks: list[list[PromptChunk]] = [] for prompt in prompts: if prompt in cache: chunks = cache[prompt] @@ -520,9 +521,30 @@ class Clip_IO(scripts.Script): batch_chunks.append(chunks) + # We have to ensure all chunk in batch_chunks have same length. + # If not, fill with padding token and raise warning. + max_length = -1 + warned = False + for chunks in batch_chunks: + for chunk in chunks: + if max_length != -1 and max_length != len(chunk.tokens) and not warned: + warnings.warn("All chunk doesn't have same length. For processing, we'll fill with padding token to match same length.") + warned = True + pass + max_length = max(max_length, len(chunk.tokens)) + pass + pass + + for chunks in batch_chunks: + for chunk in chunks: + chunk.tokens += [clip.id_pad] * max(max_length - len(chunk.tokens), 0) + chunk.multipliers += [1.0] * max(max_length - len(chunk.multipliers), 0) + pass + pass + return batch_chunks - def get_chunks(prompt: str, clip: FrozenCLIPEmbedderWithCustomWordsBase, manual_chunk: bool) -> list[list[PromptChunk]]: + def get_chunks(prompt: str, clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> list[list[PromptChunk]]: if opts.use_old_emphasis_implementation: raise NotImplementedError pass @@ -535,7 +557,7 @@ class Clip_IO(scripts.Script): pass pass - def get_flat_embeddings(batch_chunks: list[list[PromptChunk]], clip: FrozenCLIPEmbedderWithCustomWordsBase, manual_chunk: bool) -> tuple[torch.Tensor, list[str]]: + def get_flat_embeddings(batch_chunks: list[list[PromptChunk]], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[torch.Tensor, list[str]]: input_ids = [] fixes = [] offset = 0 @@ -584,7 +606,7 @@ class Clip_IO(scripts.Script): def on_save_embeddings_as_pt(prompt: str, filename: str, overwrite: bool, manual_chunk: bool): try: - clip: FrozenCLIPEmbedderWithCustomWordsBase = shared.sd_model.cond_stage_model + clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) embeddings: torch.Tensor = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk) @@ -611,7 +633,7 @@ class Clip_IO(scripts.Script): def on_save_embeddings_as_csv(prompt: str, filename: str, transpose: bool, add_token: bool, overwrite: bool, manual_chunk: bool): try: - clip: FrozenCLIPEmbedderWithCustomWordsBase = shared.sd_model.cond_stage_model + clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) embeddings, tokens = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk) @@ -658,7 +680,7 @@ class Clip_IO(scripts.Script): def on_save_conditioning_as_pt(prompt: str, filename: str, no_emphasis: bool, no_norm: bool, overwrite: bool, manual_chunk: bool): try: with devices.autocast(): - clip: FrozenCLIPEmbedderWithCustomWordsBase = shared.sd_model.cond_stage_model + clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) chunk_count = max([len(x) for x in batch_chunks]) zs = [] @@ -707,7 +729,7 @@ class Clip_IO(scripts.Script): def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool, manual_chunk: bool): try: with devices.autocast(): - clip: FrozenCLIPEmbedderWithCustomWordsBase = shared.sd_model.cond_stage_model + clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk) _, token_list = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk) chunk_count = max([len(x) for x in batch_chunks])