101 lines
3.4 KiB
Python
101 lines
3.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import ConvModule
|
|
from mmcv.runner import BaseModule
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
from ..builder import NECKS
|
|
|
|
|
|
@NECKS.register_module()
|
|
class HRFPN(BaseModule):
|
|
"""HRFPN (High Resolution Feature Pyramids)
|
|
|
|
paper: `High-Resolution Representations for Labeling Pixels and Regions
|
|
<https://arxiv.org/abs/1904.04514>`_.
|
|
|
|
Args:
|
|
in_channels (list): number of channels for each branch.
|
|
out_channels (int): output channels of feature pyramids.
|
|
num_outs (int): number of output stages.
|
|
pooling_type (str): pooling for generating feature pyramids
|
|
from {MAX, AVG}.
|
|
conv_cfg (dict): dictionary to construct and config conv layer.
|
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed.
|
|
stride (int): stride of 3x3 convolutional layers
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
num_outs=5,
|
|
pooling_type='AVG',
|
|
conv_cfg=None,
|
|
norm_cfg=None,
|
|
with_cp=False,
|
|
stride=1,
|
|
init_cfg=dict(type='Caffe2Xavier', layer='Conv2d')):
|
|
super(HRFPN, self).__init__(init_cfg)
|
|
assert isinstance(in_channels, list)
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.num_ins = len(in_channels)
|
|
self.num_outs = num_outs
|
|
self.with_cp = with_cp
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
|
|
self.reduction_conv = ConvModule(
|
|
sum(in_channels),
|
|
out_channels,
|
|
kernel_size=1,
|
|
conv_cfg=self.conv_cfg,
|
|
act_cfg=None)
|
|
|
|
self.fpn_convs = nn.ModuleList()
|
|
for i in range(self.num_outs):
|
|
self.fpn_convs.append(
|
|
ConvModule(
|
|
out_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
stride=stride,
|
|
conv_cfg=self.conv_cfg,
|
|
act_cfg=None))
|
|
|
|
if pooling_type == 'MAX':
|
|
self.pooling = F.max_pool2d
|
|
else:
|
|
self.pooling = F.avg_pool2d
|
|
|
|
def forward(self, inputs):
|
|
"""Forward function."""
|
|
assert len(inputs) == self.num_ins
|
|
outs = [inputs[0]]
|
|
for i in range(1, self.num_ins):
|
|
outs.append(
|
|
F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear'))
|
|
out = torch.cat(outs, dim=1)
|
|
if out.requires_grad and self.with_cp:
|
|
out = checkpoint(self.reduction_conv, out)
|
|
else:
|
|
out = self.reduction_conv(out)
|
|
outs = [out]
|
|
for i in range(1, self.num_outs):
|
|
outs.append(self.pooling(out, kernel_size=2**i, stride=2**i))
|
|
outputs = []
|
|
|
|
for i in range(self.num_outs):
|
|
if outs[i].requires_grad and self.with_cp:
|
|
tmp_out = checkpoint(self.fpn_convs[i], outs[i])
|
|
else:
|
|
tmp_out = self.fpn_convs[i](outs[i])
|
|
outputs.append(tmp_out)
|
|
return tuple(outputs)
|