Add support for Open_Clip

main
File_xor 2023-06-02 02:06:52 +09:00
parent 118654f900
commit 24979528fd
1 changed files with 14 additions and 4 deletions

View File

@ -3,7 +3,7 @@ import os, csv, warnings, datetime
import gradio
import torch
import lark
from tkinter import filedialog
import open_clip
from modules import scripts, script_callbacks, shared, devices, processing, prompt_parser
from modules.shared import opts
@ -445,7 +445,17 @@ class Clip_IO(scripts.Script):
clip.hijack.fixes = [fixes]
input_ids_Tensor = torch.asarray([input_ids]).to(devices.device)
tokens = [clip.wrapped.tokenizer.decoder.get(input_id) for input_id in input_ids]
decode: callable[any, str]
if hasattr(clip.wrapped, "tokenizer"): # Ver.1.x
decode = clip.wrapped.tokenizer.decoder.get
is_open_clip = False
pass
else: # Ver.2.x
decode = lambda t: open_clip.tokenizer._tokenizer.decoder.get(t)
is_open_clip = True
pass
tokens = [decode(input_id) for input_id in input_ids]
for fix in fixes:
tokens[fix.offset + 1] = fix.embedding.name
for i in range(1, fix.embedding.vec.shape[0]):
@ -453,7 +463,7 @@ class Clip_IO(scripts.Script):
pass
pass
return clip.wrapped.transformer.text_model.embeddings.token_embedding(input_ids_Tensor), tokens
return clip.wrapped.model.token_embedding(input_ids_Tensor) if is_open_clip else clip.wrapped.transformer.text_model.embeddings.token_embedding(input_ids_Tensor), tokens
pass
def on_save_embeddings_as_pt(prompt: str, filename: str, overwrite: bool):
@ -546,7 +556,7 @@ class Clip_IO(scripts.Script):
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(clip.id_end)
tokens[batch_pos, index+1:tokens.shape[1]] = clip.id_pad
z = clip.encode_with_transformers(tokens)
if not no_emphasis:
batch_multipliers = torch.asarray([x.multipliers for x in batch_chunk]).to(devices.device)