type check

pull/66/head
continue revolution 2024-02-11 15:40:43 -06:00
parent 668c647f81
commit ad9ae1b72d
2 changed files with 5 additions and 5 deletions

View File

@ -29,8 +29,8 @@ class NetworkModuleLora(network.NetworkModule):
if weight is None and none_ok:
return None
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
is_linear = isinstance(self.sd_module, (torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention))
is_conv = isinstance(self.sd_module, (torch.nn.Conv2d))
if is_linear:
weight = weight.reshape(weight.shape[0], -1)

View File

@ -36,9 +36,9 @@ class NetworkModuleOFT(network.NetworkModule):
# self.alpha is unused
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
is_linear = isinstance(self.sd_module, (torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear))
is_conv = isinstance(self.sd_module, (torch.nn.Conv2d))
is_other_linear = isinstance(self.sd_module, [torch.nn.MultiheadAttention) # unsupported
if is_linear:
self.out_dim = self.sd_module.out_features