43 lines
1.0 KiB
Python
43 lines
1.0 KiB
Python
"""
|
|
Applies the mish function element-wise:
|
|
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
|
"""
|
|
|
|
# import pytorch
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
# import activation functions
|
|
from .Fmish import mish
|
|
|
|
|
|
class Mish(nn.Module):
|
|
"""
|
|
Applies the mish function element-wise:
|
|
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
|
Shape:
|
|
- Input: (N, *) where * means, any number of additional
|
|
dimensions
|
|
- Output: (N, *), same shape as the input
|
|
Examples:
|
|
>>> m = Mish()
|
|
>>> input = torch.randn(2)
|
|
>>> output = m(input)
|
|
Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""
|
|
Init method.
|
|
"""
|
|
super().__init__()
|
|
|
|
def forward(self, input):
|
|
"""
|
|
Forward pass of the function.
|
|
"""
|
|
if torch.__version__ >= "1.9":
|
|
return F.mish(input)
|
|
else:
|
|
return mish(input) |