Support new forge

Made initial changes to ensure compatibility with new forge. The code now runs, but further testing and improvements may be necessary.
pull/169/head
takahiro-nihei 2024-09-04 16:44:13 +09:00
parent 4d94d247a6
commit f2b307d98a
1 changed files with 33 additions and 31 deletions

View File

@ -44,7 +44,7 @@ xyelem = ""
princ = False
try:
from ldm_patched.modules import model_management
from modules_forge import forge_version
forge = True
except:
forge = False
@ -441,8 +441,8 @@ class Script(modules.scripts.Script):
if forge and self.active:
if params.sampling_step in self.startsf:
shared.sd_model.forge_objects.unet.unpatch_model(device_to=devices.device)
for key, vals in shared.sd_model.forge_objects.unet.patches.items():
shared.sd_model.forge_objects.unet.forge_unpatch_model(device_to=devices.device)
for key, vals in shared.sd_model.forge_objects.unet.lora_patches.items():
n_vals = []
lvals = [val for val in vals if val[1][0] in LORAS]
for s, v, m, l, e in zip(self.startsf, lvals, self.uf, self.lf, self.ef):
@ -451,12 +451,12 @@ class Script(modules.scripts.Script):
n_vals.append((ratio * m, *v[1:]))
else:
n_vals.append(v)
shared.sd_model.forge_objects.unet.patches[key] = n_vals
shared.sd_model.forge_objects.unet.patch_model()
shared.sd_model.forge_objects.unet.lora_patches[key] = n_vals
shared.sd_model.forge_objects.unet.forge_patch_model()
if params.sampling_step in self.stopsf:
shared.sd_model.forge_objects.unet.unpatch_model(device_to=devices.device)
for key, vals in shared.sd_model.forge_objects.unet.patches.items():
shared.sd_model.forge_objects.unet.forge_unpatch_model(device_to=devices.device)
for key, vals in shared.sd_model.forge_objects.unet.lora_patches.items():
n_vals = []
lvals = [val for val in vals if val[1][0] in LORAS]
for s, v, m, l, e in zip(self.stopsf, lvals, self.uf, self.lf, self.ef):
@ -464,8 +464,8 @@ class Script(modules.scripts.Script):
n_vals.append((0, *v[1:]))
else:
n_vals.append(v)
shared.sd_model.forge_objects.unet.patches[key] = n_vals
shared.sd_model.forge_objects.unet.patch_model()
shared.sd_model.forge_objects.unet.lora_patches[key] = n_vals
shared.sd_model.forge_objects.unet.forge_patch_model()
elif self.active:
if self.starts and params.sampling_step == 0:
@ -515,20 +515,20 @@ class Script(modules.scripts.Script):
if not useblocks:
return
lora = importer(self)
emb_db = sd_hijack.model_hijack.embedding_db
emb_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
for net in lora.loaded_loras:
if hasattr(net,"bundle_embeddings"):
for emb_name, embedding in net.bundle_embeddings.items():
for embedding in net.bundle_embeddings.values():
if embedding.loaded:
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
emb_db.register_embedding(embedding)
lora.loaded_loras.clear()
if forge:
sd_models.model_data.get_sd_model().current_lora_hash = None
shared.sd_model.forge_objects_after_applying_lora.unet.unpatch_model()
shared.sd_model.forge_objects_after_applying_lora.clip.patcher.unpatch_model()
shared.sd_model.forge_objects_after_applying_lora.unet.forge_unpatch_model()
shared.sd_model.forge_objects_after_applying_lora.clip.patcher.forge_unpatch_model()
global lxyz,lzyx,xyelem
lxyz = lzyx = xyelem = ""
@ -1137,24 +1137,26 @@ def lbw(lora,lwei,elemental):
LORAS = ["lora", "loha", "lokr"]
def lbwf(mt, mu, lwei, elemental, starts):
for key, vals in shared.sd_model.forge_objects_after_applying_lora.unet.patches.items():
n_vals = []
errormodules = []
lvals = [val for val in vals if val[1][0] in LORAS]
for v, m, l, e ,s in zip(lvals, mu, lwei, elemental, starts):
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
n_vals.append((ratio * m if s is None else 0, *v[1:]))
if errormodule:errormodules.append(errormodule)
shared.sd_model.forge_objects_after_applying_lora.unet.patches[key] = n_vals
for lora_patches in shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches.values():
for key, vals in lora_patches.items():
n_vals = []
errormodules = []
lvs = [v for v in vals if v[1][0] in LORAS]
for v, m, l, e ,s in zip(lvs, mu, lwei, elemental, starts):
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
n_vals.append([ratio * m if s is None else 0, *v[1:]])
if errormodule:errormodules.append(errormodule)
lora_patches[key] = n_vals
for key, vals in shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches.items():
n_vals = []
lvals = [val for val in vals if val[1][0] in LORAS]
for v, m, l, e in zip(lvals, mt, lwei, elemental):
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
n_vals.append((ratio * m, *v[1:]))
if errormodule:errormodules.append(errormodule)
shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches[key] = n_vals
for lora_patches in shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches.values():
for key, vals in lora_patches.items():
n_vals = []
lvs = [v for v in vals if v[1][0] in LORAS]
for v, m, l, e in zip(lvs, mt, lwei, elemental):
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
n_vals.append([ratio * m, *v[1:]])
if errormodule:errormodules.append(errormodule)
lora_patches[key] = n_vals
if len(errormodules) > 0:
print("Unknown modules:",errormodules)