[Improvement] Refactor Swin-Transformer (#800)

* [Improvement] Refactor Swin-Transformer

* fixed swin test

* update patch emebd, add more tests

* fixed test

* remove pretrain_style

* fixed padding

* resolve coments

* use mmcv 2tuple

* refactor init_cfg

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
This commit is contained in:
Jerry Jiarui XU 2021-09-28 17:46:33 -07:00 committed by GitHub
parent ab12009414
commit 85227b46c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 936 additions and 245 deletions

View File

@ -23,8 +23,7 @@ model = dict(
drop_path_rate=0.3, drop_path_rate=0.3,
use_abs_pos_embed=False, use_abs_pos_embed=False,
act_cfg=dict(type='GELU'), act_cfg=dict(type='GELU'),
norm_cfg=backbone_norm_cfg, norm_cfg=backbone_norm_cfg),
pretrain_style='official'),
decode_head=dict( decode_head=dict(
type='UPerHead', type='UPerHead',
in_channels=[96, 192, 384, 768], in_channels=[96, 192, 384, 768],

View File

@ -11,8 +11,7 @@ model = dict(
window_size=7, window_size=7,
use_abs_pos_embed=False, use_abs_pos_embed=False,
drop_path_rate=0.3, drop_path_rate=0.3,
patch_norm=True, patch_norm=True),
pretrain_style='official'),
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150), decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
auxiliary_head=dict(in_channels=384, num_classes=150)) auxiliary_head=dict(in_channels=384, num_classes=150))

View File

@ -278,8 +278,6 @@ class MixVisionTransformer(BaseModule):
Default: dict(type='LN') Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs. act_cfg (dict): The activation config for FFNs.
Defalut: dict(type='GELU'). Defalut: dict(type='GELU').
pretrain_style (str): Choose to use official or mmcls pretrain weights.
Default: official.
pretrained (str, optional): model pretrained path. Default: None. pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict. init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None. Default: None.
@ -302,15 +300,10 @@ class MixVisionTransformer(BaseModule):
drop_path_rate=0., drop_path_rate=0.,
act_cfg=dict(type='GELU'), act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN', eps=1e-6), norm_cfg=dict(type='LN', eps=1e-6),
pretrain_style='official',
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super().__init__() super().__init__()
assert pretrain_style in [
'official', 'mmcls'
], 'we only support official weights or mmcls weights.'
if isinstance(pretrained, str) or pretrained is None: if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, ' warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead') 'please use "init_cfg" instead')
@ -330,7 +323,6 @@ class MixVisionTransformer(BaseModule):
self.out_indices = out_indices self.out_indices = out_indices
assert max(out_indices) < self.num_stages assert max(out_indices) < self.num_stages
self.pretrain_style = pretrain_style
self.pretrained = pretrained self.pretrained = pretrained
self.init_cfg = init_cfg self.init_cfg = init_cfg
@ -350,7 +342,6 @@ class MixVisionTransformer(BaseModule):
kernel_size=patch_sizes[i], kernel_size=patch_sizes[i],
stride=strides[i], stride=strides[i],
padding=patch_sizes[i] // 2, padding=patch_sizes[i] // 2,
pad_to_patch_size=False,
norm_cfg=norm_cfg) norm_cfg=norm_cfg)
layer = ModuleList([ layer = ModuleList([
TransformerEncoderLayer( TransformerEncoderLayer(
@ -403,8 +394,7 @@ class MixVisionTransformer(BaseModule):
outs = [] outs = []
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
x, H, W = layer[0](x), layer[0].DH, layer[0].DW x, hw_shape = layer[0](x)
hw_shape = (H, W)
for block in layer[1]: for block in layer[1]:
x = block(x, hw_shape) x = block(x, hw_shape)
x = layer[2](x) x = layer[2](x)

View File

@ -1,111 +1,37 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import build_norm_layer, trunc_normal_init import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
from mmcv.cnn.bricks.transformer import FFN, build_dropout from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmcv.cnn.utils.weight_init import constant_init from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from mmcv.runner import _load_checkpoint from mmcv.utils import to_2tuple
from mmcv.runner.base_module import BaseModule, ModuleList
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize
from ...utils import get_root_logger from ...utils import get_root_logger
from ..builder import ATTENTION, BACKBONES from ..builder import BACKBONES
from ..utils import PatchEmbed from ..utils.embed import PatchEmbed, PatchMerging
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer use nn.Unfold to group feature map by kernel_size, and use norm
and linear layer to embed grouped feature map.
Args:
in_channels (int): The num of input channels.
out_channels (int): The num of output channels.
stride (int | tuple): the stride of the sliding length in the
unfold layer. Defaults: 2. (Default to be equal with kernel_size).
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Defaults: None.
"""
def __init__(self,
in_channels,
out_channels,
stride=2,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.sampler = nn.Unfold(
kernel_size=stride, dilation=1, padding=0, stride=stride)
sample_dim = stride**2 * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, hw_shape):
"""
x: x.shape -> [B, H*W, C]
hw_shape: (H, W)
"""
B, L, C = x.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
# stride is fixed to be equal to kernel_size.
if (H % self.stride != 0) or (W % self.stride != 0):
x = F.pad(x, (0, W % self.stride, 0, H % self.stride))
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
x = self.sampler(x) # B, 4*C, H/2*W/2
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
down_hw_shape = (H + 1) // 2, (W + 1) // 2
return x, down_hw_shape
@ATTENTION.register_module()
class WindowMSA(BaseModule): class WindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative """Window based multi-head self-attention (W-MSA) module with relative
position bias. position bias.
Args: Args:
embed_dims (int): Number of input channels. embed_dims (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads. num_heads (int): Number of attention heads.
window_size (tuple[int]): The height and width of the window.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True. Default: True.
qk_scale (float | None, optional): Override default qk scale of qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None. head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight. attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0 Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.0 proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
init_cfg (dict | None, optional): The Config for initialization. init_cfg (dict | None, optional): The Config for initialization.
Default: None. Default: None.
""" """
@ -120,13 +46,12 @@ class WindowMSA(BaseModule):
proj_drop_rate=0., proj_drop_rate=0.,
init_cfg=None): init_cfg=None):
super().__init__() super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims self.embed_dims = embed_dims
self.window_size = window_size # Wh, Ww self.window_size = window_size # Wh, Ww
self.num_heads = num_heads self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5 self.scale = qk_scale or head_embed_dims**-0.5
self.init_cfg = init_cfg
# define a parameter table of relative position bias # define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter( self.relative_position_bias_table = nn.Parameter(
@ -161,8 +86,8 @@ class WindowMSA(BaseModule):
B, N, C = x.shape B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4) C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[ # make torchscript happy (cannot use tensor as tuple)
2] # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale q = q * self.scale
attn = (q @ k.transpose(-2, -1)) attn = (q @ k.transpose(-2, -1))
@ -182,8 +107,6 @@ class WindowMSA(BaseModule):
N) + mask.unsqueeze(1).unsqueeze(0) N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N) attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn) attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
@ -199,9 +122,8 @@ class WindowMSA(BaseModule):
return (seq1[:, None] + seq2[None, :]).reshape(1, -1) return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
@ATTENTION.register_module()
class ShiftWindowMSA(BaseModule): class ShiftWindowMSA(BaseModule):
"""Shift Window Multihead Self-Attention Module. """Shifted Window Multihead Self-Attention Module.
Args: Args:
embed_dims (int): Number of input channels. embed_dims (int): Number of input channels.
@ -234,7 +156,7 @@ class ShiftWindowMSA(BaseModule):
proj_drop_rate=0, proj_drop_rate=0,
dropout_layer=dict(type='DropPath', drop_prob=0.), dropout_layer=dict(type='DropPath', drop_prob=0.),
init_cfg=None): init_cfg=None):
super().__init__(init_cfg) super().__init__(init_cfg=init_cfg)
self.window_size = window_size self.window_size = window_size
self.shift_size = shift_size self.shift_size = shift_size
@ -272,8 +194,7 @@ class ShiftWindowMSA(BaseModule):
dims=(1, 2)) dims=(1, 2))
# calculate attention mask for SW-MSA # calculate attention mask for SW-MSA
img_mask = torch.zeros((1, H_pad, W_pad, 1), img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
device=query.device) # 1 H W 1
h_slices = (slice(0, -self.window_size), h_slices = (slice(0, -self.window_size),
slice(-self.window_size, slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None)) -self.shift_size), slice(-self.shift_size, None))
@ -333,7 +254,6 @@ class ShiftWindowMSA(BaseModule):
""" """
Args: Args:
windows: (num_windows*B, window_size, window_size, C) windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image H (int): Height of image
W (int): Width of image W (int): Width of image
Returns: Returns:
@ -350,7 +270,6 @@ class ShiftWindowMSA(BaseModule):
""" """
Args: Args:
x: (B, H, W, C) x: (B, H, W, C)
window_size (int): window size
Returns: Returns:
windows: (num_windows*B, window_size, window_size, C) windows: (num_windows*B, window_size, window_size, C)
""" """
@ -369,18 +288,21 @@ class SwinBlock(BaseModule):
embed_dims (int): The feature dimension. embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads. num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs. feedforward_channels (int): The hidden dimension for FFNs.
window size (int, optional): The local window scale. Default: 7. window_size (int, optional): The local window scale. Default: 7.
shift (bool): whether to shift window or not. Default False. shift (bool, optional): whether to shift window or not. Default False.
qkv_bias (int, optional): enable bias for qkv if True. Default: True. qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None. head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0. drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0. attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2. drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
act_cfg (dict, optional): The config dict of activation function. act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU'). Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization. norm_cfg (dict, optional): The config dict of normalization.
Default: dict(type='LN'). Default: dict(type='LN').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
init_cfg (dict | list | None, optional): The init config. init_cfg (dict | list | None, optional): The init config.
Default: None. Default: None.
""" """
@ -398,11 +320,12 @@ class SwinBlock(BaseModule):
drop_path_rate=0., drop_path_rate=0.,
act_cfg=dict(type='GELU'), act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'), norm_cfg=dict(type='LN'),
with_cp=False,
init_cfg=None): init_cfg=None):
super(SwinBlock, self).__init__() super(SwinBlock, self).__init__(init_cfg=init_cfg)
self.init_cfg = init_cfg self.with_cp = with_cp
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = ShiftWindowMSA( self.attn = ShiftWindowMSA(
@ -429,6 +352,8 @@ class SwinBlock(BaseModule):
init_cfg=None) init_cfg=None)
def forward(self, x, hw_shape): def forward(self, x, hw_shape):
def _inner_forward(x):
identity = x identity = x
x = self.norm1(x) x = self.norm1(x)
x = self.attn(x, hw_shape) x = self.attn(x, hw_shape)
@ -441,6 +366,13 @@ class SwinBlock(BaseModule):
return x return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class SwinBlockSequence(BaseModule): class SwinBlockSequence(BaseModule):
"""Implements one stage in Swin Transformer. """Implements one stage in Swin Transformer.
@ -450,19 +382,23 @@ class SwinBlockSequence(BaseModule):
num_heads (int): Parallel attention heads. num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs. feedforward_channels (int): The hidden dimension for FFNs.
depth (int): The number of blocks in this stage. depth (int): The number of blocks in this stage.
window size (int): The local window scale. Default: 7. window_size (int, optional): The local window scale. Default: 7.
qkv_bias (int): enable bias for qkv if True. Default: True. qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None. head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0. drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0. attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2. drop_path_rate (float | list[float], optional): Stochastic depth
rate. Default: 0.
downsample (BaseModule | None, optional): The downsample operation downsample (BaseModule | None, optional): The downsample operation
module. Default: None. module. Default: None.
act_cfg (dict, optional): The config dict of activation function. act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU'). Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization. norm_cfg (dict, optional): The config dict of normalization.
Default: dict(type='LN'). Default: dict(type='LN').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
init_cfg (dict | list | None, optional): The init config. init_cfg (dict | list | None, optional): The init config.
Default: None. Default: None.
""" """
@ -481,14 +417,15 @@ class SwinBlockSequence(BaseModule):
downsample=None, downsample=None,
act_cfg=dict(type='GELU'), act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'), norm_cfg=dict(type='LN'),
with_cp=False,
init_cfg=None): init_cfg=None):
super().__init__() super().__init__(init_cfg=init_cfg)
self.init_cfg = init_cfg if isinstance(drop_path_rate, list):
drop_path_rates = drop_path_rate
drop_path_rate = drop_path_rate if isinstance( assert len(drop_path_rates) == depth
drop_path_rate, else:
list) else [deepcopy(drop_path_rate) for _ in range(depth)] drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
self.blocks = ModuleList() self.blocks = ModuleList()
for i in range(depth): for i in range(depth):
@ -502,9 +439,10 @@ class SwinBlockSequence(BaseModule):
qk_scale=qk_scale, qk_scale=qk_scale,
drop_rate=drop_rate, drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate, attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate[i], drop_path_rate=drop_path_rates[i],
act_cfg=act_cfg, act_cfg=act_cfg,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
with_cp=with_cp,
init_cfg=None) init_cfg=None)
self.blocks.append(block) self.blocks.append(block)
@ -538,7 +476,7 @@ class SwinTransformer(BaseModule):
embed_dims (int): The feature dimension. Default: 96. embed_dims (int): The feature dimension. Default: 96.
patch_size (int | tuple[int]): Patch size. Default: 4. patch_size (int | tuple[int]): Patch size. Default: 4.
window_size (int): Window size. Default: 7. window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
Default: 4. Default: 4.
depths (tuple[int]): Depths of each Swin Transformer stage. depths (tuple[int]): Depths of each Swin Transformer stage.
Default: (2, 2, 6, 2). Default: (2, 2, 6, 2).
@ -564,7 +502,12 @@ class SwinTransformer(BaseModule):
Default: dict(type='LN'). Default: dict(type='LN').
norm_cfg (dict): Config dict for normalization layer at norm_cfg (dict): Config dict for normalization layer at
output of backone. Defaults: dict(type='LN'). output of backone. Defaults: dict(type='LN').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
pretrained (str, optional): model pretrained path. Default: None. pretrained (str, optional): model pretrained path. Default: None.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
init_cfg (dict, optional): The Config for initialization. init_cfg (dict, optional): The Config for initialization.
Defaults to None. Defaults to None.
""" """
@ -589,9 +532,11 @@ class SwinTransformer(BaseModule):
use_abs_pos_embed=False, use_abs_pos_embed=False,
act_cfg=dict(type='GELU'), act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'), norm_cfg=dict(type='LN'),
with_cp=False,
pretrained=None, pretrained=None,
frozen_stages=-1,
init_cfg=None): init_cfg=None):
super(SwinTransformer, self).__init__() self.frozen_stages = frozen_stages
if isinstance(pretrain_img_size, int): if isinstance(pretrain_img_size, int):
pretrain_img_size = to_2tuple(pretrain_img_size) pretrain_img_size = to_2tuple(pretrain_img_size)
@ -602,17 +547,22 @@ class SwinTransformer(BaseModule):
f'The size of image should have length 1 or 2, ' \ f'The size of image should have length 1 or 2, ' \
f'but got {len(pretrain_img_size)}' f'but got {len(pretrain_img_size)}'
if isinstance(pretrained, str) or pretrained is None: assert not (init_cfg and pretrained), \
warnings.warn('DeprecationWarning: pretrained is a deprecated, ' 'init_cfg and pretrained cannot be specified at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead') 'please use "init_cfg" instead')
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
init_cfg = init_cfg
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
super(SwinTransformer, self).__init__(init_cfg=init_cfg)
num_layers = len(depths) num_layers = len(depths)
self.out_indices = out_indices self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed self.use_abs_pos_embed = use_abs_pos_embed
self.pretrained = pretrained
self.init_cfg = init_cfg
assert strides[0] == patch_size, 'Use non-overlapping patch embed.' assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
@ -622,7 +572,7 @@ class SwinTransformer(BaseModule):
conv_type='Conv2d', conv_type='Conv2d',
kernel_size=patch_size, kernel_size=patch_size,
stride=strides[0], stride=strides[0],
pad_to_patch_size=True, padding='corner',
norm_cfg=norm_cfg if patch_norm else None, norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None) init_cfg=None)
@ -635,11 +585,11 @@ class SwinTransformer(BaseModule):
self.drop_after_pos = nn.Dropout(p=drop_rate) self.drop_after_pos = nn.Dropout(p=drop_rate)
# stochastic depth # set stochastic depth decay rule
total_depth = sum(depths) total_depth = sum(depths)
dpr = [ dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth) x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule ]
self.stages = ModuleList() self.stages = ModuleList()
in_channels = embed_dims in_channels = embed_dims
@ -664,14 +614,13 @@ class SwinTransformer(BaseModule):
qk_scale=qk_scale, qk_scale=qk_scale,
drop_rate=drop_rate, drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate, attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[:depths[i]], drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
downsample=downsample, downsample=downsample,
act_cfg=act_cfg, act_cfg=act_cfg,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
with_cp=with_cp,
init_cfg=None) init_cfg=None)
self.stages.append(stage) self.stages.append(stage)
dpr = dpr[depths[i]:]
if downsample: if downsample:
in_channels = downsample.out_channels in_channels = downsample.out_channels
@ -682,29 +631,67 @@ class SwinTransformer(BaseModule):
layer_name = f'norm{i}' layer_name = f'norm{i}'
self.add_module(layer_name, layer) self.add_module(layer_name, layer)
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(SwinTransformer, self).train(mode)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.use_abs_pos_embed:
self.absolute_pos_embed.requires_grad = False
self.drop_after_pos.eval()
for i in range(1, self.frozen_stages + 1):
if (i - 1) in self.out_indices:
norm_layer = getattr(self, f'norm{i-1}')
norm_layer.eval()
for param in norm_layer.parameters():
param.requires_grad = False
m = self.stages[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self): def init_weights(self):
if self.pretrained is None: logger = get_root_logger()
super().init_weights() if self.init_cfg is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
if self.use_abs_pos_embed: if self.use_abs_pos_embed:
trunc_normal_init(self.absolute_pos_embed, std=0.02) trunc_normal_init(self.absolute_pos_embed, std=0.02)
for m in self.modules(): for m in self.modules():
if isinstance(m, Linear): if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02) trunc_normal_init(m.weight, std=.02)
if m.bias is not None: if m.bias is not None:
constant_init(m.bias, 0) constant_init(m.bias, 0)
elif isinstance(m, LayerNorm): elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0) constant_init(m.bias, 0)
constant_init(m.weight, 1.0) constant_init(m.weight, 1.0)
elif isinstance(self.pretrained, str):
logger = get_root_logger()
ckpt = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
state_dict = ckpt['state_dict']
elif 'model' in ckpt:
state_dict = ckpt['model']
else: else:
state_dict = ckpt assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
ckpt = _load_checkpoint(
self.init_cfg.checkpoint, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
_state_dict = ckpt['model']
else:
_state_dict = ckpt
state_dict = OrderedDict()
for k, v in _state_dict.items():
if k.startswith('backbone.'):
state_dict[k[9:]] = v
# strip prefix of state_dict # strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'): if list(state_dict.keys())[0].startswith('module.'):
@ -733,13 +720,11 @@ class SwinTransformer(BaseModule):
L2, nH2 = table_current.size() L2, nH2 = table_current.size()
if nH1 != nH2: if nH1 != nH2:
logger.warning(f'Error in loading {table_key}, pass') logger.warning(f'Error in loading {table_key}, pass')
else: elif L1 != L2:
if L1 != L2:
S1 = int(L1**0.5) S1 = int(L1**0.5)
S2 = int(L2**0.5) S2 = int(L2**0.5)
table_pretrained_resized = resize( table_pretrained_resized = F.interpolate(
table_pretrained.permute(1, 0).reshape( table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
1, nH1, S1, S1),
size=(S2, S2), size=(S2, S2),
mode='bicubic') mode='bicubic')
state_dict[table_key] = table_pretrained_resized.view( state_dict[table_key] = table_pretrained_resized.view(
@ -749,9 +734,8 @@ class SwinTransformer(BaseModule):
self.load_state_dict(state_dict, False) self.load_state_dict(state_dict, False)
def forward(self, x): def forward(self, x):
x = self.patch_embed(x) x, hw_shape = self.patch_embed(x)
hw_shape = (self.patch_embed.DH, self.patch_embed.DW)
if self.use_abs_pos_embed: if self.use_abs_pos_embed:
x = x + self.absolute_pos_embed x = x + self.absolute_pos_embed
x = self.drop_after_pos(x) x = self.drop_after_pos(x)

View File

@ -205,7 +205,7 @@ class VisionTransformer(BaseModule):
conv_type='Conv2d', conv_type='Conv2d',
kernel_size=patch_size, kernel_size=patch_size,
stride=patch_size, stride=patch_size,
pad_to_patch_size=True, padding='corner',
norm_cfg=norm_cfg if patch_norm else None, norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None, init_cfg=None,
) )
@ -370,8 +370,8 @@ class VisionTransformer(BaseModule):
def forward(self, inputs): def forward(self, inputs):
B = inputs.shape[0] B = inputs.shape[0]
x, hw_shape = self.patch_embed(inputs), (self.patch_embed.DH, x, hw_shape = self.patch_embed(inputs)
self.patch_embed.DW)
# stole cls_tokens impl from Phil Wang, thanks # stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1) cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1) x = torch.cat((cls_tokens, x), dim=1)

View File

@ -1,28 +1,109 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Sequence
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner.base_module import BaseModule from mmcv.runner.base_module import BaseModule
from torch.nn.modules.utils import _pair as to_2tuple from mmcv.utils import to_2tuple
class AdaptivePadding(nn.Module):
"""Applies padding to input (if needed) so that input can get fully covered
by filter you specified. It support two modes "same" and "corner". The
"same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
input. The "corner" mode would pad zero to bottom right.
Args:
kernel_size (int | tuple): Size of the kernel:
stride (int | tuple): Stride of the filter. Default: 1:
dilation (int | tuple): Spacing between kernel elements.
Default: 1.
padding (str): Support "same" and "corner", "corner" mode
would pad zero to bottom right, and "same" mode would
pad zero around input. Default: "corner".
Example:
>>> kernel_size = 16
>>> stride = 16
>>> dilation = 1
>>> input = torch.rand(1, 1, 15, 17)
>>> adap_pad = AdaptivePadding(
>>> kernel_size=kernel_size,
>>> stride=stride,
>>> dilation=dilation,
>>> padding="corner")
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
>>> input = torch.rand(1, 1, 16, 17)
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
"""
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
super(AdaptivePadding, self).__init__()
assert padding in ('same', 'corner')
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
self.padding = padding
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
def get_pad_shape(self, input_shape):
input_h, input_w = input_shape
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.stride
output_h = math.ceil(input_h / stride_h)
output_w = math.ceil(input_w / stride_w)
pad_h = max((output_h - 1) * stride_h +
(kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
pad_w = max((output_w - 1) * stride_w +
(kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
return pad_h, pad_w
def forward(self, x):
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
if pad_h > 0 or pad_w > 0:
if self.padding == 'corner':
x = F.pad(x, [0, pad_w, 0, pad_h])
elif self.padding == 'same':
x = F.pad(x, [
pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
pad_h - pad_h // 2
])
return x
# Modified from Pytorch-Image-Models
class PatchEmbed(BaseModule): class PatchEmbed(BaseModule):
"""Image to Patch Embedding V2. """Image to Patch Embedding.
We use a conv layer to implement PatchEmbed. We use a conv layer to implement PatchEmbed.
Args: Args:
in_channels (int): The num of input channels. Default: 3 in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768 embed_dims (int): The dimensions of embedding. Default: 768
conv_type (dict, optional): The config dict for conv layers type conv_type (str): The config dict for embedding
selection. Default: None. conv layer type selection. Default: "Conv2d".
kernel_size (int): The kernel_size of embedding conv. Default: 16. kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int): The slide stride of embedding conv. stride (int, optional): The slide stride of embedding conv.
Default: None (Default to be equal with kernel_size). Default: None (Would be set as `kernel_size`).
padding (int): The padding length of embedding conv. Default: 0. padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int): The dilation rate of embedding conv. Default: 1. dilation (int): The dilation rate of embedding conv. Default: 1.
pad_to_patch_size (bool, optional): Whether to pad feature map shape bias (bool): Bias of embed conv. Default: True.
to multiple patch size. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer. norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
input_size (int | tuple | None): The size of input, which will be
used to calculate the out size. Only work when `dynamic_size`
is False. Default: None.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None. Default: None.
""" """
@ -30,39 +111,37 @@ class PatchEmbed(BaseModule):
def __init__(self, def __init__(self,
in_channels=3, in_channels=3,
embed_dims=768, embed_dims=768,
conv_type=None, conv_type='Conv2d',
kernel_size=16, kernel_size=16,
stride=16, stride=None,
padding=0, padding='corner',
dilation=1, dilation=1,
pad_to_patch_size=True, bias=True,
norm_cfg=None, norm_cfg=None,
input_size=None,
init_cfg=None): init_cfg=None):
super(PatchEmbed, self).__init__() super(PatchEmbed, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims self.embed_dims = embed_dims
self.init_cfg = init_cfg
if stride is None: if stride is None:
stride = kernel_size stride = kernel_size
self.pad_to_patch_size = pad_to_patch_size kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
# The default setting of patch size is equal to kernel size. if isinstance(padding, str):
patch_size = kernel_size self.adap_padding = AdaptivePadding(
if isinstance(patch_size, int): kernel_size=kernel_size,
patch_size = to_2tuple(patch_size) stride=stride,
elif isinstance(patch_size, tuple): dilation=dilation,
if len(patch_size) == 1: padding=padding)
patch_size = to_2tuple(patch_size[0]) # disable the padding of conv
assert len(patch_size) == 2, \ padding = 0
f'The size of patch should have length 1 or 2, ' \ else:
f'but got {len(patch_size)}' self.adap_padding = None
padding = to_2tuple(padding)
self.patch_size = patch_size
# Use conv layer to embed
conv_type = conv_type or 'Conv2d'
self.projection = build_conv_layer( self.projection = build_conv_layer(
dict(type=conv_type), dict(type=conv_type),
in_channels=in_channels, in_channels=in_channels,
@ -70,31 +149,182 @@ class PatchEmbed(BaseModule):
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
padding=padding, padding=padding,
dilation=dilation) dilation=dilation,
bias=bias)
if norm_cfg is not None: if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1] self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else: else:
self.norm = None self.norm = None
def forward(self, x): if input_size:
H, W = x.shape[2], x.shape[3] input_size = to_2tuple(input_size)
# `init_out_size` would be used outside to
# calculate the num_patches
# when `use_abs_pos_embed` outside
self.init_input_size = input_size
if self.adap_padding:
pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
input_h, input_w = input_size
input_h = input_h + pad_h
input_w = input_w + pad_w
input_size = (input_h, input_w)
# TODO: Process overlapping op # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
if self.pad_to_patch_size: h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
# Modify H, W to multiple of patch size. (kernel_size[0] - 1) - 1) // stride[0] + 1
if H % self.patch_size[0] != 0: w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
x = F.pad( (kernel_size[1] - 1) - 1) // stride[1] + 1
x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) self.init_out_size = (h_out, w_out)
if W % self.patch_size[1] != 0: else:
x = F.pad( self.init_input_size = None
x, (0, self.patch_size[1] - W % self.patch_size[1], 0, 0)) self.init_out_size = None
def forward(self, x):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if self.adap_padding:
x = self.adap_padding(x)
x = self.projection(x) x = self.projection(x)
self.DH, self.DW = x.shape[2], x.shape[3] out_size = (x.shape[2], x.shape[3])
x = x.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2)
if self.norm is not None: if self.norm is not None:
x = self.norm(x) x = self.norm(x)
return x, out_size
return x
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer groups feature map by kernel_size, and applies norm and linear
layers to the grouped feature map. Our implementation uses `nn.Unfold` to
merge patch, which is about 25% faster than original implementation.
Instead, we need to modify pretrained models for compatibility.
Args:
in_channels (int): The num of input channels.
out_channels (int): The num of output channels.
kernel_size (int | tuple, optional): the kernel size in the unfold
layer. Defaults to 2.
stride (int | tuple, optional): the stride of the sliding blocks in the
unfold layer. Default: None. (Would be set as `kernel_size`)
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int | tuple, optional): dilation parameter in the unfold
layer. Default: 1.
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=2,
stride=None,
padding='corner',
dilation=1,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
if stride:
stride = stride
else:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adap_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of unfold
padding = 0
else:
self.adap_padding = None
padding = to_2tuple(padding)
self.sampler = nn.Unfold(
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
stride=stride)
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, input_size):
"""
Args:
x (Tensor): Has shape (B, H*W, C_in).
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
Default: None.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- out_size (tuple[int]): Spatial shape of x, arrange as
(Merged_H, Merged_W).
"""
B, L, C = x.shape
assert isinstance(input_size, Sequence), f'Expect ' \
f'input_size is ' \
f'`Sequence` ' \
f'but get {input_size}'
H, W = input_size
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
if self.adap_padding:
x = self.adap_padding(x)
H, W = x.shape[-2:]
x = self.sampler(x)
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
(self.sampler.kernel_size[0] - 1) -
1) // self.sampler.stride[0] + 1
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
(self.sampler.kernel_size[1] - 1) -
1) // self.sampler.stride[1] + 1
output_size = (out_h, out_w)
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
return x, output_size

View File

@ -7,10 +7,6 @@ from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN
def test_mit(): def test_mit():
with pytest.raises(AssertionError):
# It's only support official style and mmcls style now.
MixVisionTransformer(pretrain_style='timm')
with pytest.raises(TypeError): with pytest.raises(TypeError):
# Pretrained represents pretrain url and must be str or None. # Pretrained represents pretrain url and must be str or None.
MixVisionTransformer(pretrained=123) MixVisionTransformer(pretrained=123)

View File

@ -1,8 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest import pytest
import torch import torch
from mmseg.models.backbones import SwinTransformer from mmseg.models.backbones.swin import SwinBlock, SwinTransformer
def test_swin_block():
# test SwinBlock structure and forward
block = SwinBlock(embed_dims=64, num_heads=4, feedforward_channels=256)
assert block.ffn.embed_dims == 64
assert block.attn.w_msa.num_heads == 4
assert block.ffn.feedforward_channels == 256
x = torch.randn(1, 56 * 56, 64)
x_out = block(x, (56, 56))
assert x_out.shape == torch.Size([1, 56 * 56, 64])
# Test BasicBlock with checkpoint forward
block = SwinBlock(
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
assert block.with_cp
x = torch.randn(1, 56 * 56, 64)
x_out = block(x, (56, 56))
assert x_out.shape == torch.Size([1, 56 * 56, 64])
def test_swin_transformer(): def test_swin_transformer():
@ -10,12 +28,16 @@ def test_swin_transformer():
with pytest.raises(TypeError): with pytest.raises(TypeError):
# Pretrained arg must be str or None. # Pretrained arg must be str or None.
model = SwinTransformer(pretrained=123) SwinTransformer(pretrained=123)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# Because swin use non-overlapping patch embed, so the stride of patch # Because swin uses non-overlapping patch embed, so the stride of patch
# embed must be equal to patch size. # embed must be equal to patch size.
model = SwinTransformer(strides=(2, 2, 2, 2), patch_size=4) SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)
# test pretrained image size
with pytest.raises(AssertionError):
SwinTransformer(pretrain_img_size=(224, 224, 224))
# Test absolute position embedding # Test absolute position embedding
temp = torch.randn((1, 3, 224, 224)) temp = torch.randn((1, 3, 224, 224))
@ -27,12 +49,6 @@ def test_swin_transformer():
model = SwinTransformer(patch_norm=False) model = SwinTransformer(patch_norm=False)
model(temp) model(temp)
# Test pretrain img size
model = SwinTransformer(pretrain_img_size=(224, ))
with pytest.raises(AssertionError):
model = SwinTransformer(pretrain_img_size=(224, 224, 224))
# Test normal inference # Test normal inference
temp = torch.randn((1, 3, 512, 512)) temp = torch.randn((1, 3, 512, 512))
model = SwinTransformer() model = SwinTransformer()
@ -42,7 +58,7 @@ def test_swin_transformer():
assert outs[2].shape == (1, 384, 32, 32) assert outs[2].shape == (1, 384, 32, 32)
assert outs[3].shape == (1, 768, 16, 16) assert outs[3].shape == (1, 768, 16, 16)
# Test abnormal inference # Test abnormal inference size
temp = torch.randn((1, 3, 511, 511)) temp = torch.randn((1, 3, 511, 511))
model = SwinTransformer() model = SwinTransformer()
outs = model(temp) outs = model(temp)
@ -51,7 +67,7 @@ def test_swin_transformer():
assert outs[2].shape == (1, 384, 32, 32) assert outs[2].shape == (1, 384, 32, 32)
assert outs[3].shape == (1, 768, 16, 16) assert outs[3].shape == (1, 768, 16, 16)
# Test abnormal inference # Test abnormal inference size
temp = torch.randn((1, 3, 112, 137)) temp = torch.randn((1, 3, 112, 137))
model = SwinTransformer() model = SwinTransformer()
outs = model(temp) outs = model(temp)
@ -59,3 +75,25 @@ def test_swin_transformer():
assert outs[1].shape == (1, 192, 14, 18) assert outs[1].shape == (1, 192, 14, 18)
assert outs[2].shape == (1, 384, 7, 9) assert outs[2].shape == (1, 384, 7, 9)
assert outs[3].shape == (1, 768, 4, 5) assert outs[3].shape == (1, 768, 4, 5)
# Test frozen
model = SwinTransformer(frozen_stages=4)
model.train()
for p in model.parameters():
assert not p.requires_grad
# Test absolute position embedding frozen
model = SwinTransformer(frozen_stages=4, use_abs_pos_embed=True)
model.train()
for p in model.parameters():
assert not p.requires_grad
# Test Swin with checkpoint forward
temp = torch.randn((1, 3, 224, 224))
model = SwinTransformer(with_cp=True)
for m in model.modules():
if isinstance(m, SwinBlock):
assert m.with_cp
model.init_weights()
model.train()
model(temp)

View File

@ -25,12 +25,6 @@ def test_vit_backbone():
x = torch.randn(1, 196) x = torch.randn(1, 196)
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear') VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
with pytest.raises(IndexError):
# forward inputs must be [N, C, H, W]
x = torch.randn(3, 30, 30)
model = VisionTransformer()
model(x)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# The length of img_size tuple must be lower than 3. # The length of img_size tuple must be lower than 3.
VisionTransformer(img_size=(224, 224, 224)) VisionTransformer(img_size=(224, 224, 224))

View File

View File

@ -0,0 +1,461 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.utils.embed import AdaptivePadding, PatchEmbed, PatchMerging
def test_adaptive_padding():
for padding in ('same', 'corner'):
kernel_size = 16
stride = 16
dilation = 1
input = torch.rand(1, 1, 15, 17)
adap_pool = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
out = adap_pool(input)
# padding to divisible by 16
assert (out.shape[2], out.shape[3]) == (16, 32)
input = torch.rand(1, 1, 16, 17)
out = adap_pool(input)
# padding to divisible by 16
assert (out.shape[2], out.shape[3]) == (16, 32)
kernel_size = (2, 2)
stride = (2, 2)
dilation = (1, 1)
adap_pad = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
input = torch.rand(1, 1, 11, 13)
out = adap_pad(input)
# padding to divisible by 2
assert (out.shape[2], out.shape[3]) == (12, 14)
kernel_size = (2, 2)
stride = (10, 10)
dilation = (1, 1)
adap_pad = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
input = torch.rand(1, 1, 10, 13)
out = adap_pad(input)
# no padding
assert (out.shape[2], out.shape[3]) == (10, 13)
kernel_size = (11, 11)
adap_pad = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
input = torch.rand(1, 1, 11, 13)
out = adap_pad(input)
# all padding
assert (out.shape[2], out.shape[3]) == (21, 21)
# test padding as kernel is (7,9)
input = torch.rand(1, 1, 11, 13)
stride = (3, 4)
kernel_size = (4, 5)
dilation = (2, 2)
# actually (7, 9)
adap_pad = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
dilation_out = adap_pad(input)
assert (dilation_out.shape[2], dilation_out.shape[3]) == (16, 21)
kernel_size = (7, 9)
dilation = (1, 1)
adap_pad = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
kernel79_out = adap_pad(input)
assert (kernel79_out.shape[2], kernel79_out.shape[3]) == (16, 21)
assert kernel79_out.shape == dilation_out.shape
# assert only support "same" "corner"
with pytest.raises(AssertionError):
AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=1)
def test_patch_embed():
B = 2
H = 3
W = 4
C = 3
embed_dims = 10
kernel_size = 3
stride = 1
dummy_input = torch.rand(B, C, H, W)
patch_merge_1 = PatchEmbed(
in_channels=C,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=1,
norm_cfg=None)
x1, shape = patch_merge_1(dummy_input)
# test out shape
assert x1.shape == (2, 2, 10)
# test outsize is correct
assert shape == (1, 2)
# test L = out_h * out_w
assert shape[0] * shape[1] == x1.shape[1]
B = 2
H = 10
W = 10
C = 3
embed_dims = 10
kernel_size = 5
stride = 2
dummy_input = torch.rand(B, C, H, W)
# test dilation
patch_merge_2 = PatchEmbed(
in_channels=C,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=2,
norm_cfg=None,
)
x2, shape = patch_merge_2(dummy_input)
# test out shape
assert x2.shape == (2, 1, 10)
# test outsize is correct
assert shape == (1, 1)
# test L = out_h * out_w
assert shape[0] * shape[1] == x2.shape[1]
stride = 2
input_size = (10, 10)
dummy_input = torch.rand(B, C, H, W)
# test stride and norm
patch_merge_3 = PatchEmbed(
in_channels=C,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=2,
norm_cfg=dict(type='LN'),
input_size=input_size)
x3, shape = patch_merge_3(dummy_input)
# test out shape
assert x3.shape == (2, 1, 10)
# test outsize is correct
assert shape == (1, 1)
# test L = out_h * out_w
assert shape[0] * shape[1] == x3.shape[1]
# test thte init_out_size with nn.Unfold
assert patch_merge_3.init_out_size[1] == (input_size[0] - 2 * 4 -
1) // 2 + 1
assert patch_merge_3.init_out_size[0] == (input_size[0] - 2 * 4 -
1) // 2 + 1
H = 11
W = 12
input_size = (H, W)
dummy_input = torch.rand(B, C, H, W)
# test stride and norm
patch_merge_3 = PatchEmbed(
in_channels=C,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=2,
norm_cfg=dict(type='LN'),
input_size=input_size)
_, shape = patch_merge_3(dummy_input)
# when input_size equal to real input
# the out_size shoule be equal to `init_out_size`
assert shape == patch_merge_3.init_out_size
input_size = (H, W)
dummy_input = torch.rand(B, C, H, W)
# test stride and norm
patch_merge_3 = PatchEmbed(
in_channels=C,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=2,
norm_cfg=dict(type='LN'),
input_size=input_size)
_, shape = patch_merge_3(dummy_input)
# when input_size equal to real input
# the out_size shoule be equal to `init_out_size`
assert shape == patch_merge_3.init_out_size
# test adap padding
for padding in ('same', 'corner'):
in_c = 2
embed_dims = 3
B = 2
# test stride is 1
input_size = (5, 5)
kernel_size = (5, 5)
stride = (1, 1)
dilation = 1
bias = False
x = torch.rand(B, in_c, *input_size)
patch_embed = PatchEmbed(
in_channels=in_c,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_embed(x)
assert x_out.size() == (B, 25, 3)
assert out_size == (5, 5)
assert x_out.size(1) == out_size[0] * out_size[1]
# test kernel_size == stride
input_size = (5, 5)
kernel_size = (5, 5)
stride = (5, 5)
dilation = 1
bias = False
x = torch.rand(B, in_c, *input_size)
patch_embed = PatchEmbed(
in_channels=in_c,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_embed(x)
assert x_out.size() == (B, 1, 3)
assert out_size == (1, 1)
assert x_out.size(1) == out_size[0] * out_size[1]
# test kernel_size == stride
input_size = (6, 5)
kernel_size = (5, 5)
stride = (5, 5)
dilation = 1
bias = False
x = torch.rand(B, in_c, *input_size)
patch_embed = PatchEmbed(
in_channels=in_c,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_embed(x)
assert x_out.size() == (B, 2, 3)
assert out_size == (2, 1)
assert x_out.size(1) == out_size[0] * out_size[1]
# test different kernel_size with diffrent stride
input_size = (6, 5)
kernel_size = (6, 2)
stride = (6, 2)
dilation = 1
bias = False
x = torch.rand(B, in_c, *input_size)
patch_embed = PatchEmbed(
in_channels=in_c,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_embed(x)
assert x_out.size() == (B, 3, 3)
assert out_size == (1, 3)
assert x_out.size(1) == out_size[0] * out_size[1]
def test_patch_merging():
# Test the model with int padding
in_c = 3
out_c = 4
kernel_size = 3
stride = 3
padding = 1
dilation = 1
bias = False
# test the case `pad_to_stride` is False
patch_merge = PatchMerging(
in_channels=in_c,
out_channels=out_c,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
B, L, C = 1, 100, 3
input_size = (10, 10)
x = torch.rand(B, L, C)
x_out, out_size = patch_merge(x, input_size)
assert x_out.size() == (1, 16, 4)
assert out_size == (4, 4)
# assert out size is consistent with real output
assert x_out.size(1) == out_size[0] * out_size[1]
in_c = 4
out_c = 5
kernel_size = 6
stride = 3
padding = 2
dilation = 2
bias = False
patch_merge = PatchMerging(
in_channels=in_c,
out_channels=out_c,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
B, L, C = 1, 100, 4
input_size = (10, 10)
x = torch.rand(B, L, C)
x_out, out_size = patch_merge(x, input_size)
assert x_out.size() == (1, 4, 5)
assert out_size == (2, 2)
# assert out size is consistent with real output
assert x_out.size(1) == out_size[0] * out_size[1]
# Test with adaptive padding
for padding in ('same', 'corner'):
in_c = 2
out_c = 3
B = 2
# test stride is 1
input_size = (5, 5)
kernel_size = (5, 5)
stride = (1, 1)
dilation = 1
bias = False
L = input_size[0] * input_size[1]
x = torch.rand(B, L, in_c)
patch_merge = PatchMerging(
in_channels=in_c,
out_channels=out_c,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_merge(x, input_size)
assert x_out.size() == (B, 25, 3)
assert out_size == (5, 5)
assert x_out.size(1) == out_size[0] * out_size[1]
# test kernel_size == stride
input_size = (5, 5)
kernel_size = (5, 5)
stride = (5, 5)
dilation = 1
bias = False
L = input_size[0] * input_size[1]
x = torch.rand(B, L, in_c)
patch_merge = PatchMerging(
in_channels=in_c,
out_channels=out_c,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_merge(x, input_size)
assert x_out.size() == (B, 1, 3)
assert out_size == (1, 1)
assert x_out.size(1) == out_size[0] * out_size[1]
# test kernel_size == stride
input_size = (6, 5)
kernel_size = (5, 5)
stride = (5, 5)
dilation = 1
bias = False
L = input_size[0] * input_size[1]
x = torch.rand(B, L, in_c)
patch_merge = PatchMerging(
in_channels=in_c,
out_channels=out_c,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_merge(x, input_size)
assert x_out.size() == (B, 2, 3)
assert out_size == (2, 1)
assert x_out.size(1) == out_size[0] * out_size[1]
# test different kernel_size with diffrent stride
input_size = (6, 5)
kernel_size = (6, 2)
stride = (6, 2)
dilation = 1
bias = False
L = input_size[0] * input_size[1]
x = torch.rand(B, L, in_c)
patch_merge = PatchMerging(
in_channels=in_c,
out_channels=out_c,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_merge(x, input_size)
assert x_out.size() == (B, 3, 3)
assert out_size == (1, 3)
assert x_out.size(1) == out_size[0] * out_size[1]