407 lines
16 KiB
Python
407 lines
16 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import ConvModule
|
|
from mmcv.runner import BaseModule
|
|
|
|
from ..builder import NECKS
|
|
|
|
|
|
class Transition(BaseModule):
|
|
"""Base class for transition.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
out_channels (int): Number of output channels.
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, init_cfg=None):
|
|
super().__init__(init_cfg)
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
|
|
def forward(x):
|
|
pass
|
|
|
|
|
|
class UpInterpolationConv(Transition):
|
|
"""A transition used for up-sampling.
|
|
|
|
Up-sample the input by interpolation then refines the feature by
|
|
a convolution layer.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
out_channels (int): Number of output channels.
|
|
scale_factor (int): Up-sampling factor. Default: 2.
|
|
mode (int): Interpolation mode. Default: nearest.
|
|
align_corners (bool): Whether align corners when interpolation.
|
|
Default: None.
|
|
kernel_size (int): Kernel size for the conv. Default: 3.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
scale_factor=2,
|
|
mode='nearest',
|
|
align_corners=None,
|
|
kernel_size=3,
|
|
init_cfg=None,
|
|
**kwargs):
|
|
super().__init__(in_channels, out_channels, init_cfg)
|
|
self.mode = mode
|
|
self.scale_factor = scale_factor
|
|
self.align_corners = align_corners
|
|
self.conv = ConvModule(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
padding=(kernel_size - 1) // 2,
|
|
**kwargs)
|
|
|
|
def forward(self, x):
|
|
x = F.interpolate(
|
|
x,
|
|
scale_factor=self.scale_factor,
|
|
mode=self.mode,
|
|
align_corners=self.align_corners)
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class LastConv(Transition):
|
|
"""A transition used for refining the output of the last stage.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
out_channels (int): Number of output channels.
|
|
num_inputs (int): Number of inputs of the FPN features.
|
|
kernel_size (int): Kernel size for the conv. Default: 3.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
num_inputs,
|
|
kernel_size=3,
|
|
init_cfg=None,
|
|
**kwargs):
|
|
super().__init__(in_channels, out_channels, init_cfg)
|
|
self.num_inputs = num_inputs
|
|
self.conv_out = ConvModule(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
padding=(kernel_size - 1) // 2,
|
|
**kwargs)
|
|
|
|
def forward(self, inputs):
|
|
assert len(inputs) == self.num_inputs
|
|
return self.conv_out(inputs[-1])
|
|
|
|
|
|
@NECKS.register_module()
|
|
class FPG(BaseModule):
|
|
"""FPG.
|
|
|
|
Implementation of `Feature Pyramid Grids (FPG)
|
|
<https://arxiv.org/abs/2004.03580>`_.
|
|
This implementation only gives the basic structure stated in the paper.
|
|
But users can implement different type of transitions to fully explore the
|
|
the potential power of the structure of FPG.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels (feature maps of all levels
|
|
should have the same channels).
|
|
out_channels (int): Number of output channels (used at each scale)
|
|
num_outs (int): Number of output scales.
|
|
stack_times (int): The number of times the pyramid architecture will
|
|
be stacked.
|
|
paths (list[str]): Specify the path order of each stack level.
|
|
Each element in the list should be either 'bu' (bottom-up) or
|
|
'td' (top-down).
|
|
inter_channels (int): Number of inter channels.
|
|
same_up_trans (dict): Transition that goes down at the same stage.
|
|
same_down_trans (dict): Transition that goes up at the same stage.
|
|
across_lateral_trans (dict): Across-pathway same-stage
|
|
across_down_trans (dict): Across-pathway bottom-up connection.
|
|
across_up_trans (dict): Across-pathway top-down connection.
|
|
across_skip_trans (dict): Across-pathway skip connection.
|
|
output_trans (dict): Transition that trans the output of the
|
|
last stage.
|
|
start_level (int): Index of the start input backbone level used to
|
|
build the feature pyramid. Default: 0.
|
|
end_level (int): Index of the end input backbone level (exclusive) to
|
|
build the feature pyramid. Default: -1, which means the last level.
|
|
add_extra_convs (bool): It decides whether to add conv
|
|
layers on top of the original feature maps. Default to False.
|
|
If True, its actual mode is specified by `extra_convs_on_inputs`.
|
|
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
"""
|
|
|
|
transition_types = {
|
|
'conv': ConvModule,
|
|
'interpolation_conv': UpInterpolationConv,
|
|
'last_conv': LastConv,
|
|
}
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
num_outs,
|
|
stack_times,
|
|
paths,
|
|
inter_channels=None,
|
|
same_down_trans=None,
|
|
same_up_trans=dict(
|
|
type='conv', kernel_size=3, stride=2, padding=1),
|
|
across_lateral_trans=dict(type='conv', kernel_size=1),
|
|
across_down_trans=dict(type='conv', kernel_size=3),
|
|
across_up_trans=None,
|
|
across_skip_trans=dict(type='identity'),
|
|
output_trans=dict(type='last_conv', kernel_size=3),
|
|
start_level=0,
|
|
end_level=-1,
|
|
add_extra_convs=False,
|
|
norm_cfg=None,
|
|
skip_inds=None,
|
|
init_cfg=[
|
|
dict(type='Caffe2Xavier', layer='Conv2d'),
|
|
dict(
|
|
type='Constant',
|
|
layer=[
|
|
'_BatchNorm', '_InstanceNorm', 'GroupNorm',
|
|
'LayerNorm'
|
|
],
|
|
val=1.0)
|
|
]):
|
|
super(FPG, self).__init__(init_cfg)
|
|
assert isinstance(in_channels, list)
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.num_ins = len(in_channels)
|
|
self.num_outs = num_outs
|
|
if inter_channels is None:
|
|
self.inter_channels = [out_channels for _ in range(num_outs)]
|
|
elif isinstance(inter_channels, int):
|
|
self.inter_channels = [inter_channels for _ in range(num_outs)]
|
|
else:
|
|
assert isinstance(inter_channels, list)
|
|
assert len(inter_channels) == num_outs
|
|
self.inter_channels = inter_channels
|
|
self.stack_times = stack_times
|
|
self.paths = paths
|
|
assert isinstance(paths, list) and len(paths) == stack_times
|
|
for d in paths:
|
|
assert d in ('bu', 'td')
|
|
|
|
self.same_down_trans = same_down_trans
|
|
self.same_up_trans = same_up_trans
|
|
self.across_lateral_trans = across_lateral_trans
|
|
self.across_down_trans = across_down_trans
|
|
self.across_up_trans = across_up_trans
|
|
self.output_trans = output_trans
|
|
self.across_skip_trans = across_skip_trans
|
|
|
|
self.with_bias = norm_cfg is None
|
|
# skip inds must be specified if across skip trans is not None
|
|
if self.across_skip_trans is not None:
|
|
skip_inds is not None
|
|
self.skip_inds = skip_inds
|
|
assert len(self.skip_inds[0]) <= self.stack_times
|
|
|
|
if end_level == -1:
|
|
self.backbone_end_level = self.num_ins
|
|
assert num_outs >= self.num_ins - start_level
|
|
else:
|
|
# if end_level < inputs, no extra level is allowed
|
|
self.backbone_end_level = end_level
|
|
assert end_level <= len(in_channels)
|
|
assert num_outs == end_level - start_level
|
|
self.start_level = start_level
|
|
self.end_level = end_level
|
|
self.add_extra_convs = add_extra_convs
|
|
|
|
# build lateral 1x1 convs to reduce channels
|
|
self.lateral_convs = nn.ModuleList()
|
|
for i in range(self.start_level, self.backbone_end_level):
|
|
l_conv = nn.Conv2d(self.in_channels[i],
|
|
self.inter_channels[i - self.start_level], 1)
|
|
self.lateral_convs.append(l_conv)
|
|
|
|
extra_levels = num_outs - self.backbone_end_level + self.start_level
|
|
self.extra_downsamples = nn.ModuleList()
|
|
for i in range(extra_levels):
|
|
if self.add_extra_convs:
|
|
fpn_idx = self.backbone_end_level - self.start_level + i
|
|
extra_conv = nn.Conv2d(
|
|
self.inter_channels[fpn_idx - 1],
|
|
self.inter_channels[fpn_idx],
|
|
3,
|
|
stride=2,
|
|
padding=1)
|
|
self.extra_downsamples.append(extra_conv)
|
|
else:
|
|
self.extra_downsamples.append(nn.MaxPool2d(1, stride=2))
|
|
|
|
self.fpn_transitions = nn.ModuleList() # stack times
|
|
for s in range(self.stack_times):
|
|
stage_trans = nn.ModuleList() # num of feature levels
|
|
for i in range(self.num_outs):
|
|
# same, across_lateral, across_down, across_up
|
|
trans = nn.ModuleDict()
|
|
if s in self.skip_inds[i]:
|
|
stage_trans.append(trans)
|
|
continue
|
|
# build same-stage down trans (used in bottom-up paths)
|
|
if i == 0 or self.same_up_trans is None:
|
|
same_up_trans = None
|
|
else:
|
|
same_up_trans = self.build_trans(
|
|
self.same_up_trans, self.inter_channels[i - 1],
|
|
self.inter_channels[i])
|
|
trans['same_up'] = same_up_trans
|
|
# build same-stage up trans (used in top-down paths)
|
|
if i == self.num_outs - 1 or self.same_down_trans is None:
|
|
same_down_trans = None
|
|
else:
|
|
same_down_trans = self.build_trans(
|
|
self.same_down_trans, self.inter_channels[i + 1],
|
|
self.inter_channels[i])
|
|
trans['same_down'] = same_down_trans
|
|
# build across lateral trans
|
|
across_lateral_trans = self.build_trans(
|
|
self.across_lateral_trans, self.inter_channels[i],
|
|
self.inter_channels[i])
|
|
trans['across_lateral'] = across_lateral_trans
|
|
# build across down trans
|
|
if i == self.num_outs - 1 or self.across_down_trans is None:
|
|
across_down_trans = None
|
|
else:
|
|
across_down_trans = self.build_trans(
|
|
self.across_down_trans, self.inter_channels[i + 1],
|
|
self.inter_channels[i])
|
|
trans['across_down'] = across_down_trans
|
|
# build across up trans
|
|
if i == 0 or self.across_up_trans is None:
|
|
across_up_trans = None
|
|
else:
|
|
across_up_trans = self.build_trans(
|
|
self.across_up_trans, self.inter_channels[i - 1],
|
|
self.inter_channels[i])
|
|
trans['across_up'] = across_up_trans
|
|
if self.across_skip_trans is None:
|
|
across_skip_trans = None
|
|
else:
|
|
across_skip_trans = self.build_trans(
|
|
self.across_skip_trans, self.inter_channels[i - 1],
|
|
self.inter_channels[i])
|
|
trans['across_skip'] = across_skip_trans
|
|
# build across_skip trans
|
|
stage_trans.append(trans)
|
|
self.fpn_transitions.append(stage_trans)
|
|
|
|
self.output_transition = nn.ModuleList() # output levels
|
|
for i in range(self.num_outs):
|
|
trans = self.build_trans(
|
|
self.output_trans,
|
|
self.inter_channels[i],
|
|
self.out_channels,
|
|
num_inputs=self.stack_times + 1)
|
|
self.output_transition.append(trans)
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
def build_trans(self, cfg, in_channels, out_channels, **extra_args):
|
|
cfg_ = cfg.copy()
|
|
trans_type = cfg_.pop('type')
|
|
trans_cls = self.transition_types[trans_type]
|
|
return trans_cls(in_channels, out_channels, **cfg_, **extra_args)
|
|
|
|
def fuse(self, fuse_dict):
|
|
out = None
|
|
for item in fuse_dict.values():
|
|
if item is not None:
|
|
if out is None:
|
|
out = item
|
|
else:
|
|
out = out + item
|
|
return out
|
|
|
|
def forward(self, inputs):
|
|
assert len(inputs) == len(self.in_channels)
|
|
|
|
# build all levels from original feature maps
|
|
feats = [
|
|
lateral_conv(inputs[i + self.start_level])
|
|
for i, lateral_conv in enumerate(self.lateral_convs)
|
|
]
|
|
for downsample in self.extra_downsamples:
|
|
feats.append(downsample(feats[-1]))
|
|
|
|
outs = [feats]
|
|
|
|
for i in range(self.stack_times):
|
|
current_outs = outs[-1]
|
|
next_outs = []
|
|
direction = self.paths[i]
|
|
for j in range(self.num_outs):
|
|
if i in self.skip_inds[j]:
|
|
next_outs.append(outs[-1][j])
|
|
continue
|
|
# feature level
|
|
if direction == 'td':
|
|
lvl = self.num_outs - j - 1
|
|
else:
|
|
lvl = j
|
|
# get transitions
|
|
if direction == 'td':
|
|
same_trans = self.fpn_transitions[i][lvl]['same_down']
|
|
else:
|
|
same_trans = self.fpn_transitions[i][lvl]['same_up']
|
|
across_lateral_trans = self.fpn_transitions[i][lvl][
|
|
'across_lateral']
|
|
across_down_trans = self.fpn_transitions[i][lvl]['across_down']
|
|
across_up_trans = self.fpn_transitions[i][lvl]['across_up']
|
|
across_skip_trans = self.fpn_transitions[i][lvl]['across_skip']
|
|
# init output
|
|
to_fuse = dict(
|
|
same=None, lateral=None, across_up=None, across_down=None)
|
|
# same downsample/upsample
|
|
if same_trans is not None:
|
|
to_fuse['same'] = same_trans(next_outs[-1])
|
|
# across lateral
|
|
if across_lateral_trans is not None:
|
|
to_fuse['lateral'] = across_lateral_trans(
|
|
current_outs[lvl])
|
|
# across downsample
|
|
if lvl > 0 and across_up_trans is not None:
|
|
to_fuse['across_up'] = across_up_trans(current_outs[lvl -
|
|
1])
|
|
# across upsample
|
|
if (lvl < self.num_outs - 1 and across_down_trans is not None):
|
|
to_fuse['across_down'] = across_down_trans(
|
|
current_outs[lvl + 1])
|
|
if across_skip_trans is not None:
|
|
to_fuse['across_skip'] = across_skip_trans(outs[0][lvl])
|
|
x = self.fuse(to_fuse)
|
|
next_outs.append(x)
|
|
|
|
if direction == 'td':
|
|
outs.append(next_outs[::-1])
|
|
else:
|
|
outs.append(next_outs)
|
|
|
|
# output trans
|
|
final_outs = []
|
|
for i in range(self.num_outs):
|
|
lvl_out_list = []
|
|
for s in range(len(outs)):
|
|
lvl_out_list.append(outs[s][i])
|
|
lvl_out = self.output_transition[i](lvl_out_list)
|
|
final_outs.append(lvl_out)
|
|
|
|
return final_outs
|