Merge pull request #173 from nihedon/support_forge

Support forge
main
hako-mikan 2024-10-14 19:37:40 +09:00 committed by GitHub
commit 2322d64101
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 44 additions and 25 deletions

View File

@ -12,6 +12,7 @@ import numpy as np
import gradio as gr
import os.path
import random
import time
from pprint import pprint
import modules.ui
import modules.scripts as scripts
@ -440,32 +441,52 @@ class Script(modules.scripts.Script):
sets.append(key)
if forge and self.active:
if params.sampling_step in self.startsf:
shared.sd_model.forge_objects.unet.forge_unpatch_model(target_device=devices.device)
for m, l, e, s, lora_patches in zip(self.uf, self.lf, self.ef, self.startsf, list(shared.sd_model.forge_objects.unet.lora_patches.values())):
for key, vals in lora_patches.items():
def apply_weight(stop = False):
if not stop:
flag_step = self.startsf
else:
flag_step = self.stopsf
lora_patches = shared.sd_model.forge_objects.unet.lora_patches
refresh_keys = {}
for m, l, e, s, (patch_key, lora_patch) in zip(self.uf, self.lf, self.ef, flag_step, list(lora_patches.items())):
refresh = False
for key, vals in lora_patch.items():
n_vals = []
for v in [v for v in vals if v[1][0] in LORAS]:
if s is not None and s == params.sampling_step:
ratio, _ = ratiodealer(key.replace(".","_"), l, e)
n_vals.append((ratio * m, *v[1:]))
if not stop:
ratio, _ = ratiodealer(key.replace(".","_"), l, e)
n_vals.append((ratio * m, *v[1:]))
else:
n_vals.append((0, *v[1:]))
refresh = True
else:
n_vals.append(v)
lora_patches[key] = n_vals
shared.sd_model.forge_objects.unet.forge_patch_model()
lora_patch[key] = n_vals
if refresh:
refresh_keys[patch_key] = None
if len(refresh_keys):
for refresh_key in list(refresh_keys.keys()):
patch = lora_patches[refresh_key]
del lora_patches[refresh_key]
new_key = (f"{refresh_key[0]}_{str(time.time())}", *refresh_key[1:])
refresh_keys[refresh_key] = new_key
lora_patches[new_key] = patch
shared.sd_model.forge_objects.unet.refresh_loras()
for refresh_key, new_key in list(refresh_keys.items()):
patch = lora_patches[new_key]
del lora_patches[new_key]
lora_patches[refresh_key] = patch
if params.sampling_step in self.startsf:
apply_weight()
if params.sampling_step in self.stopsf:
shared.sd_model.forge_objects.unet.forge_unpatch_model(target_device=devices.device)
for m, l, e, s, lora_patches in zip(self.uf, self.lf, self.ef, self.stopsf, list(shared.sd_model.forge_objects.unet.lora_patches.values())):
for key, vals in lora_patches.items():
n_vals = []
for v in [v for v in vals if v[1][0] in LORAS]:
if s is not None and s == params.sampling_step:
n_vals.append((0, *v[1:]))
else:
n_vals.append(v)
lora_patches[key] = n_vals
shared.sd_model.forge_objects.unet.forge_patch_model()
apply_weight(stop=True)
elif self.active:
if self.starts and params.sampling_step == 0:
@ -973,12 +994,10 @@ def load_loras_blocks(self, names, lwei,te,unet,elements,ltype = "lora", starts
elif "forge" == ltype:
lora_patches = shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches
lbwf(lora_patches, unet, lwei, elements, starts,
lambda r, m, s: r * m if s is None else 0)
lbwf(lora_patches, unet, lwei, elements, starts)
lora_patches = shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches
lbwf(lora_patches, te, lwei, elements, starts,
lambda r, m, _: r * m)
lbwf(lora_patches, te, lwei, elements, starts)
try:
import lora_ctl_network as ctl
@ -1142,7 +1161,7 @@ def lbw(lora,lwei,elemental):
LORAS = ["lora", "loha", "lokr"]
def lbwf(after_applying_lora_patches, ms, lwei, elements, starts, func_ratio):
def lbwf(after_applying_lora_patches, ms, lwei, elements, starts):
errormodules = []
dict_lora_patches = dict(after_applying_lora_patches.items())
for m, l, e, s, hash in zip(ms, lwei, elements, starts, list(shared.sd_model.forge_objects.unet.lora_patches.keys())):
@ -1160,7 +1179,7 @@ def lbwf(after_applying_lora_patches, ms, lwei, elements, starts, func_ratio):
lvs = [v for v in vals if v[1][0] in LORAS]
for v in lvs:
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
n_vals.append([func_ratio(ratio, m, s), *v[1:]])
n_vals.append((ratio * m if s is None or s == 0 else 0, *v[1:]))
if errormodule:
errormodules.append(errormodule)
lora_patches[key] = n_vals