mirror of https://github.com/vladmandic/automatic
NNCF fix SD 1.5 Lora
parent
65823a4016
commit
94f6f0dbcb
|
|
@ -26,7 +26,7 @@ class NetworkModuleLora(network.NetworkModule):
|
|||
return None
|
||||
linear_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear]
|
||||
is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ == "NNCFLinear"
|
||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv]
|
||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv] or self.sd_module.__class__.__name__ == "NNCFConv2d"
|
||||
if is_linear:
|
||||
weight = weight.reshape(weight.shape[0], -1)
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
|
|
|
|||
Loading…
Reference in New Issue