diff --git a/scripts/lora_block_weight.py b/scripts/lora_block_weight.py index d1bc688..c7ec593 100644 --- a/scripts/lora_block_weight.py +++ b/scripts/lora_block_weight.py @@ -54,8 +54,9 @@ BLOCKID26=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08" BLOCKID17=["BASE","IN01","IN02","IN04","IN05","IN07","IN08","M00","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"] BLOCKID12=["BASE","IN04","IN05","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05"] BLOCKID20=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08"] -BLOCKNUMS = [12,17,20,26] -BLOCKIDS=[BLOCKID12,BLOCKID17,BLOCKID20,BLOCKID26] +BLOCKIDFLUX = ["CLIP", "T5", "IN"] + ["D{:002}".format(x) for x in range(19)] + ["S{:002}".format(x) for x in range(38)] + ["OUT"] # Len: 61 +BLOCKNUMS = [12,17,20,26, len(BLOCKIDFLUX)] +BLOCKIDS=[BLOCKID12,BLOCKID17,BLOCKID20,BLOCKID26,BLOCKIDFLUX] BLOCKS=["encoder", "diffusion_model_input_blocks_0_", @@ -90,7 +91,7 @@ loopstopper = True ATYPES =["none","Block ID","values","seed","Original Weights","elements"] -DEF_WEIGHT_PRESET = "\ +DEF_WEIGHT_PRESET = f"\ NONE:0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\n\ ALL:1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\n\ INS:1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0\n\ @@ -100,7 +101,9 @@ MIDD:1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0\n\ OUTD:1,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0\n\ OUTS:1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1\n\ OUTALL:1,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1\n\ -ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5" +ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5\n\ +FLUXALL:{','.join(['1']*61)}" + scriptpath = os.path.dirname(os.path.abspath(__file__)) @@ -337,7 +340,7 @@ class Script(modules.scripts.Script): pass else: try: - with open(extpath,encoding="utf-8") as f: + with open(extpathe,encoding="utf-8") as f: return f.read() except OSError as e: pass @@ -811,10 +814,14 @@ def lorachecker(self): except: pass self.onlyco = (not self.islora) and self.islyco - self.isxl = hasattr(shared.sd_model,"conditioner") + model = shared.sd_model + self.is_sdxl = type(model).__name__ == "StableDiffusionXL" or getattr(model,'is_sdxl', False) + self.is_sd2 = type(model).__name__ == "StableDiffusion2" or getattr(model,'is_sd2', False) + self.is_sd1 = type(model).__name__ == "StableDiffusion" or getattr(model,'is_sd1', False) + self.is_flux = type(model).__name__ == "Flux" or getattr(model,'is_flux', False) self.log["isnet"] = self.isnet - self.log["isxl"] = self.isxl + self.log["isxl"] = self.is_sdxl self.log["islora"] = self.islora def resetmemory(): @@ -882,11 +889,11 @@ def loradealer(self, prompts,lratios,elementals, extra_network_data = None): else: ratios[i] = float(r) - if len(ratios) != 26: + if not (len(ratios) == 26 or len(ratios) == 61): ratios = to26(ratios) setnow = True else: - ratios = [1] * 26 + ratios = [1] * 61 if self.is_flux else [1] * 26 if elem in elementals: setnow = True @@ -972,7 +979,7 @@ def load_loras_blocks(self, names, lwei,te,unet,elements,ltype = "lora", starts for loaded in lora.loaded_loras: for n, name in enumerate(names): if name == loaded.name: - if lwei[n] == [1] * 26 and elements[n] == "": continue + if (lwei[n] == [1] * 26 or lwei[n] == [1] * 61) and elements[n] == "": continue lbw(loaded,lwei[n],elements[n]) setall(loaded,te[n],unet[n]) newname = loaded.name +"_in_LBW_"+ str(round(random.random(),3)) @@ -1000,10 +1007,10 @@ def load_loras_blocks(self, names, lwei,te,unet,elements,ltype = "lora", starts elif "forge" == ltype: lora_patches = shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches - lbwf(lora_patches, unet, lwei, elements, starts) + lbwf(lora_patches, unet, lwei, elements, starts, self.is_flux) lora_patches = shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches - lbwf(lora_patches, te, lwei, elements, starts) + lbwf(lora_patches, te, lwei, elements, starts, self.is_flux) try: import lora_ctl_network as ctl @@ -1138,9 +1145,9 @@ def effectivechecker(imgs,ss,ls,diffcol,thresh,revxy): def lbw(lora,lwei,elemental): errormodules = [] for key in lora.modules.keys(): - ratio, errormodule = ratiodealer(key, lwei, elemental) - if errormodule: - errormodules.append(errormodule) + ratio, picked = ratiodealer(key, lwei, elemental) + if not picked: + errormodules.append(key) ltype = type(lora.modules[key]).__name__ set = False @@ -1160,7 +1167,7 @@ def lbw(lora,lwei,elemental): #print("LoRA") set = True if not set : - print("unkwon LoRA") + print("UnKnown LoRA") if len(errormodules) > 0: print(errormodules) @@ -1168,7 +1175,7 @@ def lbw(lora,lwei,elemental): LORAS = ["lora", "loha", "lokr"] -def lbwf(after_applying_lora_patches, ms, lwei, elements, starts): +def lbwf(after_applying_lora_patches, ms, lwei, elements, starts, flux): errormodules = [] dict_lora_patches = dict(after_applying_lora_patches.items()) @@ -1186,14 +1193,14 @@ def lbwf(after_applying_lora_patches, ms, lwei, elements, starts): n_vals = [] lvs = [v for v in vals if v[1][0] in LORAS] for v in lvs: - ratio, errormodule = ratiodealer(key.replace(".","_"), l, e) + ratio, picked = ratiodealer(key.replace(".","_"), l, e, flux) n_vals.append((ratio * m if s is None or s == 0 else 0, *v[1:])) - if errormodule: - errormodules.append(errormodule) + if not picked: + errormodules.append(key) lora_patches[key] = n_vals # print("[DEBUG]", hash[0], *[n_val[0] for n_val in n_vals]) - lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in l]) + lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in l]) + e new_hash = (hash[0], lbw_key, *hash[2:]) after_applying_lora_patches[new_hash] = after_applying_lora_patches[hash] @@ -1203,29 +1210,33 @@ def lbwf(after_applying_lora_patches, ms, lwei, elements, starts): if len(errormodules) > 0: print("Unknown modules:", errormodules) -def ratiodealer(key, lwei, elemental): +def ratiodealer(key, lwei, elemental, flux = False): ratio = 1 picked = False - errormodules = [] - currentblock = 0 elemental = elemental.split(",") + elemkey = "" - for i,block in enumerate(BLOCKS): - if block in key: - if i == 26 or i == 27: - i = 0 - ratio = lwei[i] + if flux: + block = elemkey = get_flux_blocks(key) + print(block, key) + if block in BLOCKIDFLUX: + ratio = lwei[BLOCKIDFLUX.index(block)] picked = True - currentblock = i - - if not picked: - errormodules.append(key) + print(key, block, BLOCKIDFLUX.index(block), ratio) + else: + for i,block in enumerate(BLOCKS): + if block in key: + if i == 26 or i == 27: + i = 0 + ratio = lwei[i] + picked = True + elemkey = BLOCKID26[i] if len(elemental) > 0: - skey = key + BLOCKID26[currentblock] + skey = key + elemkey for d in elemental: if d.count(":") != 2 :continue - dbs,dws,dr = (hyphener(d.split(":")[0]),d.split(":")[1],d.split(":")[2]) + dbs,dws,dr = (hyphener(d.split(":")[0],BLOCKIDFLUX if flux else BLOCKID26),d.split(":")[1],d.split(":")[2]) dbs,dws = (dbs.split(" "), dws.split(" ")) dbn,dbs = (True,dbs[1:]) if dbs[0] == "NOT" else (False,dbs) dwn,dws = (True,dws[1:]) if dws[0] == "NOT" else (False,dws) @@ -1243,7 +1254,7 @@ def ratiodealer(key, lwei, elemental): if princ :print(dbs,dws,key,dr) ratio = dr - return ratio, errormodules + return ratio, picked LORAANDSOON = { "LoraHadaModule" : "w1a", @@ -1261,15 +1272,15 @@ LORAANDSOON = { "NetworkModuleOFT": "scale" } -def hyphener(t): +def hyphener(t, blocks): t = t.split(" ") for i,e in enumerate(t): if "-" in e: e = e.split("-") - if BLOCKID26.index(e[1]) > BLOCKID26.index(e[0]): - t[i] = " ".join(BLOCKID26[BLOCKID26.index(e[0]):BLOCKID26.index(e[1])+1]) + if blocks.index(e[1]) > blocks.index(e[0]): + t[i] = " ".join(blocks[blocks.index(e[0]):blocks.index(e[1])+1]) else: - t[i] = " ".join(BLOCKID26[BLOCKID26.index(e[1]):BLOCKID26.index(e[0])+1]) + t[i] = " ".join(blocks[blocks.index(e[1]):blocks.index(e[0])+1]) return " ".join(t) ELEMPRESETS="\ @@ -1298,3 +1309,22 @@ def checkloadcond(l:str)->bool: res=(":" not in l) or (not any(l.count(",") == x - 1 for x in BLOCKNUMS)) or ("#" in l) #print("[debug]", res,repr(l)) return res + +def get_flux_blocks(key): + if "vae" in key: + return "VAE" + if "t5xxl" in key: + return "T5" + if "text_encoders.clip" in key: + return "CLIP" + + match = re.search(r'\_(\d+)\_', key) + if "double_blocks" in key: + return f"D{match.group(1).zfill(2) }" + if "single_blocks" in key: + return f"S{match.group(1).zfill(2) }" + if "_in" in key: + return "IN" + if "final_layer" in key: + return "OUT" + return "Not Merge" \ No newline at end of file