diff --git a/patches/hypernetworks.py b/patches/hypernetworks.py index 928692f..2308d63 100644 --- a/patches/hypernetworks.py +++ b/patches/hypernetworks.py @@ -222,12 +222,12 @@ class SingularForward(Forward): def __call__(self, context_k, context_v=None, layer=None): if self.processor in available_opts: context_layers = available_opts[self.processor].layers.get(context_k.shape[2], None) - if context_layers is None: - return context_k, context_k if context_v is None: context_v = context_k + if context_layers is None: + return context_k, context_v if layer is not None and hasattr(layer, 'hyper_k') and hasattr(layer, 'hyper_v'): - layer.hyper_v = context_layers[0], layer.hyper_k = context_layers[1] + layer.hyper_k = context_layers[0], layer.hyper_v = context_layers[1] return context_layers[0](context_k, multiplier=self.strength), context_layers[1](context_v, multiplier=self.strength) # define forward_strength, which invokes HNModule with specified strength. # Note : we share same HN if it is called multiple time, which means you might not be able to train it via this structure.