add support for OpenCLIP

master
AUTOMATIC 2022-12-10 15:58:31 +03:00
parent c5e95a7233
commit ac6d541c70
1 changed files with 36 additions and 5 deletions

View File

@ -1,7 +1,8 @@
import html
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
from modules import script_callbacks, shared
import open_clip.tokenizer
import gradio as gr
@ -21,15 +22,43 @@ css = """
"""
class VanillaClip:
def __init__(self, clip):
self.clip = clip
def vocab(self):
return self.clip.tokenizer.get_vocab()
def byte_decoder(self):
return self.clip.tokenizer.byte_decoder
class OpenClip:
def __init__(self, clip):
self.clip = clip
self.tokenizer = open_clip.tokenizer._tokenizer
def vocab(self):
return self.tokenizer.encoder
def byte_decoder(self):
return self.tokenizer.byte_decoder
def tokenize(text, input_is_ids=False):
clip: FrozenCLIPEmbedder = shared.sd_model.cond_stage_model.wrapped
clip = shared.sd_model.cond_stage_model.wrapped
if isinstance(clip, FrozenCLIPEmbedder):
clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped)
elif isinstance(clip, FrozenOpenCLIPEmbedder):
clip = OpenClip(shared.sd_model.cond_stage_model.wrapped)
else:
raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}')
if input_is_ids:
tokens = [int(x.strip()) for x in text.split(",")]
else:
tokens = clip.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
tokens = shared.sd_model.cond_stage_model.tokenize([text])[0]
vocab = {v: k for k, v in clip.tokenizer.get_vocab().items()}
vocab = {v: k for k, v in clip.vocab().items()}
code = ''
ids = []
@ -37,6 +66,8 @@ def tokenize(text, input_is_ids=False):
current_ids = []
class_index = 0
byte_decoder = clip.byte_decoder()
def dump(last=False):
nonlocal code, ids, current_ids
@ -49,7 +80,7 @@ def tokenize(text, input_is_ids=False):
return res
try:
word = bytearray([clip.tokenizer.byte_decoder[x] for x in ''.join(words)]).decode("utf-8")
word = bytearray([byte_decoder[x] for x in ''.join(words)]).decode("utf-8")
except UnicodeDecodeError:
if last:
word = "" * len(current_ids)