[Refactor]: Unified parameter initialization (#567)
* [Refactor]: Unified parameter initialization * fixed pretrained
This commit is contained in:
parent
5d46314844
commit
0c5b026db1
@ -1,12 +1,12 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint as cp
|
import torch.utils.checkpoint as cp
|
||||||
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
|
from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
|
||||||
constant_init, kaiming_init)
|
from mmcv.runner import BaseModule
|
||||||
from mmcv.runner import load_checkpoint
|
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
from mmseg.utils import get_root_logger
|
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
|
|
||||||
|
|
||||||
@ -183,7 +183,7 @@ class InputInjection(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class CGNet(nn.Module):
|
class CGNet(BaseModule):
|
||||||
"""CGNet backbone.
|
"""CGNet backbone.
|
||||||
|
|
||||||
A Light-weight Context Guided Network for Semantic Segmentation
|
A Light-weight Context Guided Network for Semantic Segmentation
|
||||||
@ -210,6 +210,9 @@ class CGNet(nn.Module):
|
|||||||
and its variants only. Default: False.
|
and its variants only. Default: False.
|
||||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
memory while slowing down the training speed. Default: False.
|
memory while slowing down the training speed. Default: False.
|
||||||
|
pretrained (str, optional): model pretrained path. Default: None
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -222,9 +225,31 @@ class CGNet(nn.Module):
|
|||||||
norm_cfg=dict(type='BN', requires_grad=True),
|
norm_cfg=dict(type='BN', requires_grad=True),
|
||||||
act_cfg=dict(type='PReLU'),
|
act_cfg=dict(type='PReLU'),
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
with_cp=False):
|
with_cp=False,
|
||||||
|
pretrained=None,
|
||||||
|
init_cfg=None):
|
||||||
|
|
||||||
|
super(CGNet, self).__init__(init_cfg)
|
||||||
|
|
||||||
|
assert not (init_cfg and pretrained), \
|
||||||
|
'init_cfg and pretrained cannot be setting at the same time'
|
||||||
|
if isinstance(pretrained, str):
|
||||||
|
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||||
|
'please use "init_cfg" instead')
|
||||||
|
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||||
|
elif pretrained is None:
|
||||||
|
if init_cfg is None:
|
||||||
|
self.init_cfg = [
|
||||||
|
dict(type='Kaiming', layer=['Conv2d', 'Linear']),
|
||||||
|
dict(
|
||||||
|
type='Constant',
|
||||||
|
val=1,
|
||||||
|
layer=['_BatchNorm', 'GroupNorm']),
|
||||||
|
dict(type='Constant', val=0, layer='PReLU')
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
super(CGNet, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
assert isinstance(self.num_channels, tuple) and len(
|
assert isinstance(self.num_channels, tuple) and len(
|
||||||
@ -335,27 +360,6 @@ class CGNet(nn.Module):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
|
||||||
"""Initialize the weights in backbone.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str, optional): Path to pre-trained weights.
|
|
||||||
Defaults to None.
|
|
||||||
"""
|
|
||||||
if isinstance(pretrained, str):
|
|
||||||
logger = get_root_logger()
|
|
||||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
|
||||||
elif pretrained is None:
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
|
||||||
kaiming_init(m)
|
|
||||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
|
||||||
constant_init(m, 1)
|
|
||||||
elif isinstance(m, nn.PReLU):
|
|
||||||
constant_init(m, 0)
|
|
||||||
else:
|
|
||||||
raise TypeError('pretrained must be a str or None')
|
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
"""Convert the model into training mode will keeping the normalization
|
"""Convert the model into training mode will keeping the normalization
|
||||||
layer freezed."""
|
layer freezed."""
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
|
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||||
kaiming_init)
|
from mmcv.runner import BaseModule
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
|
||||||
|
|
||||||
from mmseg.models.decode_heads.psp_head import PPM
|
from mmseg.models.decode_heads.psp_head import PPM
|
||||||
from mmseg.ops import resize
|
from mmseg.ops import resize
|
||||||
@ -247,7 +246,7 @@ class FeatureFusionModule(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class FastSCNN(nn.Module):
|
class FastSCNN(BaseModule):
|
||||||
"""Fast-SCNN Backbone.
|
"""Fast-SCNN Backbone.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -291,6 +290,8 @@ class FastSCNN(nn.Module):
|
|||||||
dict(type='ReLU')
|
dict(type='ReLU')
|
||||||
align_corners (bool): align_corners argument of F.interpolate.
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
Default: False
|
Default: False
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -307,9 +308,18 @@ class FastSCNN(nn.Module):
|
|||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
act_cfg=dict(type='ReLU'),
|
act_cfg=dict(type='ReLU'),
|
||||||
align_corners=False):
|
align_corners=False,
|
||||||
|
init_cfg=None):
|
||||||
|
|
||||||
|
super(FastSCNN, self).__init__(init_cfg)
|
||||||
|
|
||||||
|
if init_cfg is None:
|
||||||
|
self.init_cfg = [
|
||||||
|
dict(type='Kaiming', layer='Conv2d'),
|
||||||
|
dict(
|
||||||
|
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||||
|
]
|
||||||
|
|
||||||
super(FastSCNN, self).__init__()
|
|
||||||
if global_in_channels != higher_in_channels:
|
if global_in_channels != higher_in_channels:
|
||||||
raise AssertionError('Global Input Channels must be the same \
|
raise AssertionError('Global Input Channels must be the same \
|
||||||
with Higher Input Channels!')
|
with Higher Input Channels!')
|
||||||
@ -357,13 +367,6 @@ class FastSCNN(nn.Module):
|
|||||||
act_cfg=self.act_cfg,
|
act_cfg=self.act_cfg,
|
||||||
align_corners=self.align_corners)
|
align_corners=self.align_corners)
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
kaiming_init(m)
|
|
||||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
|
||||||
constant_init(m, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
higher_res_features = self.learning_to_downsample(x)
|
higher_res_features = self.learning_to_downsample(x)
|
||||||
lower_res_features = self.global_feature_extractor(higher_res_features)
|
lower_res_features = self.global_feature_extractor(higher_res_features)
|
||||||
|
|||||||
@ -1,16 +1,16 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
|
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||||
kaiming_init)
|
from mmcv.runner import BaseModule, ModuleList, Sequential
|
||||||
from mmcv.runner import load_checkpoint
|
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
from mmseg.ops import Upsample, resize
|
from mmseg.ops import Upsample, resize
|
||||||
from mmseg.utils import get_root_logger
|
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
from .resnet import BasicBlock, Bottleneck
|
from .resnet import BasicBlock, Bottleneck
|
||||||
|
|
||||||
|
|
||||||
class HRModule(nn.Module):
|
class HRModule(BaseModule):
|
||||||
"""High-Resolution Module for HRNet.
|
"""High-Resolution Module for HRNet.
|
||||||
|
|
||||||
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
|
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
|
||||||
@ -26,8 +26,11 @@ class HRModule(nn.Module):
|
|||||||
multiscale_output=True,
|
multiscale_output=True,
|
||||||
with_cp=False,
|
with_cp=False,
|
||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN', requires_grad=True)):
|
norm_cfg=dict(type='BN', requires_grad=True),
|
||||||
super(HRModule, self).__init__()
|
block_init_cfg=None,
|
||||||
|
init_cfg=None):
|
||||||
|
super(HRModule, self).__init__(init_cfg)
|
||||||
|
self.block_init_cfg = block_init_cfg
|
||||||
self._check_branches(num_branches, num_blocks, in_channels,
|
self._check_branches(num_branches, num_blocks, in_channels,
|
||||||
num_channels)
|
num_channels)
|
||||||
|
|
||||||
@ -92,7 +95,8 @@ class HRModule(nn.Module):
|
|||||||
downsample=downsample,
|
downsample=downsample,
|
||||||
with_cp=self.with_cp,
|
with_cp=self.with_cp,
|
||||||
norm_cfg=self.norm_cfg,
|
norm_cfg=self.norm_cfg,
|
||||||
conv_cfg=self.conv_cfg))
|
conv_cfg=self.conv_cfg,
|
||||||
|
init_cfg=self.block_init_cfg))
|
||||||
self.in_channels[branch_index] = \
|
self.in_channels[branch_index] = \
|
||||||
num_channels[branch_index] * block.expansion
|
num_channels[branch_index] * block.expansion
|
||||||
for i in range(1, num_blocks[branch_index]):
|
for i in range(1, num_blocks[branch_index]):
|
||||||
@ -102,9 +106,10 @@ class HRModule(nn.Module):
|
|||||||
num_channels[branch_index],
|
num_channels[branch_index],
|
||||||
with_cp=self.with_cp,
|
with_cp=self.with_cp,
|
||||||
norm_cfg=self.norm_cfg,
|
norm_cfg=self.norm_cfg,
|
||||||
conv_cfg=self.conv_cfg))
|
conv_cfg=self.conv_cfg,
|
||||||
|
init_cfg=self.block_init_cfg))
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
return Sequential(*layers)
|
||||||
|
|
||||||
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
||||||
"""Build multiple branch."""
|
"""Build multiple branch."""
|
||||||
@ -114,7 +119,7 @@ class HRModule(nn.Module):
|
|||||||
branches.append(
|
branches.append(
|
||||||
self._make_one_branch(i, block, num_blocks, num_channels))
|
self._make_one_branch(i, block, num_blocks, num_channels))
|
||||||
|
|
||||||
return nn.ModuleList(branches)
|
return ModuleList(branches)
|
||||||
|
|
||||||
def _make_fuse_layers(self):
|
def _make_fuse_layers(self):
|
||||||
"""Build fuse layer."""
|
"""Build fuse layer."""
|
||||||
@ -209,7 +214,7 @@ class HRModule(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class HRNet(nn.Module):
|
class HRNet(BaseModule):
|
||||||
"""HRNet backbone.
|
"""HRNet backbone.
|
||||||
|
|
||||||
High-Resolution Representations for Labeling Pixels and Regions
|
High-Resolution Representations for Labeling Pixels and Regions
|
||||||
@ -227,6 +232,9 @@ class HRNet(nn.Module):
|
|||||||
memory while slowing down the training speed.
|
memory while slowing down the training speed.
|
||||||
zero_init_residual (bool): whether to use zero init for last norm layer
|
zero_init_residual (bool): whether to use zero init for last norm layer
|
||||||
in resblocks to let them behave as identity.
|
in resblocks to let them behave as identity.
|
||||||
|
pretrained (str, optional): model pretrained path. Default: None
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from mmseg.models import HRNet
|
>>> from mmseg.models import HRNet
|
||||||
@ -277,14 +285,36 @@ class HRNet(nn.Module):
|
|||||||
norm_cfg=dict(type='BN', requires_grad=True),
|
norm_cfg=dict(type='BN', requires_grad=True),
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
with_cp=False,
|
with_cp=False,
|
||||||
zero_init_residual=False):
|
zero_init_residual=False,
|
||||||
super(HRNet, self).__init__()
|
pretrained=None,
|
||||||
|
init_cfg=None):
|
||||||
|
super(HRNet, self).__init__(init_cfg)
|
||||||
|
|
||||||
|
self.pretrained = pretrained
|
||||||
|
self.zero_init_residual = zero_init_residual
|
||||||
|
assert not (init_cfg and pretrained), \
|
||||||
|
'init_cfg and pretrained cannot be setting at the same time'
|
||||||
|
if isinstance(pretrained, str):
|
||||||
|
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||||
|
'please use "init_cfg" instead')
|
||||||
|
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||||
|
elif pretrained is None:
|
||||||
|
if init_cfg is None:
|
||||||
|
self.init_cfg = [
|
||||||
|
dict(type='Kaiming', layer='Conv2d'),
|
||||||
|
dict(
|
||||||
|
type='Constant',
|
||||||
|
val=1,
|
||||||
|
layer=['_BatchNorm', 'GroupNorm'])
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
self.extra = extra
|
self.extra = extra
|
||||||
self.conv_cfg = conv_cfg
|
self.conv_cfg = conv_cfg
|
||||||
self.norm_cfg = norm_cfg
|
self.norm_cfg = norm_cfg
|
||||||
self.norm_eval = norm_eval
|
self.norm_eval = norm_eval
|
||||||
self.with_cp = with_cp
|
self.with_cp = with_cp
|
||||||
self.zero_init_residual = zero_init_residual
|
|
||||||
|
|
||||||
# stem net
|
# stem net
|
||||||
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
|
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
|
||||||
@ -430,6 +460,16 @@ class HRNet(nn.Module):
|
|||||||
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
|
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
|
block_init_cfg = None
|
||||||
|
if self.pretrained is None and not hasattr(
|
||||||
|
self, 'init_cfg') and self.zero_init_residual:
|
||||||
|
if block is BasicBlock:
|
||||||
|
block_init_cfg = dict(
|
||||||
|
type='Constant', val=0, override=dict(name='norm2'))
|
||||||
|
elif block is Bottleneck:
|
||||||
|
block_init_cfg = dict(
|
||||||
|
type='Constant', val=0, override=dict(name='norm3'))
|
||||||
|
|
||||||
layers.append(
|
layers.append(
|
||||||
block(
|
block(
|
||||||
inplanes,
|
inplanes,
|
||||||
@ -438,7 +478,8 @@ class HRNet(nn.Module):
|
|||||||
downsample=downsample,
|
downsample=downsample,
|
||||||
with_cp=self.with_cp,
|
with_cp=self.with_cp,
|
||||||
norm_cfg=self.norm_cfg,
|
norm_cfg=self.norm_cfg,
|
||||||
conv_cfg=self.conv_cfg))
|
conv_cfg=self.conv_cfg,
|
||||||
|
init_cfg=block_init_cfg))
|
||||||
inplanes = planes * block.expansion
|
inplanes = planes * block.expansion
|
||||||
for i in range(1, blocks):
|
for i in range(1, blocks):
|
||||||
layers.append(
|
layers.append(
|
||||||
@ -447,9 +488,10 @@ class HRNet(nn.Module):
|
|||||||
planes,
|
planes,
|
||||||
with_cp=self.with_cp,
|
with_cp=self.with_cp,
|
||||||
norm_cfg=self.norm_cfg,
|
norm_cfg=self.norm_cfg,
|
||||||
conv_cfg=self.conv_cfg))
|
conv_cfg=self.conv_cfg,
|
||||||
|
init_cfg=block_init_cfg))
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
return Sequential(*layers)
|
||||||
|
|
||||||
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
|
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
|
||||||
"""Make each stage."""
|
"""Make each stage."""
|
||||||
@ -460,6 +502,16 @@ class HRNet(nn.Module):
|
|||||||
block = self.blocks_dict[layer_config['block']]
|
block = self.blocks_dict[layer_config['block']]
|
||||||
|
|
||||||
hr_modules = []
|
hr_modules = []
|
||||||
|
block_init_cfg = None
|
||||||
|
if self.pretrained is None and not hasattr(
|
||||||
|
self, 'init_cfg') and self.zero_init_residual:
|
||||||
|
if block is BasicBlock:
|
||||||
|
block_init_cfg = dict(
|
||||||
|
type='Constant', val=0, override=dict(name='norm2'))
|
||||||
|
elif block is Bottleneck:
|
||||||
|
block_init_cfg = dict(
|
||||||
|
type='Constant', val=0, override=dict(name='norm3'))
|
||||||
|
|
||||||
for i in range(num_modules):
|
for i in range(num_modules):
|
||||||
# multi_scale_output is only used for the last module
|
# multi_scale_output is only used for the last module
|
||||||
if not multiscale_output and i == num_modules - 1:
|
if not multiscale_output and i == num_modules - 1:
|
||||||
@ -477,35 +529,10 @@ class HRNet(nn.Module):
|
|||||||
reset_multiscale_output,
|
reset_multiscale_output,
|
||||||
with_cp=self.with_cp,
|
with_cp=self.with_cp,
|
||||||
norm_cfg=self.norm_cfg,
|
norm_cfg=self.norm_cfg,
|
||||||
conv_cfg=self.conv_cfg))
|
conv_cfg=self.conv_cfg,
|
||||||
|
block_init_cfg=block_init_cfg))
|
||||||
|
|
||||||
return nn.Sequential(*hr_modules), in_channels
|
return Sequential(*hr_modules), in_channels
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
|
||||||
"""Initialize the weights in backbone.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str, optional): Path to pre-trained weights.
|
|
||||||
Defaults to None.
|
|
||||||
"""
|
|
||||||
if isinstance(pretrained, str):
|
|
||||||
logger = get_root_logger()
|
|
||||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
|
||||||
elif pretrained is None:
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
kaiming_init(m)
|
|
||||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
|
||||||
constant_init(m, 1)
|
|
||||||
|
|
||||||
if self.zero_init_residual:
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, Bottleneck):
|
|
||||||
constant_init(m.norm3, 0)
|
|
||||||
elif isinstance(m, BasicBlock):
|
|
||||||
constant_init(m.norm2, 0)
|
|
||||||
else:
|
|
||||||
raise TypeError('pretrained must be a str or None')
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Forward function."""
|
"""Forward function."""
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import logging
|
import warnings
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.runner import load_checkpoint
|
from mmcv.runner import BaseModule
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
@ -10,7 +10,7 @@ from ..utils import InvertedResidual, make_divisible
|
|||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class MobileNetV2(nn.Module):
|
class MobileNetV2(BaseModule):
|
||||||
"""MobileNetV2 backbone.
|
"""MobileNetV2 backbone.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -35,6 +35,9 @@ class MobileNetV2(nn.Module):
|
|||||||
and its variants only. Default: False.
|
and its variants only. Default: False.
|
||||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
memory while slowing down the training speed. Default: False.
|
memory while slowing down the training speed. Default: False.
|
||||||
|
pretrained (str, optional): model pretrained path. Default: None
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Parameters to build layers. 3 parameters are needed to construct a
|
# Parameters to build layers. 3 parameters are needed to construct a
|
||||||
@ -52,8 +55,30 @@ class MobileNetV2(nn.Module):
|
|||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
act_cfg=dict(type='ReLU6'),
|
act_cfg=dict(type='ReLU6'),
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
with_cp=False):
|
with_cp=False,
|
||||||
super(MobileNetV2, self).__init__()
|
pretrained=None,
|
||||||
|
init_cfg=None):
|
||||||
|
super(MobileNetV2, self).__init__(init_cfg)
|
||||||
|
|
||||||
|
self.pretrained = pretrained
|
||||||
|
assert not (init_cfg and pretrained), \
|
||||||
|
'init_cfg and pretrained cannot be setting at the same time'
|
||||||
|
if isinstance(pretrained, str):
|
||||||
|
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||||
|
'please use "init_cfg" instead')
|
||||||
|
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||||
|
elif pretrained is None:
|
||||||
|
if init_cfg is None:
|
||||||
|
self.init_cfg = [
|
||||||
|
dict(type='Kaiming', layer='Conv2d'),
|
||||||
|
dict(
|
||||||
|
type='Constant',
|
||||||
|
val=1,
|
||||||
|
layer=['_BatchNorm', 'GroupNorm'])
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
self.widen_factor = widen_factor
|
self.widen_factor = widen_factor
|
||||||
self.strides = strides
|
self.strides = strides
|
||||||
self.dilations = dilations
|
self.dilations = dilations
|
||||||
@ -133,19 +158,6 @@ class MobileNetV2(nn.Module):
|
|||||||
|
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
|
||||||
if isinstance(pretrained, str):
|
|
||||||
logger = logging.getLogger()
|
|
||||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
|
||||||
elif pretrained is None:
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
kaiming_init(m)
|
|
||||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
|
||||||
constant_init(m, 1)
|
|
||||||
else:
|
|
||||||
raise TypeError('pretrained must be a str or None')
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
import logging
|
import warnings
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import torch.nn as nn
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
|
||||||
from mmcv.cnn.bricks import Conv2dAdaptivePadding
|
from mmcv.cnn.bricks import Conv2dAdaptivePadding
|
||||||
from mmcv.runner import load_checkpoint
|
from mmcv.runner import BaseModule
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
@ -12,7 +11,7 @@ from ..utils import InvertedResidualV3 as InvertedResidual
|
|||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class MobileNetV3(nn.Module):
|
class MobileNetV3(BaseModule):
|
||||||
"""MobileNetV3 backbone.
|
"""MobileNetV3 backbone.
|
||||||
|
|
||||||
This backbone is the improved implementation of `Searching for MobileNetV3
|
This backbone is the improved implementation of `Searching for MobileNetV3
|
||||||
@ -35,6 +34,9 @@ class MobileNetV3(nn.Module):
|
|||||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||||
some memory while slowing down the training speed.
|
some memory while slowing down the training speed.
|
||||||
Default: False.
|
Default: False.
|
||||||
|
pretrained (str, optional): model pretrained path. Default: None
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None
|
||||||
"""
|
"""
|
||||||
# Parameters to build each block:
|
# Parameters to build each block:
|
||||||
# [kernel size, mid channels, out channels, with_se, act type, stride]
|
# [kernel size, mid channels, out channels, with_se, act type, stride]
|
||||||
@ -75,8 +77,30 @@ class MobileNetV3(nn.Module):
|
|||||||
frozen_stages=-1,
|
frozen_stages=-1,
|
||||||
reduction_factor=1,
|
reduction_factor=1,
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
with_cp=False):
|
with_cp=False,
|
||||||
super(MobileNetV3, self).__init__()
|
pretrained=None,
|
||||||
|
init_cfg=None):
|
||||||
|
super(MobileNetV3, self).__init__(init_cfg)
|
||||||
|
|
||||||
|
self.pretrained = pretrained
|
||||||
|
assert not (init_cfg and pretrained), \
|
||||||
|
'init_cfg and pretrained cannot be setting at the same time'
|
||||||
|
if isinstance(pretrained, str):
|
||||||
|
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||||
|
'please use "init_cfg" instead')
|
||||||
|
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||||
|
elif pretrained is None:
|
||||||
|
if init_cfg is None:
|
||||||
|
self.init_cfg = [
|
||||||
|
dict(type='Kaiming', layer='Conv2d'),
|
||||||
|
dict(
|
||||||
|
type='Constant',
|
||||||
|
val=1,
|
||||||
|
layer=['_BatchNorm', 'GroupNorm'])
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
assert arch in self.arch_settings
|
assert arch in self.arch_settings
|
||||||
assert isinstance(reduction_factor, int) and reduction_factor > 0
|
assert isinstance(reduction_factor, int) and reduction_factor > 0
|
||||||
assert mmcv.is_tuple_of(out_indices, int)
|
assert mmcv.is_tuple_of(out_indices, int)
|
||||||
@ -217,19 +241,6 @@ class MobileNetV3(nn.Module):
|
|||||||
|
|
||||||
return layers
|
return layers
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
|
||||||
if isinstance(pretrained, str):
|
|
||||||
logger = logging.getLogger()
|
|
||||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
|
||||||
elif pretrained is None:
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
kaiming_init(m)
|
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
|
||||||
constant_init(m, 1)
|
|
||||||
else:
|
|
||||||
raise TypeError('pretrained must be a str or None')
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
outs = []
|
outs = []
|
||||||
for i, layer_name in enumerate(self.layers):
|
for i, layer_name in enumerate(self.layers):
|
||||||
|
|||||||
@ -1,16 +1,16 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint as cp
|
import torch.utils.checkpoint as cp
|
||||||
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
|
from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
|
||||||
constant_init, kaiming_init)
|
from mmcv.runner import BaseModule
|
||||||
from mmcv.runner import load_checkpoint
|
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
from mmseg.utils import get_root_logger
|
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
from ..utils import ResLayer
|
from ..utils import ResLayer
|
||||||
|
|
||||||
|
|
||||||
class BasicBlock(nn.Module):
|
class BasicBlock(BaseModule):
|
||||||
"""Basic block for ResNet."""
|
"""Basic block for ResNet."""
|
||||||
|
|
||||||
expansion = 1
|
expansion = 1
|
||||||
@ -26,8 +26,9 @@ class BasicBlock(nn.Module):
|
|||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
dcn=None,
|
dcn=None,
|
||||||
plugins=None):
|
plugins=None,
|
||||||
super(BasicBlock, self).__init__()
|
init_cfg=None):
|
||||||
|
super(BasicBlock, self).__init__(init_cfg)
|
||||||
assert dcn is None, 'Not implemented yet.'
|
assert dcn is None, 'Not implemented yet.'
|
||||||
assert plugins is None, 'Not implemented yet.'
|
assert plugins is None, 'Not implemented yet.'
|
||||||
|
|
||||||
@ -94,7 +95,7 @@ class BasicBlock(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Bottleneck(nn.Module):
|
class Bottleneck(BaseModule):
|
||||||
"""Bottleneck block for ResNet.
|
"""Bottleneck block for ResNet.
|
||||||
|
|
||||||
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
|
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
|
||||||
@ -114,8 +115,9 @@ class Bottleneck(nn.Module):
|
|||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
dcn=None,
|
dcn=None,
|
||||||
plugins=None):
|
plugins=None,
|
||||||
super(Bottleneck, self).__init__()
|
init_cfg=None):
|
||||||
|
super(Bottleneck, self).__init__(init_cfg)
|
||||||
assert style in ['pytorch', 'caffe']
|
assert style in ['pytorch', 'caffe']
|
||||||
assert dcn is None or isinstance(dcn, dict)
|
assert dcn is None or isinstance(dcn, dict)
|
||||||
assert plugins is None or isinstance(plugins, list)
|
assert plugins is None or isinstance(plugins, list)
|
||||||
@ -305,7 +307,7 @@ class Bottleneck(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class ResNet(nn.Module):
|
class ResNet(BaseModule):
|
||||||
"""ResNet backbone.
|
"""ResNet backbone.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -346,6 +348,9 @@ class ResNet(nn.Module):
|
|||||||
memory while slowing down the training speed.
|
memory while slowing down the training speed.
|
||||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||||
in resblocks to let them behave as identity.
|
in resblocks to let them behave as identity.
|
||||||
|
pretrained (str, optional): model pretrained path. Default: None
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from mmseg.models import ResNet
|
>>> from mmseg.models import ResNet
|
||||||
@ -392,10 +397,46 @@ class ResNet(nn.Module):
|
|||||||
multi_grid=None,
|
multi_grid=None,
|
||||||
contract_dilation=False,
|
contract_dilation=False,
|
||||||
with_cp=False,
|
with_cp=False,
|
||||||
zero_init_residual=True):
|
zero_init_residual=True,
|
||||||
|
pretrained=None,
|
||||||
|
init_cfg=None):
|
||||||
super(ResNet, self).__init__()
|
super(ResNet, self).__init__()
|
||||||
if depth not in self.arch_settings:
|
if depth not in self.arch_settings:
|
||||||
raise KeyError(f'invalid depth {depth} for resnet')
|
raise KeyError(f'invalid depth {depth} for resnet')
|
||||||
|
|
||||||
|
self.pretrained = pretrained
|
||||||
|
self.zero_init_residual = zero_init_residual
|
||||||
|
block_init_cfg = None
|
||||||
|
assert not (init_cfg and pretrained), \
|
||||||
|
'init_cfg and pretrained cannot be setting at the same time'
|
||||||
|
if isinstance(pretrained, str):
|
||||||
|
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||||
|
'please use "init_cfg" instead')
|
||||||
|
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||||
|
elif pretrained is None:
|
||||||
|
if init_cfg is None:
|
||||||
|
self.init_cfg = [
|
||||||
|
dict(type='Kaiming', layer='Conv2d'),
|
||||||
|
dict(
|
||||||
|
type='Constant',
|
||||||
|
val=1,
|
||||||
|
layer=['_BatchNorm', 'GroupNorm'])
|
||||||
|
]
|
||||||
|
block = self.arch_settings[depth][0]
|
||||||
|
if self.zero_init_residual:
|
||||||
|
if block is BasicBlock:
|
||||||
|
block_init_cfg = dict(
|
||||||
|
type='Constant',
|
||||||
|
val=0,
|
||||||
|
override=dict(name='norm2'))
|
||||||
|
elif block is Bottleneck:
|
||||||
|
block_init_cfg = dict(
|
||||||
|
type='Constant',
|
||||||
|
val=0,
|
||||||
|
override=dict(name='norm3'))
|
||||||
|
else:
|
||||||
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
self.depth = depth
|
self.depth = depth
|
||||||
self.stem_channels = stem_channels
|
self.stem_channels = stem_channels
|
||||||
self.base_channels = base_channels
|
self.base_channels = base_channels
|
||||||
@ -421,7 +462,6 @@ class ResNet(nn.Module):
|
|||||||
self.plugins = plugins
|
self.plugins = plugins
|
||||||
self.multi_grid = multi_grid
|
self.multi_grid = multi_grid
|
||||||
self.contract_dilation = contract_dilation
|
self.contract_dilation = contract_dilation
|
||||||
self.zero_init_residual = zero_init_residual
|
|
||||||
self.block, stage_blocks = self.arch_settings[depth]
|
self.block, stage_blocks = self.arch_settings[depth]
|
||||||
self.stage_blocks = stage_blocks[:num_stages]
|
self.stage_blocks = stage_blocks[:num_stages]
|
||||||
self.inplanes = stem_channels
|
self.inplanes = stem_channels
|
||||||
@ -456,7 +496,8 @@ class ResNet(nn.Module):
|
|||||||
dcn=dcn,
|
dcn=dcn,
|
||||||
plugins=stage_plugins,
|
plugins=stage_plugins,
|
||||||
multi_grid=stage_multi_grid,
|
multi_grid=stage_multi_grid,
|
||||||
contract_dilation=contract_dilation)
|
contract_dilation=contract_dilation,
|
||||||
|
init_cfg=block_init_cfg)
|
||||||
self.inplanes = planes * self.block.expansion
|
self.inplanes = planes * self.block.expansion
|
||||||
layer_name = f'layer{i+1}'
|
layer_name = f'layer{i+1}'
|
||||||
self.add_module(layer_name, res_layer)
|
self.add_module(layer_name, res_layer)
|
||||||
@ -597,38 +638,6 @@ class ResNet(nn.Module):
|
|||||||
for param in m.parameters():
|
for param in m.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
|
||||||
"""Initialize the weights in backbone.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str, optional): Path to pre-trained weights.
|
|
||||||
Defaults to None.
|
|
||||||
"""
|
|
||||||
if isinstance(pretrained, str):
|
|
||||||
logger = get_root_logger()
|
|
||||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
|
||||||
elif pretrained is None:
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
kaiming_init(m)
|
|
||||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
|
||||||
constant_init(m, 1)
|
|
||||||
|
|
||||||
if self.dcn is not None:
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, Bottleneck) and hasattr(
|
|
||||||
m, 'conv2_offset'):
|
|
||||||
constant_init(m.conv2_offset, 0)
|
|
||||||
|
|
||||||
if self.zero_init_residual:
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, Bottleneck):
|
|
||||||
constant_init(m.norm3, 0)
|
|
||||||
elif isinstance(m, BasicBlock):
|
|
||||||
constant_init(m.norm2, 0)
|
|
||||||
else:
|
|
||||||
raise TypeError('pretrained must be a str or None')
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Forward function."""
|
"""Forward function."""
|
||||||
if self.deep_stem:
|
if self.deep_stem:
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint as cp
|
import torch.utils.checkpoint as cp
|
||||||
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
||||||
build_norm_layer, constant_init, kaiming_init)
|
build_norm_layer)
|
||||||
from mmcv.runner import load_checkpoint
|
from mmcv.runner import BaseModule
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
from mmseg.utils import get_root_logger
|
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
from ..utils import UpConvBlock
|
from ..utils import UpConvBlock
|
||||||
|
|
||||||
@ -219,7 +220,7 @@ class InterpConv(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class UNet(nn.Module):
|
class UNet(BaseModule):
|
||||||
"""UNet backbone.
|
"""UNet backbone.
|
||||||
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
||||||
https://arxiv.org/pdf/1505.04597.pdf
|
https://arxiv.org/pdf/1505.04597.pdf
|
||||||
@ -266,6 +267,9 @@ class UNet(nn.Module):
|
|||||||
dcn (bool): Use deformable convolution in convolutional layer or not.
|
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||||
Default: None.
|
Default: None.
|
||||||
plugins (dict): plugins for convolutional layers. Default: None.
|
plugins (dict): plugins for convolutional layers. Default: None.
|
||||||
|
pretrained (str, optional): model pretrained path. Default: None
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None
|
||||||
|
|
||||||
Notice:
|
Notice:
|
||||||
The input image size should be divisible by the whole downsample rate
|
The input image size should be divisible by the whole downsample rate
|
||||||
@ -291,8 +295,30 @@ class UNet(nn.Module):
|
|||||||
upsample_cfg=dict(type='InterpConv'),
|
upsample_cfg=dict(type='InterpConv'),
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
dcn=None,
|
dcn=None,
|
||||||
plugins=None):
|
plugins=None,
|
||||||
super(UNet, self).__init__()
|
pretrained=None,
|
||||||
|
init_cfg=None):
|
||||||
|
super(UNet, self).__init__(init_cfg)
|
||||||
|
|
||||||
|
self.pretrained = pretrained
|
||||||
|
assert not (init_cfg and pretrained), \
|
||||||
|
'init_cfg and pretrained cannot be setting at the same time'
|
||||||
|
if isinstance(pretrained, str):
|
||||||
|
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||||
|
'please use "init_cfg" instead')
|
||||||
|
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||||
|
elif pretrained is None:
|
||||||
|
if init_cfg is None:
|
||||||
|
self.init_cfg = [
|
||||||
|
dict(type='Kaiming', layer='Conv2d'),
|
||||||
|
dict(
|
||||||
|
type='Constant',
|
||||||
|
val=1,
|
||||||
|
layer=['_BatchNorm', 'GroupNorm'])
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
assert dcn is None, 'Not implemented yet.'
|
assert dcn is None, 'Not implemented yet.'
|
||||||
assert plugins is None, 'Not implemented yet.'
|
assert plugins is None, 'Not implemented yet.'
|
||||||
assert len(strides) == num_stages, \
|
assert len(strides) == num_stages, \
|
||||||
@ -408,22 +434,3 @@ class UNet(nn.Module):
|
|||||||
f'downsample rate {whole_downsample_rate}, when num_stages is '\
|
f'downsample rate {whole_downsample_rate}, when num_stages is '\
|
||||||
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
|
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
|
||||||
f'is {self.downsamples}.'
|
f'is {self.downsamples}.'
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
|
||||||
"""Initialize the weights in backbone.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str, optional): Path to pre-trained weights.
|
|
||||||
Defaults to None.
|
|
||||||
"""
|
|
||||||
if isinstance(pretrained, str):
|
|
||||||
logger = get_root_logger()
|
|
||||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
|
||||||
elif pretrained is None:
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
kaiming_init(m)
|
|
||||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
|
||||||
constant_init(m, 1)
|
|
||||||
else:
|
|
||||||
raise TypeError('pretrained must be a str or None')
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import torch.nn.functional as F
|
|||||||
import torch.utils.checkpoint as cp
|
import torch.utils.checkpoint as cp
|
||||||
from mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
|
from mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
|
||||||
constant_init, kaiming_init, normal_init)
|
constant_init, kaiming_init, normal_init)
|
||||||
from mmcv.runner import _load_checkpoint
|
from mmcv.runner import BaseModule, _load_checkpoint
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
from mmseg.utils import get_root_logger
|
from mmseg.utils import get_root_logger
|
||||||
@ -203,7 +203,7 @@ class PatchEmbed(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class VisionTransformer(nn.Module):
|
class VisionTransformer(BaseModule):
|
||||||
"""Vision transformer backbone.
|
"""Vision transformer backbone.
|
||||||
|
|
||||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
|
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
|
||||||
@ -243,6 +243,9 @@ class VisionTransformer(nn.Module):
|
|||||||
with_cp (bool): Use checkpoint or not. Using checkpoint
|
with_cp (bool): Use checkpoint or not. Using checkpoint
|
||||||
will save some memory while slowing down the training speed.
|
will save some memory while slowing down the training speed.
|
||||||
Default: False.
|
Default: False.
|
||||||
|
pretrained (str, optional): model pretrained path. Default: None
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -266,8 +269,12 @@ class VisionTransformer(nn.Module):
|
|||||||
out_shape='NCHW',
|
out_shape='NCHW',
|
||||||
with_cls_token=True,
|
with_cls_token=True,
|
||||||
interpolate_mode='bicubic',
|
interpolate_mode='bicubic',
|
||||||
with_cp=False):
|
with_cp=False,
|
||||||
super(VisionTransformer, self).__init__()
|
pretrained=None,
|
||||||
|
init_cfg=None):
|
||||||
|
super(VisionTransformer, self).__init__(init_cfg)
|
||||||
|
self.pretrained = pretrained
|
||||||
|
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.features = self.embed_dim = embed_dim
|
self.features = self.embed_dim = embed_dim
|
||||||
@ -319,7 +326,8 @@ class VisionTransformer(nn.Module):
|
|||||||
self.norm_eval = norm_eval
|
self.norm_eval = norm_eval
|
||||||
self.with_cp = with_cp
|
self.with_cp = with_cp
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
def init_weights(self):
|
||||||
|
pretrained = self.pretrained
|
||||||
if isinstance(pretrained, str):
|
if isinstance(pretrained, str):
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
checkpoint = _load_checkpoint(pretrained, logger=logger)
|
checkpoint = _load_checkpoint(pretrained, logger=logger)
|
||||||
|
|||||||
@ -2,8 +2,7 @@ from abc import ABCMeta, abstractmethod
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import normal_init
|
from mmcv.runner import BaseModule, auto_fp16, force_fp32
|
||||||
from mmcv.runner import auto_fp16, force_fp32
|
|
||||||
|
|
||||||
from mmseg.core import build_pixel_sampler
|
from mmseg.core import build_pixel_sampler
|
||||||
from mmseg.ops import resize
|
from mmseg.ops import resize
|
||||||
@ -11,7 +10,7 @@ from ..builder import build_loss
|
|||||||
from ..losses import accuracy
|
from ..losses import accuracy
|
||||||
|
|
||||||
|
|
||||||
class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||||
"""Base class for BaseDecodeHead.
|
"""Base class for BaseDecodeHead.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -41,6 +40,7 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
|||||||
Default: None.
|
Default: None.
|
||||||
align_corners (bool): align_corners argument of F.interpolate.
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
Default: False.
|
Default: False.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -60,8 +60,10 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
|||||||
loss_weight=1.0),
|
loss_weight=1.0),
|
||||||
ignore_index=255,
|
ignore_index=255,
|
||||||
sampler=None,
|
sampler=None,
|
||||||
align_corners=False):
|
align_corners=False,
|
||||||
super(BaseDecodeHead, self).__init__()
|
init_cfg=dict(
|
||||||
|
type='Normal', std=0.01, override=dict(name='conv_seg'))):
|
||||||
|
super(BaseDecodeHead, self).__init__(init_cfg)
|
||||||
self._init_inputs(in_channels, in_index, input_transform)
|
self._init_inputs(in_channels, in_index, input_transform)
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
@ -130,10 +132,6 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
|||||||
assert isinstance(in_index, int)
|
assert isinstance(in_index, int)
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
"""Initialize weights of classification layer."""
|
|
||||||
normal_init(self.conv_seg, mean=0, std=0.01)
|
|
||||||
|
|
||||||
def _transform_inputs(self, inputs):
|
def _transform_inputs(self, inputs):
|
||||||
"""Transform inputs for decoder.
|
"""Transform inputs for decoder.
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import ConvModule, normal_init
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.ops import point_sample
|
from mmcv.ops import point_sample
|
||||||
|
|
||||||
from mmseg.models.builder import HEADS
|
from mmseg.models.builder import HEADS
|
||||||
@ -69,6 +69,8 @@ class PointHead(BaseCascadeDecodeHead):
|
|||||||
conv_cfg=conv_cfg,
|
conv_cfg=conv_cfg,
|
||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
act_cfg=act_cfg,
|
act_cfg=act_cfg,
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Normal', std=0.01, override=dict(name='fc_seg')),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
self.num_fcs = num_fcs
|
self.num_fcs = num_fcs
|
||||||
@ -101,10 +103,6 @@ class PointHead(BaseCascadeDecodeHead):
|
|||||||
self.dropout = nn.Dropout(self.dropout_ratio)
|
self.dropout = nn.Dropout(self.dropout_ratio)
|
||||||
delattr(self, 'conv_seg')
|
delattr(self, 'conv_seg')
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
"""Initialize weights of classification layer."""
|
|
||||||
normal_init(self.fc_seg, std=0.001)
|
|
||||||
|
|
||||||
def cls_seg(self, feat):
|
def cls_seg(self, feat):
|
||||||
"""Classify each pixel with fc."""
|
"""Classify each pixel with fc."""
|
||||||
if self.dropout is not None:
|
if self.dropout is not None:
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mmcv.cnn import ConvModule, xavier_init
|
from mmcv.cnn import ConvModule
|
||||||
|
from mmcv.runner import BaseModule, auto_fp16
|
||||||
|
|
||||||
from ..builder import NECKS
|
from ..builder import NECKS
|
||||||
|
|
||||||
|
|
||||||
@NECKS.register_module()
|
@NECKS.register_module()
|
||||||
class FPN(nn.Module):
|
class FPN(BaseModule):
|
||||||
"""Feature Pyramid Network.
|
"""Feature Pyramid Network.
|
||||||
|
|
||||||
This is an implementation of - Feature Pyramid Networks for Object
|
This is an implementation of - Feature Pyramid Networks for Object
|
||||||
@ -43,6 +44,7 @@ class FPN(nn.Module):
|
|||||||
Default: None.
|
Default: None.
|
||||||
upsample_cfg (dict): Config dict for interpolate layer.
|
upsample_cfg (dict): Config dict for interpolate layer.
|
||||||
Default: `dict(mode='nearest')`
|
Default: `dict(mode='nearest')`
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> import torch
|
>>> import torch
|
||||||
@ -73,8 +75,10 @@ class FPN(nn.Module):
|
|||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=None,
|
norm_cfg=None,
|
||||||
act_cfg=None,
|
act_cfg=None,
|
||||||
upsample_cfg=dict(mode='nearest')):
|
upsample_cfg=dict(mode='nearest'),
|
||||||
super(FPN, self).__init__()
|
init_cfg=dict(
|
||||||
|
type='Xavier', layer='Conv2d', distribution='uniform')):
|
||||||
|
super(FPN, self).__init__(init_cfg)
|
||||||
assert isinstance(in_channels, list)
|
assert isinstance(in_channels, list)
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
@ -153,12 +157,7 @@ class FPN(nn.Module):
|
|||||||
inplace=False)
|
inplace=False)
|
||||||
self.fpn_convs.append(extra_fpn_conv)
|
self.fpn_convs.append(extra_fpn_conv)
|
||||||
|
|
||||||
# default init_weights for conv(msra) and norm in ConvModule
|
@auto_fp16()
|
||||||
def init_weights(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
xavier_init(m, distribution='uniform')
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
assert len(inputs) == len(self.in_channels)
|
assert len(inputs) == len(self.in_channels)
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
@ -7,17 +6,14 @@ import mmcv
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
from mmcv.runner import BaseModule, auto_fp16
|
||||||
from mmcv.runner import auto_fp16
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentor(nn.Module):
|
class BaseSegmentor(BaseModule, metaclass=ABCMeta):
|
||||||
"""Base class for segmentors."""
|
"""Base class for segmentors."""
|
||||||
|
|
||||||
__metaclass__ = ABCMeta
|
def __init__(self, init_cfg=None):
|
||||||
|
super(BaseSegmentor, self).__init__(init_cfg)
|
||||||
def __init__(self):
|
|
||||||
super(BaseSegmentor, self).__init__()
|
|
||||||
self.fp16_enabled = False
|
self.fp16_enabled = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -62,17 +58,6 @@ class BaseSegmentor(nn.Module):
|
|||||||
"""Placeholder for augmentation test."""
|
"""Placeholder for augmentation test."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
|
||||||
"""Initialize the weights in segmentor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str, optional): Path to pre-trained weights.
|
|
||||||
Defaults to None.
|
|
||||||
"""
|
|
||||||
if pretrained is not None:
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.info(f'load model from: {pretrained}')
|
|
||||||
|
|
||||||
def forward_test(self, imgs, img_metas, **kwargs):
|
def forward_test(self, imgs, img_metas, **kwargs):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -24,7 +24,8 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
|||||||
auxiliary_head=None,
|
auxiliary_head=None,
|
||||||
train_cfg=None,
|
train_cfg=None,
|
||||||
test_cfg=None,
|
test_cfg=None,
|
||||||
pretrained=None):
|
pretrained=None,
|
||||||
|
init_cfg=None):
|
||||||
self.num_stages = num_stages
|
self.num_stages = num_stages
|
||||||
super(CascadeEncoderDecoder, self).__init__(
|
super(CascadeEncoderDecoder, self).__init__(
|
||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
@ -33,7 +34,8 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
|||||||
auxiliary_head=auxiliary_head,
|
auxiliary_head=auxiliary_head,
|
||||||
train_cfg=train_cfg,
|
train_cfg=train_cfg,
|
||||||
test_cfg=test_cfg,
|
test_cfg=test_cfg,
|
||||||
pretrained=pretrained)
|
pretrained=pretrained,
|
||||||
|
init_cfg=init_cfg)
|
||||||
|
|
||||||
def _init_decode_head(self, decode_head):
|
def _init_decode_head(self, decode_head):
|
||||||
"""Initialize ``decode_head``"""
|
"""Initialize ``decode_head``"""
|
||||||
@ -45,23 +47,6 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
|||||||
self.align_corners = self.decode_head[-1].align_corners
|
self.align_corners = self.decode_head[-1].align_corners
|
||||||
self.num_classes = self.decode_head[-1].num_classes
|
self.num_classes = self.decode_head[-1].num_classes
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
|
||||||
"""Initialize the weights in backbone and heads.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str, optional): Path to pre-trained weights.
|
|
||||||
Defaults to None.
|
|
||||||
"""
|
|
||||||
self.backbone.init_weights(pretrained=pretrained)
|
|
||||||
for i in range(self.num_stages):
|
|
||||||
self.decode_head[i].init_weights()
|
|
||||||
if self.with_auxiliary_head:
|
|
||||||
if isinstance(self.auxiliary_head, nn.ModuleList):
|
|
||||||
for aux_head in self.auxiliary_head:
|
|
||||||
aux_head.init_weights()
|
|
||||||
else:
|
|
||||||
self.auxiliary_head.init_weights()
|
|
||||||
|
|
||||||
def encode_decode(self, img, img_metas):
|
def encode_decode(self, img, img_metas):
|
||||||
"""Encode images with backbone and decode into a semantic segmentation
|
"""Encode images with backbone and decode into a semantic segmentation
|
||||||
map of the same size as input."""
|
map of the same size as input."""
|
||||||
|
|||||||
@ -25,8 +25,13 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
auxiliary_head=None,
|
auxiliary_head=None,
|
||||||
train_cfg=None,
|
train_cfg=None,
|
||||||
test_cfg=None,
|
test_cfg=None,
|
||||||
pretrained=None):
|
pretrained=None,
|
||||||
super(EncoderDecoder, self).__init__()
|
init_cfg=None):
|
||||||
|
super(EncoderDecoder, self).__init__(init_cfg)
|
||||||
|
if pretrained is not None:
|
||||||
|
assert backbone.get('pretrained') is None, \
|
||||||
|
'both backbone and segmentor set pretrained weight'
|
||||||
|
backbone.pretrained = pretrained
|
||||||
self.backbone = builder.build_backbone(backbone)
|
self.backbone = builder.build_backbone(backbone)
|
||||||
if neck is not None:
|
if neck is not None:
|
||||||
self.neck = builder.build_neck(neck)
|
self.neck = builder.build_neck(neck)
|
||||||
@ -36,8 +41,6 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
self.train_cfg = train_cfg
|
self.train_cfg = train_cfg
|
||||||
self.test_cfg = test_cfg
|
self.test_cfg = test_cfg
|
||||||
|
|
||||||
self.init_weights(pretrained=pretrained)
|
|
||||||
|
|
||||||
assert self.with_decode_head
|
assert self.with_decode_head
|
||||||
|
|
||||||
def _init_decode_head(self, decode_head):
|
def _init_decode_head(self, decode_head):
|
||||||
@ -56,24 +59,6 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
else:
|
else:
|
||||||
self.auxiliary_head = builder.build_head(auxiliary_head)
|
self.auxiliary_head = builder.build_head(auxiliary_head)
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
|
||||||
"""Initialize the weights in backbone and heads.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str, optional): Path to pre-trained weights.
|
|
||||||
Defaults to None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
super(EncoderDecoder, self).init_weights(pretrained)
|
|
||||||
self.backbone.init_weights(pretrained=pretrained)
|
|
||||||
self.decode_head.init_weights()
|
|
||||||
if self.with_auxiliary_head:
|
|
||||||
if isinstance(self.auxiliary_head, nn.ModuleList):
|
|
||||||
for aux_head in self.auxiliary_head:
|
|
||||||
aux_head.init_weights()
|
|
||||||
else:
|
|
||||||
self.auxiliary_head.init_weights()
|
|
||||||
|
|
||||||
def extract_feat(self, img):
|
def extract_feat(self, img):
|
||||||
"""Extract features from images."""
|
"""Extract features from images."""
|
||||||
x = self.backbone(img)
|
x = self.backbone(img)
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||||
|
from mmcv.runner import Sequential
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
|
|
||||||
class ResLayer(nn.Sequential):
|
class ResLayer(Sequential):
|
||||||
"""ResLayer to build ResNet style backbone.
|
"""ResLayer to build ResNet style backbone.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -300,8 +300,8 @@ def test_resnet_backbone():
|
|||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
# pretrained must be a string path
|
# pretrained must be a string path
|
||||||
model = ResNet(50)
|
model = ResNet(50, pretrained=0)
|
||||||
model.init_weights(pretrained=0)
|
model.init_weights()
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# Style must be in ['pytorch', 'caffe']
|
# Style must be in ['pytorch', 'caffe']
|
||||||
@ -314,8 +314,9 @@ def test_resnet_backbone():
|
|||||||
assert check_norm_state(model.modules(), False)
|
assert check_norm_state(model.modules(), False)
|
||||||
|
|
||||||
# Test ResNet50 with torchvision pretrained weight
|
# Test ResNet50 with torchvision pretrained weight
|
||||||
model = ResNet(depth=50, norm_eval=True)
|
model = ResNet(
|
||||||
model.init_weights('torchvision://resnet50')
|
depth=50, norm_eval=True, pretrained='torchvision://resnet50')
|
||||||
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
assert check_norm_state(model.modules(), False)
|
assert check_norm_state(model.modules(), False)
|
||||||
|
|
||||||
|
|||||||
@ -734,7 +734,6 @@ def test_unet():
|
|||||||
downsamples=(True, True, True, True),
|
downsamples=(True, True, True, True),
|
||||||
enc_dilations=(1, 1, 1, 1, 1),
|
enc_dilations=(1, 1, 1, 1, 1),
|
||||||
dec_dilations=(1, 1, 1, 1))
|
dec_dilations=(1, 1, 1, 1))
|
||||||
print(unet)
|
|
||||||
x = torch.randn(2, 3, 128, 128)
|
x = torch.randn(2, 3, 128, 128)
|
||||||
x_outs = unet(x)
|
x_outs = unet(x)
|
||||||
assert x_outs[0].shape == torch.Size([2, 1024, 8, 8])
|
assert x_outs[0].shape == torch.Size([2, 1024, 8, 8])
|
||||||
@ -754,7 +753,6 @@ def test_unet():
|
|||||||
downsamples=(True, True, True, False),
|
downsamples=(True, True, True, False),
|
||||||
enc_dilations=(1, 1, 1, 1, 1),
|
enc_dilations=(1, 1, 1, 1, 1),
|
||||||
dec_dilations=(1, 1, 1, 1))
|
dec_dilations=(1, 1, 1, 1))
|
||||||
print(unet)
|
|
||||||
x = torch.randn(2, 3, 128, 128)
|
x = torch.randn(2, 3, 128, 128)
|
||||||
x_outs = unet(x)
|
x_outs = unet(x)
|
||||||
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
|
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
|
||||||
@ -774,7 +772,6 @@ def test_unet():
|
|||||||
downsamples=(True, True, True, False),
|
downsamples=(True, True, True, False),
|
||||||
enc_dilations=(1, 1, 1, 1, 1),
|
enc_dilations=(1, 1, 1, 1, 1),
|
||||||
dec_dilations=(1, 1, 1, 1))
|
dec_dilations=(1, 1, 1, 1))
|
||||||
print(unet)
|
|
||||||
x = torch.randn(2, 3, 128, 128)
|
x = torch.randn(2, 3, 128, 128)
|
||||||
x_outs = unet(x)
|
x_outs = unet(x)
|
||||||
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
|
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
|
||||||
@ -794,7 +791,6 @@ def test_unet():
|
|||||||
downsamples=(True, True, False, False),
|
downsamples=(True, True, False, False),
|
||||||
enc_dilations=(1, 1, 1, 1, 1),
|
enc_dilations=(1, 1, 1, 1, 1),
|
||||||
dec_dilations=(1, 1, 1, 1))
|
dec_dilations=(1, 1, 1, 1))
|
||||||
print(unet)
|
|
||||||
x = torch.randn(2, 3, 128, 128)
|
x = torch.randn(2, 3, 128, 128)
|
||||||
x_outs = unet(x)
|
x_outs = unet(x)
|
||||||
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
|
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
|
||||||
@ -813,9 +809,9 @@ def test_unet():
|
|||||||
dec_num_convs=(2, 2, 2, 2),
|
dec_num_convs=(2, 2, 2, 2),
|
||||||
downsamples=(True, True, False, False),
|
downsamples=(True, True, False, False),
|
||||||
enc_dilations=(1, 1, 1, 1, 1),
|
enc_dilations=(1, 1, 1, 1, 1),
|
||||||
dec_dilations=(1, 1, 1, 1))
|
dec_dilations=(1, 1, 1, 1),
|
||||||
unet.init_weights(pretrained=None)
|
pretrained=None)
|
||||||
print(unet)
|
unet.init_weights()
|
||||||
x = torch.randn(2, 3, 128, 128)
|
x = torch.randn(2, 3, 128, 128)
|
||||||
x_outs = unet(x)
|
x_outs = unet(x)
|
||||||
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
|
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
|
||||||
|
|||||||
@ -215,6 +215,7 @@ def _test_encoder_decoder_forward(cfg_file):
|
|||||||
|
|
||||||
from mmseg.models import build_segmentor
|
from mmseg.models import build_segmentor
|
||||||
segmentor = build_segmentor(model)
|
segmentor = build_segmentor(model)
|
||||||
|
segmentor.init_weights()
|
||||||
|
|
||||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
if isinstance(segmentor.decode_head, nn.ModuleList):
|
||||||
num_classes = segmentor.decode_head[-1].num_classes
|
num_classes = segmentor.decode_head[-1].num_classes
|
||||||
|
|||||||
@ -131,6 +131,7 @@ def main():
|
|||||||
cfg.model,
|
cfg.model,
|
||||||
train_cfg=cfg.get('train_cfg'),
|
train_cfg=cfg.get('train_cfg'),
|
||||||
test_cfg=cfg.get('test_cfg'))
|
test_cfg=cfg.get('test_cfg'))
|
||||||
|
model.init_weights()
|
||||||
|
|
||||||
logger.info(model)
|
logger.info(model)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user