hako-mikan 2024-04-06 16:32:18 +09:00 committed by GitHub
parent b629031419
commit 4d94d247a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 16 additions and 6 deletions

View File

@ -82,7 +82,8 @@ BLOCKS=["encoder",
"diffusion_model_output_blocks_9_",
"diffusion_model_output_blocks_10_",
"diffusion_model_output_blocks_11_",
"embedders"]
"embedders",
"transformer_resblocks"]
loopstopper = True
@ -1103,8 +1104,11 @@ def effectivechecker(imgs,ss,ls,diffcol,thresh,revxy):
def lbw(lora,lwei,elemental):
elemental = elemental.split(",")
errormodules = []
for key in lora.modules.keys():
ratio, errormodules = ratiodealer(key, lwei, elemental)
ratio, errormodule = ratiodealer(key, lwei, elemental)
if errormodule:
errormodules.append(errormodule)
ltype = type(lora.modules[key]).__name__
set = False
@ -1135,19 +1139,25 @@ LORAS = ["lora", "loha", "lokr"]
def lbwf(mt, mu, lwei, elemental, starts):
for key, vals in shared.sd_model.forge_objects_after_applying_lora.unet.patches.items():
n_vals = []
errormodules = []
lvals = [val for val in vals if val[1][0] in LORAS]
for v, m, l, e ,s in zip(lvals, mu, lwei, elemental, starts):
ratio, errormodules = ratiodealer(key.replace(".","_"), l, e)
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
n_vals.append((ratio * m if s is None else 0, *v[1:]))
if errormodule:errormodules.append(errormodule)
shared.sd_model.forge_objects_after_applying_lora.unet.patches[key] = n_vals
for key, vals in shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches.items():
n_vals = []
lvals = [val for val in vals if val[1][0] in LORAS]
for v, m, l, e in zip(lvals, mt, lwei, elemental):
ratio, errormodules = ratiodealer(key.replace(".","_"), l, e)
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
n_vals.append((ratio * m, *v[1:]))
shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches[key] = n_vals
if errormodule:errormodules.append(errormodule)
shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches[key] = n_vals
if len(errormodules) > 0:
print("Unknown modules:",errormodules)
def ratiodealer(key, lwei, elemental):
ratio = 1
@ -1157,7 +1167,7 @@ def ratiodealer(key, lwei, elemental):
for i,block in enumerate(BLOCKS):
if block in key:
if i == 26:
if i == 26 or i == 27:
i = 0
ratio = lwei[i]
picked = True