pull/169/head
takahiro-nihei 2024-09-06 20:35:17 +09:00
parent e3bcf1b9ca
commit 42d9b65551
1 changed files with 16 additions and 15 deletions

View File

@ -1137,40 +1137,41 @@ def lbw(lora,lwei,elemental):
LORAS = ["lora", "loha", "lokr"]
def lbwf(mt, mu, lwei, elemental, starts):
errormodules = []
after_applying_unet_lora_patches = shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches
unet_hashes = []
for hash, lora_patches in after_applying_unet_lora_patches.items():
lbw_key = ",".join([str(mu[0])] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in lwei[0]])
unet_hashes.append((hash, (hash[0], lbw_key, *hash[2:])))
for hash, new_hash in unet_hashes:
hashes = []
for m, hash in zip(mu, after_applying_unet_lora_patches.keys()):
lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in lwei[0]])
hashes.append((hash, (hash[0], lbw_key, *hash[2:])))
for hash, new_hash in hashes:
after_applying_unet_lora_patches[new_hash] = after_applying_unet_lora_patches[hash]
del after_applying_unet_lora_patches[hash]
for hash, lora_patches in after_applying_unet_lora_patches.items():
for m, l, e, s, (hash, lora_patches) in zip(mu, lwei, elemental, starts, after_applying_unet_lora_patches.items()):
for key, vals in lora_patches.items():
n_vals = []
errormodules = []
lvs = [v for v in vals if v[1][0] in LORAS]
for v, m, l, e ,s in zip(lvs, mu, lwei, elemental, starts):
for v in lvs:
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
n_vals.append([ratio * m if s is None else 0, *v[1:]])
if errormodule:errormodules.append(errormodule)
lora_patches[key] = n_vals
after_applying_clip_lora_patches = shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches
unet_hashes = []
for hash, lora_patches in after_applying_clip_lora_patches.items():
lbw_key = ",".join([str(mt[0])] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in lwei[0]])
unet_hashes.append((hash, (hash[0], lbw_key, *hash[2:])))
for hash, new_hash in unet_hashes:
hashes = []
for m, hash in zip(mt, after_applying_clip_lora_patches.keys()):
lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in lwei[0]])
hashes.append((hash, (hash[0], lbw_key, *hash[2:])))
for hash, new_hash in hashes:
after_applying_clip_lora_patches[new_hash] = after_applying_clip_lora_patches[hash]
del after_applying_clip_lora_patches[hash]
for hash, lora_patches in after_applying_clip_lora_patches.items():
for m, l, e, s, (hash, lora_patches) in zip(mu, lwei, elemental, starts, after_applying_clip_lora_patches.items()):
for key, vals in lora_patches.items():
n_vals = []
lvs = [v for v in vals if v[1][0] in LORAS]
for v, m, l, e in zip(lvs, mt, lwei, elemental):
for v in lvs:
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
n_vals.append([ratio * m, *v[1:]])
if errormodule:errormodules.append(errormodule)