pull/9/merge
YSH 2024-06-13 18:45:53 -07:00 committed by GitHub
commit 3eba94cbca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 30 additions and 19 deletions

View File

@ -1,30 +1,29 @@
import html
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
from modules import script_callbacks, shared
import open_clip.tokenizer
import gradio as gr
css = """
.tokenizer-token{
opacity: 0.8;
cursor: pointer;
--body-text-color: #000;
}
.tokenizer-token-0 {background: rgba(255, 0, 0, 0.05);}
.tokenizer-token-0:hover {background: rgba(255, 0, 0, 0.15);}
.tokenizer-token-1 {background: rgba(0, 255, 0, 0.05);}
.tokenizer-token-1:hover {background: rgba(0, 255, 0, 0.15);}
.tokenizer-token-2 {background: rgba(0, 0, 255, 0.05);}
.tokenizer-token-2:hover {background: rgba(0, 0, 255, 0.15);}
.tokenizer-token-3 {background: rgba(255, 156, 0, 0.05);}
.tokenizer-token-3:hover {background: rgba(255, 156, 0, 0.15);}
.tokenizer-token:hover {opacity: 1;}
.tokenizer-token-0 {background: var(--primary-300);}
.tokenizer-token-1 {background: var(--primary-400);}
.tokenizer-token-2 {background: var(--primary-500);}
.tokenizer-token-3 {background: var(--primary-600);}
"""
class VanillaClip:
def __init__(self, clip):
self.clip = clip
assert hasattr(self.clip, "tokenizer"), "VanillaClip requires 'tokenizer' attribute"
assert hasattr(self.clip.tokenizer, "get_vocab"), "Tokenizer must have 'get_vocab'"
assert hasattr(self.clip.tokenizer, "byte_decoder"), "Tokenizer must have 'byte_decoder'"
def vocab(self):
return self.clip.tokenizer.get_vocab()
@ -35,7 +34,9 @@ class VanillaClip:
class OpenClip:
def __init__(self, clip):
self.clip = clip
self.tokenizer = open_clip.tokenizer._tokenizer
assert hasattr(self.clip, "tokenizer"), "OpenClip requires 'tokenizer' attribute"
assert hasattr(self.clip.tokenizer, "_tokenizer"), "Tokenizer must have '_tokenizer'"
self.tokenizer = self.clip.tokenizer._tokenizer
def vocab(self):
return self.tokenizer.encoder
@ -43,15 +44,25 @@ class OpenClip:
def byte_decoder(self):
return self.tokenizer.byte_decoder
def initialize_clip_instance():
base = shared.sd_model.cond_stage_model
clip_candidates = [base.wrapped]
# For SDXL, it is FrozenCLIPEmbedderForSDXLWithCustomWords and some others which are initalized in sd_hijack_clip.py in the embedders
if hasattr(base, 'embedders'):
for embedder in base.embedders:
if hasattr(embedder, 'wrapped'):
clip_candidates.append(embedder.wrapped)
initializers = [VanillaClip, OpenClip]
for clip in clip_candidates:
for initializer in initializers:
try:
return initializer(clip)
except AssertionError:
continue
raise RuntimeError('Failed to initialize a compatible CLIP instance from any candidate')
def tokenize(text, input_is_ids=False):
clip = shared.sd_model.cond_stage_model.wrapped
if isinstance(clip, FrozenCLIPEmbedder):
clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped)
elif isinstance(clip, FrozenOpenCLIPEmbedder):
clip = OpenClip(shared.sd_model.cond_stage_model.wrapped)
else:
raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}')
clip = initialize_clip_instance()
if input_is_ids:
tokens = [int(x.strip()) for x in text.split(",")]