add support for OpenCLIP
parent
c5e95a7233
commit
ac6d541c70
|
|
@ -1,7 +1,8 @@
|
|||
import html
|
||||
|
||||
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
|
||||
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
|
||||
from modules import script_callbacks, shared
|
||||
import open_clip.tokenizer
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
|
@ -21,15 +22,43 @@ css = """
|
|||
"""
|
||||
|
||||
|
||||
class VanillaClip:
|
||||
def __init__(self, clip):
|
||||
self.clip = clip
|
||||
|
||||
def vocab(self):
|
||||
return self.clip.tokenizer.get_vocab()
|
||||
|
||||
def byte_decoder(self):
|
||||
return self.clip.tokenizer.byte_decoder
|
||||
|
||||
class OpenClip:
|
||||
def __init__(self, clip):
|
||||
self.clip = clip
|
||||
self.tokenizer = open_clip.tokenizer._tokenizer
|
||||
|
||||
def vocab(self):
|
||||
return self.tokenizer.encoder
|
||||
|
||||
def byte_decoder(self):
|
||||
return self.tokenizer.byte_decoder
|
||||
|
||||
|
||||
def tokenize(text, input_is_ids=False):
|
||||
clip: FrozenCLIPEmbedder = shared.sd_model.cond_stage_model.wrapped
|
||||
clip = shared.sd_model.cond_stage_model.wrapped
|
||||
if isinstance(clip, FrozenCLIPEmbedder):
|
||||
clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped)
|
||||
elif isinstance(clip, FrozenOpenCLIPEmbedder):
|
||||
clip = OpenClip(shared.sd_model.cond_stage_model.wrapped)
|
||||
else:
|
||||
raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}')
|
||||
|
||||
if input_is_ids:
|
||||
tokens = [int(x.strip()) for x in text.split(",")]
|
||||
else:
|
||||
tokens = clip.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
tokens = shared.sd_model.cond_stage_model.tokenize([text])[0]
|
||||
|
||||
vocab = {v: k for k, v in clip.tokenizer.get_vocab().items()}
|
||||
vocab = {v: k for k, v in clip.vocab().items()}
|
||||
|
||||
code = ''
|
||||
ids = []
|
||||
|
|
@ -37,6 +66,8 @@ def tokenize(text, input_is_ids=False):
|
|||
current_ids = []
|
||||
class_index = 0
|
||||
|
||||
byte_decoder = clip.byte_decoder()
|
||||
|
||||
def dump(last=False):
|
||||
nonlocal code, ids, current_ids
|
||||
|
||||
|
|
@ -49,7 +80,7 @@ def tokenize(text, input_is_ids=False):
|
|||
return res
|
||||
|
||||
try:
|
||||
word = bytearray([clip.tokenizer.byte_decoder[x] for x in ''.join(words)]).decode("utf-8")
|
||||
word = bytearray([byte_decoder[x] for x in ''.join(words)]).decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
if last:
|
||||
word = "❌" * len(current_ids)
|
||||
|
|
|
|||
Loading…
Reference in New Issue