STDC/mmseg/models/decode_heads/setr_up_head.py
Sixiao Zheng ec91893931
[Feature] Official implementation of SETR (#531)
* Adjust vision transformer backbone architectures;

* Add DropPath, trunc_normal_ for VisionTransformer implementation;

* Add class token buring intermediate period and remove it during final period;

* Fix some parameters loss bug;

* * Store intermediate token features and impose no processes on them;

* Remove class token and reshape entire token feature from NLC to NCHW;

* Fix some doc error

* Add a arg for VisionTransformer backbone to control if input class token into transformer;

* Add stochastic depth decay rule for DropPath;

* * Fix output bug when input_cls_token=False;

* Add related unit test;

* Re-implement of SETR

* Add two head -- SETRUPHead (Naive, PUP) & SETRMLAHead (MLA);

* * Modify some docs of heads of SETR;

* Add MLA auxiliary head of SETR;

* * Modify some arg of setr heads;

* Add unit test for setr heads;

* * Add 768x768 cityscapes dataset config;

* Add Backbone: SETR -- Backbone: MLA, PUP, Naive;

* Add SETR cityscapes training & testing config;

* * Fix the low code coverage of unit test about heads of setr;

* Remove some rebundant error capture;

* * Add pascal context dataset & ade20k dataset config;

* Modify auxiliary head relative config;

* Modify folder structure.

* add setr

* modify vit

* Fix the test_cfg arg position;

* Fix some learning schedule bug;

* optimize setr code

* Add arg: final_reshape to control if converting output feature information from NLC to NCHW;

* Fix the default value of final_reshape;

* Modify arg: final_reshape to arg: out_shape;

* Fix some unit test bug;

* Add MLA neck;

* Modify setr configs to add MLA neck;

* Modify MLA decode head to remove rebundant structure;

* Remove some rebundant files.

* * Fix the code style bug;

* Remove some rebundant files;

* Modify some unit tests of SETR;

* Ignoring CityscapesCoarseDataset and MapillaryDataset.

* Fix the activation function loss bug;

* Fix the img_size bug of SETR_PUP_ADE20K

* * Fix the lint bug of transformers.py;

* Add mla neck unit test;

* Convert vit of setr out shape from NLC to NCHW.

* * Modify Resize action of data pipeline;

* Fix deit related bug;

* Set find_unused_parameters=False for pascal context dataset;

* Remove arg: find_unused_parameters which is False by default.

* Error auxiliary head of PUP deit

* Remove the minimal restrict of slide inference.

* Modify doc string of Resize

* Seperate this part of code to a new PR #544

* * Remove some rebundant codes;

* Modify unit tests of SETR heads;

* Fix the tuple in_channels of mla_deit.

* Modify code style

* Move detailed definition of auxiliary head into model config dict;

* Add some setr config for default cityscapes.py;

* Fix the doc string of SETR head;

* Modify implementation of SETR Heads

* Remove setr aux head and use fcn head to replace it;

* Remove arg: img_size and remove last interpolate op of heads;

* Rename arg: conv3x3_conv1x1 to kernel_size of SETRUPHead;

* non-square input support for setr heads

* Modify config argument for above commits

* Remove norm_layer argument of SETRMLAHead

* Add mla_align_corners for MLAModule interpolate

* [Refactor]Refactor of SETRMLAHead

* Modify Head implementation;

* Modify Head unit test;

* Modify related config file;

* [Refactor]MLA Neck

* Fix config bug

* [Refactor]SETR Naive Head and SETR PUP Head

* [Fix]Fix the lack of arg: act_cfg and arg: norm_cfg

* Fix config error

* Refactor of SETR MLA, Naive, PUP heads.

* Modify some attribute name of SETR Heads.

* Modify setr configs to adapt new vit code.

* Fix trunc_normal_ bug

* Parameters init adjustment.

* Remove redundant doc string of SETRUPHead

* Fix pretrained bug

* [Fix] Fix vit init bug

* Add some vit unit tests

* Modify module import

* Remove norm from PatchEmbed

* Fix pretrain weights bug

* Modify pretrained judge

* Fix some gradient backward bugs.

* Add some unit tests to improve code cov

* Fix init_weights of setr up head

* Add DropPath in FFN

* Finish benchmark of SETR

1. Add benchmark information into README.MD of SETR;

2. Fix some name bugs of vit;

* Remove DropPath implementation and use DropPath from mmcv.

* Modify out_indices arg

* Fix out_indices bug.

* Remove cityscapes base dataset config.

Co-authored-by: sennnnn <201730271412@mail.scut.edu.cn>
Co-authored-by: CuttlefishXuan <zhaoxinxuan1997@gmail.com>
2021-06-23 09:39:29 -07:00

76 lines
2.5 KiB
Python

import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer, constant_init
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class SETRUPHead(BaseDecodeHead):
"""Naive upsampling head and Progressive upsampling head of SETR.
Naive or PUP head of `SETR <https://arxiv.org/pdf/2012.15840.pdf>`.
Args:
norm_layer (dict): Config dict for input normalization.
Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True).
num_convs (int): Number of decoder convolutions. Default: 1.
up_scale (int): The scale factor of interpolate. Default:4.
kernel_size (int): The kernel size of convolution when decoding
feature information from backbone. Default: 3.
"""
def __init__(self,
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
num_convs=1,
up_scale=4,
kernel_size=3,
**kwargs):
assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.'
super(SETRUPHead, self).__init__(**kwargs)
assert isinstance(self.in_channels, int)
_, self.norm = build_norm_layer(norm_layer, self.in_channels)
self.up_convs = nn.ModuleList()
in_channels = self.in_channels
out_channels = self.channels
for i in range(num_convs):
self.up_convs.append(
nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=int(kernel_size - 1) // 2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Upsample(
scale_factor=up_scale,
mode='bilinear',
align_corners=self.align_corners)))
in_channels = out_channels
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
def forward(self, x):
x = self._transform_inputs(x)
n, c, h, w = x.shape
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
x = self.norm(x)
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
for up_conv in self.up_convs:
x = up_conv(x)
out = self.cls_seg(x)
return out