diff --git a/scripts/lora_block_weight.py b/scripts/lora_block_weight.py index 999e40a..d766cff 100644 --- a/scripts/lora_block_weight.py +++ b/scripts/lora_block_weight.py @@ -82,7 +82,8 @@ BLOCKS=["encoder", "diffusion_model_output_blocks_9_", "diffusion_model_output_blocks_10_", "diffusion_model_output_blocks_11_", -"embedders"] +"embedders", +"transformer_resblocks"] loopstopper = True @@ -1103,8 +1104,11 @@ def effectivechecker(imgs,ss,ls,diffcol,thresh,revxy): def lbw(lora,lwei,elemental): elemental = elemental.split(",") + errormodules = [] for key in lora.modules.keys(): - ratio, errormodules = ratiodealer(key, lwei, elemental) + ratio, errormodule = ratiodealer(key, lwei, elemental) + if errormodule: + errormodules.append(errormodule) ltype = type(lora.modules[key]).__name__ set = False @@ -1135,19 +1139,25 @@ LORAS = ["lora", "loha", "lokr"] def lbwf(mt, mu, lwei, elemental, starts): for key, vals in shared.sd_model.forge_objects_after_applying_lora.unet.patches.items(): n_vals = [] + errormodules = [] lvals = [val for val in vals if val[1][0] in LORAS] for v, m, l, e ,s in zip(lvals, mu, lwei, elemental, starts): - ratio, errormodules = ratiodealer(key.replace(".","_"), l, e) + ratio, errormodule = ratiodealer(key.replace(".","_"), l, e) n_vals.append((ratio * m if s is None else 0, *v[1:])) + if errormodule:errormodules.append(errormodule) shared.sd_model.forge_objects_after_applying_lora.unet.patches[key] = n_vals for key, vals in shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches.items(): n_vals = [] lvals = [val for val in vals if val[1][0] in LORAS] for v, m, l, e in zip(lvals, mt, lwei, elemental): - ratio, errormodules = ratiodealer(key.replace(".","_"), l, e) + ratio, errormodule = ratiodealer(key.replace(".","_"), l, e) n_vals.append((ratio * m, *v[1:])) - shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches[key] = n_vals + if errormodule:errormodules.append(errormodule) + shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches[key] = n_vals + + if len(errormodules) > 0: + print("Unknown modules:",errormodules) def ratiodealer(key, lwei, elemental): ratio = 1 @@ -1157,7 +1167,7 @@ def ratiodealer(key, lwei, elemental): for i,block in enumerate(BLOCKS): if block in key: - if i == 26: + if i == 26 or i == 27: i = 0 ratio = lwei[i] picked = True