63 lines
2.5 KiB
Python
63 lines
2.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import math
|
|
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from .registry import CONV_LAYERS
|
|
|
|
|
|
@CONV_LAYERS.register_module()
|
|
class Conv2dAdaptivePadding(nn.Conv2d):
|
|
"""Implementation of 2D convolution in tensorflow with `padding` as "same",
|
|
which applies padding to input (if needed) so that input image gets fully
|
|
covered by filter and stride you specified. For stride 1, this will ensure
|
|
that output image size is same as input. For stride of 2, output dimensions
|
|
will be half, for example.
|
|
|
|
Args:
|
|
in_channels (int): Number of channels in the input image
|
|
out_channels (int): Number of channels produced by the convolution
|
|
kernel_size (int or tuple): Size of the convolving kernel
|
|
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
|
padding (int or tuple, optional): Zero-padding added to both sides of
|
|
the input. Default: 0
|
|
dilation (int or tuple, optional): Spacing between kernel elements.
|
|
Default: 1
|
|
groups (int, optional): Number of blocked connections from input
|
|
channels to output channels. Default: 1
|
|
bias (bool, optional): If ``True``, adds a learnable bias to the
|
|
output. Default: ``True``
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=True):
|
|
super().__init__(in_channels, out_channels, kernel_size, stride, 0,
|
|
dilation, groups, bias)
|
|
|
|
def forward(self, x):
|
|
img_h, img_w = x.size()[-2:]
|
|
kernel_h, kernel_w = self.weight.size()[-2:]
|
|
stride_h, stride_w = self.stride
|
|
output_h = math.ceil(img_h / stride_h)
|
|
output_w = math.ceil(img_w / stride_w)
|
|
pad_h = (
|
|
max((output_h - 1) * self.stride[0] +
|
|
(kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
|
|
pad_w = (
|
|
max((output_w - 1) * self.stride[1] +
|
|
(kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
|
|
if pad_h > 0 or pad_w > 0:
|
|
x = F.pad(x, [
|
|
pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
|
|
])
|
|
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
|
|
self.dilation, self.groups)
|