NNCF fix SD 1.5 Lora

pull/3302/head
Disty0 2024-06-24 22:35:47 +03:00
parent 65823a4016
commit 94f6f0dbcb
1 changed files with 1 additions and 1 deletions

View File

@ -26,7 +26,7 @@ class NetworkModuleLora(network.NetworkModule):
return None
linear_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear]
is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ == "NNCFLinear"
is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv]
is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv] or self.sd_module.__class__.__name__ == "NNCFConv2d"
if is_linear:
weight = weight.reshape(weight.shape[0], -1)
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)