Support new forge
Made initial changes to ensure compatibility with new forge. The code now runs, but further testing and improvements may be necessary.pull/169/head
parent
4d94d247a6
commit
f2b307d98a
|
|
@ -44,7 +44,7 @@ xyelem = ""
|
|||
princ = False
|
||||
|
||||
try:
|
||||
from ldm_patched.modules import model_management
|
||||
from modules_forge import forge_version
|
||||
forge = True
|
||||
except:
|
||||
forge = False
|
||||
|
|
@ -441,8 +441,8 @@ class Script(modules.scripts.Script):
|
|||
|
||||
if forge and self.active:
|
||||
if params.sampling_step in self.startsf:
|
||||
shared.sd_model.forge_objects.unet.unpatch_model(device_to=devices.device)
|
||||
for key, vals in shared.sd_model.forge_objects.unet.patches.items():
|
||||
shared.sd_model.forge_objects.unet.forge_unpatch_model(device_to=devices.device)
|
||||
for key, vals in shared.sd_model.forge_objects.unet.lora_patches.items():
|
||||
n_vals = []
|
||||
lvals = [val for val in vals if val[1][0] in LORAS]
|
||||
for s, v, m, l, e in zip(self.startsf, lvals, self.uf, self.lf, self.ef):
|
||||
|
|
@ -451,12 +451,12 @@ class Script(modules.scripts.Script):
|
|||
n_vals.append((ratio * m, *v[1:]))
|
||||
else:
|
||||
n_vals.append(v)
|
||||
shared.sd_model.forge_objects.unet.patches[key] = n_vals
|
||||
shared.sd_model.forge_objects.unet.patch_model()
|
||||
shared.sd_model.forge_objects.unet.lora_patches[key] = n_vals
|
||||
shared.sd_model.forge_objects.unet.forge_patch_model()
|
||||
|
||||
if params.sampling_step in self.stopsf:
|
||||
shared.sd_model.forge_objects.unet.unpatch_model(device_to=devices.device)
|
||||
for key, vals in shared.sd_model.forge_objects.unet.patches.items():
|
||||
shared.sd_model.forge_objects.unet.forge_unpatch_model(device_to=devices.device)
|
||||
for key, vals in shared.sd_model.forge_objects.unet.lora_patches.items():
|
||||
n_vals = []
|
||||
lvals = [val for val in vals if val[1][0] in LORAS]
|
||||
for s, v, m, l, e in zip(self.stopsf, lvals, self.uf, self.lf, self.ef):
|
||||
|
|
@ -464,8 +464,8 @@ class Script(modules.scripts.Script):
|
|||
n_vals.append((0, *v[1:]))
|
||||
else:
|
||||
n_vals.append(v)
|
||||
shared.sd_model.forge_objects.unet.patches[key] = n_vals
|
||||
shared.sd_model.forge_objects.unet.patch_model()
|
||||
shared.sd_model.forge_objects.unet.lora_patches[key] = n_vals
|
||||
shared.sd_model.forge_objects.unet.forge_patch_model()
|
||||
|
||||
elif self.active:
|
||||
if self.starts and params.sampling_step == 0:
|
||||
|
|
@ -515,20 +515,20 @@ class Script(modules.scripts.Script):
|
|||
if not useblocks:
|
||||
return
|
||||
lora = importer(self)
|
||||
emb_db = sd_hijack.model_hijack.embedding_db
|
||||
emb_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||
|
||||
for net in lora.loaded_loras:
|
||||
if hasattr(net,"bundle_embeddings"):
|
||||
for emb_name, embedding in net.bundle_embeddings.items():
|
||||
for embedding in net.bundle_embeddings.values():
|
||||
if embedding.loaded:
|
||||
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
|
||||
emb_db.register_embedding(embedding)
|
||||
|
||||
lora.loaded_loras.clear()
|
||||
|
||||
if forge:
|
||||
sd_models.model_data.get_sd_model().current_lora_hash = None
|
||||
shared.sd_model.forge_objects_after_applying_lora.unet.unpatch_model()
|
||||
shared.sd_model.forge_objects_after_applying_lora.clip.patcher.unpatch_model()
|
||||
shared.sd_model.forge_objects_after_applying_lora.unet.forge_unpatch_model()
|
||||
shared.sd_model.forge_objects_after_applying_lora.clip.patcher.forge_unpatch_model()
|
||||
|
||||
global lxyz,lzyx,xyelem
|
||||
lxyz = lzyx = xyelem = ""
|
||||
|
|
@ -1137,24 +1137,26 @@ def lbw(lora,lwei,elemental):
|
|||
LORAS = ["lora", "loha", "lokr"]
|
||||
|
||||
def lbwf(mt, mu, lwei, elemental, starts):
|
||||
for key, vals in shared.sd_model.forge_objects_after_applying_lora.unet.patches.items():
|
||||
n_vals = []
|
||||
errormodules = []
|
||||
lvals = [val for val in vals if val[1][0] in LORAS]
|
||||
for v, m, l, e ,s in zip(lvals, mu, lwei, elemental, starts):
|
||||
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
|
||||
n_vals.append((ratio * m if s is None else 0, *v[1:]))
|
||||
if errormodule:errormodules.append(errormodule)
|
||||
shared.sd_model.forge_objects_after_applying_lora.unet.patches[key] = n_vals
|
||||
for lora_patches in shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches.values():
|
||||
for key, vals in lora_patches.items():
|
||||
n_vals = []
|
||||
errormodules = []
|
||||
lvs = [v for v in vals if v[1][0] in LORAS]
|
||||
for v, m, l, e ,s in zip(lvs, mu, lwei, elemental, starts):
|
||||
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
|
||||
n_vals.append([ratio * m if s is None else 0, *v[1:]])
|
||||
if errormodule:errormodules.append(errormodule)
|
||||
lora_patches[key] = n_vals
|
||||
|
||||
for key, vals in shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches.items():
|
||||
n_vals = []
|
||||
lvals = [val for val in vals if val[1][0] in LORAS]
|
||||
for v, m, l, e in zip(lvals, mt, lwei, elemental):
|
||||
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
|
||||
n_vals.append((ratio * m, *v[1:]))
|
||||
if errormodule:errormodules.append(errormodule)
|
||||
shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches[key] = n_vals
|
||||
for lora_patches in shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches.values():
|
||||
for key, vals in lora_patches.items():
|
||||
n_vals = []
|
||||
lvs = [v for v in vals if v[1][0] in LORAS]
|
||||
for v, m, l, e in zip(lvs, mt, lwei, elemental):
|
||||
ratio, errormodule = ratiodealer(key.replace(".","_"), l, e)
|
||||
n_vals.append([ratio * m, *v[1:]])
|
||||
if errormodule:errormodules.append(errormodule)
|
||||
lora_patches[key] = n_vals
|
||||
|
||||
if len(errormodules) > 0:
|
||||
print("Unknown modules:",errormodules)
|
||||
|
|
|
|||
Loading…
Reference in New Issue