parent
138a25b69f
commit
4dd988c5b7
|
|
@ -222,12 +222,12 @@ class SingularForward(Forward):
|
|||
def __call__(self, context_k, context_v=None, layer=None):
|
||||
if self.processor in available_opts:
|
||||
context_layers = available_opts[self.processor].layers.get(context_k.shape[2], None)
|
||||
if context_layers is None:
|
||||
return context_k, context_k
|
||||
if context_v is None:
|
||||
context_v = context_k
|
||||
if context_layers is None:
|
||||
return context_k, context_v
|
||||
if layer is not None and hasattr(layer, 'hyper_k') and hasattr(layer, 'hyper_v'):
|
||||
layer.hyper_v = context_layers[0], layer.hyper_k = context_layers[1]
|
||||
layer.hyper_k = context_layers[0], layer.hyper_v = context_layers[1]
|
||||
return context_layers[0](context_k, multiplier=self.strength), context_layers[1](context_v,
|
||||
multiplier=self.strength) # define forward_strength, which invokes HNModule with specified strength.
|
||||
# Note : we share same HN if it is called multiple time, which means you might not be able to train it via this structure.
|
||||
|
|
|
|||
Loading…
Reference in New Issue