From 8eec8dfe0973fc0e633a7149b6ad1ab80aa14887 Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Sun, 27 Feb 2022 11:35:29 +0800 Subject: [PATCH] [Feature] Support kenerl updation for some decoder heads. (#1299) * [Feature] Add kenerl updation for some decoder heads. * [Feature] Add kenerl updation for some decoder heads. * directly use forward_feature && modify other 3 decoder heads * remover kernel_update attr * delete unnecessary variables in forward function * delete kernel update function * delete kernel update function * delete unnecessary docstrings * modify comments in self._forward_feature() * modify docstrings in self._forward_feature() * fix docstring * modify uperhead --- mmseg/models/decode_heads/aspp_head.py | 20 +++++++++++++++++--- mmseg/models/decode_heads/fcn_head.py | 22 ++++++++++++++++++---- mmseg/models/decode_heads/psp_head.py | 20 +++++++++++++++++--- mmseg/models/decode_heads/uper_head.py | 19 ++++++++++++++++--- 4 files changed, 68 insertions(+), 13 deletions(-) diff --git a/mmseg/models/decode_heads/aspp_head.py b/mmseg/models/decode_heads/aspp_head.py index 1fbd1bc..7059aee 100644 --- a/mmseg/models/decode_heads/aspp_head.py +++ b/mmseg/models/decode_heads/aspp_head.py @@ -91,8 +91,17 @@ class ASPPHead(BaseDecodeHead): norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) - def forward(self, inputs): - """Forward function.""" + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ x = self._transform_inputs(inputs) aspp_outs = [ resize( @@ -103,6 +112,11 @@ class ASPPHead(BaseDecodeHead): ] aspp_outs.extend(self.aspp_modules(x)) aspp_outs = torch.cat(aspp_outs, dim=1) - output = self.bottleneck(aspp_outs) + feats = self.bottleneck(aspp_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) output = self.cls_seg(output) return output diff --git a/mmseg/models/decode_heads/fcn_head.py b/mmseg/models/decode_heads/fcn_head.py index 3c8de51..fb79a0d 100644 --- a/mmseg/models/decode_heads/fcn_head.py +++ b/mmseg/models/decode_heads/fcn_head.py @@ -72,11 +72,25 @@ class FCNHead(BaseDecodeHead): norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + feats = self.convs(x) + if self.concat_input: + feats = self.conv_cat(torch.cat([x, feats], dim=1)) + return feats + def forward(self, inputs): """Forward function.""" - x = self._transform_inputs(inputs) - output = self.convs(x) - if self.concat_input: - output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self._forward_feature(inputs) output = self.cls_seg(output) return output diff --git a/mmseg/models/decode_heads/psp_head.py b/mmseg/models/decode_heads/psp_head.py index a27ae4b..6990676 100644 --- a/mmseg/models/decode_heads/psp_head.py +++ b/mmseg/models/decode_heads/psp_head.py @@ -92,12 +92,26 @@ class PSPHead(BaseDecodeHead): norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) - def forward(self, inputs): - """Forward function.""" + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ x = self._transform_inputs(inputs) psp_outs = [x] psp_outs.extend(self.psp_modules(x)) psp_outs = torch.cat(psp_outs, dim=1) - output = self.bottleneck(psp_outs) + feats = self.bottleneck(psp_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) output = self.cls_seg(output) return output diff --git a/mmseg/models/decode_heads/uper_head.py b/mmseg/models/decode_heads/uper_head.py index 57d80be..06b152a 100644 --- a/mmseg/models/decode_heads/uper_head.py +++ b/mmseg/models/decode_heads/uper_head.py @@ -84,9 +84,17 @@ class UPerHead(BaseDecodeHead): return output - def forward(self, inputs): - """Forward function.""" + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ inputs = self._transform_inputs(inputs) # build laterals @@ -122,6 +130,11 @@ class UPerHead(BaseDecodeHead): mode='bilinear', align_corners=self.align_corners) fpn_outs = torch.cat(fpn_outs, dim=1) - output = self.fpn_bottleneck(fpn_outs) + feats = self.fpn_bottleneck(fpn_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) output = self.cls_seg(output) return output