#176, support flux
parent
5b9dc37fdb
commit
9f4b321af6
|
|
@ -54,8 +54,9 @@ BLOCKID26=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08"
|
|||
BLOCKID17=["BASE","IN01","IN02","IN04","IN05","IN07","IN08","M00","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
|
||||
BLOCKID12=["BASE","IN04","IN05","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05"]
|
||||
BLOCKID20=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08"]
|
||||
BLOCKNUMS = [12,17,20,26]
|
||||
BLOCKIDS=[BLOCKID12,BLOCKID17,BLOCKID20,BLOCKID26]
|
||||
BLOCKIDFLUX = ["CLIP", "T5", "IN"] + ["D{:002}".format(x) for x in range(19)] + ["S{:002}".format(x) for x in range(38)] + ["OUT"] # Len: 61
|
||||
BLOCKNUMS = [12,17,20,26, len(BLOCKIDFLUX)]
|
||||
BLOCKIDS=[BLOCKID12,BLOCKID17,BLOCKID20,BLOCKID26,BLOCKIDFLUX]
|
||||
|
||||
BLOCKS=["encoder",
|
||||
"diffusion_model_input_blocks_0_",
|
||||
|
|
@ -90,7 +91,7 @@ loopstopper = True
|
|||
|
||||
ATYPES =["none","Block ID","values","seed","Original Weights","elements"]
|
||||
|
||||
DEF_WEIGHT_PRESET = "\
|
||||
DEF_WEIGHT_PRESET = f"\
|
||||
NONE:0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
|
||||
ALL:1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\n\
|
||||
INS:1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
|
||||
|
|
@ -100,7 +101,9 @@ MIDD:1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0\n\
|
|||
OUTD:1,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0\n\
|
||||
OUTS:1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1\n\
|
||||
OUTALL:1,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1\n\
|
||||
ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
|
||||
ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5\n\
|
||||
FLUXALL:{','.join(['1']*61)}"
|
||||
|
||||
|
||||
scriptpath = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
|
@ -337,7 +340,7 @@ class Script(modules.scripts.Script):
|
|||
pass
|
||||
else:
|
||||
try:
|
||||
with open(extpath,encoding="utf-8") as f:
|
||||
with open(extpathe,encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except OSError as e:
|
||||
pass
|
||||
|
|
@ -811,10 +814,14 @@ def lorachecker(self):
|
|||
except:
|
||||
pass
|
||||
self.onlyco = (not self.islora) and self.islyco
|
||||
self.isxl = hasattr(shared.sd_model,"conditioner")
|
||||
model = shared.sd_model
|
||||
self.is_sdxl = type(model).__name__ == "StableDiffusionXL" or getattr(model,'is_sdxl', False)
|
||||
self.is_sd2 = type(model).__name__ == "StableDiffusion2" or getattr(model,'is_sd2', False)
|
||||
self.is_sd1 = type(model).__name__ == "StableDiffusion" or getattr(model,'is_sd1', False)
|
||||
self.is_flux = type(model).__name__ == "Flux" or getattr(model,'is_flux', False)
|
||||
|
||||
self.log["isnet"] = self.isnet
|
||||
self.log["isxl"] = self.isxl
|
||||
self.log["isxl"] = self.is_sdxl
|
||||
self.log["islora"] = self.islora
|
||||
|
||||
def resetmemory():
|
||||
|
|
@ -882,11 +889,11 @@ def loradealer(self, prompts,lratios,elementals, extra_network_data = None):
|
|||
else:
|
||||
ratios[i] = float(r)
|
||||
|
||||
if len(ratios) != 26:
|
||||
if not (len(ratios) == 26 or len(ratios) == 61):
|
||||
ratios = to26(ratios)
|
||||
setnow = True
|
||||
else:
|
||||
ratios = [1] * 26
|
||||
ratios = [1] * 61 if self.is_flux else [1] * 26
|
||||
|
||||
if elem in elementals:
|
||||
setnow = True
|
||||
|
|
@ -972,7 +979,7 @@ def load_loras_blocks(self, names, lwei,te,unet,elements,ltype = "lora", starts
|
|||
for loaded in lora.loaded_loras:
|
||||
for n, name in enumerate(names):
|
||||
if name == loaded.name:
|
||||
if lwei[n] == [1] * 26 and elements[n] == "": continue
|
||||
if (lwei[n] == [1] * 26 or lwei[n] == [1] * 61) and elements[n] == "": continue
|
||||
lbw(loaded,lwei[n],elements[n])
|
||||
setall(loaded,te[n],unet[n])
|
||||
newname = loaded.name +"_in_LBW_"+ str(round(random.random(),3))
|
||||
|
|
@ -1000,10 +1007,10 @@ def load_loras_blocks(self, names, lwei,te,unet,elements,ltype = "lora", starts
|
|||
|
||||
elif "forge" == ltype:
|
||||
lora_patches = shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches
|
||||
lbwf(lora_patches, unet, lwei, elements, starts)
|
||||
lbwf(lora_patches, unet, lwei, elements, starts, self.is_flux)
|
||||
|
||||
lora_patches = shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches
|
||||
lbwf(lora_patches, te, lwei, elements, starts)
|
||||
lbwf(lora_patches, te, lwei, elements, starts, self.is_flux)
|
||||
|
||||
try:
|
||||
import lora_ctl_network as ctl
|
||||
|
|
@ -1138,9 +1145,9 @@ def effectivechecker(imgs,ss,ls,diffcol,thresh,revxy):
|
|||
def lbw(lora,lwei,elemental):
|
||||
errormodules = []
|
||||
for key in lora.modules.keys():
|
||||
ratio, errormodule = ratiodealer(key, lwei, elemental)
|
||||
if errormodule:
|
||||
errormodules.append(errormodule)
|
||||
ratio, picked = ratiodealer(key, lwei, elemental)
|
||||
if not picked:
|
||||
errormodules.append(key)
|
||||
|
||||
ltype = type(lora.modules[key]).__name__
|
||||
set = False
|
||||
|
|
@ -1160,7 +1167,7 @@ def lbw(lora,lwei,elemental):
|
|||
#print("LoRA")
|
||||
set = True
|
||||
if not set :
|
||||
print("unkwon LoRA")
|
||||
print("UnKnown LoRA")
|
||||
|
||||
if len(errormodules) > 0:
|
||||
print(errormodules)
|
||||
|
|
@ -1168,7 +1175,7 @@ def lbw(lora,lwei,elemental):
|
|||
|
||||
LORAS = ["lora", "loha", "lokr"]
|
||||
|
||||
def lbwf(after_applying_lora_patches, ms, lwei, elements, starts):
|
||||
def lbwf(after_applying_lora_patches, ms, lwei, elements, starts, flux):
|
||||
errormodules = []
|
||||
dict_lora_patches = dict(after_applying_lora_patches.items())
|
||||
|
||||
|
|
@ -1186,14 +1193,14 @@ def lbwf(after_applying_lora_patches, ms, lwei, elements, starts):
|
|||
n_vals = []
|
||||
lvs = [v for v in vals if v[1][0] in LORAS]
|
||||
for v in lvs:
|
||||
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
|
||||
ratio, picked = ratiodealer(key.replace(".","_"), l, e, flux)
|
||||
n_vals.append((ratio * m if s is None or s == 0 else 0, *v[1:]))
|
||||
if errormodule:
|
||||
errormodules.append(errormodule)
|
||||
if not picked:
|
||||
errormodules.append(key)
|
||||
lora_patches[key] = n_vals
|
||||
# print("[DEBUG]", hash[0], *[n_val[0] for n_val in n_vals])
|
||||
|
||||
lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in l])
|
||||
lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in l]) + e
|
||||
new_hash = (hash[0], lbw_key, *hash[2:])
|
||||
|
||||
after_applying_lora_patches[new_hash] = after_applying_lora_patches[hash]
|
||||
|
|
@ -1203,29 +1210,33 @@ def lbwf(after_applying_lora_patches, ms, lwei, elements, starts):
|
|||
if len(errormodules) > 0:
|
||||
print("Unknown modules:", errormodules)
|
||||
|
||||
def ratiodealer(key, lwei, elemental):
|
||||
def ratiodealer(key, lwei, elemental, flux = False):
|
||||
ratio = 1
|
||||
picked = False
|
||||
errormodules = []
|
||||
currentblock = 0
|
||||
elemental = elemental.split(",")
|
||||
elemkey = ""
|
||||
|
||||
for i,block in enumerate(BLOCKS):
|
||||
if block in key:
|
||||
if i == 26 or i == 27:
|
||||
i = 0
|
||||
ratio = lwei[i]
|
||||
if flux:
|
||||
block = elemkey = get_flux_blocks(key)
|
||||
print(block, key)
|
||||
if block in BLOCKIDFLUX:
|
||||
ratio = lwei[BLOCKIDFLUX.index(block)]
|
||||
picked = True
|
||||
currentblock = i
|
||||
|
||||
if not picked:
|
||||
errormodules.append(key)
|
||||
print(key, block, BLOCKIDFLUX.index(block), ratio)
|
||||
else:
|
||||
for i,block in enumerate(BLOCKS):
|
||||
if block in key:
|
||||
if i == 26 or i == 27:
|
||||
i = 0
|
||||
ratio = lwei[i]
|
||||
picked = True
|
||||
elemkey = BLOCKID26[i]
|
||||
|
||||
if len(elemental) > 0:
|
||||
skey = key + BLOCKID26[currentblock]
|
||||
skey = key + elemkey
|
||||
for d in elemental:
|
||||
if d.count(":") != 2 :continue
|
||||
dbs,dws,dr = (hyphener(d.split(":")[0]),d.split(":")[1],d.split(":")[2])
|
||||
dbs,dws,dr = (hyphener(d.split(":")[0],BLOCKIDFLUX if flux else BLOCKID26),d.split(":")[1],d.split(":")[2])
|
||||
dbs,dws = (dbs.split(" "), dws.split(" "))
|
||||
dbn,dbs = (True,dbs[1:]) if dbs[0] == "NOT" else (False,dbs)
|
||||
dwn,dws = (True,dws[1:]) if dws[0] == "NOT" else (False,dws)
|
||||
|
|
@ -1243,7 +1254,7 @@ def ratiodealer(key, lwei, elemental):
|
|||
if princ :print(dbs,dws,key,dr)
|
||||
ratio = dr
|
||||
|
||||
return ratio, errormodules
|
||||
return ratio, picked
|
||||
|
||||
LORAANDSOON = {
|
||||
"LoraHadaModule" : "w1a",
|
||||
|
|
@ -1261,15 +1272,15 @@ LORAANDSOON = {
|
|||
"NetworkModuleOFT": "scale"
|
||||
}
|
||||
|
||||
def hyphener(t):
|
||||
def hyphener(t, blocks):
|
||||
t = t.split(" ")
|
||||
for i,e in enumerate(t):
|
||||
if "-" in e:
|
||||
e = e.split("-")
|
||||
if BLOCKID26.index(e[1]) > BLOCKID26.index(e[0]):
|
||||
t[i] = " ".join(BLOCKID26[BLOCKID26.index(e[0]):BLOCKID26.index(e[1])+1])
|
||||
if blocks.index(e[1]) > blocks.index(e[0]):
|
||||
t[i] = " ".join(blocks[blocks.index(e[0]):blocks.index(e[1])+1])
|
||||
else:
|
||||
t[i] = " ".join(BLOCKID26[BLOCKID26.index(e[1]):BLOCKID26.index(e[0])+1])
|
||||
t[i] = " ".join(blocks[blocks.index(e[1]):blocks.index(e[0])+1])
|
||||
return " ".join(t)
|
||||
|
||||
ELEMPRESETS="\
|
||||
|
|
@ -1298,3 +1309,22 @@ def checkloadcond(l:str)->bool:
|
|||
res=(":" not in l) or (not any(l.count(",") == x - 1 for x in BLOCKNUMS)) or ("#" in l)
|
||||
#print("[debug]", res,repr(l))
|
||||
return res
|
||||
|
||||
def get_flux_blocks(key):
|
||||
if "vae" in key:
|
||||
return "VAE"
|
||||
if "t5xxl" in key:
|
||||
return "T5"
|
||||
if "text_encoders.clip" in key:
|
||||
return "CLIP"
|
||||
|
||||
match = re.search(r'\_(\d+)\_', key)
|
||||
if "double_blocks" in key:
|
||||
return f"D{match.group(1).zfill(2) }"
|
||||
if "single_blocks" in key:
|
||||
return f"S{match.group(1).zfill(2) }"
|
||||
if "_in" in key:
|
||||
return "IN"
|
||||
if "final_layer" in key:
|
||||
return "OUT"
|
||||
return "Not Merge"
|
||||
Loading…
Reference in New Issue