bug fix
Fixed issue where LBW assignments were misaligned when LoRAs without CLIP or UNet were present.pull/169/head
parent
42d9b65551
commit
29b8600cbb
|
|
@ -441,30 +441,30 @@ class Script(modules.scripts.Script):
|
|||
|
||||
if forge and self.active:
|
||||
if params.sampling_step in self.startsf:
|
||||
shared.sd_model.forge_objects.unet.forge_unpatch_model(device_to=devices.device)
|
||||
for key, vals in shared.sd_model.forge_objects.unet.lora_patches.items():
|
||||
n_vals = []
|
||||
lvals = [val for val in vals if val[1][0] in LORAS]
|
||||
for s, v, m, l, e in zip(self.startsf, lvals, self.uf, self.lf, self.ef):
|
||||
if s is not None and s == params.sampling_step:
|
||||
ratio, errormodules = ratiodealer(key.replace(".","_"), l, e)
|
||||
n_vals.append((ratio * m, *v[1:]))
|
||||
else:
|
||||
n_vals.append(v)
|
||||
shared.sd_model.forge_objects.unet.lora_patches[key] = n_vals
|
||||
shared.sd_model.forge_objects.unet.forge_unpatch_model(target_device=devices.device)
|
||||
for m, l, e, s, lora_patches in zip(self.uf, self.lf, self.ef, self.startsf, list(shared.sd_model.forge_objects.unet.lora_patches.values())):
|
||||
for key, vals in lora_patches.items():
|
||||
n_vals = []
|
||||
for v in [v for v in vals if v[1][0] in LORAS]:
|
||||
if s is not None and s == params.sampling_step:
|
||||
ratio, _ = ratiodealer(key.replace(".","_"), l, e)
|
||||
n_vals.append((ratio * m, *v[1:]))
|
||||
else:
|
||||
n_vals.append(v)
|
||||
lora_patches[key] = n_vals
|
||||
shared.sd_model.forge_objects.unet.forge_patch_model()
|
||||
|
||||
if params.sampling_step in self.stopsf:
|
||||
shared.sd_model.forge_objects.unet.forge_unpatch_model(device_to=devices.device)
|
||||
for key, vals in shared.sd_model.forge_objects.unet.lora_patches.items():
|
||||
n_vals = []
|
||||
lvals = [val for val in vals if val[1][0] in LORAS]
|
||||
for s, v, m, l, e in zip(self.stopsf, lvals, self.uf, self.lf, self.ef):
|
||||
if s is not None and s == params.sampling_step:
|
||||
n_vals.append((0, *v[1:]))
|
||||
else:
|
||||
n_vals.append(v)
|
||||
shared.sd_model.forge_objects.unet.lora_patches[key] = n_vals
|
||||
shared.sd_model.forge_objects.unet.forge_unpatch_model(target_device=devices.device)
|
||||
for m, l, e, s, lora_patches in zip(self.uf, self.lf, self.ef, self.stopsf, list(shared.sd_model.forge_objects.unet.lora_patches.values())):
|
||||
for key, vals in lora_patches.items():
|
||||
n_vals = []
|
||||
for v in [v for v in vals if v[1][0] in LORAS]:
|
||||
if s is not None and s == params.sampling_step:
|
||||
n_vals.append((0, *v[1:]))
|
||||
else:
|
||||
n_vals.append(v)
|
||||
lora_patches[key] = n_vals
|
||||
shared.sd_model.forge_objects.unet.forge_patch_model()
|
||||
|
||||
elif self.active:
|
||||
|
|
@ -972,7 +972,13 @@ def load_loras_blocks(self, names, lwei,te,unet,elements,ltype = "lora", starts
|
|||
setall(loaded,te[n],unet[n])
|
||||
|
||||
elif "forge" == ltype:
|
||||
lbwf(te, unet, lwei, elements, starts)
|
||||
lora_patches = shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches
|
||||
lbwf(lora_patches, unet, lwei, elements, starts,
|
||||
lambda r, m, s: r * m if s is None else 0)
|
||||
|
||||
lora_patches = shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches
|
||||
lbwf(lora_patches, te, lwei, elements, starts,
|
||||
lambda r, m, _: r * m)
|
||||
|
||||
try:
|
||||
import lora_ctl_network as ctl
|
||||
|
|
@ -1136,49 +1142,38 @@ def lbw(lora,lwei,elemental):
|
|||
|
||||
LORAS = ["lora", "loha", "lokr"]
|
||||
|
||||
def lbwf(mt, mu, lwei, elemental, starts):
|
||||
def lbwf(after_applying_lora_patches, ms, lwei, elements, starts, func_ratio):
|
||||
errormodules = []
|
||||
|
||||
after_applying_unet_lora_patches = shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches
|
||||
hashes = []
|
||||
for m, hash in zip(mu, after_applying_unet_lora_patches.keys()):
|
||||
lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in lwei[0]])
|
||||
hashes.append((hash, (hash[0], lbw_key, *hash[2:])))
|
||||
for hash, new_hash in hashes:
|
||||
after_applying_unet_lora_patches[new_hash] = after_applying_unet_lora_patches[hash]
|
||||
del after_applying_unet_lora_patches[hash]
|
||||
|
||||
for m, l, e, s, (hash, lora_patches) in zip(mu, lwei, elemental, starts, after_applying_unet_lora_patches.items()):
|
||||
dict_lora_patches = dict(after_applying_lora_patches.items())
|
||||
for m, l, e, s, hash in zip(ms, lwei, elements, starts, list(shared.sd_model.forge_objects.unet.lora_patches.keys())):
|
||||
lora_patches = None
|
||||
for k, v in dict_lora_patches.items():
|
||||
if k[0] == hash[0]:
|
||||
hash = k
|
||||
lora_patches = v
|
||||
del dict_lora_patches[k]
|
||||
break
|
||||
if lora_patches is None:
|
||||
continue
|
||||
for key, vals in lora_patches.items():
|
||||
n_vals = []
|
||||
lvs = [v for v in vals if v[1][0] in LORAS]
|
||||
for v in lvs:
|
||||
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
|
||||
n_vals.append([ratio * m if s is None else 0, *v[1:]])
|
||||
if errormodule:errormodules.append(errormodule)
|
||||
n_vals.append([func_ratio(ratio, m, s), *v[1:]])
|
||||
if errormodule:
|
||||
errormodules.append(errormodule)
|
||||
lora_patches[key] = n_vals
|
||||
# print("[DEBUG]", hash[0], *[n_val[0] for n_val in n_vals])
|
||||
|
||||
after_applying_clip_lora_patches = shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches
|
||||
hashes = []
|
||||
for m, hash in zip(mt, after_applying_clip_lora_patches.keys()):
|
||||
lbw_key = ",".join([str(m)] + [str(int(w) if type(w) is int or w.is_integer() else float(w)) for w in lwei[0]])
|
||||
hashes.append((hash, (hash[0], lbw_key, *hash[2:])))
|
||||
for hash, new_hash in hashes:
|
||||
after_applying_clip_lora_patches[new_hash] = after_applying_clip_lora_patches[hash]
|
||||
del after_applying_clip_lora_patches[hash]
|
||||
new_hash = (hash[0], lbw_key, *hash[2:])
|
||||
|
||||
after_applying_lora_patches[new_hash] = after_applying_lora_patches[hash]
|
||||
del after_applying_lora_patches[hash]
|
||||
|
||||
for m, l, e, s, (hash, lora_patches) in zip(mu, lwei, elemental, starts, after_applying_clip_lora_patches.items()):
|
||||
for key, vals in lora_patches.items():
|
||||
n_vals = []
|
||||
lvs = [v for v in vals if v[1][0] in LORAS]
|
||||
for v in lvs:
|
||||
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
|
||||
n_vals.append([ratio * m, *v[1:]])
|
||||
if errormodule:errormodules.append(errormodule)
|
||||
lora_patches[key] = n_vals
|
||||
|
||||
if len(errormodules) > 0:
|
||||
print("Unknown modules:",errormodules)
|
||||
print("Unknown modules:", errormodules)
|
||||
|
||||
def ratiodealer(key, lwei, elemental):
|
||||
ratio = 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue