diff --git a/scripts/lora_block_weight.py b/scripts/lora_block_weight.py index d766cff..10457c7 100644 --- a/scripts/lora_block_weight.py +++ b/scripts/lora_block_weight.py @@ -44,7 +44,7 @@ xyelem = "" princ = False try: - from ldm_patched.modules import model_management + from modules_forge import forge_version forge = True except: forge = False @@ -441,8 +441,8 @@ class Script(modules.scripts.Script): if forge and self.active: if params.sampling_step in self.startsf: - shared.sd_model.forge_objects.unet.unpatch_model(device_to=devices.device) - for key, vals in shared.sd_model.forge_objects.unet.patches.items(): + shared.sd_model.forge_objects.unet.forge_unpatch_model(device_to=devices.device) + for key, vals in shared.sd_model.forge_objects.unet.lora_patches.items(): n_vals = [] lvals = [val for val in vals if val[1][0] in LORAS] for s, v, m, l, e in zip(self.startsf, lvals, self.uf, self.lf, self.ef): @@ -451,12 +451,12 @@ class Script(modules.scripts.Script): n_vals.append((ratio * m, *v[1:])) else: n_vals.append(v) - shared.sd_model.forge_objects.unet.patches[key] = n_vals - shared.sd_model.forge_objects.unet.patch_model() + shared.sd_model.forge_objects.unet.lora_patches[key] = n_vals + shared.sd_model.forge_objects.unet.forge_patch_model() if params.sampling_step in self.stopsf: - shared.sd_model.forge_objects.unet.unpatch_model(device_to=devices.device) - for key, vals in shared.sd_model.forge_objects.unet.patches.items(): + shared.sd_model.forge_objects.unet.forge_unpatch_model(device_to=devices.device) + for key, vals in shared.sd_model.forge_objects.unet.lora_patches.items(): n_vals = [] lvals = [val for val in vals if val[1][0] in LORAS] for s, v, m, l, e in zip(self.stopsf, lvals, self.uf, self.lf, self.ef): @@ -464,8 +464,8 @@ class Script(modules.scripts.Script): n_vals.append((0, *v[1:])) else: n_vals.append(v) - shared.sd_model.forge_objects.unet.patches[key] = n_vals - shared.sd_model.forge_objects.unet.patch_model() + shared.sd_model.forge_objects.unet.lora_patches[key] = n_vals + shared.sd_model.forge_objects.unet.forge_patch_model() elif self.active: if self.starts and params.sampling_step == 0: @@ -515,20 +515,20 @@ class Script(modules.scripts.Script): if not useblocks: return lora = importer(self) - emb_db = sd_hijack.model_hijack.embedding_db + emb_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() for net in lora.loaded_loras: if hasattr(net,"bundle_embeddings"): - for emb_name, embedding in net.bundle_embeddings.items(): + for embedding in net.bundle_embeddings.values(): if embedding.loaded: - emb_db.register_embedding_by_name(None, shared.sd_model, emb_name) + emb_db.register_embedding(embedding) lora.loaded_loras.clear() if forge: sd_models.model_data.get_sd_model().current_lora_hash = None - shared.sd_model.forge_objects_after_applying_lora.unet.unpatch_model() - shared.sd_model.forge_objects_after_applying_lora.clip.patcher.unpatch_model() + shared.sd_model.forge_objects_after_applying_lora.unet.forge_unpatch_model() + shared.sd_model.forge_objects_after_applying_lora.clip.patcher.forge_unpatch_model() global lxyz,lzyx,xyelem lxyz = lzyx = xyelem = "" @@ -1137,24 +1137,26 @@ def lbw(lora,lwei,elemental): LORAS = ["lora", "loha", "lokr"] def lbwf(mt, mu, lwei, elemental, starts): - for key, vals in shared.sd_model.forge_objects_after_applying_lora.unet.patches.items(): - n_vals = [] - errormodules = [] - lvals = [val for val in vals if val[1][0] in LORAS] - for v, m, l, e ,s in zip(lvals, mu, lwei, elemental, starts): - ratio, errormodule = ratiodealer(key.replace(".","_"), l, e) - n_vals.append((ratio * m if s is None else 0, *v[1:])) - if errormodule:errormodules.append(errormodule) - shared.sd_model.forge_objects_after_applying_lora.unet.patches[key] = n_vals + for lora_patches in shared.sd_model.forge_objects_after_applying_lora.unet.lora_patches.values(): + for key, vals in lora_patches.items(): + n_vals = [] + errormodules = [] + lvs = [v for v in vals if v[1][0] in LORAS] + for v, m, l, e ,s in zip(lvs, mu, lwei, elemental, starts): + ratio, errormodule = ratiodealer(key.replace(".","_"), l, e) + n_vals.append([ratio * m if s is None else 0, *v[1:]]) + if errormodule:errormodules.append(errormodule) + lora_patches[key] = n_vals - for key, vals in shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches.items(): - n_vals = [] - lvals = [val for val in vals if val[1][0] in LORAS] - for v, m, l, e in zip(lvals, mt, lwei, elemental): - ratio, errormodule = ratiodealer(key.replace(".","_"), l, e) - n_vals.append((ratio * m, *v[1:])) - if errormodule:errormodules.append(errormodule) - shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches[key] = n_vals + for lora_patches in shared.sd_model.forge_objects_after_applying_lora.clip.patcher.lora_patches.values(): + for key, vals in lora_patches.items(): + n_vals = [] + lvs = [v for v in vals if v[1][0] in LORAS] + for v, m, l, e in zip(lvs, mt, lwei, elemental): + ratio, errormodule = ratiodealer(key.replace(".","_"), l, e) + n_vals.append([ratio * m, *v[1:]]) + if errormodule:errormodules.append(errormodule) + lora_patches[key] = n_vals if len(errormodules) > 0: print("Unknown modules:",errormodules)