Fixed issue where LBW assignments were misaligned when LoRAs without CLIP or UNet were present.
pull/169/head
takahiro-nihei 2024-09-07 18:57:42 +09:00
parent 42d9b65551
commit 29b8600cbb
1 changed files with 49 additions and 54 deletions

View File

@ -441,30 +441,30 @@ class Script(modules.scripts.Script):
if forge and self.active:
if params.sampling_step in self.startsf:
shared.sd_model.forge_objects.unet.forge_unpatch_model(device_to=devices.device)
for key, vals in shared.sd_model.forge_objects.unet.lora_patches.items():
n_vals = []
lvals = [val for val in vals if val[1][0] in LORAS]
for s, v, m, l, e in zip(self.startsf, lvals, self.uf, self.lf, self.ef):
if s is not None and s == params.sampling_step:
ratio, errormodules = ratiodealer(key.replace(".","_"), l, e)
n_vals.append((ratio * m, *v[1:]))
else:
n_vals.append(v)
shared.sd_model.forge_objects.unet.lora_patches[key] = n_vals
shared.sd_model.forge_objects.unet.forge_unpatch_model(target_device=devices.device)
for m, l, e, s, lora_patches in zip(self.uf, self.lf, self.ef, self.startsf, list(shared.sd_model.forge_objects.unet.lora_patches.values())):
for key, vals in lora_patches.items():
n_vals = []
for v in [v for v in vals if v[1][0] in LORAS]:
if s is not None and s == params.sampling_step:
ratio, _ = ratiodealer(key.replace(".","_"), l, e)
n_vals.append((ratio * m, *v[1:]))
else:
n_vals.append(v)
lora_patches[key] = n_vals
shared.sd_model.forge_objects.unet.forge_patch_model()
if params.sampling_step in self.stopsf:
shared.sd_model.forge_objects.unet.forge_unpatch_model(device_to=devices.device)
for key, vals in shared.sd_model.forge_objects.unet.lora_patches.items():
n_vals = []
lvals = [val for val in vals if val[1][0] in LORAS]
for s, v, m, l, e in zip(self.stopsf, lvals, self.uf, self.lf, self.ef):
if s is not None and s == params.sampling_step:
n_vals.append((0, *v[1:]))
else:
n_vals.append(v)
shared.sd_model.forge_objects.unet.lora_patches[key] = n_vals
shared.sd_model.forge_objects.unet.forge_unpatch_model(target_device=devices.device)
for m, l, e, s, lora_patches in zip(self.uf, self.lf, self.ef, self.stopsf, list(shared.sd_model.forge_objects.unet.lora_patches.values())):
for key, vals in lora_patches.items():
n_vals = []
for v in [v for v in vals if v[1][0] in LORAS]:
if s is not None and s == params.sampling_step:
n_vals.append((0, *v[1:]))
else:
n_vals.append(v)
lora_patches[key] = n_vals
shared.sd_model.forge_objects.unet.forge_patch_model()
elif self.active:
@ -972,7 +972,13 @@ def load_loras_blocks(self, names, lwei,te,unet,elements,ltype = "lora", starts
setall(loaded,te[n],unet[n])
elif "forge" == ltype:
lbwf(te, unet, lwei, elements, starts)
lora_patches = shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches
lbwf(lora_patches, unet, lwei, elements, starts,
lambda r, m, s: r * m if s is None else 0)
lora_patches = shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches
lbwf(lora_patches, te, lwei, elements, starts,
lambda r, m, _: r * m)
try:
import lora_ctl_network as ctl
@ -1136,49 +1142,38 @@ def lbw(lora,lwei,elemental):
LORAS = ["lora", "loha", "lokr"]
def lbwf(mt, mu, lwei, elemental, starts):
def lbwf(after_applying_lora_patches, ms, lwei, elements, starts, func_ratio):
errormodules = []
after_applying_unet_lora_patches = shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches
hashes = []
for m, hash in zip(mu, after_applying_unet_lora_patches.keys()):
lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in lwei[0]])
hashes.append((hash, (hash[0], lbw_key, *hash[2:])))
for hash, new_hash in hashes:
after_applying_unet_lora_patches[new_hash] = after_applying_unet_lora_patches[hash]
del after_applying_unet_lora_patches[hash]
for m, l, e, s, (hash, lora_patches) in zip(mu, lwei, elemental, starts, after_applying_unet_lora_patches.items()):
dict_lora_patches = dict(after_applying_lora_patches.items())
for m, l, e, s, hash in zip(ms, lwei, elements, starts, list(shared.sd_model.forge_objects.unet.lora_patches.keys())):
lora_patches = None
for k, v in dict_lora_patches.items():
if k[0] == hash[0]:
hash = k
lora_patches = v
del dict_lora_patches[k]
break
if lora_patches is None:
continue
for key, vals in lora_patches.items():
n_vals = []
lvs = [v for v in vals if v[1][0] in LORAS]
for v in lvs:
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
n_vals.append([ratio * m if s is None else 0, *v[1:]])
if errormodule:errormodules.append(errormodule)
n_vals.append([func_ratio(ratio, m, s), *v[1:]])
if errormodule:
errormodules.append(errormodule)
lora_patches[key] = n_vals
# print("[DEBUG]", hash[0], *[n_val[0] for n_val in n_vals])
after_applying_clip_lora_patches = shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches
hashes = []
for m, hash in zip(mt, after_applying_clip_lora_patches.keys()):
lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in lwei[0]])
hashes.append((hash, (hash[0], lbw_key, *hash[2:])))
for hash, new_hash in hashes:
after_applying_clip_lora_patches[new_hash] = after_applying_clip_lora_patches[hash]
del after_applying_clip_lora_patches[hash]
new_hash = (hash[0], lbw_key, *hash[2:])
after_applying_lora_patches[new_hash] = after_applying_lora_patches[hash]
del after_applying_lora_patches[hash]
for m, l, e, s, (hash, lora_patches) in zip(mu, lwei, elemental, starts, after_applying_clip_lora_patches.items()):
for key, vals in lora_patches.items():
n_vals = []
lvs = [v for v in vals if v[1][0] in LORAS]
for v in lvs:
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
n_vals.append([ratio * m, *v[1:]])
if errormodule:errormodules.append(errormodule)
lora_patches[key] = n_vals
if len(errormodules) > 0:
print("Unknown modules:",errormodules)
print("Unknown modules:", errormodules)
def ratiodealer(key, lwei, elemental):
ratio = 1