fix for lora block weight

main
hako-mikan 2025-01-31 19:07:35 +09:00
parent fdf99f04dd
commit 8acfd26036
2 changed files with 14 additions and 12 deletions

View File

@ -470,6 +470,7 @@ class LoRARegioner:
self.stop_hr = stop_hr
self.stopped = False
self.stopped_hr = False
self.orig_weight = {}
try:
import lora_ctl_network as ctl
@ -568,6 +569,7 @@ class LoRARegioner:
self.u_count = 0
self.stopped = False
self.stopped_hr = False
self.orig_weight = {}
def te_start_f(self):
self.mlist = self.te_llist[self.te_count % len(self.te_llist)]
@ -585,9 +587,12 @@ class LoRARegioner:
for lora_key, patch in lora_patches.items():
for list_key in self.mlist:
if list_key in lora_key[0]:
if labug: print(f"LoRA {lora_key} detected in {self.mlist}")
if labug:
print(f"LoRA {lora_key} detected in {self.mlist}")
for patch_key in patch:
patch[patch_key][0][0] = self.mlist[list_key]
if patch_key + list_key not in self.orig_weight:
self.orig_weight[patch_key + list_key] = patch[patch_key][0][0]
patch[patch_key][0][0] = self.orig_weight[patch_key + list_key] * self.mlist[list_key]
refresh(lora_lorader, lora_patches=lora_patches, offload_device=offload_device)
@ -601,21 +606,18 @@ class LoRARegioner:
strengths = list(self.mlist.values())
def set_strengths(strengths):
all = strengths == 0
for name, module in shared.sd_model.forge_objects.unet.model.named_modules():
patches = getattr(module, 'forge_online_loras', None)
weight_patches, bias_patches = None, None
if patches is not None:
weight_patches = patches.get('weight', None)
if weight_patches:
if all:
for i in range(len(weight_patches)):
weight_patches[i][0] = 0
else:
if len(weight_patches) != len(strengths) :
continue
for i in range(len(strengths)):
weight_patches[i][0] = strengths[i]
if len(weight_patches) != len(strengths) :
continue
for i in range(len(strengths)):
if name not in self.orig_weight:
self.orig_weight[name] = [x[0] for x in weight_patches]
weight_patches[i][0] = strengths[i] * self.orig_weight[name][i]
stopstep = self.stop_hr if in_hr else self.stop
if self.step >= stopstep:

View File

@ -710,7 +710,7 @@ def commondealer(p, usecom, usencom, flip):
ppl = ppl[1:]
if flip:
ppl = ppl[::-1]
prompt = f"{KEYBRK} ".join(ppl)
prompt = f" {KEYBRK} ".join(ppl)
return prompt
if usecom: