151 lines
5.0 KiB
Python
151 lines
5.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
|
from mmcv.runner import BaseModule
|
|
|
|
|
|
class DarknetBottleneck(BaseModule):
|
|
"""The basic bottleneck block used in Darknet.
|
|
|
|
Each ResBlock consists of two ConvModules and the input is added to the
|
|
final output. Each ConvModule is composed of Conv, BN, and LeakyReLU.
|
|
The first convLayer has filter size of 1x1 and the second one has the
|
|
filter size of 3x3.
|
|
|
|
Args:
|
|
in_channels (int): The input channels of this Module.
|
|
out_channels (int): The output channels of this Module.
|
|
expansion (int): The kernel size of the convolution. Default: 0.5
|
|
add_identity (bool): Whether to add identity to the out.
|
|
Default: True
|
|
use_depthwise (bool): Whether to use depthwise separable convolution.
|
|
Default: False
|
|
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
|
which means using conv2d.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='Swish').
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
expansion=0.5,
|
|
add_identity=True,
|
|
use_depthwise=False,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
|
act_cfg=dict(type='Swish'),
|
|
init_cfg=None):
|
|
super().__init__(init_cfg)
|
|
hidden_channels = int(out_channels * expansion)
|
|
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
|
|
self.conv1 = ConvModule(
|
|
in_channels,
|
|
hidden_channels,
|
|
1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.conv2 = conv(
|
|
hidden_channels,
|
|
out_channels,
|
|
3,
|
|
stride=1,
|
|
padding=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.add_identity = \
|
|
add_identity and in_channels == out_channels
|
|
|
|
def forward(self, x):
|
|
identity = x
|
|
out = self.conv1(x)
|
|
out = self.conv2(out)
|
|
|
|
if self.add_identity:
|
|
return out + identity
|
|
else:
|
|
return out
|
|
|
|
|
|
class CSPLayer(BaseModule):
|
|
"""Cross Stage Partial Layer.
|
|
|
|
Args:
|
|
in_channels (int): The input channels of the CSP layer.
|
|
out_channels (int): The output channels of the CSP layer.
|
|
expand_ratio (float): Ratio to adjust the number of channels of the
|
|
hidden layer. Default: 0.5
|
|
num_blocks (int): Number of blocks. Default: 1
|
|
add_identity (bool): Whether to add identity in blocks.
|
|
Default: True
|
|
use_depthwise (bool): Whether to depthwise separable convolution in
|
|
blocks. Default: False
|
|
conv_cfg (dict, optional): Config dict for convolution layer.
|
|
Default: None, which means using conv2d.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN')
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='Swish')
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
expand_ratio=0.5,
|
|
num_blocks=1,
|
|
add_identity=True,
|
|
use_depthwise=False,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
|
act_cfg=dict(type='Swish'),
|
|
init_cfg=None):
|
|
super().__init__(init_cfg)
|
|
mid_channels = int(out_channels * expand_ratio)
|
|
self.main_conv = ConvModule(
|
|
in_channels,
|
|
mid_channels,
|
|
1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.short_conv = ConvModule(
|
|
in_channels,
|
|
mid_channels,
|
|
1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.final_conv = ConvModule(
|
|
2 * mid_channels,
|
|
out_channels,
|
|
1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
|
|
self.blocks = nn.Sequential(*[
|
|
DarknetBottleneck(
|
|
mid_channels,
|
|
mid_channels,
|
|
1.0,
|
|
add_identity,
|
|
use_depthwise,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg) for _ in range(num_blocks)
|
|
])
|
|
|
|
def forward(self, x):
|
|
x_short = self.short_conv(x)
|
|
|
|
x_main = self.main_conv(x)
|
|
x_main = self.blocks(x_main)
|
|
|
|
x_final = torch.cat((x_main, x_short), dim=1)
|
|
return self.final_conv(x_final)
|