From 34ec4e39cb066fb8680d3bb8752c8f2a3e51601c Mon Sep 17 00:00:00 2001 From: AI-Casanova <54461896+AI-Casanova@users.noreply.github.com> Date: Tue, 2 Jul 2024 11:35:01 -0500 Subject: [PATCH] Add Dora, fix NTC key names --- extensions-builtin/Lora/lora_convert.py | 2 ++ extensions-builtin/Lora/network.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/extensions-builtin/Lora/lora_convert.py b/extensions-builtin/Lora/lora_convert.py index 827f97e3d..8432d8208 100644 --- a/extensions-builtin/Lora/lora_convert.py +++ b/extensions-builtin/Lora/lora_convert.py @@ -164,6 +164,8 @@ class KeyConvert: def diffusers(self, key): if self.is_sdxl: + if "diffusion_model" in key: # Fix NTC Slider naming error + key = key.replace("diffusion_model", "lora_unet") map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules map_keys.sort() search_key = key.replace(self.LORA_PREFIX_UNET, "").replace(self.OFT_PREFIX_UNET, "").replace(self.LORA_PREFIX_TEXT_ENCODER1, "").replace(self.LORA_PREFIX_TEXT_ENCODER2, "") diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index bc60389fd..dc9ec4c8a 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -88,6 +88,8 @@ class NetworkModule: self.bias = weights.w.get("bias") self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None self.scale = weights.w["scale"].item() if "scale" in weights.w else None + self.dora_scale = weights.w.get("dora_scale", None) + self.dora_norm_dims = len(self.shape) - 1 def multiplier(self): unet_multiplier = 3 * [self.network.unet_multiplier] if not isinstance(self.network.unet_multiplier, list) else self.network.unet_multiplier @@ -109,6 +111,27 @@ class NetworkModule: return self.alpha / self.dim return 1.0 + def apply_weight_decompose(self, updown, orig_weight): + # Match the device/dtype + orig_weight = orig_weight.to(updown.dtype) + dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype) + updown = updown.to(orig_weight.device) + + merged_scale1 = updown + orig_weight + merged_scale1_norm = ( + merged_scale1.transpose(0, 1) + .reshape(merged_scale1.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims) + .transpose(0, 1) + ) + + dora_merged = ( + merged_scale1 * (dora_scale / merged_scale1_norm) + ) + final_updown = dora_merged - orig_weight + return final_updown + def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) @@ -120,6 +143,8 @@ class NetworkModule: updown = updown.reshape(orig_weight.shape) if ex_bias is not None: ex_bias = ex_bias * self.multiplier() + if self.dora_scale is not None: + updown = self.apply_weight_decompose(updown, orig_weight) return updown * self.calc_scale() * self.multiplier(), ex_bias def calc_updown(self, target):