From 40d439de3cda5965fc1f9cb4412ef329885d1d6d Mon Sep 17 00:00:00 2001 From: klimaleksus Date: Fri, 6 Jun 2025 20:10:56 +0500 Subject: [PATCH] Restore partial support for Forge, see #22 --- scripts/embedding_merge.py | 119 ++++++++++++++++++++++++++----------- 1 file changed, 85 insertions(+), 34 deletions(-) diff --git a/scripts/embedding_merge.py b/scripts/embedding_merge.py index d7ec1a5..18c793d 100644 --- a/scripts/embedding_merge.py +++ b/scripts/embedding_merge.py @@ -306,6 +306,8 @@ A cat is chasing a dog. <''-'road'-'grass'> clip = shared.sd_model.cond_stage_model if hasattr(clip,'embedders'): clip = clip.embedders[0] + if clip is None: + return None clip = clip.wrapped typename = type(clip).__name__.split('.')[-1] if typename=='FrozenOpenCLIPEmbedder': @@ -361,7 +363,18 @@ A cat is chasing a dog. <''-'road'-'grass'> return res def get_model_clips(): - clip = shared.sd_model.cond_stage_model + sd_model = shared.sd_model + clip = sd_model.cond_stage_model + if clip is None: + clip = sd_model.text_processing_engine if hasattr(sd_model,'text_processing_engine') else None + if clip is None: + clip_l = sd_model.text_processing_engine_l if hasattr(sd_model,'text_processing_engine_l') else None + clip_g = sd_model.text_processing_engine_g if hasattr(sd_model,'text_processing_engine_g') else None + if clip_l is not None: + if clip_g is not None: + return (clip_l,clip_g) + return (clip_l,) + raise Exception_From_EmbeddingMergeExtension('Could not find CLIP model!') if(hasattr(clip,'embedders')): try: return (clip.embedders[0],clip.embedders[1]) # SDXL @@ -369,13 +382,41 @@ A cat is chasing a dog. <''-'road'-'grass'> pass return (clip,) # SD1 or SD2 + def get_embedding_db(): + try: + db = modules.sd_hijack.model_hijack.embedding_db + if db is not None: + return (db,) + except: + pass + clips = get_model_clips() + return [c.embeddings for c in clips] + + def tokenize_line(clip,text): + if hasattr(clip,'encode_embedding_init_text'): + return clip.tokenize_line(str_to_escape(text)) + old = clip.emphasis.name + clip.emphasis.name = 'None' + try: + res = clip.tokenize_line(text) + finally: + clip.emphasis.name = old + return res + + def encode_embedding_init_text(clip,text,length=999): + if hasattr(clip,'encode_embedding_init_text'): + return clip.encode_embedding_init_text(text,length) + part = tokenize_line(clip,text) + tokens = part[0][0].tokens + return clip.text_encoder.transformer.text_model.embeddings.token_embedding.wrapped(torch.tensor(tokens[1:part[1]+1])) + def text_to_vectors(orig_text): try: both = [] for clip,lg in zip(get_model_clips(),('clip_l','clip_g')): res = [] text = orig_text.lstrip().lower() - tokens = clip.tokenize_line(str_to_escape(text)) + tokens = tokenize_line(clip,text) count = tokens[1] tokens = tokens[0][0] fixes = tokens.fixes @@ -401,13 +442,13 @@ A cat is chasing a dog. <''-'road'-'grass'> return None test = pos+lenname sub = text[0:test] - part = clip.tokenize_line(str_to_escape(sub)) + part = tokenize_line(clip,sub) cnt = part[1] part = part[0][0] vec = off-start need = tokens[start:off+num] if part.tokens[1:cnt+1]==need: - trans = clip.encode_embedding_init_text(text,vec) + trans = encode_embedding_init_text(clip,text,vec) t = trans[:vec].to(device=devices.device,dtype=torch.float32) res.append((t,sub[:pos],need[:vec])) text = text[pos:] @@ -420,13 +461,13 @@ A cat is chasing a dog. <''-'road'-'grass'> start += num text = text[lenname:].lstrip() if text!='': - part = clip.tokenize_line(str_to_escape(text)) + part = tokenize_line(clip,text) cnt = part[1] part = part[0][0] need = tokens[start:] if part.tokens[1:cnt+1]!=need: return None - trans = clip.encode_embedding_init_text(text,999) + trans = encode_embedding_init_text(clip,text,999) trans = trans.to(device=devices.device,dtype=torch.float32) res.append((trans,text,need)) both.append(res) @@ -840,7 +881,7 @@ A cat is chasing a dog. <''-'road'-'grass'> return (right,None) def grab_embedding_cache(): - db = modules.sd_hijack.model_hijack.embedding_db + db = get_embedding_db()[0] field = '__embedding_merge_cache_' if hasattr(db,field): cache = getattr(db,field) @@ -850,30 +891,33 @@ A cat is chasing a dog. <''-'road'-'grass'> return cache def register_embedding(name,embedding): - self = modules.sd_hijack.model_hijack.embedding_db - model = shared.sd_model - if hasattr(self,'register_embedding_by_name'): - return self.register_embedding_by_name(embedding,model,name) - # /modules/textual_inversion/textual_inversion.py - try: - ids = model.cond_stage_model.tokenize([name])[0] - first_id = ids[0] - except: - return - if embedding is None: - if self.word_embeddings[name] is None: + for self in get_embedding_db(): + model = shared.sd_model + if hasattr(self,'register_embedding_by_name'): + try: + return self.register_embedding_by_name(embedding,model,name) + except TypeError: + return self.register_embedding_by_name(embedding,name) + # /modules/textual_inversion/textual_inversion.py + try: + ids = model.cond_stage_model.tokenize([name])[0] + first_id = ids[0] + except: return - del self.word_embeddings[name] - else: - self.word_embeddings[name] = embedding - if first_id not in self.ids_lookup: if embedding is None: - return - self.ids_lookup[first_id] = [] - save = [(ids, embedding)] if embedding is not None else [] - old = [x for x in self.ids_lookup[first_id] if x[1].name!=name] - self.ids_lookup[first_id] = sorted(old + save, key=lambda x: len(x[0]), reverse=True) - return embedding + if self.word_embeddings[name] is None: + return + del self.word_embeddings[name] + else: + self.word_embeddings[name] = embedding + if first_id not in self.ids_lookup: + if embedding is None: + return + self.ids_lookup[first_id] = [] + save = [(ids, embedding)] if embedding is not None else [] + old = [x for x in self.ids_lookup[first_id] if x[1].name!=name] + self.ids_lookup[first_id] = sorted(old + save, key=lambda x: len(x[0]), reverse=True) + return embedding def make_temp_embedding(name,vectors,cache,fake): embed = None @@ -1045,6 +1089,10 @@ A cat is chasing a dog. <''-'road'-'grass'> sd_models.reload_model_weights() except: pass + try: + sd_models.forge_model_reload() + except: + pass gr_orig = gr_text font = 'font-family:Consolas,Courier New,Courier,monospace;' table = '' @@ -1344,10 +1392,11 @@ A cat is chasing a dog. <''-'road'-'grass'> if pt is not None: from safetensors.torch import save_file save_file(pt,target) - try: - modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) - except: - modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + for db in get_embedding_db(): + try: + db.load_textual_inversion_embeddings(force_reload=True) + except: + db.load_textual_inversion_embeddings() return '' except: traceback.print_exc() @@ -1467,6 +1516,8 @@ A cat is chasing a dog. <''-'road'-'grass'> pretty_print(v,indent,dupl) else: print(tab + k + ': ' + str(v)) + import code + code.interact(local=locals()) ''' def hook_infotext(hook): @@ -1649,7 +1700,7 @@ A cat is chasing a dog. <''-'road'-'grass'> except: traceback.print_exc() try: - db = modules.sd_hijack.model_hijack.embedding_db + db = get_embedding_db()[0] field = '__embedding_merge_cache_' if hasattr(db,field): delattr(db,field)