From c23e90289616c4797e7f329e271ca1ae4740c751 Mon Sep 17 00:00:00 2001 From: Rockey <41846794+RockeyCoss@users.noreply.github.com> Date: Thu, 9 Dec 2021 15:17:43 +0800 Subject: [PATCH] [Fix] Fix the bug that mit cannot process init_cfg (#1102) * [Fix] Fix the bug that mit cannot process init_cfg * fix error --- mmseg/models/backbones/mit.py | 30 ++++------- tests/test_models/test_backbones/test_mit.py | 56 ++++++++++++++++++++ 2 files changed, 67 insertions(+), 19 deletions(-) diff --git a/mmseg/models/backbones/mit.py b/mmseg/models/backbones/mit.py index 8eb1011..c97213a 100644 --- a/mmseg/models/backbones/mit.py +++ b/mmseg/models/backbones/mit.py @@ -9,9 +9,8 @@ from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.transformer import MultiheadAttention from mmcv.cnn.utils.weight_init import (constant_init, normal_init, trunc_normal_init) -from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint +from mmcv.runner import BaseModule, ModuleList, Sequential -from ...utils import get_root_logger from ..builder import BACKBONES from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw @@ -341,16 +340,18 @@ class MixVisionTransformer(BaseModule): norm_cfg=dict(type='LN', eps=1e-6), pretrained=None, init_cfg=None): - super().__init__(init_cfg=init_cfg) + super(MixVisionTransformer, self).__init__(init_cfg=init_cfg) - if isinstance(pretrained, str) or pretrained is None: - warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' 'please use "init_cfg" instead') - else: + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: raise TypeError('pretrained must be a str or None') self.embed_dims = embed_dims - self.num_stages = num_stages self.num_layers = num_layers self.num_heads = num_heads @@ -362,7 +363,6 @@ class MixVisionTransformer(BaseModule): self.out_indices = out_indices assert max(out_indices) < self.num_stages - self.pretrained = pretrained # transformer encoder dpr = [ @@ -401,7 +401,7 @@ class MixVisionTransformer(BaseModule): cur += num_layer def init_weights(self): - if self.pretrained is None: + if self.init_cfg is None: for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) @@ -413,16 +413,8 @@ class MixVisionTransformer(BaseModule): fan_out //= m.groups normal_init( m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) - elif isinstance(self.pretrained, str): - logger = get_root_logger() - checkpoint = _load_checkpoint( - self.pretrained, logger=logger, map_location='cpu') - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - else: - state_dict = checkpoint - - self.load_state_dict(state_dict, False) + else: + super(MixVisionTransformer, self).init_weights() def forward(self, x): outs = [] diff --git a/tests/test_models/test_backbones/test_mit.py b/tests/test_models/test_backbones/test_mit.py index 6159d65..9eec1fa 100644 --- a/tests/test_models/test_backbones/test_mit.py +++ b/tests/test_models/test_backbones/test_mit.py @@ -55,3 +55,59 @@ def test_mit(): # Out identity outs = MHA(temp, hw_shape, temp) assert out.shape == (1, token_len, 64) + + +def test_mit_init(): + path = 'PATH_THAT_DO_NOT_EXIST' + # Test all combinations of pretrained and init_cfg + # pretrained=None, init_cfg=None + model = MixVisionTransformer(pretrained=None, init_cfg=None) + assert model.init_cfg is None + model.init_weights() + + # pretrained=None + # init_cfg loads pretrain from an non-existent file + model = MixVisionTransformer( + pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path)) + assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + # Test loading a checkpoint from an non-existent file + with pytest.raises(OSError): + model.init_weights() + + # pretrained=None + # init_cfg=123, whose type is unsupported + model = MixVisionTransformer(pretrained=None, init_cfg=123) + with pytest.raises(TypeError): + model.init_weights() + + # pretrained loads pretrain from an non-existent file + # init_cfg=None + model = MixVisionTransformer(pretrained=path, init_cfg=None) + assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + # Test loading a checkpoint from an non-existent file + with pytest.raises(OSError): + model.init_weights() + + # pretrained loads pretrain from an non-existent file + # init_cfg loads pretrain from an non-existent file + with pytest.raises(AssertionError): + MixVisionTransformer( + pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path)) + with pytest.raises(AssertionError): + MixVisionTransformer(pretrained=path, init_cfg=123) + + # pretrain=123, whose type is unsupported + # init_cfg=None + with pytest.raises(TypeError): + MixVisionTransformer(pretrained=123, init_cfg=None) + + # pretrain=123, whose type is unsupported + # init_cfg loads pretrain from an non-existent file + with pytest.raises(AssertionError): + MixVisionTransformer( + pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path)) + + # pretrain=123, whose type is unsupported + # init_cfg=123, whose type is unsupported + with pytest.raises(AssertionError): + MixVisionTransformer(pretrained=123, init_cfg=123)