compatibility

latest-fix
aria1th 2023-01-21 21:54:54 +09:00
parent 82d66c5ced
commit 86d71904d4
1 changed files with 3 additions and 3 deletions

View File

@ -206,10 +206,10 @@ class HypernetworkModule(torch.nn.Module):
resnet_result = self.linear(x)
residual = resnet_result - x
if multiplier is None or not isinstance(multiplier, (int, float)):
multiplier = self.multiplier
return x + multiplier * residual # interpolate
multiplier = self.multiplier if not version_flag else HypernetworkModule.multiplier
return x + multiplier * residual # interpolate
if multiplier is None or not isinstance(multiplier, (int, float)):
return x + self.linear(x) * (self.multiplier if not self.training else 1)
return x + self.linear(x) * ((self.multiplier if not version_flag else HypernetworkModule.multiplier) if not self.training else 1)
return x + self.linear(x) * multiplier
def trainables(self, train=False):