Add option "Don't add bos / eos / pad tokens"

main
File_xor 2023-06-03 18:36:26 +09:00
parent 24979528fd
commit ac55afdbf2
1 changed files with 140 additions and 23 deletions

View File

@ -421,15 +421,121 @@ class Clip_IO(scripts.Script):
Clip_IO.mode_negative = "Disabled"
pass
def get_chunks(prompt: str, clip: FrozenCLIPEmbedderWithCustomWordsBase) -> PromptChunk:
def tokenize_line_manual_chunk(prompt: str, clip: FrozenCLIPEmbedderWithCustomWordsBase) -> list[PromptChunk]:
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(prompt)
pass
else:
parsed = [[prompt, 1.0]]
pass
tokenized = clip.tokenize([text for text, _ in parsed])
chunks: list[PromptChunk] = []
chunk = PromptChunk()
last_comma = -1
def next_chunk(is_last=False):
nonlocal last_comma
nonlocal chunk
# We don't have to fill the chunk.
last_comma = -1
chunks.append(chunk)
chunk = PromptChunk()
pass
for tokens, (text, weight) in zip(tokenized, parsed):
if text == 'BREAK' and weight == -1:
next_chunk()
continue
pass
position = 0
while position < len(tokens):
token = tokens[position]
if token == clip.comma_token:
last_comma = len(chunk.tokens)
pass
# this is when we are at the end of alloted 77 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == clip.chunk_length + 2 and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
break_location = last_comma + 1
reloc_tokens = chunk.tokens[break_location:]
reloc_mults = chunk.multipliers[break_location:]
chunk.tokens = chunk.tokens[:break_location]
chunk.multipliers = chunk.multipliers[:break_location]
next_chunk()
chunk.tokens = reloc_tokens
chunk.multipliers = reloc_mults
pass
if len(chunk.tokens) == clip.chunk_length + 2:
next_chunk()
pass
embedding, embedding_length_in_tokens = clip.hijack.embedding_db.find_embedding_at_position(tokens, position)
# The token is not Textual Inversion
if embedding is None:
chunk.tokens.append(token)
chunk.multipliers.append(weight)
position += 1
continue
pass
emb_len = int(embedding.vec.shape[0])
if len(chunk.tokens) + emb_len > clip.chunk_length + 2:
next_chunk()
pass
chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
chunk.tokens += [0] * emb_len
chunk.multipliers += [weight] * emb_len
position += embedding_length_in_tokens
pass
if len(chunk.tokens) > 0 or len(chunks) == 0:
next_chunk(is_last=True)
pass
return chunks
def process_texts_manual_chunk(prompts: list[str], clip: FrozenCLIPEmbedderWithCustomWordsBase) -> list[list[PromptChunk]]:
cache = {}
batch_chunks = []
for prompt in prompts:
if prompt in cache:
chunks = cache[prompt]
else:
chunks = Clip_IO.tokenize_line_manual_chunk(prompt, clip)
cache[prompt] = chunks
batch_chunks.append(chunks)
return batch_chunks
def get_chunks(prompt: str, clip: FrozenCLIPEmbedderWithCustomWordsBase, manual_chunk: bool) -> list[list[PromptChunk]]:
if opts.use_old_emphasis_implementation:
raise NotImplementedError
pass
batch_chunks, _ = clip.process_texts([prompt])
return batch_chunks
if manual_chunk:
return Clip_IO.process_texts_manual_chunk([prompt], clip)
pass
else:
batch_chunks, _ = clip.process_texts([prompt])
return batch_chunks
pass
pass
def get_flat_embeddings(batch_chunks: PromptChunk, clip: FrozenCLIPEmbedderWithCustomWordsBase) -> tuple[torch.Tensor, list[str]]:
def get_flat_embeddings(batch_chunks: list[list[PromptChunk]], clip: FrozenCLIPEmbedderWithCustomWordsBase, manual_chunk: bool) -> tuple[torch.Tensor, list[str]]:
input_ids = []
fixes = []
offset = 0
@ -456,21 +562,31 @@ class Clip_IO(scripts.Script):
pass
tokens = [decode(input_id) for input_id in input_ids]
for fix in fixes:
tokens[fix.offset + 1] = fix.embedding.name
for i in range(1, fix.embedding.vec.shape[0]):
tokens[fix.offset + 1 + i] = ""
if manual_chunk:
for fix in fixes:
tokens[fix.offset] = fix.embedding.name
for i in range(1, fix.embedding.vec.shape[0]):
tokens[fix.offset + i] = ""
pass
pass
pass
else:
for fix in fixes:
tokens[fix.offset + 1] = fix.embedding.name
for i in range(1, fix.embedding.vec.shape[0]):
tokens[fix.offset + 1 + i] = ""
pass
pass
pass
return clip.wrapped.model.token_embedding(input_ids_Tensor) if is_open_clip else clip.wrapped.transformer.text_model.embeddings.token_embedding(input_ids_Tensor), tokens
pass
def on_save_embeddings_as_pt(prompt: str, filename: str, overwrite: bool):
def on_save_embeddings_as_pt(prompt: str, filename: str, overwrite: bool, manual_chunk: bool):
try:
clip: FrozenCLIPEmbedderWithCustomWordsBase = shared.sd_model.cond_stage_model
batch_chunks = Clip_IO.get_chunks(prompt, clip)
embeddings: torch.Tensor = Clip_IO.get_flat_embeddings(batch_chunks, clip)
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
embeddings: torch.Tensor = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk)
filename = os.path.join(os.path.dirname(__file__), "../conditioning", filename)
filename = os.path.realpath(filename)
@ -493,11 +609,11 @@ class Clip_IO(scripts.Script):
return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}'
pass
def on_save_embeddings_as_csv(prompt: str, filename: str, transpose: bool, add_token: bool, overwrite: bool):
def on_save_embeddings_as_csv(prompt: str, filename: str, transpose: bool, add_token: bool, overwrite: bool, manual_chunk: bool):
try:
clip: FrozenCLIPEmbedderWithCustomWordsBase = shared.sd_model.cond_stage_model
batch_chunks = Clip_IO.get_chunks(prompt, clip)
embeddings, tokens = Clip_IO.get_flat_embeddings(batch_chunks, clip)
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
embeddings, tokens = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk)
embeddings: list[list[str]] = embeddings[0].tolist()
width = len(embeddings[0])
@ -539,11 +655,11 @@ class Clip_IO(scripts.Script):
return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}'
pass
def on_save_conditioning_as_pt(prompt: str, filename: str, no_emphasis: bool, no_norm: bool, overwrite: bool):
def on_save_conditioning_as_pt(prompt: str, filename: str, no_emphasis: bool, no_norm: bool, overwrite: bool, manual_chunk: bool):
try:
with devices.autocast():
clip: FrozenCLIPEmbedderWithCustomWordsBase = shared.sd_model.cond_stage_model
batch_chunks = Clip_IO.get_chunks(prompt, clip)
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
chunk_count = max([len(x) for x in batch_chunks])
zs = []
for i in range(chunk_count):
@ -588,12 +704,12 @@ class Clip_IO(scripts.Script):
return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}'
pass
def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool):
def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool, manual_chunk: bool):
try:
with devices.autocast():
clip: FrozenCLIPEmbedderWithCustomWordsBase = shared.sd_model.cond_stage_model
batch_chunks = Clip_IO.get_chunks(prompt, clip)
_, token_list = Clip_IO.get_flat_embeddings(batch_chunks, clip)
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
_, token_list = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk)
chunk_count = max([len(x) for x in batch_chunks])
zs = []
for i in range(chunk_count):
@ -663,6 +779,7 @@ class Clip_IO(scripts.Script):
prompt = gradio.TextArea(max_lines = 256, label = "Prompt")
with gradio.Row():
output_transpose = gradio.Checkbox(value = True, label = "Transpose matrix")
output_manual_chunk = gradio.Checkbox(value = False, label = "Don't add bos / eos / pad tokens")
output_ignore_emphasis = gradio.Checkbox(value = False, label = "Ignore emphasis")
output_bypass_conditioning_normalization = gradio.Checkbox(value = False, label = "Bypass conditioning normalization")
output_token_string = gradio.Checkbox(value = True, label = "Add token strings")
@ -676,10 +793,10 @@ class Clip_IO(scripts.Script):
output_conditioning_as_csv = gradio.Button("Save conditioning as .csv")
pass
output_notification = gradio.HTML()
output_embeddings_as_pt.click(Clip_IO.on_save_embeddings_as_pt, [prompt, output_name, output_overwrite], [output_notification])
output_embeddings_as_csv.click(Clip_IO.on_save_embeddings_as_csv, [prompt, output_name, output_transpose, output_token_string, output_overwrite], [output_notification])
output_conditioning_as_pt.click(Clip_IO.on_save_conditioning_as_pt, [prompt, output_name, output_ignore_emphasis, output_bypass_conditioning_normalization, output_overwrite], [output_notification])
output_conditioning_as_csv.click(Clip_IO.on_save_conditioning_as_csv, [prompt, output_name, output_transpose, output_ignore_emphasis, output_bypass_conditioning_normalization, output_token_string, output_overwrite], [output_notification])
output_embeddings_as_pt.click(Clip_IO.on_save_embeddings_as_pt, [prompt, output_name, output_overwrite, output_manual_chunk], [output_notification])
output_embeddings_as_csv.click(Clip_IO.on_save_embeddings_as_csv, [prompt, output_name, output_transpose, output_token_string, output_overwrite, output_manual_chunk], [output_notification])
output_conditioning_as_pt.click(Clip_IO.on_save_conditioning_as_pt, [prompt, output_name, output_ignore_emphasis, output_bypass_conditioning_normalization, output_overwrite, output_manual_chunk], [output_notification])
output_conditioning_as_csv.click(Clip_IO.on_save_conditioning_as_csv, [prompt, output_name, output_transpose, output_ignore_emphasis, output_bypass_conditioning_normalization, output_token_string, output_overwrite, output_manual_chunk], [output_notification])
pass
return [(tab, "Clip Output", "Clip_Output")]
pass