diff --git a/scripts/tokenizer.py b/scripts/tokenizer.py index 02963b3..dc67e64 100644 --- a/scripts/tokenizer.py +++ b/scripts/tokenizer.py @@ -1,7 +1,8 @@ import html -from ldm.modules.encoders.modules import FrozenCLIPEmbedder +from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder from modules import script_callbacks, shared +import open_clip.tokenizer import gradio as gr @@ -21,15 +22,43 @@ css = """ """ +class VanillaClip: + def __init__(self, clip): + self.clip = clip + + def vocab(self): + return self.clip.tokenizer.get_vocab() + + def byte_decoder(self): + return self.clip.tokenizer.byte_decoder + +class OpenClip: + def __init__(self, clip): + self.clip = clip + self.tokenizer = open_clip.tokenizer._tokenizer + + def vocab(self): + return self.tokenizer.encoder + + def byte_decoder(self): + return self.tokenizer.byte_decoder + + def tokenize(text, input_is_ids=False): - clip: FrozenCLIPEmbedder = shared.sd_model.cond_stage_model.wrapped + clip = shared.sd_model.cond_stage_model.wrapped + if isinstance(clip, FrozenCLIPEmbedder): + clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped) + elif isinstance(clip, FrozenOpenCLIPEmbedder): + clip = OpenClip(shared.sd_model.cond_stage_model.wrapped) + else: + raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}') if input_is_ids: tokens = [int(x.strip()) for x in text.split(",")] else: - tokens = clip.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] + tokens = shared.sd_model.cond_stage_model.tokenize([text])[0] - vocab = {v: k for k, v in clip.tokenizer.get_vocab().items()} + vocab = {v: k for k, v in clip.vocab().items()} code = '' ids = [] @@ -37,6 +66,8 @@ def tokenize(text, input_is_ids=False): current_ids = [] class_index = 0 + byte_decoder = clip.byte_decoder() + def dump(last=False): nonlocal code, ids, current_ids @@ -49,7 +80,7 @@ def tokenize(text, input_is_ids=False): return res try: - word = bytearray([clip.tokenizer.byte_decoder[x] for x in ''.join(words)]).decode("utf-8") + word = bytearray([byte_decoder[x] for x in ''.join(words)]).decode("utf-8") except UnicodeDecodeError: if last: word = "❌" * len(current_ids)