diff --git a/scripts/tokenizer.py b/scripts/tokenizer.py index dc67e64..7a45587 100644 --- a/scripts/tokenizer.py +++ b/scripts/tokenizer.py @@ -1,30 +1,29 @@ import html - -from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder from modules import script_callbacks, shared -import open_clip.tokenizer import gradio as gr css = """ .tokenizer-token{ + opacity: 0.8; cursor: pointer; + --body-text-color: #000; } -.tokenizer-token-0 {background: rgba(255, 0, 0, 0.05);} -.tokenizer-token-0:hover {background: rgba(255, 0, 0, 0.15);} -.tokenizer-token-1 {background: rgba(0, 255, 0, 0.05);} -.tokenizer-token-1:hover {background: rgba(0, 255, 0, 0.15);} -.tokenizer-token-2 {background: rgba(0, 0, 255, 0.05);} -.tokenizer-token-2:hover {background: rgba(0, 0, 255, 0.15);} -.tokenizer-token-3 {background: rgba(255, 156, 0, 0.05);} -.tokenizer-token-3:hover {background: rgba(255, 156, 0, 0.15);} +.tokenizer-token:hover {opacity: 1;} +.tokenizer-token-0 {background: var(--primary-300);} +.tokenizer-token-1 {background: var(--primary-400);} +.tokenizer-token-2 {background: var(--primary-500);} +.tokenizer-token-3 {background: var(--primary-600);} """ class VanillaClip: def __init__(self, clip): self.clip = clip + assert hasattr(self.clip, "tokenizer"), "VanillaClip requires 'tokenizer' attribute" + assert hasattr(self.clip.tokenizer, "get_vocab"), "Tokenizer must have 'get_vocab'" + assert hasattr(self.clip.tokenizer, "byte_decoder"), "Tokenizer must have 'byte_decoder'" def vocab(self): return self.clip.tokenizer.get_vocab() @@ -35,7 +34,9 @@ class VanillaClip: class OpenClip: def __init__(self, clip): self.clip = clip - self.tokenizer = open_clip.tokenizer._tokenizer + assert hasattr(self.clip, "tokenizer"), "OpenClip requires 'tokenizer' attribute" + assert hasattr(self.clip.tokenizer, "_tokenizer"), "Tokenizer must have '_tokenizer'" + self.tokenizer = self.clip.tokenizer._tokenizer def vocab(self): return self.tokenizer.encoder @@ -43,15 +44,25 @@ class OpenClip: def byte_decoder(self): return self.tokenizer.byte_decoder +def initialize_clip_instance(): + base = shared.sd_model.cond_stage_model + clip_candidates = [base.wrapped] + # For SDXL, it is FrozenCLIPEmbedderForSDXLWithCustomWords and some others which are initalized in sd_hijack_clip.py in the embedders + if hasattr(base, 'embedders'): + for embedder in base.embedders: + if hasattr(embedder, 'wrapped'): + clip_candidates.append(embedder.wrapped) + initializers = [VanillaClip, OpenClip] + for clip in clip_candidates: + for initializer in initializers: + try: + return initializer(clip) + except AssertionError: + continue + raise RuntimeError('Failed to initialize a compatible CLIP instance from any candidate') def tokenize(text, input_is_ids=False): - clip = shared.sd_model.cond_stage_model.wrapped - if isinstance(clip, FrozenCLIPEmbedder): - clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped) - elif isinstance(clip, FrozenOpenCLIPEmbedder): - clip = OpenClip(shared.sd_model.cond_stage_model.wrapped) - else: - raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}') + clip = initialize_clip_instance() if input_is_ids: tokens = [int(x.strip()) for x in text.split(",")]