diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 5194222a0..b31536e7d 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -26,7 +26,7 @@ class NetworkModuleLora(network.NetworkModule): return None linear_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear] is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ == "NNCFLinear" - is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv] + is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv] or self.sd_module.__class__.__name__ == "NNCFConv2d" if is_linear: weight = weight.reshape(weight.shape[0], -1) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)