type check
parent
668c647f81
commit
ad9ae1b72d
|
|
@ -29,8 +29,8 @@ class NetworkModuleLora(network.NetworkModule):
|
|||
if weight is None and none_ok:
|
||||
return None
|
||||
|
||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
|
||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||
is_linear = isinstance(self.sd_module, (torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention))
|
||||
is_conv = isinstance(self.sd_module, (torch.nn.Conv2d))
|
||||
|
||||
if is_linear:
|
||||
weight = weight.reshape(weight.shape[0], -1)
|
||||
|
|
|
|||
|
|
@ -36,9 +36,9 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||
# self.alpha is unused
|
||||
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||
|
||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
||||
is_linear = isinstance(self.sd_module, (torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear))
|
||||
is_conv = isinstance(self.sd_module, (torch.nn.Conv2d))
|
||||
is_other_linear = isinstance(self.sd_module, [torch.nn.MultiheadAttention) # unsupported
|
||||
|
||||
if is_linear:
|
||||
self.out_dim = self.sd_module.out_features
|
||||
|
|
|
|||
Loading…
Reference in New Issue