diff --git a/patches/hypernetwork.py b/patches/hypernetwork.py index d6f1be4..02e670e 100644 --- a/patches/hypernetwork.py +++ b/patches/hypernetwork.py @@ -206,10 +206,10 @@ class HypernetworkModule(torch.nn.Module): resnet_result = self.linear(x) residual = resnet_result - x if multiplier is None or not isinstance(multiplier, (int, float)): - multiplier = self.multiplier - return x + multiplier * residual # interpolate + multiplier = self.multiplier if not version_flag else HypernetworkModule.multiplier + return x + multiplier * residual # interpolate if multiplier is None or not isinstance(multiplier, (int, float)): - return x + self.linear(x) * (self.multiplier if not self.training else 1) + return x + self.linear(x) * ((self.multiplier if not version_flag else HypernetworkModule.multiplier) if not self.training else 1) return x + self.linear(x) * multiplier def trainables(self, train=False):