fix for lora block weight
parent
fdf99f04dd
commit
8acfd26036
|
|
@ -470,6 +470,7 @@ class LoRARegioner:
|
|||
self.stop_hr = stop_hr
|
||||
self.stopped = False
|
||||
self.stopped_hr = False
|
||||
self.orig_weight = {}
|
||||
|
||||
try:
|
||||
import lora_ctl_network as ctl
|
||||
|
|
@ -568,6 +569,7 @@ class LoRARegioner:
|
|||
self.u_count = 0
|
||||
self.stopped = False
|
||||
self.stopped_hr = False
|
||||
self.orig_weight = {}
|
||||
|
||||
def te_start_f(self):
|
||||
self.mlist = self.te_llist[self.te_count % len(self.te_llist)]
|
||||
|
|
@ -585,9 +587,12 @@ class LoRARegioner:
|
|||
for lora_key, patch in lora_patches.items():
|
||||
for list_key in self.mlist:
|
||||
if list_key in lora_key[0]:
|
||||
if labug: print(f"LoRA {lora_key} detected in {self.mlist}")
|
||||
if labug:
|
||||
print(f"LoRA {lora_key} detected in {self.mlist}")
|
||||
for patch_key in patch:
|
||||
patch[patch_key][0][0] = self.mlist[list_key]
|
||||
if patch_key + list_key not in self.orig_weight:
|
||||
self.orig_weight[patch_key + list_key] = patch[patch_key][0][0]
|
||||
patch[patch_key][0][0] = self.orig_weight[patch_key + list_key] * self.mlist[list_key]
|
||||
|
||||
refresh(lora_lorader, lora_patches=lora_patches, offload_device=offload_device)
|
||||
|
||||
|
|
@ -601,21 +606,18 @@ class LoRARegioner:
|
|||
strengths = list(self.mlist.values())
|
||||
|
||||
def set_strengths(strengths):
|
||||
all = strengths == 0
|
||||
for name, module in shared.sd_model.forge_objects.unet.model.named_modules():
|
||||
patches = getattr(module, 'forge_online_loras', None)
|
||||
weight_patches, bias_patches = None, None
|
||||
if patches is not None:
|
||||
weight_patches = patches.get('weight', None)
|
||||
if weight_patches:
|
||||
if all:
|
||||
for i in range(len(weight_patches)):
|
||||
weight_patches[i][0] = 0
|
||||
else:
|
||||
if len(weight_patches) != len(strengths) :
|
||||
continue
|
||||
for i in range(len(strengths)):
|
||||
weight_patches[i][0] = strengths[i]
|
||||
if len(weight_patches) != len(strengths) :
|
||||
continue
|
||||
for i in range(len(strengths)):
|
||||
if name not in self.orig_weight:
|
||||
self.orig_weight[name] = [x[0] for x in weight_patches]
|
||||
weight_patches[i][0] = strengths[i] * self.orig_weight[name][i]
|
||||
|
||||
stopstep = self.stop_hr if in_hr else self.stop
|
||||
if self.step >= stopstep:
|
||||
|
|
|
|||
|
|
@ -710,7 +710,7 @@ def commondealer(p, usecom, usencom, flip):
|
|||
ppl = ppl[1:]
|
||||
if flip:
|
||||
ppl = ppl[::-1]
|
||||
prompt = f"{KEYBRK} ".join(ppl)
|
||||
prompt = f" {KEYBRK} ".join(ppl)
|
||||
return prompt
|
||||
|
||||
if usecom:
|
||||
|
|
|
|||
Loading…
Reference in New Issue