17 lines
557 B
Python
17 lines
557 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class DepthWiseSeperableConv(nn.Module):
|
|
def __init__(self, in_dim, out_dim, *args, **kwargs):
|
|
super().__init__()
|
|
if 'groups' in kwargs:
|
|
# ignoring groups for Depthwise Sep Conv
|
|
del kwargs['groups']
|
|
|
|
self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs)
|
|
self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
|
|
|
def forward(self, x):
|
|
out = self.depthwise(x)
|
|
out = self.pointwise(out)
|
|
return out |