diff --git a/scripts/Clip_IO.py b/scripts/Clip_IO.py index 77ae2b6..a447aac 100644 --- a/scripts/Clip_IO.py +++ b/scripts/Clip_IO.py @@ -3,7 +3,7 @@ import os, csv, warnings, datetime import gradio import torch import lark -from tkinter import filedialog +import open_clip from modules import scripts, script_callbacks, shared, devices, processing, prompt_parser from modules.shared import opts @@ -445,7 +445,17 @@ class Clip_IO(scripts.Script): clip.hijack.fixes = [fixes] input_ids_Tensor = torch.asarray([input_ids]).to(devices.device) - tokens = [clip.wrapped.tokenizer.decoder.get(input_id) for input_id in input_ids] + decode: callable[any, str] + if hasattr(clip.wrapped, "tokenizer"): # Ver.1.x + decode = clip.wrapped.tokenizer.decoder.get + is_open_clip = False + pass + else: # Ver.2.x + decode = lambda t: open_clip.tokenizer._tokenizer.decoder.get(t) + is_open_clip = True + pass + tokens = [decode(input_id) for input_id in input_ids] + for fix in fixes: tokens[fix.offset + 1] = fix.embedding.name for i in range(1, fix.embedding.vec.shape[0]): @@ -453,7 +463,7 @@ class Clip_IO(scripts.Script): pass pass - return clip.wrapped.transformer.text_model.embeddings.token_embedding(input_ids_Tensor), tokens + return clip.wrapped.model.token_embedding(input_ids_Tensor) if is_open_clip else clip.wrapped.transformer.text_model.embeddings.token_embedding(input_ids_Tensor), tokens pass def on_save_embeddings_as_pt(prompt: str, filename: str, overwrite: bool): @@ -546,7 +556,7 @@ class Clip_IO(scripts.Script): for batch_pos in range(len(remade_batch_tokens)): index = remade_batch_tokens[batch_pos].index(clip.id_end) tokens[batch_pos, index+1:tokens.shape[1]] = clip.id_pad - + z = clip.encode_with_transformers(tokens) if not no_emphasis: batch_multipliers = torch.asarray([x.multipliers for x in batch_chunk]).to(devices.device)