diff --git a/scripts/lora_block_weight.py b/scripts/lora_block_weight.py index 0b0af70..71c5c35 100644 --- a/scripts/lora_block_weight.py +++ b/scripts/lora_block_weight.py @@ -1137,40 +1137,41 @@ def lbw(lora,lwei,elemental): LORAS = ["lora", "loha", "lokr"] def lbwf(mt, mu, lwei, elemental, starts): + errormodules = [] + after_applying_unet_lora_patches = shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches - unet_hashes = [] - for hash, lora_patches in after_applying_unet_lora_patches.items(): - lbw_key = ",".join([str(mu[0])] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in lwei[0]]) - unet_hashes.append((hash, (hash[0], lbw_key, *hash[2:]))) - for hash, new_hash in unet_hashes: + hashes = [] + for m, hash in zip(mu, after_applying_unet_lora_patches.keys()): + lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in lwei[0]]) + hashes.append((hash, (hash[0], lbw_key, *hash[2:]))) + for hash, new_hash in hashes: after_applying_unet_lora_patches[new_hash] = after_applying_unet_lora_patches[hash] del after_applying_unet_lora_patches[hash] - for hash, lora_patches in after_applying_unet_lora_patches.items(): + for m, l, e, s, (hash, lora_patches) in zip(mu, lwei, elemental, starts, after_applying_unet_lora_patches.items()): for key, vals in lora_patches.items(): n_vals = [] - errormodules = [] lvs = [v for v in vals if v[1][0] in LORAS] - for v, m, l, e ,s in zip(lvs, mu, lwei, elemental, starts): + for v in lvs: ratio, errormodule = ratiodealer(key.replace(".","_"), l, e) n_vals.append([ratio * m if s is None else 0, *v[1:]]) if errormodule:errormodules.append(errormodule) lora_patches[key] = n_vals after_applying_clip_lora_patches = shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches - unet_hashes = [] - for hash, lora_patches in after_applying_clip_lora_patches.items(): - lbw_key = ",".join([str(mt[0])] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in lwei[0]]) - unet_hashes.append((hash, (hash[0], lbw_key, *hash[2:]))) - for hash, new_hash in unet_hashes: + hashes = [] + for m, hash in zip(mt, after_applying_clip_lora_patches.keys()): + lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in lwei[0]]) + hashes.append((hash, (hash[0], lbw_key, *hash[2:]))) + for hash, new_hash in hashes: after_applying_clip_lora_patches[new_hash] = after_applying_clip_lora_patches[hash] del after_applying_clip_lora_patches[hash] - for hash, lora_patches in after_applying_clip_lora_patches.items(): + for m, l, e, s, (hash, lora_patches) in zip(mu, lwei, elemental, starts, after_applying_clip_lora_patches.items()): for key, vals in lora_patches.items(): n_vals = [] lvs = [v for v in vals if v[1][0] in LORAS] - for v, m, l, e in zip(lvs, mt, lwei, elemental): + for v in lvs: ratio, errormodule = ratiodealer(key.replace(".","_"), l, e) n_vals.append([ratio * m, *v[1:]]) if errormodule:errormodules.append(errormodule)