mirror of https://github.com/Filexor/Clip_IO.git
Add support for Open_Clip
parent
118654f900
commit
24979528fd
|
|
@ -3,7 +3,7 @@ import os, csv, warnings, datetime
|
|||
import gradio
|
||||
import torch
|
||||
import lark
|
||||
from tkinter import filedialog
|
||||
import open_clip
|
||||
|
||||
from modules import scripts, script_callbacks, shared, devices, processing, prompt_parser
|
||||
from modules.shared import opts
|
||||
|
|
@ -445,7 +445,17 @@ class Clip_IO(scripts.Script):
|
|||
clip.hijack.fixes = [fixes]
|
||||
input_ids_Tensor = torch.asarray([input_ids]).to(devices.device)
|
||||
|
||||
tokens = [clip.wrapped.tokenizer.decoder.get(input_id) for input_id in input_ids]
|
||||
decode: callable[any, str]
|
||||
if hasattr(clip.wrapped, "tokenizer"): # Ver.1.x
|
||||
decode = clip.wrapped.tokenizer.decoder.get
|
||||
is_open_clip = False
|
||||
pass
|
||||
else: # Ver.2.x
|
||||
decode = lambda t: open_clip.tokenizer._tokenizer.decoder.get(t)
|
||||
is_open_clip = True
|
||||
pass
|
||||
tokens = [decode(input_id) for input_id in input_ids]
|
||||
|
||||
for fix in fixes:
|
||||
tokens[fix.offset + 1] = fix.embedding.name
|
||||
for i in range(1, fix.embedding.vec.shape[0]):
|
||||
|
|
@ -453,7 +463,7 @@ class Clip_IO(scripts.Script):
|
|||
pass
|
||||
pass
|
||||
|
||||
return clip.wrapped.transformer.text_model.embeddings.token_embedding(input_ids_Tensor), tokens
|
||||
return clip.wrapped.model.token_embedding(input_ids_Tensor) if is_open_clip else clip.wrapped.transformer.text_model.embeddings.token_embedding(input_ids_Tensor), tokens
|
||||
pass
|
||||
|
||||
def on_save_embeddings_as_pt(prompt: str, filename: str, overwrite: bool):
|
||||
|
|
@ -546,7 +556,7 @@ class Clip_IO(scripts.Script):
|
|||
for batch_pos in range(len(remade_batch_tokens)):
|
||||
index = remade_batch_tokens[batch_pos].index(clip.id_end)
|
||||
tokens[batch_pos, index+1:tokens.shape[1]] = clip.id_pad
|
||||
|
||||
|
||||
z = clip.encode_with_transformers(tokens)
|
||||
if not no_emphasis:
|
||||
batch_multipliers = torch.asarray([x.multipliers for x in batch_chunk]).to(devices.device)
|
||||
|
|
|
|||
Loading…
Reference in New Issue