add support for OpenCLIP
parent
c5e95a7233
commit
ac6d541c70
|
|
@ -1,7 +1,8 @@
|
||||||
import html
|
import html
|
||||||
|
|
||||||
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
|
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
|
||||||
from modules import script_callbacks, shared
|
from modules import script_callbacks, shared
|
||||||
|
import open_clip.tokenizer
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
@ -21,15 +22,43 @@ css = """
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class VanillaClip:
|
||||||
|
def __init__(self, clip):
|
||||||
|
self.clip = clip
|
||||||
|
|
||||||
|
def vocab(self):
|
||||||
|
return self.clip.tokenizer.get_vocab()
|
||||||
|
|
||||||
|
def byte_decoder(self):
|
||||||
|
return self.clip.tokenizer.byte_decoder
|
||||||
|
|
||||||
|
class OpenClip:
|
||||||
|
def __init__(self, clip):
|
||||||
|
self.clip = clip
|
||||||
|
self.tokenizer = open_clip.tokenizer._tokenizer
|
||||||
|
|
||||||
|
def vocab(self):
|
||||||
|
return self.tokenizer.encoder
|
||||||
|
|
||||||
|
def byte_decoder(self):
|
||||||
|
return self.tokenizer.byte_decoder
|
||||||
|
|
||||||
|
|
||||||
def tokenize(text, input_is_ids=False):
|
def tokenize(text, input_is_ids=False):
|
||||||
clip: FrozenCLIPEmbedder = shared.sd_model.cond_stage_model.wrapped
|
clip = shared.sd_model.cond_stage_model.wrapped
|
||||||
|
if isinstance(clip, FrozenCLIPEmbedder):
|
||||||
|
clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped)
|
||||||
|
elif isinstance(clip, FrozenOpenCLIPEmbedder):
|
||||||
|
clip = OpenClip(shared.sd_model.cond_stage_model.wrapped)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}')
|
||||||
|
|
||||||
if input_is_ids:
|
if input_is_ids:
|
||||||
tokens = [int(x.strip()) for x in text.split(",")]
|
tokens = [int(x.strip()) for x in text.split(",")]
|
||||||
else:
|
else:
|
||||||
tokens = clip.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
tokens = shared.sd_model.cond_stage_model.tokenize([text])[0]
|
||||||
|
|
||||||
vocab = {v: k for k, v in clip.tokenizer.get_vocab().items()}
|
vocab = {v: k for k, v in clip.vocab().items()}
|
||||||
|
|
||||||
code = ''
|
code = ''
|
||||||
ids = []
|
ids = []
|
||||||
|
|
@ -37,6 +66,8 @@ def tokenize(text, input_is_ids=False):
|
||||||
current_ids = []
|
current_ids = []
|
||||||
class_index = 0
|
class_index = 0
|
||||||
|
|
||||||
|
byte_decoder = clip.byte_decoder()
|
||||||
|
|
||||||
def dump(last=False):
|
def dump(last=False):
|
||||||
nonlocal code, ids, current_ids
|
nonlocal code, ids, current_ids
|
||||||
|
|
||||||
|
|
@ -49,7 +80,7 @@ def tokenize(text, input_is_ids=False):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
try:
|
try:
|
||||||
word = bytearray([clip.tokenizer.byte_decoder[x] for x in ''.join(words)]).decode("utf-8")
|
word = bytearray([byte_decoder[x] for x in ''.join(words)]).decode("utf-8")
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
if last:
|
if last:
|
||||||
word = "❌" * len(current_ids)
|
word = "❌" * len(current_ids)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue