Merge 7e74e5de76 into ac6d541c70
commit
3eba94cbca
|
|
@ -1,30 +1,29 @@
|
|||
import html
|
||||
|
||||
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
|
||||
from modules import script_callbacks, shared
|
||||
import open_clip.tokenizer
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
css = """
|
||||
.tokenizer-token{
|
||||
opacity: 0.8;
|
||||
cursor: pointer;
|
||||
--body-text-color: #000;
|
||||
}
|
||||
.tokenizer-token-0 {background: rgba(255, 0, 0, 0.05);}
|
||||
.tokenizer-token-0:hover {background: rgba(255, 0, 0, 0.15);}
|
||||
.tokenizer-token-1 {background: rgba(0, 255, 0, 0.05);}
|
||||
.tokenizer-token-1:hover {background: rgba(0, 255, 0, 0.15);}
|
||||
.tokenizer-token-2 {background: rgba(0, 0, 255, 0.05);}
|
||||
.tokenizer-token-2:hover {background: rgba(0, 0, 255, 0.15);}
|
||||
.tokenizer-token-3 {background: rgba(255, 156, 0, 0.05);}
|
||||
.tokenizer-token-3:hover {background: rgba(255, 156, 0, 0.15);}
|
||||
.tokenizer-token:hover {opacity: 1;}
|
||||
.tokenizer-token-0 {background: var(--primary-300);}
|
||||
.tokenizer-token-1 {background: var(--primary-400);}
|
||||
.tokenizer-token-2 {background: var(--primary-500);}
|
||||
.tokenizer-token-3 {background: var(--primary-600);}
|
||||
"""
|
||||
|
||||
|
||||
class VanillaClip:
|
||||
def __init__(self, clip):
|
||||
self.clip = clip
|
||||
assert hasattr(self.clip, "tokenizer"), "VanillaClip requires 'tokenizer' attribute"
|
||||
assert hasattr(self.clip.tokenizer, "get_vocab"), "Tokenizer must have 'get_vocab'"
|
||||
assert hasattr(self.clip.tokenizer, "byte_decoder"), "Tokenizer must have 'byte_decoder'"
|
||||
|
||||
def vocab(self):
|
||||
return self.clip.tokenizer.get_vocab()
|
||||
|
|
@ -35,7 +34,9 @@ class VanillaClip:
|
|||
class OpenClip:
|
||||
def __init__(self, clip):
|
||||
self.clip = clip
|
||||
self.tokenizer = open_clip.tokenizer._tokenizer
|
||||
assert hasattr(self.clip, "tokenizer"), "OpenClip requires 'tokenizer' attribute"
|
||||
assert hasattr(self.clip.tokenizer, "_tokenizer"), "Tokenizer must have '_tokenizer'"
|
||||
self.tokenizer = self.clip.tokenizer._tokenizer
|
||||
|
||||
def vocab(self):
|
||||
return self.tokenizer.encoder
|
||||
|
|
@ -43,15 +44,25 @@ class OpenClip:
|
|||
def byte_decoder(self):
|
||||
return self.tokenizer.byte_decoder
|
||||
|
||||
def initialize_clip_instance():
|
||||
base = shared.sd_model.cond_stage_model
|
||||
clip_candidates = [base.wrapped]
|
||||
# For SDXL, it is FrozenCLIPEmbedderForSDXLWithCustomWords and some others which are initalized in sd_hijack_clip.py in the embedders
|
||||
if hasattr(base, 'embedders'):
|
||||
for embedder in base.embedders:
|
||||
if hasattr(embedder, 'wrapped'):
|
||||
clip_candidates.append(embedder.wrapped)
|
||||
initializers = [VanillaClip, OpenClip]
|
||||
for clip in clip_candidates:
|
||||
for initializer in initializers:
|
||||
try:
|
||||
return initializer(clip)
|
||||
except AssertionError:
|
||||
continue
|
||||
raise RuntimeError('Failed to initialize a compatible CLIP instance from any candidate')
|
||||
|
||||
def tokenize(text, input_is_ids=False):
|
||||
clip = shared.sd_model.cond_stage_model.wrapped
|
||||
if isinstance(clip, FrozenCLIPEmbedder):
|
||||
clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped)
|
||||
elif isinstance(clip, FrozenOpenCLIPEmbedder):
|
||||
clip = OpenClip(shared.sd_model.cond_stage_model.wrapped)
|
||||
else:
|
||||
raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}')
|
||||
clip = initialize_clip_instance()
|
||||
|
||||
if input_is_ids:
|
||||
tokens = [int(x.strip()) for x in text.split(",")]
|
||||
|
|
|
|||
Loading…
Reference in New Issue