Restore partial support for Forge, see #22
parent
4512e7a250
commit
40d439de3c
|
|
@ -306,6 +306,8 @@ A cat is chasing a dog. <''-'road'-'grass'>
|
|||
clip = shared.sd_model.cond_stage_model
|
||||
if hasattr(clip,'embedders'):
|
||||
clip = clip.embedders[0]
|
||||
if clip is None:
|
||||
return None
|
||||
clip = clip.wrapped
|
||||
typename = type(clip).__name__.split('.')[-1]
|
||||
if typename=='FrozenOpenCLIPEmbedder':
|
||||
|
|
@ -361,7 +363,18 @@ A cat is chasing a dog. <''-'road'-'grass'>
|
|||
return res
|
||||
|
||||
def get_model_clips():
|
||||
clip = shared.sd_model.cond_stage_model
|
||||
sd_model = shared.sd_model
|
||||
clip = sd_model.cond_stage_model
|
||||
if clip is None:
|
||||
clip = sd_model.text_processing_engine if hasattr(sd_model,'text_processing_engine') else None
|
||||
if clip is None:
|
||||
clip_l = sd_model.text_processing_engine_l if hasattr(sd_model,'text_processing_engine_l') else None
|
||||
clip_g = sd_model.text_processing_engine_g if hasattr(sd_model,'text_processing_engine_g') else None
|
||||
if clip_l is not None:
|
||||
if clip_g is not None:
|
||||
return (clip_l,clip_g)
|
||||
return (clip_l,)
|
||||
raise Exception_From_EmbeddingMergeExtension('Could not find CLIP model!')
|
||||
if(hasattr(clip,'embedders')):
|
||||
try:
|
||||
return (clip.embedders[0],clip.embedders[1]) # SDXL
|
||||
|
|
@ -369,13 +382,41 @@ A cat is chasing a dog. <''-'road'-'grass'>
|
|||
pass
|
||||
return (clip,) # SD1 or SD2
|
||||
|
||||
def get_embedding_db():
|
||||
try:
|
||||
db = modules.sd_hijack.model_hijack.embedding_db
|
||||
if db is not None:
|
||||
return (db,)
|
||||
except:
|
||||
pass
|
||||
clips = get_model_clips()
|
||||
return [c.embeddings for c in clips]
|
||||
|
||||
def tokenize_line(clip,text):
|
||||
if hasattr(clip,'encode_embedding_init_text'):
|
||||
return clip.tokenize_line(str_to_escape(text))
|
||||
old = clip.emphasis.name
|
||||
clip.emphasis.name = 'None'
|
||||
try:
|
||||
res = clip.tokenize_line(text)
|
||||
finally:
|
||||
clip.emphasis.name = old
|
||||
return res
|
||||
|
||||
def encode_embedding_init_text(clip,text,length=999):
|
||||
if hasattr(clip,'encode_embedding_init_text'):
|
||||
return clip.encode_embedding_init_text(text,length)
|
||||
part = tokenize_line(clip,text)
|
||||
tokens = part[0][0].tokens
|
||||
return clip.text_encoder.transformer.text_model.embeddings.token_embedding.wrapped(torch.tensor(tokens[1:part[1]+1]))
|
||||
|
||||
def text_to_vectors(orig_text):
|
||||
try:
|
||||
both = []
|
||||
for clip,lg in zip(get_model_clips(),('clip_l','clip_g')):
|
||||
res = []
|
||||
text = orig_text.lstrip().lower()
|
||||
tokens = clip.tokenize_line(str_to_escape(text))
|
||||
tokens = tokenize_line(clip,text)
|
||||
count = tokens[1]
|
||||
tokens = tokens[0][0]
|
||||
fixes = tokens.fixes
|
||||
|
|
@ -401,13 +442,13 @@ A cat is chasing a dog. <''-'road'-'grass'>
|
|||
return None
|
||||
test = pos+lenname
|
||||
sub = text[0:test]
|
||||
part = clip.tokenize_line(str_to_escape(sub))
|
||||
part = tokenize_line(clip,sub)
|
||||
cnt = part[1]
|
||||
part = part[0][0]
|
||||
vec = off-start
|
||||
need = tokens[start:off+num]
|
||||
if part.tokens[1:cnt+1]==need:
|
||||
trans = clip.encode_embedding_init_text(text,vec)
|
||||
trans = encode_embedding_init_text(clip,text,vec)
|
||||
t = trans[:vec].to(device=devices.device,dtype=torch.float32)
|
||||
res.append((t,sub[:pos],need[:vec]))
|
||||
text = text[pos:]
|
||||
|
|
@ -420,13 +461,13 @@ A cat is chasing a dog. <''-'road'-'grass'>
|
|||
start += num
|
||||
text = text[lenname:].lstrip()
|
||||
if text!='':
|
||||
part = clip.tokenize_line(str_to_escape(text))
|
||||
part = tokenize_line(clip,text)
|
||||
cnt = part[1]
|
||||
part = part[0][0]
|
||||
need = tokens[start:]
|
||||
if part.tokens[1:cnt+1]!=need:
|
||||
return None
|
||||
trans = clip.encode_embedding_init_text(text,999)
|
||||
trans = encode_embedding_init_text(clip,text,999)
|
||||
trans = trans.to(device=devices.device,dtype=torch.float32)
|
||||
res.append((trans,text,need))
|
||||
both.append(res)
|
||||
|
|
@ -840,7 +881,7 @@ A cat is chasing a dog. <''-'road'-'grass'>
|
|||
return (right,None)
|
||||
|
||||
def grab_embedding_cache():
|
||||
db = modules.sd_hijack.model_hijack.embedding_db
|
||||
db = get_embedding_db()[0]
|
||||
field = '__embedding_merge_cache_'
|
||||
if hasattr(db,field):
|
||||
cache = getattr(db,field)
|
||||
|
|
@ -850,30 +891,33 @@ A cat is chasing a dog. <''-'road'-'grass'>
|
|||
return cache
|
||||
|
||||
def register_embedding(name,embedding):
|
||||
self = modules.sd_hijack.model_hijack.embedding_db
|
||||
model = shared.sd_model
|
||||
if hasattr(self,'register_embedding_by_name'):
|
||||
return self.register_embedding_by_name(embedding,model,name)
|
||||
# /modules/textual_inversion/textual_inversion.py
|
||||
try:
|
||||
ids = model.cond_stage_model.tokenize([name])[0]
|
||||
first_id = ids[0]
|
||||
except:
|
||||
return
|
||||
if embedding is None:
|
||||
if self.word_embeddings[name] is None:
|
||||
for self in get_embedding_db():
|
||||
model = shared.sd_model
|
||||
if hasattr(self,'register_embedding_by_name'):
|
||||
try:
|
||||
return self.register_embedding_by_name(embedding,model,name)
|
||||
except TypeError:
|
||||
return self.register_embedding_by_name(embedding,name)
|
||||
# /modules/textual_inversion/textual_inversion.py
|
||||
try:
|
||||
ids = model.cond_stage_model.tokenize([name])[0]
|
||||
first_id = ids[0]
|
||||
except:
|
||||
return
|
||||
del self.word_embeddings[name]
|
||||
else:
|
||||
self.word_embeddings[name] = embedding
|
||||
if first_id not in self.ids_lookup:
|
||||
if embedding is None:
|
||||
return
|
||||
self.ids_lookup[first_id] = []
|
||||
save = [(ids, embedding)] if embedding is not None else []
|
||||
old = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
|
||||
self.ids_lookup[first_id] = sorted(old + save, key=lambda x: len(x[0]), reverse=True)
|
||||
return embedding
|
||||
if self.word_embeddings[name] is None:
|
||||
return
|
||||
del self.word_embeddings[name]
|
||||
else:
|
||||
self.word_embeddings[name] = embedding
|
||||
if first_id not in self.ids_lookup:
|
||||
if embedding is None:
|
||||
return
|
||||
self.ids_lookup[first_id] = []
|
||||
save = [(ids, embedding)] if embedding is not None else []
|
||||
old = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
|
||||
self.ids_lookup[first_id] = sorted(old + save, key=lambda x: len(x[0]), reverse=True)
|
||||
return embedding
|
||||
|
||||
def make_temp_embedding(name,vectors,cache,fake):
|
||||
embed = None
|
||||
|
|
@ -1045,6 +1089,10 @@ A cat is chasing a dog. <''-'road'-'grass'>
|
|||
sd_models.reload_model_weights()
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
sd_models.forge_model_reload()
|
||||
except:
|
||||
pass
|
||||
gr_orig = gr_text
|
||||
font = 'font-family:Consolas,Courier New,Courier,monospace;'
|
||||
table = '<style>.webui_embedding_merge_table,.webui_embedding_merge_table td,.webui_embedding_merge_table th{border:1px solid gray;border-collapse:collapse}.webui_embedding_merge_table td,.webui_embedding_merge_table th{padding:2px 5px !important;text-align:center !important;vertical-align:middle;'+font+'font-weight:bold;}.webui_embedding_merge_table{margin:6px auto !important;}</style>'
|
||||
|
|
@ -1344,10 +1392,11 @@ A cat is chasing a dog. <''-'road'-'grass'>
|
|||
if pt is not None:
|
||||
from safetensors.torch import save_file
|
||||
save_file(pt,target)
|
||||
try:
|
||||
modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||
except:
|
||||
modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
for db in get_embedding_db():
|
||||
try:
|
||||
db.load_textual_inversion_embeddings(force_reload=True)
|
||||
except:
|
||||
db.load_textual_inversion_embeddings()
|
||||
return ''
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
|
@ -1467,6 +1516,8 @@ A cat is chasing a dog. <''-'road'-'grass'>
|
|||
pretty_print(v,indent,dupl)
|
||||
else:
|
||||
print(tab + k + ': ' + str(v))
|
||||
import code
|
||||
code.interact(local=locals())
|
||||
'''
|
||||
|
||||
def hook_infotext(hook):
|
||||
|
|
@ -1649,7 +1700,7 @@ A cat is chasing a dog. <''-'road'-'grass'>
|
|||
except:
|
||||
traceback.print_exc()
|
||||
try:
|
||||
db = modules.sd_hijack.model_hijack.embedding_db
|
||||
db = get_embedding_db()[0]
|
||||
field = '__embedding_merge_cache_'
|
||||
if hasattr(db,field):
|
||||
delattr(db,field)
|
||||
|
|
|
|||
Loading…
Reference in New Issue