From 7e1d853f2b7195b39efb3eb7c4d8e912e3ca57ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Fri, 2 Jul 2021 17:58:35 +0800 Subject: [PATCH] [Fix] fix fast scnn (#606) * [Refactor] Match paddle seg weight * Match inference * fix exp setting * delete comment and rename config files * replace hard code with config parameters * fix ppm concat order * remove hardcode * update result * fix typo * complement docstring * complement FutureFusionModule docstring * modify log link --- configs/_base_/models/fast_scnn.py | 2 +- configs/fastscnn/README.md | 2 +- ...> fast_scnn_lr0.12_8x4_160k_cityscapes.py} | 4 +- mmseg/models/backbones/fast_scnn.py | 63 +++++++++++++------ mmseg/models/decode_heads/psp_head.py | 5 +- mmseg/models/decode_heads/sep_fcn_head.py | 14 +++-- mmseg/models/utils/inverted_residual.py | 12 ++-- 7 files changed, 70 insertions(+), 32 deletions(-) rename configs/fastscnn/{fast_scnn_4x8_80k_lr0.12_cityscapes.py => fast_scnn_lr0.12_8x4_160k_cityscapes.py} (82%) diff --git a/configs/_base_/models/fast_scnn.py b/configs/_base_/models/fast_scnn.py index 32fdeb6..8e89d91 100644 --- a/configs/_base_/models/fast_scnn.py +++ b/configs/_base_/models/fast_scnn.py @@ -25,7 +25,7 @@ model = dict( norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1)), auxiliary_head=[ dict( type='FCNHead', diff --git a/configs/fastscnn/README.md b/configs/fastscnn/README.md index 9cea8d0..f81b4b8 100644 --- a/configs/fastscnn/README.md +++ b/configs/fastscnn/README.md @@ -19,4 +19,4 @@ | Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | | --------- | --------- | --------- | ------: | -------- | -------------- | ----: | ------------- | --------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Fast-SCNN | Fast-SCNN | 512x1024 | 80000 | 8.4 | 63.61 | 69.06 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fast_scnn.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_4x8_80k_lr0.12_cityscapes-f5096c79.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_4x8_80k_lr0.12_cityscapes-20200807_165744.log.json) | +| Fast-SCNN | Fast-SCNN | 512x1024 | 160000 | 3.3 | 56.45 | 70.96 | 72.65 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fast_scnn.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_8x4_160k_lr0.12_cityscapes-0cec9937.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_8x4_160k_lr0.12_cityscapes-20210630_164853.log.json) | diff --git a/configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py b/configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py similarity index 82% rename from configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py rename to configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py index 3d9c999..4698125 100644 --- a/configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py +++ b/configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py @@ -1,10 +1,10 @@ _base_ = [ '../_base_/models/fast_scnn.py', '../_base_/datasets/cityscapes.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' ] # Re-config the data sampler. -data = dict(samples_per_gpu=2, workers_per_gpu=4) +data = dict(samples_per_gpu=4, workers_per_gpu=4) # Re-config the optimizer. optimizer = dict(type='SGD', lr=0.12, momentum=0.9, weight_decay=4e-5) diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py index e8a8703..84289da 100644 --- a/mmseg/models/backbones/fast_scnn.py +++ b/mmseg/models/backbones/fast_scnn.py @@ -6,7 +6,7 @@ from mmcv.runner import BaseModule from mmseg.models.decode_heads.psp_head import PPM from mmseg.ops import resize from ..builder import BACKBONES -from ..utils.inverted_residual import InvertedResidual +from ..utils import InvertedResidual class LearningToDownsample(nn.Module): @@ -23,6 +23,9 @@ class LearningToDownsample(nn.Module): dict(type='BN') act_cfg (dict): Config of activation layers. Default: dict(type='ReLU') + dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config + of depthwise ConvModule. If it is 'default', it will be the same + as `act_cfg`. Default: None. """ def __init__(self, @@ -31,11 +34,13 @@ class LearningToDownsample(nn.Module): out_channels, conv_cfg=None, norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU')): + act_cfg=dict(type='ReLU'), + dw_act_cfg=None): super(LearningToDownsample, self).__init__() self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg + self.dw_act_cfg = dw_act_cfg dw_channels1 = dw_channels[0] dw_channels2 = dw_channels[1] @@ -44,23 +49,28 @@ class LearningToDownsample(nn.Module): dw_channels1, 3, stride=2, + padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) + self.dsconv1 = DepthwiseSeparableConvModule( dw_channels1, dw_channels2, kernel_size=3, stride=2, padding=1, - norm_cfg=self.norm_cfg) + norm_cfg=self.norm_cfg, + dw_act_cfg=self.dw_act_cfg) + self.dsconv2 = DepthwiseSeparableConvModule( dw_channels2, out_channels, kernel_size=3, stride=2, padding=1, - norm_cfg=self.norm_cfg) + norm_cfg=self.norm_cfg, + dw_act_cfg=self.dw_act_cfg) def forward(self, x): x = self.conv(x) @@ -136,10 +146,12 @@ class GlobalFeatureExtractor(nn.Module): norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, align_corners=align_corners) + self.out = ConvModule( block_channels[2] * 2, out_channels, - 1, + 3, + padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) @@ -156,7 +168,8 @@ class GlobalFeatureExtractor(nn.Module): out_channels, stride, expand_ratio, - norm_cfg=self.norm_cfg) + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) ] for i in range(1, blocks): layers.append( @@ -165,7 +178,8 @@ class GlobalFeatureExtractor(nn.Module): out_channels, 1, expand_ratio, - norm_cfg=self.norm_cfg)) + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) return nn.Sequential(*layers) def forward(self, x): @@ -189,10 +203,12 @@ class FeatureFusionModule(nn.Module): conv_cfg (dict | None): Config of conv layers. Default: None norm_cfg (dict | None): Config of norm layers. Default: dict(type='BN') - act_cfg (dict): Config of activation layers. Default: - dict(type='ReLU') + dwconv_act_cfg (dict): Config of activation layers in 3x3 conv. + Default: dict(type='ReLU'). + conv_act_cfg (dict): Config of activation layers in the two 1x1 conv. + Default: None. align_corners (bool): align_corners argument of F.interpolate. - Default: False + Default: False. """ def __init__(self, @@ -201,34 +217,40 @@ class FeatureFusionModule(nn.Module): out_channels, conv_cfg=None, norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU'), + dwconv_act_cfg=dict(type='ReLU'), + conv_act_cfg=None, align_corners=False): super(FeatureFusionModule, self).__init__() self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg - self.act_cfg = act_cfg + self.dwconv_act_cfg = dwconv_act_cfg + self.conv_act_cfg = conv_act_cfg self.align_corners = align_corners self.dwconv = ConvModule( lower_in_channels, out_channels, - 1, + 3, + padding=1, + groups=out_channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.dwconv_act_cfg) self.conv_lower_res = ConvModule( out_channels, out_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=None) + act_cfg=self.conv_act_cfg) + self.conv_higher_res = ConvModule( higher_in_channels, out_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=None) + act_cfg=self.conv_act_cfg) + self.relu = nn.ReLU(True) def forward(self, higher_res_feature, lower_res_feature): @@ -290,6 +312,9 @@ class FastSCNN(BaseModule): dict(type='ReLU') align_corners (bool): align_corners argument of F.interpolate. Default: False + dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config + of depthwise ConvModule. If it is 'default', it will be the same + as `act_cfg`. Default: None. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None """ @@ -309,6 +334,7 @@ class FastSCNN(BaseModule): norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), align_corners=False, + dw_act_cfg=None, init_cfg=None): super(FastSCNN, self).__init__(init_cfg) @@ -348,7 +374,8 @@ class FastSCNN(BaseModule): global_in_channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg) + act_cfg=self.act_cfg, + dw_act_cfg=dw_act_cfg) self.global_feature_extractor = GlobalFeatureExtractor( global_in_channels, global_block_channels, @@ -364,7 +391,7 @@ class FastSCNN(BaseModule): fusion_out_channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg, + dwconv_act_cfg=self.act_cfg, align_corners=self.align_corners) def forward(self, x): diff --git a/mmseg/models/decode_heads/psp_head.py b/mmseg/models/decode_heads/psp_head.py index bdbe2c8..4416199 100644 --- a/mmseg/models/decode_heads/psp_head.py +++ b/mmseg/models/decode_heads/psp_head.py @@ -22,7 +22,7 @@ class PPM(nn.ModuleList): """ def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, - act_cfg, align_corners): + act_cfg, align_corners, **kwargs): super(PPM, self).__init__() self.pool_scales = pool_scales self.align_corners = align_corners @@ -41,7 +41,8 @@ class PPM(nn.ModuleList): 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg))) + act_cfg=self.act_cfg, + **kwargs))) def forward(self, x): """Forward function.""" diff --git a/mmseg/models/decode_heads/sep_fcn_head.py b/mmseg/models/decode_heads/sep_fcn_head.py index a636f70..39844c9 100644 --- a/mmseg/models/decode_heads/sep_fcn_head.py +++ b/mmseg/models/decode_heads/sep_fcn_head.py @@ -24,23 +24,28 @@ class DepthwiseSeparableFCNHead(FCNHead): Default: False. loss_decode(dict): Config of loss type and some relevant additional options. + dw_act_cfg (dict):Activation config of depthwise ConvModule. If it is + 'default', it will be the same as `act_cfg`. Default: None. """ - def __init__(self, **kwargs): + def __init__(self, dw_act_cfg=None, **kwargs): super(DepthwiseSeparableFCNHead, self).__init__(**kwargs) self.convs[0] = DepthwiseSeparableConvModule( self.in_channels, self.channels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, - norm_cfg=self.norm_cfg) + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) + for i in range(1, self.num_convs): self.convs[i] = DepthwiseSeparableConvModule( self.channels, self.channels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, - norm_cfg=self.norm_cfg) + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) if self.concat_input: self.conv_cat = DepthwiseSeparableConvModule( @@ -48,4 +53,5 @@ class DepthwiseSeparableFCNHead(FCNHead): self.channels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, - norm_cfg=self.norm_cfg) + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) diff --git a/mmseg/models/utils/inverted_residual.py b/mmseg/models/utils/inverted_residual.py index ede71a2..5a209a5 100644 --- a/mmseg/models/utils/inverted_residual.py +++ b/mmseg/models/utils/inverted_residual.py @@ -37,7 +37,8 @@ class InvertedResidual(nn.Module): conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU6'), - with_cp=False): + with_cp=False, + **kwargs): super(InvertedResidual, self).__init__() self.stride = stride assert stride in [1, 2], f'stride must in [1, 2]. ' \ @@ -55,7 +56,8 @@ class InvertedResidual(nn.Module): kernel_size=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + **kwargs)) layers.extend([ ConvModule( in_channels=hidden_dim, @@ -67,14 +69,16 @@ class InvertedResidual(nn.Module): groups=hidden_dim, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg), + act_cfg=act_cfg, + **kwargs), ConvModule( in_channels=hidden_dim, out_channels=out_channels, kernel_size=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=None) + act_cfg=None, + **kwargs) ]) self.conv = nn.Sequential(*layers)