Make chunk size same if not

main
File_xor 2023-06-03 20:36:23 +09:00
parent ac55afdbf2
commit c65919d233
1 changed files with 32 additions and 10 deletions

View File

@ -7,7 +7,8 @@ import open_clip
from modules import scripts, script_callbacks, shared, devices, processing, prompt_parser
from modules.shared import opts
from modules.sd_hijack_clip import PromptChunkFix, PromptChunk, FrozenCLIPEmbedderWithCustomWordsBase
from modules.sd_hijack_clip import PromptChunkFix, PromptChunk, FrozenCLIPEmbedderWithCustomWords
from modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedderWithCustomWords
mode_types = ["replace", "concatenate", "command"]
@ -421,7 +422,7 @@ class Clip_IO(scripts.Script):
Clip_IO.mode_negative = "Disabled"
pass
def tokenize_line_manual_chunk(prompt: str, clip: FrozenCLIPEmbedderWithCustomWordsBase) -> list[PromptChunk]:
def tokenize_line_manual_chunk(prompt: str, clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords) -> list[PromptChunk]:
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(prompt)
pass
@ -508,9 +509,9 @@ class Clip_IO(scripts.Script):
return chunks
def process_texts_manual_chunk(prompts: list[str], clip: FrozenCLIPEmbedderWithCustomWordsBase) -> list[list[PromptChunk]]:
def process_texts_manual_chunk(prompts: list[str], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords) -> list[list[PromptChunk]]:
cache = {}
batch_chunks = []
batch_chunks: list[list[PromptChunk]] = []
for prompt in prompts:
if prompt in cache:
chunks = cache[prompt]
@ -520,9 +521,30 @@ class Clip_IO(scripts.Script):
batch_chunks.append(chunks)
# We have to ensure all chunk in batch_chunks have same length.
# If not, fill with padding token and raise warning.
max_length = -1
warned = False
for chunks in batch_chunks:
for chunk in chunks:
if max_length != -1 and max_length != len(chunk.tokens) and not warned:
warnings.warn("All chunk doesn't have same length. For processing, we'll fill with padding token to match same length.")
warned = True
pass
max_length = max(max_length, len(chunk.tokens))
pass
pass
for chunks in batch_chunks:
for chunk in chunks:
chunk.tokens += [clip.id_pad] * max(max_length - len(chunk.tokens), 0)
chunk.multipliers += [1.0] * max(max_length - len(chunk.multipliers), 0)
pass
pass
return batch_chunks
def get_chunks(prompt: str, clip: FrozenCLIPEmbedderWithCustomWordsBase, manual_chunk: bool) -> list[list[PromptChunk]]:
def get_chunks(prompt: str, clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> list[list[PromptChunk]]:
if opts.use_old_emphasis_implementation:
raise NotImplementedError
pass
@ -535,7 +557,7 @@ class Clip_IO(scripts.Script):
pass
pass
def get_flat_embeddings(batch_chunks: list[list[PromptChunk]], clip: FrozenCLIPEmbedderWithCustomWordsBase, manual_chunk: bool) -> tuple[torch.Tensor, list[str]]:
def get_flat_embeddings(batch_chunks: list[list[PromptChunk]], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[torch.Tensor, list[str]]:
input_ids = []
fixes = []
offset = 0
@ -584,7 +606,7 @@ class Clip_IO(scripts.Script):
def on_save_embeddings_as_pt(prompt: str, filename: str, overwrite: bool, manual_chunk: bool):
try:
clip: FrozenCLIPEmbedderWithCustomWordsBase = shared.sd_model.cond_stage_model
clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
embeddings: torch.Tensor = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk)
@ -611,7 +633,7 @@ class Clip_IO(scripts.Script):
def on_save_embeddings_as_csv(prompt: str, filename: str, transpose: bool, add_token: bool, overwrite: bool, manual_chunk: bool):
try:
clip: FrozenCLIPEmbedderWithCustomWordsBase = shared.sd_model.cond_stage_model
clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
embeddings, tokens = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk)
@ -658,7 +680,7 @@ class Clip_IO(scripts.Script):
def on_save_conditioning_as_pt(prompt: str, filename: str, no_emphasis: bool, no_norm: bool, overwrite: bool, manual_chunk: bool):
try:
with devices.autocast():
clip: FrozenCLIPEmbedderWithCustomWordsBase = shared.sd_model.cond_stage_model
clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
chunk_count = max([len(x) for x in batch_chunks])
zs = []
@ -707,7 +729,7 @@ class Clip_IO(scripts.Script):
def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool, manual_chunk: bool):
try:
with devices.autocast():
clip: FrozenCLIPEmbedderWithCustomWordsBase = shared.sd_model.cond_stage_model
clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
_, token_list = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk)
chunk_count = max([len(x) for x in batch_chunks])