#176, support flux

main
hako-mikan 2025-01-23 20:32:45 +09:00
parent 5b9dc37fdb
commit 9f4b321af6
1 changed files with 70 additions and 40 deletions

View File

@ -54,8 +54,9 @@ BLOCKID26=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08"
BLOCKID17=["BASE","IN01","IN02","IN04","IN05","IN07","IN08","M00","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
BLOCKID12=["BASE","IN04","IN05","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05"]
BLOCKID20=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08"]
BLOCKNUMS = [12,17,20,26]
BLOCKIDS=[BLOCKID12,BLOCKID17,BLOCKID20,BLOCKID26]
BLOCKIDFLUX = ["CLIP", "T5", "IN"] + ["D{:002}".format(x) for x in range(19)] + ["S{:002}".format(x) for x in range(38)] + ["OUT"] # Len: 61
BLOCKNUMS = [12,17,20,26, len(BLOCKIDFLUX)]
BLOCKIDS=[BLOCKID12,BLOCKID17,BLOCKID20,BLOCKID26,BLOCKIDFLUX]
BLOCKS=["encoder",
"diffusion_model_input_blocks_0_",
@ -90,7 +91,7 @@ loopstopper = True
ATYPES =["none","Block ID","values","seed","Original Weights","elements"]
DEF_WEIGHT_PRESET = "\
DEF_WEIGHT_PRESET = f"\
NONE:0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
ALL:1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\n\
INS:1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
@ -100,7 +101,9 @@ MIDD:1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0\n\
OUTD:1,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0\n\
OUTS:1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1\n\
OUTALL:1,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1\n\
ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5\n\
FLUXALL:{','.join(['1']*61)}"
scriptpath = os.path.dirname(os.path.abspath(__file__))
@ -337,7 +340,7 @@ class Script(modules.scripts.Script):
pass
else:
try:
with open(extpath,encoding="utf-8") as f:
with open(extpathe,encoding="utf-8") as f:
return f.read()
except OSError as e:
pass
@ -811,10 +814,14 @@ def lorachecker(self):
except:
pass
self.onlyco = (not self.islora) and self.islyco
self.isxl = hasattr(shared.sd_model,"conditioner")
model = shared.sd_model
self.is_sdxl = type(model).__name__ == "StableDiffusionXL" or getattr(model,'is_sdxl', False)
self.is_sd2 = type(model).__name__ == "StableDiffusion2" or getattr(model,'is_sd2', False)
self.is_sd1 = type(model).__name__ == "StableDiffusion" or getattr(model,'is_sd1', False)
self.is_flux = type(model).__name__ == "Flux" or getattr(model,'is_flux', False)
self.log["isnet"] = self.isnet
self.log["isxl"] = self.isxl
self.log["isxl"] = self.is_sdxl
self.log["islora"] = self.islora
def resetmemory():
@ -882,11 +889,11 @@ def loradealer(self, prompts,lratios,elementals, extra_network_data = None):
else:
ratios[i] = float(r)
if len(ratios) != 26:
if not (len(ratios) == 26 or len(ratios) == 61):
ratios = to26(ratios)
setnow = True
else:
ratios = [1] * 26
ratios = [1] * 61 if self.is_flux else [1] * 26
if elem in elementals:
setnow = True
@ -972,7 +979,7 @@ def load_loras_blocks(self, names, lwei,te,unet,elements,ltype = "lora", starts
for loaded in lora.loaded_loras:
for n, name in enumerate(names):
if name == loaded.name:
if lwei[n] == [1] * 26 and elements[n] == "": continue
if (lwei[n] == [1] * 26 or lwei[n] == [1] * 61) and elements[n] == "": continue
lbw(loaded,lwei[n],elements[n])
setall(loaded,te[n],unet[n])
newname = loaded.name +"_in_LBW_"+ str(round(random.random(),3))
@ -1000,10 +1007,10 @@ def load_loras_blocks(self, names, lwei,te,unet,elements,ltype = "lora", starts
elif "forge" == ltype:
lora_patches = shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches
lbwf(lora_patches, unet, lwei, elements, starts)
lbwf(lora_patches, unet, lwei, elements, starts, self.is_flux)
lora_patches = shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches
lbwf(lora_patches, te, lwei, elements, starts)
lbwf(lora_patches, te, lwei, elements, starts, self.is_flux)
try:
import lora_ctl_network as ctl
@ -1138,9 +1145,9 @@ def effectivechecker(imgs,ss,ls,diffcol,thresh,revxy):
def lbw(lora,lwei,elemental):
errormodules = []
for key in lora.modules.keys():
ratio, errormodule = ratiodealer(key, lwei, elemental)
if errormodule:
errormodules.append(errormodule)
ratio, picked = ratiodealer(key, lwei, elemental)
if not picked:
errormodules.append(key)
ltype = type(lora.modules[key]).__name__
set = False
@ -1160,7 +1167,7 @@ def lbw(lora,lwei,elemental):
#print("LoRA")
set = True
if not set :
print("unkwon LoRA")
print("UnKnown LoRA")
if len(errormodules) > 0:
print(errormodules)
@ -1168,7 +1175,7 @@ def lbw(lora,lwei,elemental):
LORAS = ["lora", "loha", "lokr"]
def lbwf(after_applying_lora_patches, ms, lwei, elements, starts):
def lbwf(after_applying_lora_patches, ms, lwei, elements, starts, flux):
errormodules = []
dict_lora_patches = dict(after_applying_lora_patches.items())
@ -1186,14 +1193,14 @@ def lbwf(after_applying_lora_patches, ms, lwei, elements, starts):
n_vals = []
lvs = [v for v in vals if v[1][0] in LORAS]
for v in lvs:
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
ratio, picked = ratiodealer(key.replace(".","_"), l, e, flux)
n_vals.append((ratio * m if s is None or s == 0 else 0, *v[1:]))
if errormodule:
errormodules.append(errormodule)
if not picked:
errormodules.append(key)
lora_patches[key] = n_vals
# print("[DEBUG]", hash[0], *[n_val[0] for n_val in n_vals])
lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in l])
lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in l]) + e
new_hash = (hash[0], lbw_key, *hash[2:])
after_applying_lora_patches[new_hash] = after_applying_lora_patches[hash]
@ -1203,29 +1210,33 @@ def lbwf(after_applying_lora_patches, ms, lwei, elements, starts):
if len(errormodules) > 0:
print("Unknown modules:", errormodules)
def ratiodealer(key, lwei, elemental):
def ratiodealer(key, lwei, elemental, flux = False):
ratio = 1
picked = False
errormodules = []
currentblock = 0
elemental = elemental.split(",")
elemkey = ""
for i,block in enumerate(BLOCKS):
if block in key:
if i == 26 or i == 27:
i = 0
ratio = lwei[i]
if flux:
block = elemkey = get_flux_blocks(key)
print(block, key)
if block in BLOCKIDFLUX:
ratio = lwei[BLOCKIDFLUX.index(block)]
picked = True
currentblock = i
if not picked:
errormodules.append(key)
print(key, block, BLOCKIDFLUX.index(block), ratio)
else:
for i,block in enumerate(BLOCKS):
if block in key:
if i == 26 or i == 27:
i = 0
ratio = lwei[i]
picked = True
elemkey = BLOCKID26[i]
if len(elemental) > 0:
skey = key + BLOCKID26[currentblock]
skey = key + elemkey
for d in elemental:
if d.count(":") != 2 :continue
dbs,dws,dr = (hyphener(d.split(":")[0]),d.split(":")[1],d.split(":")[2])
dbs,dws,dr = (hyphener(d.split(":")[0],BLOCKIDFLUX if flux else BLOCKID26),d.split(":")[1],d.split(":")[2])
dbs,dws = (dbs.split(" "), dws.split(" "))
dbn,dbs = (True,dbs[1:]) if dbs[0] == "NOT" else (False,dbs)
dwn,dws = (True,dws[1:]) if dws[0] == "NOT" else (False,dws)
@ -1243,7 +1254,7 @@ def ratiodealer(key, lwei, elemental):
if princ :print(dbs,dws,key,dr)
ratio = dr
return ratio, errormodules
return ratio, picked
LORAANDSOON = {
"LoraHadaModule" : "w1a",
@ -1261,15 +1272,15 @@ LORAANDSOON = {
"NetworkModuleOFT": "scale"
}
def hyphener(t):
def hyphener(t, blocks):
t = t.split(" ")
for i,e in enumerate(t):
if "-" in e:
e = e.split("-")
if BLOCKID26.index(e[1]) > BLOCKID26.index(e[0]):
t[i] = " ".join(BLOCKID26[BLOCKID26.index(e[0]):BLOCKID26.index(e[1])+1])
if blocks.index(e[1]) > blocks.index(e[0]):
t[i] = " ".join(blocks[blocks.index(e[0]):blocks.index(e[1])+1])
else:
t[i] = " ".join(BLOCKID26[BLOCKID26.index(e[1]):BLOCKID26.index(e[0])+1])
t[i] = " ".join(blocks[blocks.index(e[1]):blocks.index(e[0])+1])
return " ".join(t)
ELEMPRESETS="\
@ -1298,3 +1309,22 @@ def checkloadcond(l:str)->bool:
res=(":" not in l) or (not any(l.count(",") == x - 1 for x in BLOCKNUMS)) or ("#" in l)
#print("[debug]", res,repr(l))
return res
def get_flux_blocks(key):
if "vae" in key:
return "VAE"
if "t5xxl" in key:
return "T5"
if "text_encoders.clip" in key:
return "CLIP"
match = re.search(r'\_(\d+)\_', key)
if "double_blocks" in key:
return f"D{match.group(1).zfill(2) }"
if "single_blocks" in key:
return f"S{match.group(1).zfill(2) }"
if "_in" in key:
return "IN"
if "final_layer" in key:
return "OUT"
return "Not Merge"