149 lines
5.3 KiB
Python
149 lines
5.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .registry import CONV_LAYERS
|
|
|
|
|
|
def conv_ws_2d(input,
|
|
weight,
|
|
bias=None,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups=1,
|
|
eps=1e-5):
|
|
c_in = weight.size(0)
|
|
weight_flat = weight.view(c_in, -1)
|
|
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
|
std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
|
weight = (weight - mean) / (std + eps)
|
|
return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
|
|
|
|
|
@CONV_LAYERS.register_module('ConvWS')
|
|
class ConvWS2d(nn.Conv2d):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=True,
|
|
eps=1e-5):
|
|
super(ConvWS2d, self).__init__(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias=bias)
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
|
|
self.dilation, self.groups, self.eps)
|
|
|
|
|
|
@CONV_LAYERS.register_module(name='ConvAWS')
|
|
class ConvAWS2d(nn.Conv2d):
|
|
"""AWS (Adaptive Weight Standardization)
|
|
|
|
This is a variant of Weight Standardization
|
|
(https://arxiv.org/pdf/1903.10520.pdf)
|
|
It is used in DetectoRS to avoid NaN
|
|
(https://arxiv.org/pdf/2006.02334.pdf)
|
|
|
|
Args:
|
|
in_channels (int): Number of channels in the input image
|
|
out_channels (int): Number of channels produced by the convolution
|
|
kernel_size (int or tuple): Size of the conv kernel
|
|
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
|
padding (int or tuple, optional): Zero-padding added to both sides of
|
|
the input. Default: 0
|
|
dilation (int or tuple, optional): Spacing between kernel elements.
|
|
Default: 1
|
|
groups (int, optional): Number of blocked connections from input
|
|
channels to output channels. Default: 1
|
|
bias (bool, optional): If set True, adds a learnable bias to the
|
|
output. Default: True
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=True):
|
|
super().__init__(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias=bias)
|
|
self.register_buffer('weight_gamma',
|
|
torch.ones(self.out_channels, 1, 1, 1))
|
|
self.register_buffer('weight_beta',
|
|
torch.zeros(self.out_channels, 1, 1, 1))
|
|
|
|
def _get_weight(self, weight):
|
|
weight_flat = weight.view(weight.size(0), -1)
|
|
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
|
|
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
|
|
weight = (weight - mean) / std
|
|
weight = self.weight_gamma * weight + self.weight_beta
|
|
return weight
|
|
|
|
def forward(self, x):
|
|
weight = self._get_weight(self.weight)
|
|
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
|
|
self.dilation, self.groups)
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
"""Override default load function.
|
|
|
|
AWS overrides the function _load_from_state_dict to recover
|
|
weight_gamma and weight_beta if they are missing. If weight_gamma and
|
|
weight_beta are found in the checkpoint, this function will return
|
|
after super()._load_from_state_dict. Otherwise, it will compute the
|
|
mean and std of the pretrained weights and store them in weight_beta
|
|
and weight_gamma.
|
|
"""
|
|
|
|
self.weight_gamma.data.fill_(-1)
|
|
local_missing_keys = []
|
|
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
|
strict, local_missing_keys,
|
|
unexpected_keys, error_msgs)
|
|
if self.weight_gamma.data.mean() > 0:
|
|
for k in local_missing_keys:
|
|
missing_keys.append(k)
|
|
return
|
|
weight = self.weight.data
|
|
weight_flat = weight.view(weight.size(0), -1)
|
|
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
|
|
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
|
|
self.weight_beta.data.copy_(mean)
|
|
self.weight_gamma.data.copy_(std)
|
|
missing_gamma_beta = [
|
|
k for k in local_missing_keys
|
|
if k.endswith('weight_gamma') or k.endswith('weight_beta')
|
|
]
|
|
for k in missing_gamma_beta:
|
|
local_missing_keys.remove(k)
|
|
for k in local_missing_keys:
|
|
missing_keys.append(k)
|