Add Dora, fix NTC key names

pull/3308/head
AI-Casanova 2024-07-02 11:35:01 -05:00
parent 1f7c23ba0d
commit 34ec4e39cb
No known key found for this signature in database
GPG Key ID: 2A04488D60A5BF98
2 changed files with 27 additions and 0 deletions

View File

@ -164,6 +164,8 @@ class KeyConvert:
def diffusers(self, key): def diffusers(self, key):
if self.is_sdxl: if self.is_sdxl:
if "diffusion_model" in key: # Fix NTC Slider naming error
key = key.replace("diffusion_model", "lora_unet")
map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
map_keys.sort() map_keys.sort()
search_key = key.replace(self.LORA_PREFIX_UNET, "").replace(self.OFT_PREFIX_UNET, "").replace(self.LORA_PREFIX_TEXT_ENCODER1, "").replace(self.LORA_PREFIX_TEXT_ENCODER2, "") search_key = key.replace(self.LORA_PREFIX_UNET, "").replace(self.OFT_PREFIX_UNET, "").replace(self.LORA_PREFIX_TEXT_ENCODER1, "").replace(self.LORA_PREFIX_TEXT_ENCODER2, "")

View File

@ -88,6 +88,8 @@ class NetworkModule:
self.bias = weights.w.get("bias") self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
self.scale = weights.w["scale"].item() if "scale" in weights.w else None self.scale = weights.w["scale"].item() if "scale" in weights.w else None
self.dora_scale = weights.w.get("dora_scale", None)
self.dora_norm_dims = len(self.shape) - 1
def multiplier(self): def multiplier(self):
unet_multiplier = 3 * [self.network.unet_multiplier] if not isinstance(self.network.unet_multiplier, list) else self.network.unet_multiplier unet_multiplier = 3 * [self.network.unet_multiplier] if not isinstance(self.network.unet_multiplier, list) else self.network.unet_multiplier
@ -109,6 +111,27 @@ class NetworkModule:
return self.alpha / self.dim return self.alpha / self.dim
return 1.0 return 1.0
def apply_weight_decompose(self, updown, orig_weight):
# Match the device/dtype
orig_weight = orig_weight.to(updown.dtype)
dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
updown = updown.to(orig_weight.device)
merged_scale1 = updown + orig_weight
merged_scale1_norm = (
merged_scale1.transpose(0, 1)
.reshape(merged_scale1.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
.transpose(0, 1)
)
dora_merged = (
merged_scale1 * (dora_scale / merged_scale1_norm)
)
final_updown = dora_merged - orig_weight
return final_updown
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None: if self.bias is not None:
updown = updown.reshape(self.bias.shape) updown = updown.reshape(self.bias.shape)
@ -120,6 +143,8 @@ class NetworkModule:
updown = updown.reshape(orig_weight.shape) updown = updown.reshape(orig_weight.shape)
if ex_bias is not None: if ex_bias is not None:
ex_bias = ex_bias * self.multiplier() ex_bias = ex_bias * self.multiplier()
if self.dora_scale is not None:
updown = self.apply_weight_decompose(updown, orig_weight)
return updown * self.calc_scale() * self.multiplier(), ex_bias return updown * self.calc_scale() * self.multiplier(), ex_bias
def calc_updown(self, target): def calc_updown(self, target):