[Refactor]: Unified parameter initialization (#567)

* [Refactor]: Unified parameter initialization

* fixed pretrained
This commit is contained in:
Jerry Jiarui XU 2021-06-16 21:41:29 -07:00 committed by GitHub
parent 5d46314844
commit 0c5b026db1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 329 additions and 298 deletions

View File

@ -1,12 +1,12 @@
import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer, from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
constant_init, kaiming_init) from mmcv.runner import BaseModule
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.utils import get_root_logger
from ..builder import BACKBONES from ..builder import BACKBONES
@ -183,7 +183,7 @@ class InputInjection(nn.Module):
@BACKBONES.register_module() @BACKBONES.register_module()
class CGNet(nn.Module): class CGNet(BaseModule):
"""CGNet backbone. """CGNet backbone.
A Light-weight Context Guided Network for Semantic Segmentation A Light-weight Context Guided Network for Semantic Segmentation
@ -210,6 +210,9 @@ class CGNet(nn.Module):
and its variants only. Default: False. and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False. memory while slowing down the training speed. Default: False.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
""" """
def __init__(self, def __init__(self,
@ -222,9 +225,31 @@ class CGNet(nn.Module):
norm_cfg=dict(type='BN', requires_grad=True), norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='PReLU'), act_cfg=dict(type='PReLU'),
norm_eval=False, norm_eval=False,
with_cp=False): with_cp=False,
pretrained=None,
init_cfg=None):
super(CGNet, self).__init__(init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer=['Conv2d', 'Linear']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm']),
dict(type='Constant', val=0, layer='PReLU')
]
else:
raise TypeError('pretrained must be a str or None')
super(CGNet, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.num_channels = num_channels self.num_channels = num_channels
assert isinstance(self.num_channels, tuple) and len( assert isinstance(self.num_channels, tuple) and len(
@ -335,27 +360,6 @@ class CGNet(nn.Module):
return output return output
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
elif isinstance(m, nn.PReLU):
constant_init(m, 0)
else:
raise TypeError('pretrained must be a str or None')
def train(self, mode=True): def train(self, mode=True):
"""Convert the model into training mode will keeping the normalization """Convert the model into training mode will keeping the normalization
layer freezed.""" layer freezed."""

View File

@ -1,8 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init, from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
kaiming_init) from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.models.decode_heads.psp_head import PPM from mmseg.models.decode_heads.psp_head import PPM
from mmseg.ops import resize from mmseg.ops import resize
@ -247,7 +246,7 @@ class FeatureFusionModule(nn.Module):
@BACKBONES.register_module() @BACKBONES.register_module()
class FastSCNN(nn.Module): class FastSCNN(BaseModule):
"""Fast-SCNN Backbone. """Fast-SCNN Backbone.
Args: Args:
@ -291,6 +290,8 @@ class FastSCNN(nn.Module):
dict(type='ReLU') dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate. align_corners (bool): align_corners argument of F.interpolate.
Default: False Default: False
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
""" """
def __init__(self, def __init__(self,
@ -307,9 +308,18 @@ class FastSCNN(nn.Module):
conv_cfg=None, conv_cfg=None,
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
align_corners=False): align_corners=False,
init_cfg=None):
super(FastSCNN, self).__init__(init_cfg)
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]
super(FastSCNN, self).__init__()
if global_in_channels != higher_in_channels: if global_in_channels != higher_in_channels:
raise AssertionError('Global Input Channels must be the same \ raise AssertionError('Global Input Channels must be the same \
with Higher Input Channels!') with Higher Input Channels!')
@ -357,13 +367,6 @@ class FastSCNN(nn.Module):
act_cfg=self.act_cfg, act_cfg=self.act_cfg,
align_corners=self.align_corners) align_corners=self.align_corners)
def init_weights(self, pretrained=None):
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
def forward(self, x): def forward(self, x):
higher_res_features = self.learning_to_downsample(x) higher_res_features = self.learning_to_downsample(x)
lower_res_features = self.global_feature_extractor(higher_res_features) lower_res_features = self.global_feature_extractor(higher_res_features)

View File

@ -1,16 +1,16 @@
import warnings
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, from mmcv.cnn import build_conv_layer, build_norm_layer
kaiming_init) from mmcv.runner import BaseModule, ModuleList, Sequential
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.ops import Upsample, resize from mmseg.ops import Upsample, resize
from mmseg.utils import get_root_logger
from ..builder import BACKBONES from ..builder import BACKBONES
from .resnet import BasicBlock, Bottleneck from .resnet import BasicBlock, Bottleneck
class HRModule(nn.Module): class HRModule(BaseModule):
"""High-Resolution Module for HRNet. """High-Resolution Module for HRNet.
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
@ -26,8 +26,11 @@ class HRModule(nn.Module):
multiscale_output=True, multiscale_output=True,
with_cp=False, with_cp=False,
conv_cfg=None, conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True)): norm_cfg=dict(type='BN', requires_grad=True),
super(HRModule, self).__init__() block_init_cfg=None,
init_cfg=None):
super(HRModule, self).__init__(init_cfg)
self.block_init_cfg = block_init_cfg
self._check_branches(num_branches, num_blocks, in_channels, self._check_branches(num_branches, num_blocks, in_channels,
num_channels) num_channels)
@ -92,7 +95,8 @@ class HRModule(nn.Module):
downsample=downsample, downsample=downsample,
with_cp=self.with_cp, with_cp=self.with_cp,
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg)) conv_cfg=self.conv_cfg,
init_cfg=self.block_init_cfg))
self.in_channels[branch_index] = \ self.in_channels[branch_index] = \
num_channels[branch_index] * block.expansion num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]): for i in range(1, num_blocks[branch_index]):
@ -102,9 +106,10 @@ class HRModule(nn.Module):
num_channels[branch_index], num_channels[branch_index],
with_cp=self.with_cp, with_cp=self.with_cp,
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg)) conv_cfg=self.conv_cfg,
init_cfg=self.block_init_cfg))
return nn.Sequential(*layers) return Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels): def _make_branches(self, num_branches, block, num_blocks, num_channels):
"""Build multiple branch.""" """Build multiple branch."""
@ -114,7 +119,7 @@ class HRModule(nn.Module):
branches.append( branches.append(
self._make_one_branch(i, block, num_blocks, num_channels)) self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches) return ModuleList(branches)
def _make_fuse_layers(self): def _make_fuse_layers(self):
"""Build fuse layer.""" """Build fuse layer."""
@ -209,7 +214,7 @@ class HRModule(nn.Module):
@BACKBONES.register_module() @BACKBONES.register_module()
class HRNet(nn.Module): class HRNet(BaseModule):
"""HRNet backbone. """HRNet backbone.
High-Resolution Representations for Labeling Pixels and Regions High-Resolution Representations for Labeling Pixels and Regions
@ -227,6 +232,9 @@ class HRNet(nn.Module):
memory while slowing down the training speed. memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity. in resblocks to let them behave as identity.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
Example: Example:
>>> from mmseg.models import HRNet >>> from mmseg.models import HRNet
@ -277,14 +285,36 @@ class HRNet(nn.Module):
norm_cfg=dict(type='BN', requires_grad=True), norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False, norm_eval=False,
with_cp=False, with_cp=False,
zero_init_residual=False): zero_init_residual=False,
super(HRNet, self).__init__() pretrained=None,
init_cfg=None):
super(HRNet, self).__init__(init_cfg)
self.pretrained = pretrained
self.zero_init_residual = zero_init_residual
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
self.extra = extra self.extra = extra
self.conv_cfg = conv_cfg self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.norm_eval = norm_eval self.norm_eval = norm_eval
self.with_cp = with_cp self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
# stem net # stem net
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
@ -430,6 +460,16 @@ class HRNet(nn.Module):
build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
layers = [] layers = []
block_init_cfg = None
if self.pretrained is None and not hasattr(
self, 'init_cfg') and self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm3'))
layers.append( layers.append(
block( block(
inplanes, inplanes,
@ -438,7 +478,8 @@ class HRNet(nn.Module):
downsample=downsample, downsample=downsample,
with_cp=self.with_cp, with_cp=self.with_cp,
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg)) conv_cfg=self.conv_cfg,
init_cfg=block_init_cfg))
inplanes = planes * block.expansion inplanes = planes * block.expansion
for i in range(1, blocks): for i in range(1, blocks):
layers.append( layers.append(
@ -447,9 +488,10 @@ class HRNet(nn.Module):
planes, planes,
with_cp=self.with_cp, with_cp=self.with_cp,
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg)) conv_cfg=self.conv_cfg,
init_cfg=block_init_cfg))
return nn.Sequential(*layers) return Sequential(*layers)
def _make_stage(self, layer_config, in_channels, multiscale_output=True): def _make_stage(self, layer_config, in_channels, multiscale_output=True):
"""Make each stage.""" """Make each stage."""
@ -460,6 +502,16 @@ class HRNet(nn.Module):
block = self.blocks_dict[layer_config['block']] block = self.blocks_dict[layer_config['block']]
hr_modules = [] hr_modules = []
block_init_cfg = None
if self.pretrained is None and not hasattr(
self, 'init_cfg') and self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm3'))
for i in range(num_modules): for i in range(num_modules):
# multi_scale_output is only used for the last module # multi_scale_output is only used for the last module
if not multiscale_output and i == num_modules - 1: if not multiscale_output and i == num_modules - 1:
@ -477,35 +529,10 @@ class HRNet(nn.Module):
reset_multiscale_output, reset_multiscale_output,
with_cp=self.with_cp, with_cp=self.with_cp,
norm_cfg=self.norm_cfg, norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg)) conv_cfg=self.conv_cfg,
block_init_cfg=block_init_cfg))
return nn.Sequential(*hr_modules), in_channels return Sequential(*hr_modules), in_channels
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x):
"""Forward function.""" """Forward function."""

View File

@ -1,8 +1,8 @@
import logging import warnings
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init from mmcv.cnn import ConvModule
from mmcv.runner import load_checkpoint from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES from ..builder import BACKBONES
@ -10,7 +10,7 @@ from ..utils import InvertedResidual, make_divisible
@BACKBONES.register_module() @BACKBONES.register_module()
class MobileNetV2(nn.Module): class MobileNetV2(BaseModule):
"""MobileNetV2 backbone. """MobileNetV2 backbone.
Args: Args:
@ -35,6 +35,9 @@ class MobileNetV2(nn.Module):
and its variants only. Default: False. and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False. memory while slowing down the training speed. Default: False.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
""" """
# Parameters to build layers. 3 parameters are needed to construct a # Parameters to build layers. 3 parameters are needed to construct a
@ -52,8 +55,30 @@ class MobileNetV2(nn.Module):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'), act_cfg=dict(type='ReLU6'),
norm_eval=False, norm_eval=False,
with_cp=False): with_cp=False,
super(MobileNetV2, self).__init__() pretrained=None,
init_cfg=None):
super(MobileNetV2, self).__init__(init_cfg)
self.pretrained = pretrained
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
self.widen_factor = widen_factor self.widen_factor = widen_factor
self.strides = strides self.strides = strides
self.dilations = dilations self.dilations = dilations
@ -133,19 +158,6 @@ class MobileNetV2(nn.Module):
return nn.Sequential(*layers) return nn.Sequential(*layers)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)

View File

@ -1,10 +1,9 @@
import logging import warnings
import mmcv import mmcv
import torch.nn as nn from mmcv.cnn import ConvModule
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.cnn.bricks import Conv2dAdaptivePadding from mmcv.cnn.bricks import Conv2dAdaptivePadding
from mmcv.runner import load_checkpoint from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES from ..builder import BACKBONES
@ -12,7 +11,7 @@ from ..utils import InvertedResidualV3 as InvertedResidual
@BACKBONES.register_module() @BACKBONES.register_module()
class MobileNetV3(nn.Module): class MobileNetV3(BaseModule):
"""MobileNetV3 backbone. """MobileNetV3 backbone.
This backbone is the improved implementation of `Searching for MobileNetV3 This backbone is the improved implementation of `Searching for MobileNetV3
@ -35,6 +34,9 @@ class MobileNetV3(nn.Module):
with_cp (bool): Use checkpoint or not. Using checkpoint will save with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. some memory while slowing down the training speed.
Default: False. Default: False.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
""" """
# Parameters to build each block: # Parameters to build each block:
# [kernel size, mid channels, out channels, with_se, act type, stride] # [kernel size, mid channels, out channels, with_se, act type, stride]
@ -75,8 +77,30 @@ class MobileNetV3(nn.Module):
frozen_stages=-1, frozen_stages=-1,
reduction_factor=1, reduction_factor=1,
norm_eval=False, norm_eval=False,
with_cp=False): with_cp=False,
super(MobileNetV3, self).__init__() pretrained=None,
init_cfg=None):
super(MobileNetV3, self).__init__(init_cfg)
self.pretrained = pretrained
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
assert arch in self.arch_settings assert arch in self.arch_settings
assert isinstance(reduction_factor, int) and reduction_factor > 0 assert isinstance(reduction_factor, int) and reduction_factor > 0
assert mmcv.is_tuple_of(out_indices, int) assert mmcv.is_tuple_of(out_indices, int)
@ -217,19 +241,6 @@ class MobileNetV3(nn.Module):
return layers return layers
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x):
outs = [] outs = []
for i, layer_name in enumerate(self.layers): for i, layer_name in enumerate(self.layers):

View File

@ -1,16 +1,16 @@
import warnings
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer, from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
constant_init, kaiming_init) from mmcv.runner import BaseModule
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.utils import get_root_logger
from ..builder import BACKBONES from ..builder import BACKBONES
from ..utils import ResLayer from ..utils import ResLayer
class BasicBlock(nn.Module): class BasicBlock(BaseModule):
"""Basic block for ResNet.""" """Basic block for ResNet."""
expansion = 1 expansion = 1
@ -26,8 +26,9 @@ class BasicBlock(nn.Module):
conv_cfg=None, conv_cfg=None,
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
dcn=None, dcn=None,
plugins=None): plugins=None,
super(BasicBlock, self).__init__() init_cfg=None):
super(BasicBlock, self).__init__(init_cfg)
assert dcn is None, 'Not implemented yet.' assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.' assert plugins is None, 'Not implemented yet.'
@ -94,7 +95,7 @@ class BasicBlock(nn.Module):
return out return out
class Bottleneck(nn.Module): class Bottleneck(BaseModule):
"""Bottleneck block for ResNet. """Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
@ -114,8 +115,9 @@ class Bottleneck(nn.Module):
conv_cfg=None, conv_cfg=None,
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
dcn=None, dcn=None,
plugins=None): plugins=None,
super(Bottleneck, self).__init__() init_cfg=None):
super(Bottleneck, self).__init__(init_cfg)
assert style in ['pytorch', 'caffe'] assert style in ['pytorch', 'caffe']
assert dcn is None or isinstance(dcn, dict) assert dcn is None or isinstance(dcn, dict)
assert plugins is None or isinstance(plugins, list) assert plugins is None or isinstance(plugins, list)
@ -305,7 +307,7 @@ class Bottleneck(nn.Module):
@BACKBONES.register_module() @BACKBONES.register_module()
class ResNet(nn.Module): class ResNet(BaseModule):
"""ResNet backbone. """ResNet backbone.
Args: Args:
@ -346,6 +348,9 @@ class ResNet(nn.Module):
memory while slowing down the training speed. memory while slowing down the training speed.
zero_init_residual (bool): Whether to use zero init for last norm layer zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. in resblocks to let them behave as identity.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
Example: Example:
>>> from mmseg.models import ResNet >>> from mmseg.models import ResNet
@ -392,10 +397,46 @@ class ResNet(nn.Module):
multi_grid=None, multi_grid=None,
contract_dilation=False, contract_dilation=False,
with_cp=False, with_cp=False,
zero_init_residual=True): zero_init_residual=True,
pretrained=None,
init_cfg=None):
super(ResNet, self).__init__() super(ResNet, self).__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet') raise KeyError(f'invalid depth {depth} for resnet')
self.pretrained = pretrained
self.zero_init_residual = zero_init_residual
block_init_cfg = None
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
block = self.arch_settings[depth][0]
if self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(
type='Constant',
val=0,
override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(
type='Constant',
val=0,
override=dict(name='norm3'))
else:
raise TypeError('pretrained must be a str or None')
self.depth = depth self.depth = depth
self.stem_channels = stem_channels self.stem_channels = stem_channels
self.base_channels = base_channels self.base_channels = base_channels
@ -421,7 +462,6 @@ class ResNet(nn.Module):
self.plugins = plugins self.plugins = plugins
self.multi_grid = multi_grid self.multi_grid = multi_grid
self.contract_dilation = contract_dilation self.contract_dilation = contract_dilation
self.zero_init_residual = zero_init_residual
self.block, stage_blocks = self.arch_settings[depth] self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages] self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = stem_channels self.inplanes = stem_channels
@ -456,7 +496,8 @@ class ResNet(nn.Module):
dcn=dcn, dcn=dcn,
plugins=stage_plugins, plugins=stage_plugins,
multi_grid=stage_multi_grid, multi_grid=stage_multi_grid,
contract_dilation=contract_dilation) contract_dilation=contract_dilation,
init_cfg=block_init_cfg)
self.inplanes = planes * self.block.expansion self.inplanes = planes * self.block.expansion
layer_name = f'layer{i+1}' layer_name = f'layer{i+1}'
self.add_module(layer_name, res_layer) self.add_module(layer_name, res_layer)
@ -597,38 +638,6 @@ class ResNet(nn.Module):
for param in m.parameters(): for param in m.parameters():
param.requires_grad = False param.requires_grad = False
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.dcn is not None:
for m in self.modules():
if isinstance(m, Bottleneck) and hasattr(
m, 'conv2_offset'):
constant_init(m.conv2_offset, 0)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x):
"""Forward function.""" """Forward function."""
if self.deep_stem: if self.deep_stem:

View File

@ -1,11 +1,12 @@
import warnings
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
build_norm_layer, constant_init, kaiming_init) build_norm_layer)
from mmcv.runner import load_checkpoint from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.utils import get_root_logger
from ..builder import BACKBONES from ..builder import BACKBONES
from ..utils import UpConvBlock from ..utils import UpConvBlock
@ -219,7 +220,7 @@ class InterpConv(nn.Module):
@BACKBONES.register_module() @BACKBONES.register_module()
class UNet(nn.Module): class UNet(BaseModule):
"""UNet backbone. """UNet backbone.
U-Net: Convolutional Networks for Biomedical Image Segmentation. U-Net: Convolutional Networks for Biomedical Image Segmentation.
https://arxiv.org/pdf/1505.04597.pdf https://arxiv.org/pdf/1505.04597.pdf
@ -266,6 +267,9 @@ class UNet(nn.Module):
dcn (bool): Use deformable convolution in convolutional layer or not. dcn (bool): Use deformable convolution in convolutional layer or not.
Default: None. Default: None.
plugins (dict): plugins for convolutional layers. Default: None. plugins (dict): plugins for convolutional layers. Default: None.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
Notice: Notice:
The input image size should be divisible by the whole downsample rate The input image size should be divisible by the whole downsample rate
@ -291,8 +295,30 @@ class UNet(nn.Module):
upsample_cfg=dict(type='InterpConv'), upsample_cfg=dict(type='InterpConv'),
norm_eval=False, norm_eval=False,
dcn=None, dcn=None,
plugins=None): plugins=None,
super(UNet, self).__init__() pretrained=None,
init_cfg=None):
super(UNet, self).__init__(init_cfg)
self.pretrained = pretrained
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
assert dcn is None, 'Not implemented yet.' assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.' assert plugins is None, 'Not implemented yet.'
assert len(strides) == num_stages, \ assert len(strides) == num_stages, \
@ -408,22 +434,3 @@ class UNet(nn.Module):
f'downsample rate {whole_downsample_rate}, when num_stages is '\ f'downsample rate {whole_downsample_rate}, when num_stages is '\
f'{self.num_stages}, strides is {self.strides}, and downsamples '\ f'{self.num_stages}, strides is {self.strides}, and downsamples '\
f'is {self.downsamples}.' f'is {self.downsamples}.'
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')

View File

@ -9,7 +9,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer, from mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
constant_init, kaiming_init, normal_init) constant_init, kaiming_init, normal_init)
from mmcv.runner import _load_checkpoint from mmcv.runner import BaseModule, _load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
@ -203,7 +203,7 @@ class PatchEmbed(nn.Module):
@BACKBONES.register_module() @BACKBONES.register_module()
class VisionTransformer(nn.Module): class VisionTransformer(BaseModule):
"""Vision transformer backbone. """Vision transformer backbone.
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
@ -243,6 +243,9 @@ class VisionTransformer(nn.Module):
with_cp (bool): Use checkpoint or not. Using checkpoint with_cp (bool): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed. will save some memory while slowing down the training speed.
Default: False. Default: False.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
""" """
def __init__(self, def __init__(self,
@ -266,8 +269,12 @@ class VisionTransformer(nn.Module):
out_shape='NCHW', out_shape='NCHW',
with_cls_token=True, with_cls_token=True,
interpolate_mode='bicubic', interpolate_mode='bicubic',
with_cp=False): with_cp=False,
super(VisionTransformer, self).__init__() pretrained=None,
init_cfg=None):
super(VisionTransformer, self).__init__(init_cfg)
self.pretrained = pretrained
self.img_size = img_size self.img_size = img_size
self.patch_size = patch_size self.patch_size = patch_size
self.features = self.embed_dim = embed_dim self.features = self.embed_dim = embed_dim
@ -319,7 +326,8 @@ class VisionTransformer(nn.Module):
self.norm_eval = norm_eval self.norm_eval = norm_eval
self.with_cp = with_cp self.with_cp = with_cp
def init_weights(self, pretrained=None): def init_weights(self):
pretrained = self.pretrained
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = get_root_logger() logger = get_root_logger()
checkpoint = _load_checkpoint(pretrained, logger=logger) checkpoint = _load_checkpoint(pretrained, logger=logger)

View File

@ -2,8 +2,7 @@ from abc import ABCMeta, abstractmethod
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import normal_init from mmcv.runner import BaseModule, auto_fp16, force_fp32
from mmcv.runner import auto_fp16, force_fp32
from mmseg.core import build_pixel_sampler from mmseg.core import build_pixel_sampler
from mmseg.ops import resize from mmseg.ops import resize
@ -11,7 +10,7 @@ from ..builder import build_loss
from ..losses import accuracy from ..losses import accuracy
class BaseDecodeHead(nn.Module, metaclass=ABCMeta): class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead. """Base class for BaseDecodeHead.
Args: Args:
@ -41,6 +40,7 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
Default: None. Default: None.
align_corners (bool): align_corners argument of F.interpolate. align_corners (bool): align_corners argument of F.interpolate.
Default: False. Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
""" """
def __init__(self, def __init__(self,
@ -60,8 +60,10 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
loss_weight=1.0), loss_weight=1.0),
ignore_index=255, ignore_index=255,
sampler=None, sampler=None,
align_corners=False): align_corners=False,
super(BaseDecodeHead, self).__init__() init_cfg=dict(
type='Normal', std=0.01, override=dict(name='conv_seg'))):
super(BaseDecodeHead, self).__init__(init_cfg)
self._init_inputs(in_channels, in_index, input_transform) self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels self.channels = channels
self.num_classes = num_classes self.num_classes = num_classes
@ -130,10 +132,6 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
assert isinstance(in_index, int) assert isinstance(in_index, int)
self.in_channels = in_channels self.in_channels = in_channels
def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.conv_seg, mean=0, std=0.01)
def _transform_inputs(self, inputs): def _transform_inputs(self, inputs):
"""Transform inputs for decoder. """Transform inputs for decoder.

View File

@ -2,7 +2,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ConvModule, normal_init from mmcv.cnn import ConvModule
from mmcv.ops import point_sample from mmcv.ops import point_sample
from mmseg.models.builder import HEADS from mmseg.models.builder import HEADS
@ -69,6 +69,8 @@ class PointHead(BaseCascadeDecodeHead):
conv_cfg=conv_cfg, conv_cfg=conv_cfg,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
act_cfg=act_cfg, act_cfg=act_cfg,
init_cfg=dict(
type='Normal', std=0.01, override=dict(name='fc_seg')),
**kwargs) **kwargs)
self.num_fcs = num_fcs self.num_fcs = num_fcs
@ -101,10 +103,6 @@ class PointHead(BaseCascadeDecodeHead):
self.dropout = nn.Dropout(self.dropout_ratio) self.dropout = nn.Dropout(self.dropout_ratio)
delattr(self, 'conv_seg') delattr(self, 'conv_seg')
def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.fc_seg, std=0.001)
def cls_seg(self, feat): def cls_seg(self, feat):
"""Classify each pixel with fc.""" """Classify each pixel with fc."""
if self.dropout is not None: if self.dropout is not None:

View File

@ -1,12 +1,13 @@
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import ConvModule, xavier_init from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16
from ..builder import NECKS from ..builder import NECKS
@NECKS.register_module() @NECKS.register_module()
class FPN(nn.Module): class FPN(BaseModule):
"""Feature Pyramid Network. """Feature Pyramid Network.
This is an implementation of - Feature Pyramid Networks for Object This is an implementation of - Feature Pyramid Networks for Object
@ -43,6 +44,7 @@ class FPN(nn.Module):
Default: None. Default: None.
upsample_cfg (dict): Config dict for interpolate layer. upsample_cfg (dict): Config dict for interpolate layer.
Default: `dict(mode='nearest')` Default: `dict(mode='nearest')`
init_cfg (dict or list[dict], optional): Initialization config dict.
Example: Example:
>>> import torch >>> import torch
@ -73,8 +75,10 @@ class FPN(nn.Module):
conv_cfg=None, conv_cfg=None,
norm_cfg=None, norm_cfg=None,
act_cfg=None, act_cfg=None,
upsample_cfg=dict(mode='nearest')): upsample_cfg=dict(mode='nearest'),
super(FPN, self).__init__() init_cfg=dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super(FPN, self).__init__(init_cfg)
assert isinstance(in_channels, list) assert isinstance(in_channels, list)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -153,12 +157,7 @@ class FPN(nn.Module):
inplace=False) inplace=False)
self.fpn_convs.append(extra_fpn_conv) self.fpn_convs.append(extra_fpn_conv)
# default init_weights for conv(msra) and norm in ConvModule @auto_fp16()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
def forward(self, inputs): def forward(self, inputs):
assert len(inputs) == len(self.in_channels) assert len(inputs) == len(self.in_channels)

View File

@ -1,4 +1,3 @@
import logging
import warnings import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import OrderedDict from collections import OrderedDict
@ -7,17 +6,14 @@ import mmcv
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn from mmcv.runner import BaseModule, auto_fp16
from mmcv.runner import auto_fp16
class BaseSegmentor(nn.Module): class BaseSegmentor(BaseModule, metaclass=ABCMeta):
"""Base class for segmentors.""" """Base class for segmentors."""
__metaclass__ = ABCMeta def __init__(self, init_cfg=None):
super(BaseSegmentor, self).__init__(init_cfg)
def __init__(self):
super(BaseSegmentor, self).__init__()
self.fp16_enabled = False self.fp16_enabled = False
@property @property
@ -62,17 +58,6 @@ class BaseSegmentor(nn.Module):
"""Placeholder for augmentation test.""" """Placeholder for augmentation test."""
pass pass
def init_weights(self, pretrained=None):
"""Initialize the weights in segmentor.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if pretrained is not None:
logger = logging.getLogger()
logger.info(f'load model from: {pretrained}')
def forward_test(self, imgs, img_metas, **kwargs): def forward_test(self, imgs, img_metas, **kwargs):
""" """
Args: Args:

View File

@ -24,7 +24,8 @@ class CascadeEncoderDecoder(EncoderDecoder):
auxiliary_head=None, auxiliary_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None): pretrained=None,
init_cfg=None):
self.num_stages = num_stages self.num_stages = num_stages
super(CascadeEncoderDecoder, self).__init__( super(CascadeEncoderDecoder, self).__init__(
backbone=backbone, backbone=backbone,
@ -33,7 +34,8 @@ class CascadeEncoderDecoder(EncoderDecoder):
auxiliary_head=auxiliary_head, auxiliary_head=auxiliary_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained) pretrained=pretrained,
init_cfg=init_cfg)
def _init_decode_head(self, decode_head): def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``""" """Initialize ``decode_head``"""
@ -45,23 +47,6 @@ class CascadeEncoderDecoder(EncoderDecoder):
self.align_corners = self.decode_head[-1].align_corners self.align_corners = self.decode_head[-1].align_corners
self.num_classes = self.decode_head[-1].num_classes self.num_classes = self.decode_head[-1].num_classes
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone and heads.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
self.backbone.init_weights(pretrained=pretrained)
for i in range(self.num_stages):
self.decode_head[i].init_weights()
if self.with_auxiliary_head:
if isinstance(self.auxiliary_head, nn.ModuleList):
for aux_head in self.auxiliary_head:
aux_head.init_weights()
else:
self.auxiliary_head.init_weights()
def encode_decode(self, img, img_metas): def encode_decode(self, img, img_metas):
"""Encode images with backbone and decode into a semantic segmentation """Encode images with backbone and decode into a semantic segmentation
map of the same size as input.""" map of the same size as input."""

View File

@ -25,8 +25,13 @@ class EncoderDecoder(BaseSegmentor):
auxiliary_head=None, auxiliary_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None): pretrained=None,
super(EncoderDecoder, self).__init__() init_cfg=None):
super(EncoderDecoder, self).__init__(init_cfg)
if pretrained is not None:
assert backbone.get('pretrained') is None, \
'both backbone and segmentor set pretrained weight'
backbone.pretrained = pretrained
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
if neck is not None: if neck is not None:
self.neck = builder.build_neck(neck) self.neck = builder.build_neck(neck)
@ -36,8 +41,6 @@ class EncoderDecoder(BaseSegmentor):
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
assert self.with_decode_head assert self.with_decode_head
def _init_decode_head(self, decode_head): def _init_decode_head(self, decode_head):
@ -56,24 +59,6 @@ class EncoderDecoder(BaseSegmentor):
else: else:
self.auxiliary_head = builder.build_head(auxiliary_head) self.auxiliary_head = builder.build_head(auxiliary_head)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone and heads.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super(EncoderDecoder, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
self.decode_head.init_weights()
if self.with_auxiliary_head:
if isinstance(self.auxiliary_head, nn.ModuleList):
for aux_head in self.auxiliary_head:
aux_head.init_weights()
else:
self.auxiliary_head.init_weights()
def extract_feat(self, img): def extract_feat(self, img):
"""Extract features from images.""" """Extract features from images."""
x = self.backbone(img) x = self.backbone(img)

View File

@ -1,8 +1,9 @@
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import Sequential
from torch import nn as nn from torch import nn as nn
class ResLayer(nn.Sequential): class ResLayer(Sequential):
"""ResLayer to build ResNet style backbone. """ResLayer to build ResNet style backbone.
Args: Args:

View File

@ -300,8 +300,8 @@ def test_resnet_backbone():
with pytest.raises(TypeError): with pytest.raises(TypeError):
# pretrained must be a string path # pretrained must be a string path
model = ResNet(50) model = ResNet(50, pretrained=0)
model.init_weights(pretrained=0) model.init_weights()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe'] # Style must be in ['pytorch', 'caffe']
@ -314,8 +314,9 @@ def test_resnet_backbone():
assert check_norm_state(model.modules(), False) assert check_norm_state(model.modules(), False)
# Test ResNet50 with torchvision pretrained weight # Test ResNet50 with torchvision pretrained weight
model = ResNet(depth=50, norm_eval=True) model = ResNet(
model.init_weights('torchvision://resnet50') depth=50, norm_eval=True, pretrained='torchvision://resnet50')
model.init_weights()
model.train() model.train()
assert check_norm_state(model.modules(), False) assert check_norm_state(model.modules(), False)

View File

@ -734,7 +734,6 @@ def test_unet():
downsamples=(True, True, True, True), downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1), enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1)) dec_dilations=(1, 1, 1, 1))
print(unet)
x = torch.randn(2, 3, 128, 128) x = torch.randn(2, 3, 128, 128)
x_outs = unet(x) x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 8, 8]) assert x_outs[0].shape == torch.Size([2, 1024, 8, 8])
@ -754,7 +753,6 @@ def test_unet():
downsamples=(True, True, True, False), downsamples=(True, True, True, False),
enc_dilations=(1, 1, 1, 1, 1), enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1)) dec_dilations=(1, 1, 1, 1))
print(unet)
x = torch.randn(2, 3, 128, 128) x = torch.randn(2, 3, 128, 128)
x_outs = unet(x) x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
@ -774,7 +772,6 @@ def test_unet():
downsamples=(True, True, True, False), downsamples=(True, True, True, False),
enc_dilations=(1, 1, 1, 1, 1), enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1)) dec_dilations=(1, 1, 1, 1))
print(unet)
x = torch.randn(2, 3, 128, 128) x = torch.randn(2, 3, 128, 128)
x_outs = unet(x) x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
@ -794,7 +791,6 @@ def test_unet():
downsamples=(True, True, False, False), downsamples=(True, True, False, False),
enc_dilations=(1, 1, 1, 1, 1), enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1)) dec_dilations=(1, 1, 1, 1))
print(unet)
x = torch.randn(2, 3, 128, 128) x = torch.randn(2, 3, 128, 128)
x_outs = unet(x) x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
@ -813,9 +809,9 @@ def test_unet():
dec_num_convs=(2, 2, 2, 2), dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, False, False), downsamples=(True, True, False, False),
enc_dilations=(1, 1, 1, 1, 1), enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1)) dec_dilations=(1, 1, 1, 1),
unet.init_weights(pretrained=None) pretrained=None)
print(unet) unet.init_weights()
x = torch.randn(2, 3, 128, 128) x = torch.randn(2, 3, 128, 128)
x_outs = unet(x) x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])

View File

@ -215,6 +215,7 @@ def _test_encoder_decoder_forward(cfg_file):
from mmseg.models import build_segmentor from mmseg.models import build_segmentor
segmentor = build_segmentor(model) segmentor = build_segmentor(model)
segmentor.init_weights()
if isinstance(segmentor.decode_head, nn.ModuleList): if isinstance(segmentor.decode_head, nn.ModuleList):
num_classes = segmentor.decode_head[-1].num_classes num_classes = segmentor.decode_head[-1].num_classes

View File

@ -131,6 +131,7 @@ def main():
cfg.model, cfg.model,
train_cfg=cfg.get('train_cfg'), train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg')) test_cfg=cfg.get('test_cfg'))
model.init_weights()
logger.info(model) logger.info(model)