commit
2322d64101
|
|
@ -12,6 +12,7 @@ import numpy as np
|
|||
import gradio as gr
|
||||
import os.path
|
||||
import random
|
||||
import time
|
||||
from pprint import pprint
|
||||
import modules.ui
|
||||
import modules.scripts as scripts
|
||||
|
|
@ -440,32 +441,52 @@ class Script(modules.scripts.Script):
|
|||
sets.append(key)
|
||||
|
||||
if forge and self.active:
|
||||
if params.sampling_step in self.startsf:
|
||||
shared.sd_model.forge_objects.unet.forge_unpatch_model(target_device=devices.device)
|
||||
for m, l, e, s, lora_patches in zip(self.uf, self.lf, self.ef, self.startsf, list(shared.sd_model.forge_objects.unet.lora_patches.values())):
|
||||
for key, vals in lora_patches.items():
|
||||
def apply_weight(stop = False):
|
||||
if not stop:
|
||||
flag_step = self.startsf
|
||||
else:
|
||||
flag_step = self.stopsf
|
||||
|
||||
lora_patches = shared.sd_model.forge_objects.unet.lora_patches
|
||||
refresh_keys = {}
|
||||
for m, l, e, s, (patch_key, lora_patch) in zip(self.uf, self.lf, self.ef, flag_step, list(lora_patches.items())):
|
||||
refresh = False
|
||||
for key, vals in lora_patch.items():
|
||||
n_vals = []
|
||||
for v in [v for v in vals if v[1][0] in LORAS]:
|
||||
if s is not None and s == params.sampling_step:
|
||||
ratio, _ = ratiodealer(key.replace(".","_"), l, e)
|
||||
n_vals.append((ratio * m, *v[1:]))
|
||||
if not stop:
|
||||
ratio, _ = ratiodealer(key.replace(".","_"), l, e)
|
||||
n_vals.append((ratio * m, *v[1:]))
|
||||
else:
|
||||
n_vals.append((0, *v[1:]))
|
||||
refresh = True
|
||||
else:
|
||||
n_vals.append(v)
|
||||
lora_patches[key] = n_vals
|
||||
shared.sd_model.forge_objects.unet.forge_patch_model()
|
||||
lora_patch[key] = n_vals
|
||||
if refresh:
|
||||
refresh_keys[patch_key] = None
|
||||
|
||||
if len(refresh_keys):
|
||||
for refresh_key in list(refresh_keys.keys()):
|
||||
patch = lora_patches[refresh_key]
|
||||
del lora_patches[refresh_key]
|
||||
new_key = (f"{refresh_key[0]}_{str(time.time())}", *refresh_key[1:])
|
||||
refresh_keys[refresh_key] = new_key
|
||||
lora_patches[new_key] = patch
|
||||
|
||||
shared.sd_model.forge_objects.unet.refresh_loras()
|
||||
|
||||
for refresh_key, new_key in list(refresh_keys.items()):
|
||||
patch = lora_patches[new_key]
|
||||
del lora_patches[new_key]
|
||||
lora_patches[refresh_key] = patch
|
||||
|
||||
if params.sampling_step in self.startsf:
|
||||
apply_weight()
|
||||
|
||||
if params.sampling_step in self.stopsf:
|
||||
shared.sd_model.forge_objects.unet.forge_unpatch_model(target_device=devices.device)
|
||||
for m, l, e, s, lora_patches in zip(self.uf, self.lf, self.ef, self.stopsf, list(shared.sd_model.forge_objects.unet.lora_patches.values())):
|
||||
for key, vals in lora_patches.items():
|
||||
n_vals = []
|
||||
for v in [v for v in vals if v[1][0] in LORAS]:
|
||||
if s is not None and s == params.sampling_step:
|
||||
n_vals.append((0, *v[1:]))
|
||||
else:
|
||||
n_vals.append(v)
|
||||
lora_patches[key] = n_vals
|
||||
shared.sd_model.forge_objects.unet.forge_patch_model()
|
||||
apply_weight(stop=True)
|
||||
|
||||
elif self.active:
|
||||
if self.starts and params.sampling_step == 0:
|
||||
|
|
@ -973,12 +994,10 @@ def load_loras_blocks(self, names, lwei,te,unet,elements,ltype = "lora", starts
|
|||
|
||||
elif "forge" == ltype:
|
||||
lora_patches = shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches
|
||||
lbwf(lora_patches, unet, lwei, elements, starts,
|
||||
lambda r, m, s: r * m if s is None else 0)
|
||||
lbwf(lora_patches, unet, lwei, elements, starts)
|
||||
|
||||
lora_patches = shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches
|
||||
lbwf(lora_patches, te, lwei, elements, starts,
|
||||
lambda r, m, _: r * m)
|
||||
lbwf(lora_patches, te, lwei, elements, starts)
|
||||
|
||||
try:
|
||||
import lora_ctl_network as ctl
|
||||
|
|
@ -1142,7 +1161,7 @@ def lbw(lora,lwei,elemental):
|
|||
|
||||
LORAS = ["lora", "loha", "lokr"]
|
||||
|
||||
def lbwf(after_applying_lora_patches, ms, lwei, elements, starts, func_ratio):
|
||||
def lbwf(after_applying_lora_patches, ms, lwei, elements, starts):
|
||||
errormodules = []
|
||||
dict_lora_patches = dict(after_applying_lora_patches.items())
|
||||
for m, l, e, s, hash in zip(ms, lwei, elements, starts, list(shared.sd_model.forge_objects.unet.lora_patches.keys())):
|
||||
|
|
@ -1160,7 +1179,7 @@ def lbwf(after_applying_lora_patches, ms, lwei, elements, starts, func_ratio):
|
|||
lvs = [v for v in vals if v[1][0] in LORAS]
|
||||
for v in lvs:
|
||||
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
|
||||
n_vals.append([func_ratio(ratio, m, s), *v[1:]])
|
||||
n_vals.append((ratio * m if s is None or s == 0 else 0, *v[1:]))
|
||||
if errormodule:
|
||||
errormodules.append(errormodule)
|
||||
lora_patches[key] = n_vals
|
||||
|
|
|
|||
Loading…
Reference in New Issue