Restore partial support for Forge, see #22

sdxl
klimaleksus 2025-06-06 20:10:56 +05:00 committed by GitHub
parent 4512e7a250
commit 40d439de3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 85 additions and 34 deletions

View File

@ -306,6 +306,8 @@ A cat is chasing a dog. <''-'road'-'grass'>
clip = shared.sd_model.cond_stage_model
if hasattr(clip,'embedders'):
clip = clip.embedders[0]
if clip is None:
return None
clip = clip.wrapped
typename = type(clip).__name__.split('.')[-1]
if typename=='FrozenOpenCLIPEmbedder':
@ -361,7 +363,18 @@ A cat is chasing a dog. <''-'road'-'grass'>
return res
def get_model_clips():
clip = shared.sd_model.cond_stage_model
sd_model = shared.sd_model
clip = sd_model.cond_stage_model
if clip is None:
clip = sd_model.text_processing_engine if hasattr(sd_model,'text_processing_engine') else None
if clip is None:
clip_l = sd_model.text_processing_engine_l if hasattr(sd_model,'text_processing_engine_l') else None
clip_g = sd_model.text_processing_engine_g if hasattr(sd_model,'text_processing_engine_g') else None
if clip_l is not None:
if clip_g is not None:
return (clip_l,clip_g)
return (clip_l,)
raise Exception_From_EmbeddingMergeExtension('Could not find CLIP model!')
if(hasattr(clip,'embedders')):
try:
return (clip.embedders[0],clip.embedders[1]) # SDXL
@ -369,13 +382,41 @@ A cat is chasing a dog. <''-'road'-'grass'>
pass
return (clip,) # SD1 or SD2
def get_embedding_db():
try:
db = modules.sd_hijack.model_hijack.embedding_db
if db is not None:
return (db,)
except:
pass
clips = get_model_clips()
return [c.embeddings for c in clips]
def tokenize_line(clip,text):
if hasattr(clip,'encode_embedding_init_text'):
return clip.tokenize_line(str_to_escape(text))
old = clip.emphasis.name
clip.emphasis.name = 'None'
try:
res = clip.tokenize_line(text)
finally:
clip.emphasis.name = old
return res
def encode_embedding_init_text(clip,text,length=999):
if hasattr(clip,'encode_embedding_init_text'):
return clip.encode_embedding_init_text(text,length)
part = tokenize_line(clip,text)
tokens = part[0][0].tokens
return clip.text_encoder.transformer.text_model.embeddings.token_embedding.wrapped(torch.tensor(tokens[1:part[1]+1]))
def text_to_vectors(orig_text):
try:
both = []
for clip,lg in zip(get_model_clips(),('clip_l','clip_g')):
res = []
text = orig_text.lstrip().lower()
tokens = clip.tokenize_line(str_to_escape(text))
tokens = tokenize_line(clip,text)
count = tokens[1]
tokens = tokens[0][0]
fixes = tokens.fixes
@ -401,13 +442,13 @@ A cat is chasing a dog. <''-'road'-'grass'>
return None
test = pos+lenname
sub = text[0:test]
part = clip.tokenize_line(str_to_escape(sub))
part = tokenize_line(clip,sub)
cnt = part[1]
part = part[0][0]
vec = off-start
need = tokens[start:off+num]
if part.tokens[1:cnt+1]==need:
trans = clip.encode_embedding_init_text(text,vec)
trans = encode_embedding_init_text(clip,text,vec)
t = trans[:vec].to(device=devices.device,dtype=torch.float32)
res.append((t,sub[:pos],need[:vec]))
text = text[pos:]
@ -420,13 +461,13 @@ A cat is chasing a dog. <''-'road'-'grass'>
start += num
text = text[lenname:].lstrip()
if text!='':
part = clip.tokenize_line(str_to_escape(text))
part = tokenize_line(clip,text)
cnt = part[1]
part = part[0][0]
need = tokens[start:]
if part.tokens[1:cnt+1]!=need:
return None
trans = clip.encode_embedding_init_text(text,999)
trans = encode_embedding_init_text(clip,text,999)
trans = trans.to(device=devices.device,dtype=torch.float32)
res.append((trans,text,need))
both.append(res)
@ -840,7 +881,7 @@ A cat is chasing a dog. <''-'road'-'grass'>
return (right,None)
def grab_embedding_cache():
db = modules.sd_hijack.model_hijack.embedding_db
db = get_embedding_db()[0]
field = '__embedding_merge_cache_'
if hasattr(db,field):
cache = getattr(db,field)
@ -850,30 +891,33 @@ A cat is chasing a dog. <''-'road'-'grass'>
return cache
def register_embedding(name,embedding):
self = modules.sd_hijack.model_hijack.embedding_db
model = shared.sd_model
if hasattr(self,'register_embedding_by_name'):
return self.register_embedding_by_name(embedding,model,name)
# /modules/textual_inversion/textual_inversion.py
try:
ids = model.cond_stage_model.tokenize([name])[0]
first_id = ids[0]
except:
return
if embedding is None:
if self.word_embeddings[name] is None:
for self in get_embedding_db():
model = shared.sd_model
if hasattr(self,'register_embedding_by_name'):
try:
return self.register_embedding_by_name(embedding,model,name)
except TypeError:
return self.register_embedding_by_name(embedding,name)
# /modules/textual_inversion/textual_inversion.py
try:
ids = model.cond_stage_model.tokenize([name])[0]
first_id = ids[0]
except:
return
del self.word_embeddings[name]
else:
self.word_embeddings[name] = embedding
if first_id not in self.ids_lookup:
if embedding is None:
return
self.ids_lookup[first_id] = []
save = [(ids, embedding)] if embedding is not None else []
old = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
self.ids_lookup[first_id] = sorted(old + save, key=lambda x: len(x[0]), reverse=True)
return embedding
if self.word_embeddings[name] is None:
return
del self.word_embeddings[name]
else:
self.word_embeddings[name] = embedding
if first_id not in self.ids_lookup:
if embedding is None:
return
self.ids_lookup[first_id] = []
save = [(ids, embedding)] if embedding is not None else []
old = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
self.ids_lookup[first_id] = sorted(old + save, key=lambda x: len(x[0]), reverse=True)
return embedding
def make_temp_embedding(name,vectors,cache,fake):
embed = None
@ -1045,6 +1089,10 @@ A cat is chasing a dog. <''-'road'-'grass'>
sd_models.reload_model_weights()
except:
pass
try:
sd_models.forge_model_reload()
except:
pass
gr_orig = gr_text
font = 'font-family:Consolas,Courier New,Courier,monospace;'
table = '<style>.webui_embedding_merge_table,.webui_embedding_merge_table td,.webui_embedding_merge_table th{border:1px solid gray;border-collapse:collapse}.webui_embedding_merge_table td,.webui_embedding_merge_table th{padding:2px 5px !important;text-align:center !important;vertical-align:middle;'+font+'font-weight:bold;}.webui_embedding_merge_table{margin:6px auto !important;}</style>'
@ -1344,10 +1392,11 @@ A cat is chasing a dog. <''-'road'-'grass'>
if pt is not None:
from safetensors.torch import save_file
save_file(pt,target)
try:
modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
except:
modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
for db in get_embedding_db():
try:
db.load_textual_inversion_embeddings(force_reload=True)
except:
db.load_textual_inversion_embeddings()
return ''
except:
traceback.print_exc()
@ -1467,6 +1516,8 @@ A cat is chasing a dog. <''-'road'-'grass'>
pretty_print(v,indent,dupl)
else:
print(tab + k + ': ' + str(v))
import code
code.interact(local=locals())
'''
def hook_infotext(hook):
@ -1649,7 +1700,7 @@ A cat is chasing a dog. <''-'road'-'grass'>
except:
traceback.print_exc()
try:
db = modules.sd_hijack.model_hijack.embedding_db
db = get_embedding_db()[0]
field = '__embedding_merge_cache_'
if hasattr(db,field):
delattr(db,field)