add support for OpenCLIP

master
AUTOMATIC 2022-12-10 15:58:31 +03:00
parent c5e95a7233
commit ac6d541c70
1 changed files with 36 additions and 5 deletions

View File

@ -1,7 +1,8 @@
import html import html
from ldm.modules.encoders.modules import FrozenCLIPEmbedder from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
from modules import script_callbacks, shared from modules import script_callbacks, shared
import open_clip.tokenizer
import gradio as gr import gradio as gr
@ -21,15 +22,43 @@ css = """
""" """
class VanillaClip:
def __init__(self, clip):
self.clip = clip
def vocab(self):
return self.clip.tokenizer.get_vocab()
def byte_decoder(self):
return self.clip.tokenizer.byte_decoder
class OpenClip:
def __init__(self, clip):
self.clip = clip
self.tokenizer = open_clip.tokenizer._tokenizer
def vocab(self):
return self.tokenizer.encoder
def byte_decoder(self):
return self.tokenizer.byte_decoder
def tokenize(text, input_is_ids=False): def tokenize(text, input_is_ids=False):
clip: FrozenCLIPEmbedder = shared.sd_model.cond_stage_model.wrapped clip = shared.sd_model.cond_stage_model.wrapped
if isinstance(clip, FrozenCLIPEmbedder):
clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped)
elif isinstance(clip, FrozenOpenCLIPEmbedder):
clip = OpenClip(shared.sd_model.cond_stage_model.wrapped)
else:
raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}')
if input_is_ids: if input_is_ids:
tokens = [int(x.strip()) for x in text.split(",")] tokens = [int(x.strip()) for x in text.split(",")]
else: else:
tokens = clip.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] tokens = shared.sd_model.cond_stage_model.tokenize([text])[0]
vocab = {v: k for k, v in clip.tokenizer.get_vocab().items()} vocab = {v: k for k, v in clip.vocab().items()}
code = '' code = ''
ids = [] ids = []
@ -37,6 +66,8 @@ def tokenize(text, input_is_ids=False):
current_ids = [] current_ids = []
class_index = 0 class_index = 0
byte_decoder = clip.byte_decoder()
def dump(last=False): def dump(last=False):
nonlocal code, ids, current_ids nonlocal code, ids, current_ids
@ -49,7 +80,7 @@ def tokenize(text, input_is_ids=False):
return res return res
try: try:
word = bytearray([clip.tokenizer.byte_decoder[x] for x in ''.join(words)]).decode("utf-8") word = bytearray([byte_decoder[x] for x in ''.join(words)]).decode("utf-8")
except UnicodeDecodeError: except UnicodeDecodeError:
if last: if last:
word = "" * len(current_ids) word = "" * len(current_ids)