103 lines
3.7 KiB
Python
103 lines
3.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import ConvModule
|
|
from mmcv.cnn.bricks import NonLocal2d
|
|
from mmcv.runner import BaseModule
|
|
|
|
from ..builder import NECKS
|
|
|
|
|
|
@NECKS.register_module()
|
|
class BFP(BaseModule):
|
|
"""BFP (Balanced Feature Pyramids)
|
|
|
|
BFP takes multi-level features as inputs and gather them into a single one,
|
|
then refine the gathered feature and scatter the refined results to
|
|
multi-level features. This module is used in Libra R-CNN (CVPR 2019), see
|
|
the paper `Libra R-CNN: Towards Balanced Learning for Object Detection
|
|
<https://arxiv.org/abs/1904.02701>`_ for details.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels (feature maps of all levels
|
|
should have the same channels).
|
|
num_levels (int): Number of input feature levels.
|
|
conv_cfg (dict): The config dict for convolution layers.
|
|
norm_cfg (dict): The config dict for normalization layers.
|
|
refine_level (int): Index of integration and refine level of BSF in
|
|
multi-level features from bottom to top.
|
|
refine_type (str): Type of the refine op, currently support
|
|
[None, 'conv', 'non_local'].
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
num_levels,
|
|
refine_level=2,
|
|
refine_type=None,
|
|
conv_cfg=None,
|
|
norm_cfg=None,
|
|
init_cfg=dict(
|
|
type='Xavier', layer='Conv2d', distribution='uniform')):
|
|
super(BFP, self).__init__(init_cfg)
|
|
assert refine_type in [None, 'conv', 'non_local']
|
|
|
|
self.in_channels = in_channels
|
|
self.num_levels = num_levels
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
|
|
self.refine_level = refine_level
|
|
self.refine_type = refine_type
|
|
assert 0 <= self.refine_level < self.num_levels
|
|
|
|
if self.refine_type == 'conv':
|
|
self.refine = ConvModule(
|
|
self.in_channels,
|
|
self.in_channels,
|
|
3,
|
|
padding=1,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg)
|
|
elif self.refine_type == 'non_local':
|
|
self.refine = NonLocal2d(
|
|
self.in_channels,
|
|
reduction=1,
|
|
use_scale=False,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg)
|
|
|
|
def forward(self, inputs):
|
|
"""Forward function."""
|
|
assert len(inputs) == self.num_levels
|
|
|
|
# step 1: gather multi-level features by resize and average
|
|
feats = []
|
|
gather_size = inputs[self.refine_level].size()[2:]
|
|
for i in range(self.num_levels):
|
|
if i < self.refine_level:
|
|
gathered = F.adaptive_max_pool2d(
|
|
inputs[i], output_size=gather_size)
|
|
else:
|
|
gathered = F.interpolate(
|
|
inputs[i], size=gather_size, mode='nearest')
|
|
feats.append(gathered)
|
|
|
|
bsf = sum(feats) / len(feats)
|
|
|
|
# step 2: refine gathered features
|
|
if self.refine_type is not None:
|
|
bsf = self.refine(bsf)
|
|
|
|
# step 3: scatter refined features to multi-levels by a residual path
|
|
outs = []
|
|
for i in range(self.num_levels):
|
|
out_size = inputs[i].size()[2:]
|
|
if i < self.refine_level:
|
|
residual = F.interpolate(bsf, size=out_size, mode='nearest')
|
|
else:
|
|
residual = F.adaptive_max_pool2d(bsf, output_size=out_size)
|
|
outs.append(residual + inputs[i])
|
|
|
|
return tuple(outs)
|