# All modification made by Kneron Corp.: Copyright (c) 2022 Kneron Corp. # Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from mmcv.cnn import ConvModule from mmcv.runner import BaseModule, auto_fp16 from ..builder import NECKS @NECKS.register_module() class FCOSNeckKneron(BaseModule): r"""Feature Pyramid Network. This is an implementation of paper `Feature Pyramid Networks for Object Detection `_. Args: in_channels (List[int]): Number of input channels per scale. out_channels (int): Number of output channels (used at each scale) num_outs (int): Number of output scales. start_level (int): Index of the start input backbone level used to build the feature pyramid. Default: 0. end_level (int): Index of the end input backbone level (exclusive) to build the feature pyramid. Default: -1, which means the last level. add_extra_convs (bool | str): If bool, it decides whether to add conv layers on top of the original feature maps. Default to False. If True, it is equivalent to `add_extra_convs='on_input'`. If str, it specifies the source feature map of the extra convs. Only the following options are allowed - 'on_input': Last feat map of neck inputs (i.e. backbone feature). - 'on_lateral': Last feature map after lateral convs. - 'on_output': The last output feature map after fpn convs. relu_before_extra_convs (bool): Whether to apply relu before the extra conv. Default: False. no_norm_on_lateral (bool): Whether to apply norm on lateral. Default: False. conv_cfg (dict): Config dict for convolution layer. Default: None. norm_cfg (dict): Config dict for normalization layer. Default: None. act_cfg (str): Config dict for activation layer in ConvModule. Default: None. upsample_cfg (dict): Config dict for interpolate layer. Default: `dict(mode='nearest')` init_cfg (dict or list[dict], optional): Initialization config dict. Example: >>> import torch >>> in_channels = [2, 3, 5, 7] >>> scales = [340, 170, 84, 43] >>> inputs = [torch.rand(1, c, s, s) ... for c, s in zip(in_channels, scales)] >>> self = FPN(in_channels, 11, len(in_channels)).eval() >>> outputs = self.forward(inputs) >>> for i in range(len(outputs)): ... print(f'outputs[{i}].shape = {outputs[i].shape}') outputs[0].shape = torch.Size([1, 11, 340, 340]) outputs[1].shape = torch.Size([1, 11, 170, 170]) outputs[2].shape = torch.Size([1, 11, 84, 84]) outputs[3].shape = torch.Size([1, 11, 43, 43]) """ def __init__(self, in_channels, mid_channels, out_channels, num_outs, out_kernel=3, out_padding=1, start_level=0, end_level=-1, add_extra_convs=False, relu_before_extra_convs=False, no_norm_on_lateral=False, conv_cfg=None, norm_cfg=None, act_cfg=None, upsample_cfg=dict(mode='nearest'), init_cfg=dict( type='Xavier', layer='Conv2d', distribution='uniform')): super(FCOSNeckKneron, self).__init__(init_cfg) assert isinstance(in_channels, list) self.in_channels = in_channels self.mid_channels = mid_channels self.out_channels = out_channels self.num_ins = len(in_channels) self.num_outs = num_outs self.relu_before_extra_convs = relu_before_extra_convs self.no_norm_on_lateral = no_norm_on_lateral self.fp16_enabled = False self.upsample_cfg = upsample_cfg.copy() if end_level == -1: self.backbone_end_level = self.num_ins assert num_outs >= self.num_ins - start_level else: # if end_level < inputs, no extra level is allowed self.backbone_end_level = end_level assert end_level <= len(in_channels) assert num_outs == end_level - start_level self.start_level = start_level self.end_level = end_level self.add_extra_convs = add_extra_convs assert isinstance(add_extra_convs, (str, bool)) if isinstance(add_extra_convs, str): # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') elif add_extra_convs: # True self.add_extra_convs = 'on_input' self.lateral_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList() kernel_sizes = (5, 9) self.poolings = nn.ModuleList([ nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes ]) for i in range(self.start_level, self.backbone_end_level): l_conv = ConvModule( in_channels[i], mid_channels[i], 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, act_cfg=act_cfg, inplace=False) fpn_conv = ConvModule( mid_channels[i]*3, out_channels[i], out_kernel, padding=out_padding, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) @auto_fp16() def forward(self, inputs): """Forward function.""" assert len(inputs) == len(self.in_channels) # build laterals laterals = [ lateral_conv(inputs[i + self.start_level]) for i, lateral_conv in enumerate(self.lateral_convs) ] # build top-down path used_backbone_levels = len(laterals) # build outputs # part 1: from original levels outs = [] for i in range(used_backbone_levels): x = laterals[i] x = torch.cat( [x] + [pooling(x) for pooling in self.poolings], dim=1 ) x = self.fpn_convs[i](x) outs.append(x) return tuple(outs)