diff --git a/scripts/lora_block_weight.py b/scripts/lora_block_weight.py index d9adf5d..fe66fc9 100644 --- a/scripts/lora_block_weight.py +++ b/scripts/lora_block_weight.py @@ -12,6 +12,7 @@ import numpy as np import gradio as gr import os.path import random +import time from pprint import pprint import modules.ui import modules.scripts as scripts @@ -440,32 +441,52 @@ class Script(modules.scripts.Script): sets.append(key) if forge and self.active: - if params.sampling_step in self.startsf: - shared.sd_model.forge_objects.unet.forge_unpatch_model(target_device=devices.device) - for m, l, e, s, lora_patches in zip(self.uf, self.lf, self.ef, self.startsf, list(shared.sd_model.forge_objects.unet.lora_patches.values())): - for key, vals in lora_patches.items(): + def apply_weight(stop = False): + if not stop: + flag_step = self.startsf + else: + flag_step = self.stopsf + + lora_patches = shared.sd_model.forge_objects.unet.lora_patches + refresh_keys = {} + for m, l, e, s, (patch_key, lora_patch) in zip(self.uf, self.lf, self.ef, flag_step, list(lora_patches.items())): + refresh = False + for key, vals in lora_patch.items(): n_vals = [] for v in [v for v in vals if v[1][0] in LORAS]: if s is not None and s == params.sampling_step: - ratio, _ = ratiodealer(key.replace(".","_"), l, e) - n_vals.append((ratio * m, *v[1:])) + if not stop: + ratio, _ = ratiodealer(key.replace(".","_"), l, e) + n_vals.append((ratio * m, *v[1:])) + else: + n_vals.append((0, *v[1:])) + refresh = True else: n_vals.append(v) - lora_patches[key] = n_vals - shared.sd_model.forge_objects.unet.forge_patch_model() + lora_patch[key] = n_vals + if refresh: + refresh_keys[patch_key] = None + + if len(refresh_keys): + for refresh_key in list(refresh_keys.keys()): + patch = lora_patches[refresh_key] + del lora_patches[refresh_key] + new_key = (f"{refresh_key[0]}_{str(time.time())}", *refresh_key[1:]) + refresh_keys[refresh_key] = new_key + lora_patches[new_key] = patch + + shared.sd_model.forge_objects.unet.refresh_loras() + + for refresh_key, new_key in list(refresh_keys.items()): + patch = lora_patches[new_key] + del lora_patches[new_key] + lora_patches[refresh_key] = patch + + if params.sampling_step in self.startsf: + apply_weight() if params.sampling_step in self.stopsf: - shared.sd_model.forge_objects.unet.forge_unpatch_model(target_device=devices.device) - for m, l, e, s, lora_patches in zip(self.uf, self.lf, self.ef, self.stopsf, list(shared.sd_model.forge_objects.unet.lora_patches.values())): - for key, vals in lora_patches.items(): - n_vals = [] - for v in [v for v in vals if v[1][0] in LORAS]: - if s is not None and s == params.sampling_step: - n_vals.append((0, *v[1:])) - else: - n_vals.append(v) - lora_patches[key] = n_vals - shared.sd_model.forge_objects.unet.forge_patch_model() + apply_weight(stop=True) elif self.active: if self.starts and params.sampling_step == 0: @@ -973,12 +994,10 @@ def load_loras_blocks(self, names, lwei,te,unet,elements,ltype = "lora", starts elif "forge" == ltype: lora_patches = shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches - lbwf(lora_patches, unet, lwei, elements, starts, - lambda r, m, s: r * m if s is None else 0) + lbwf(lora_patches, unet, lwei, elements, starts) lora_patches = shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches - lbwf(lora_patches, te, lwei, elements, starts, - lambda r, m, _: r * m) + lbwf(lora_patches, te, lwei, elements, starts) try: import lora_ctl_network as ctl @@ -1142,7 +1161,7 @@ def lbw(lora,lwei,elemental): LORAS = ["lora", "loha", "lokr"] -def lbwf(after_applying_lora_patches, ms, lwei, elements, starts, func_ratio): +def lbwf(after_applying_lora_patches, ms, lwei, elements, starts): errormodules = [] dict_lora_patches = dict(after_applying_lora_patches.items()) for m, l, e, s, hash in zip(ms, lwei, elements, starts, list(shared.sd_model.forge_objects.unet.lora_patches.keys())): @@ -1160,7 +1179,7 @@ def lbwf(after_applying_lora_patches, ms, lwei, elements, starts, func_ratio): lvs = [v for v in vals if v[1][0] in LORAS] for v in lvs: ratio, errormodule = ratiodealer(key.replace(".","_"), l, e) - n_vals.append([func_ratio(ratio, m, s), *v[1:]]) + n_vals.append((ratio * m if s is None or s == 0 else 0, *v[1:])) if errormodule: errormodules.append(errormodule) lora_patches[key] = n_vals