From 86d71904d42a89f1dd81ea810bef10d8145dd8d1 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Sat, 21 Jan 2023 21:54:54 +0900 Subject: [PATCH] compatibility --- patches/hypernetwork.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/patches/hypernetwork.py b/patches/hypernetwork.py index d6f1be4..02e670e 100644 --- a/patches/hypernetwork.py +++ b/patches/hypernetwork.py @@ -206,10 +206,10 @@ class HypernetworkModule(torch.nn.Module): resnet_result = self.linear(x) residual = resnet_result - x if multiplier is None or not isinstance(multiplier, (int, float)): - multiplier = self.multiplier - return x + multiplier * residual # interpolate + multiplier = self.multiplier if not version_flag else HypernetworkModule.multiplier + return x + multiplier * residual # interpolate if multiplier is None or not isinstance(multiplier, (int, float)): - return x + self.linear(x) * (self.multiplier if not self.training else 1) + return x + self.linear(x) * ((self.multiplier if not version_flag else HypernetworkModule.multiplier) if not self.training else 1) return x + self.linear(x) * multiplier def trainables(self, train=False):