[Improvement] Refactor Swin-Transformer (#800)
* [Improvement] Refactor Swin-Transformer * fixed swin test * update patch emebd, add more tests * fixed test * remove pretrain_style * fixed padding * resolve coments * use mmcv 2tuple * refactor init_cfg Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
This commit is contained in:
parent
ab12009414
commit
85227b46c7
@ -23,8 +23,7 @@ model = dict(
|
||||
drop_path_rate=0.3,
|
||||
use_abs_pos_embed=False,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=backbone_norm_cfg,
|
||||
pretrain_style='official'),
|
||||
norm_cfg=backbone_norm_cfg),
|
||||
decode_head=dict(
|
||||
type='UPerHead',
|
||||
in_channels=[96, 192, 384, 768],
|
||||
|
||||
@ -11,8 +11,7 @@ model = dict(
|
||||
window_size=7,
|
||||
use_abs_pos_embed=False,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
pretrain_style='official'),
|
||||
patch_norm=True),
|
||||
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
|
||||
auxiliary_head=dict(in_channels=384, num_classes=150))
|
||||
|
||||
|
||||
@ -278,8 +278,6 @@ class MixVisionTransformer(BaseModule):
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Defalut: dict(type='GELU').
|
||||
pretrain_style (str): Choose to use official or mmcls pretrain weights.
|
||||
Default: official.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
@ -302,15 +300,10 @@ class MixVisionTransformer(BaseModule):
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
pretrain_style='official',
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__()
|
||||
|
||||
assert pretrain_style in [
|
||||
'official', 'mmcls'
|
||||
], 'we only support official weights or mmcls weights.'
|
||||
|
||||
if isinstance(pretrained, str) or pretrained is None:
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
@ -330,7 +323,6 @@ class MixVisionTransformer(BaseModule):
|
||||
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < self.num_stages
|
||||
self.pretrain_style = pretrain_style
|
||||
self.pretrained = pretrained
|
||||
self.init_cfg = init_cfg
|
||||
|
||||
@ -350,7 +342,6 @@ class MixVisionTransformer(BaseModule):
|
||||
kernel_size=patch_sizes[i],
|
||||
stride=strides[i],
|
||||
padding=patch_sizes[i] // 2,
|
||||
pad_to_patch_size=False,
|
||||
norm_cfg=norm_cfg)
|
||||
layer = ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
@ -403,8 +394,7 @@ class MixVisionTransformer(BaseModule):
|
||||
outs = []
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x, H, W = layer[0](x), layer[0].DH, layer[0].DW
|
||||
hw_shape = (H, W)
|
||||
x, hw_shape = layer[0](x)
|
||||
for block in layer[1]:
|
||||
x = block(x, hw_shape)
|
||||
x = layer[2](x)
|
||||
|
||||
@ -1,111 +1,37 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer, trunc_normal_init
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
|
||||
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
||||
from mmcv.cnn.utils.weight_init import constant_init
|
||||
from mmcv.runner import _load_checkpoint
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
from torch.nn.modules.linear import Linear
|
||||
from torch.nn.modules.normalization import LayerNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
||||
from mmcv.utils import to_2tuple
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ...utils import get_root_logger
|
||||
from ..builder import ATTENTION, BACKBONES
|
||||
from ..utils import PatchEmbed
|
||||
from ..builder import BACKBONES
|
||||
from ..utils.embed import PatchEmbed, PatchMerging
|
||||
|
||||
|
||||
class PatchMerging(BaseModule):
|
||||
"""Merge patch feature map.
|
||||
|
||||
This layer use nn.Unfold to group feature map by kernel_size, and use norm
|
||||
and linear layer to embed grouped feature map.
|
||||
|
||||
Args:
|
||||
in_channels (int): The num of input channels.
|
||||
out_channels (int): The num of output channels.
|
||||
stride (int | tuple): the stride of the sliding length in the
|
||||
unfold layer. Defaults: 2. (Default to be equal with kernel_size).
|
||||
bias (bool, optional): Whether to add bias in linear layer or not.
|
||||
Defaults: False.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Defaults: dict(type='LN').
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=2,
|
||||
bias=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.stride = stride
|
||||
|
||||
self.sampler = nn.Unfold(
|
||||
kernel_size=stride, dilation=1, padding=0, stride=stride)
|
||||
|
||||
sample_dim = stride**2 * in_channels
|
||||
|
||||
if norm_cfg is not None:
|
||||
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
"""
|
||||
x: x.shape -> [B, H*W, C]
|
||||
hw_shape: (H, W)
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
H, W = hw_shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
|
||||
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
|
||||
|
||||
# stride is fixed to be equal to kernel_size.
|
||||
if (H % self.stride != 0) or (W % self.stride != 0):
|
||||
x = F.pad(x, (0, W % self.stride, 0, H % self.stride))
|
||||
|
||||
# Use nn.Unfold to merge patch. About 25% faster than original method,
|
||||
# but need to modify pretrained model for compatibility
|
||||
x = self.sampler(x) # B, 4*C, H/2*W/2
|
||||
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
|
||||
|
||||
x = self.norm(x) if self.norm else x
|
||||
x = self.reduction(x)
|
||||
|
||||
down_hw_shape = (H + 1) // 2, (W + 1) // 2
|
||||
return x, down_hw_shape
|
||||
|
||||
|
||||
@ATTENTION.register_module()
|
||||
class WindowMSA(BaseModule):
|
||||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||||
position bias.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: True.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.0
|
||||
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
|
||||
init_cfg (dict | None, optional): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
@ -120,13 +46,12 @@ class WindowMSA(BaseModule):
|
||||
proj_drop_rate=0.,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__()
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.embed_dims = embed_dims
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_embed_dims = embed_dims // num_heads
|
||||
self.scale = qk_scale or head_embed_dims**-0.5
|
||||
self.init_cfg = init_cfg
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
@ -161,8 +86,8 @@ class WindowMSA(BaseModule):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[
|
||||
2] # make torchscript happy (cannot use tensor as tuple)
|
||||
# make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
@ -181,9 +106,7 @@ class WindowMSA(BaseModule):
|
||||
attn = attn.view(B // nW, nW, self.num_heads, N,
|
||||
N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
@ -199,9 +122,8 @@ class WindowMSA(BaseModule):
|
||||
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
|
||||
|
||||
|
||||
@ATTENTION.register_module()
|
||||
class ShiftWindowMSA(BaseModule):
|
||||
"""Shift Window Multihead Self-Attention Module.
|
||||
"""Shifted Window Multihead Self-Attention Module.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
@ -234,7 +156,7 @@ class ShiftWindowMSA(BaseModule):
|
||||
proj_drop_rate=0,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=0.),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
@ -272,8 +194,7 @@ class ShiftWindowMSA(BaseModule):
|
||||
dims=(1, 2))
|
||||
|
||||
# calculate attention mask for SW-MSA
|
||||
img_mask = torch.zeros((1, H_pad, W_pad, 1),
|
||||
device=query.device) # 1 H W 1
|
||||
img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
@ -333,7 +254,6 @@ class ShiftWindowMSA(BaseModule):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
Returns:
|
||||
@ -350,7 +270,6 @@ class ShiftWindowMSA(BaseModule):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
@ -369,18 +288,21 @@ class SwinBlock(BaseModule):
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
window size (int, optional): The local window scale. Default: 7.
|
||||
shift (bool): whether to shift window or not. Default False.
|
||||
qkv_bias (int, optional): enable bias for qkv if True. Default: True.
|
||||
window_size (int, optional): The local window scale. Default: 7.
|
||||
shift (bool, optional): whether to shift window or not. Default False.
|
||||
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
drop_rate (float, optional): Dropout rate. Default: 0.
|
||||
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
|
||||
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
|
||||
act_cfg (dict, optional): The config dict of activation function.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict, optional): The config dict of nomalization.
|
||||
norm_cfg (dict, optional): The config dict of normalization.
|
||||
Default: dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
init_cfg (dict | list | None, optional): The init config.
|
||||
Default: None.
|
||||
"""
|
||||
@ -398,11 +320,12 @@ class SwinBlock(BaseModule):
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
|
||||
super(SwinBlock, self).__init__()
|
||||
super(SwinBlock, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.init_cfg = init_cfg
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.attn = ShiftWindowMSA(
|
||||
@ -429,15 +352,24 @@ class SwinBlock(BaseModule):
|
||||
init_cfg=None)
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
identity = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x, hw_shape)
|
||||
|
||||
x = x + identity
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x, hw_shape)
|
||||
|
||||
identity = x
|
||||
x = self.norm2(x)
|
||||
x = self.ffn(x, identity=identity)
|
||||
x = x + identity
|
||||
|
||||
identity = x
|
||||
x = self.norm2(x)
|
||||
x = self.ffn(x, identity=identity)
|
||||
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
|
||||
return x
|
||||
|
||||
@ -450,19 +382,23 @@ class SwinBlockSequence(BaseModule):
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
depth (int): The number of blocks in this stage.
|
||||
window size (int): The local window scale. Default: 7.
|
||||
qkv_bias (int): enable bias for qkv if True. Default: True.
|
||||
window_size (int, optional): The local window scale. Default: 7.
|
||||
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
drop_rate (float, optional): Dropout rate. Default: 0.
|
||||
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
|
||||
drop_path_rate (float | list[float], optional): Stochastic depth
|
||||
rate. Default: 0.
|
||||
downsample (BaseModule | None, optional): The downsample operation
|
||||
module. Default: None.
|
||||
act_cfg (dict, optional): The config dict of activation function.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict, optional): The config dict of nomalization.
|
||||
norm_cfg (dict, optional): The config dict of normalization.
|
||||
Default: dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
init_cfg (dict | list | None, optional): The init config.
|
||||
Default: None.
|
||||
"""
|
||||
@ -481,14 +417,15 @@ class SwinBlockSequence(BaseModule):
|
||||
downsample=None,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
super().__init__()
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.init_cfg = init_cfg
|
||||
|
||||
drop_path_rate = drop_path_rate if isinstance(
|
||||
drop_path_rate,
|
||||
list) else [deepcopy(drop_path_rate) for _ in range(depth)]
|
||||
if isinstance(drop_path_rate, list):
|
||||
drop_path_rates = drop_path_rate
|
||||
assert len(drop_path_rates) == depth
|
||||
else:
|
||||
drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
|
||||
|
||||
self.blocks = ModuleList()
|
||||
for i in range(depth):
|
||||
@ -502,9 +439,10 @@ class SwinBlockSequence(BaseModule):
|
||||
qk_scale=qk_scale,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate[i],
|
||||
drop_path_rate=drop_path_rates[i],
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
init_cfg=None)
|
||||
self.blocks.append(block)
|
||||
|
||||
@ -538,7 +476,7 @@ class SwinTransformer(BaseModule):
|
||||
embed_dims (int): The feature dimension. Default: 96.
|
||||
patch_size (int | tuple[int]): Patch size. Default: 4.
|
||||
window_size (int): Window size. Default: 7.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
depths (tuple[int]): Depths of each Swin Transformer stage.
|
||||
Default: (2, 2, 6, 2).
|
||||
@ -564,7 +502,12 @@ class SwinTransformer(BaseModule):
|
||||
Default: dict(type='LN').
|
||||
norm_cfg (dict): Config dict for normalization layer at
|
||||
output of backone. Defaults: dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
@ -589,9 +532,11 @@ class SwinTransformer(BaseModule):
|
||||
use_abs_pos_embed=False,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
frozen_stages=-1,
|
||||
init_cfg=None):
|
||||
super(SwinTransformer, self).__init__()
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
if isinstance(pretrain_img_size, int):
|
||||
pretrain_img_size = to_2tuple(pretrain_img_size)
|
||||
@ -602,17 +547,22 @@ class SwinTransformer(BaseModule):
|
||||
f'The size of image should have length 1 or 2, ' \
|
||||
f'but got {len(pretrain_img_size)}'
|
||||
|
||||
if isinstance(pretrained, str) or pretrained is None:
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be specified at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
init_cfg = init_cfg
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
super(SwinTransformer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
num_layers = len(depths)
|
||||
self.out_indices = out_indices
|
||||
self.use_abs_pos_embed = use_abs_pos_embed
|
||||
self.pretrained = pretrained
|
||||
self.init_cfg = init_cfg
|
||||
|
||||
assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
|
||||
|
||||
@ -622,7 +572,7 @@ class SwinTransformer(BaseModule):
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=strides[0],
|
||||
pad_to_patch_size=True,
|
||||
padding='corner',
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None)
|
||||
|
||||
@ -635,11 +585,11 @@ class SwinTransformer(BaseModule):
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
# set stochastic depth decay rule
|
||||
total_depth = sum(depths)
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||
] # stochastic depth decay rule
|
||||
]
|
||||
|
||||
self.stages = ModuleList()
|
||||
in_channels = embed_dims
|
||||
@ -664,14 +614,13 @@ class SwinTransformer(BaseModule):
|
||||
qk_scale=qk_scale,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[:depths[i]],
|
||||
drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
||||
downsample=downsample,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
init_cfg=None)
|
||||
self.stages.append(stage)
|
||||
|
||||
dpr = dpr[depths[i]:]
|
||||
if downsample:
|
||||
in_channels = downsample.out_channels
|
||||
|
||||
@ -682,29 +631,67 @@ class SwinTransformer(BaseModule):
|
||||
layer_name = f'norm{i}'
|
||||
self.add_module(layer_name, layer)
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep layers freezed."""
|
||||
super(SwinTransformer, self).train(mode)
|
||||
self._freeze_stages()
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
if self.use_abs_pos_embed:
|
||||
self.absolute_pos_embed.requires_grad = False
|
||||
self.drop_after_pos.eval()
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
|
||||
if (i - 1) in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i-1}')
|
||||
norm_layer.eval()
|
||||
for param in norm_layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
m = self.stages[i - 1]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self):
|
||||
if self.pretrained is None:
|
||||
super().init_weights()
|
||||
logger = get_root_logger()
|
||||
if self.init_cfg is None:
|
||||
logger.warn(f'No pre-trained weights for '
|
||||
f'{self.__class__.__name__}, '
|
||||
f'training start from scratch')
|
||||
if self.use_abs_pos_embed:
|
||||
trunc_normal_init(self.absolute_pos_embed, std=0.02)
|
||||
for m in self.modules():
|
||||
if isinstance(m, Linear):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
constant_init(m.bias, 0)
|
||||
elif isinstance(m, LayerNorm):
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m.bias, 0)
|
||||
constant_init(m.weight, 1.0)
|
||||
elif isinstance(self.pretrained, str):
|
||||
logger = get_root_logger()
|
||||
else:
|
||||
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
||||
f'specify `Pretrained` in ' \
|
||||
f'`init_cfg` in ' \
|
||||
f'{self.__class__.__name__} '
|
||||
ckpt = _load_checkpoint(
|
||||
self.pretrained, logger=logger, map_location='cpu')
|
||||
self.init_cfg.checkpoint, logger=logger, map_location='cpu')
|
||||
if 'state_dict' in ckpt:
|
||||
state_dict = ckpt['state_dict']
|
||||
_state_dict = ckpt['state_dict']
|
||||
elif 'model' in ckpt:
|
||||
state_dict = ckpt['model']
|
||||
_state_dict = ckpt['model']
|
||||
else:
|
||||
state_dict = ckpt
|
||||
_state_dict = ckpt
|
||||
|
||||
state_dict = OrderedDict()
|
||||
for k, v in _state_dict.items():
|
||||
if k.startswith('backbone.'):
|
||||
state_dict[k[9:]] = v
|
||||
|
||||
# strip prefix of state_dict
|
||||
if list(state_dict.keys())[0].startswith('module.'):
|
||||
@ -733,25 +720,22 @@ class SwinTransformer(BaseModule):
|
||||
L2, nH2 = table_current.size()
|
||||
if nH1 != nH2:
|
||||
logger.warning(f'Error in loading {table_key}, pass')
|
||||
else:
|
||||
if L1 != L2:
|
||||
S1 = int(L1**0.5)
|
||||
S2 = int(L2**0.5)
|
||||
table_pretrained_resized = resize(
|
||||
table_pretrained.permute(1, 0).reshape(
|
||||
1, nH1, S1, S1),
|
||||
size=(S2, S2),
|
||||
mode='bicubic')
|
||||
state_dict[table_key] = table_pretrained_resized.view(
|
||||
nH2, L2).permute(1, 0).contiguous()
|
||||
elif L1 != L2:
|
||||
S1 = int(L1**0.5)
|
||||
S2 = int(L2**0.5)
|
||||
table_pretrained_resized = F.interpolate(
|
||||
table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
|
||||
size=(S2, S2),
|
||||
mode='bicubic')
|
||||
state_dict[table_key] = table_pretrained_resized.view(
|
||||
nH2, L2).permute(1, 0).contiguous()
|
||||
|
||||
# load state_dict
|
||||
self.load_state_dict(state_dict, False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
|
||||
hw_shape = (self.patch_embed.DH, self.patch_embed.DW)
|
||||
if self.use_abs_pos_embed:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
@ -205,7 +205,7 @@ class VisionTransformer(BaseModule):
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
pad_to_patch_size=True,
|
||||
padding='corner',
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None,
|
||||
)
|
||||
@ -370,8 +370,8 @@ class VisionTransformer(BaseModule):
|
||||
def forward(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs), (self.patch_embed.DH,
|
||||
self.patch_embed.DW)
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
@ -1,28 +1,109 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Sequence
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmcv.runner.base_module import BaseModule
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
from mmcv.utils import to_2tuple
|
||||
|
||||
|
||||
class AdaptivePadding(nn.Module):
|
||||
"""Applies padding to input (if needed) so that input can get fully covered
|
||||
by filter you specified. It support two modes "same" and "corner". The
|
||||
"same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
|
||||
input. The "corner" mode would pad zero to bottom right.
|
||||
|
||||
Args:
|
||||
kernel_size (int | tuple): Size of the kernel:
|
||||
stride (int | tuple): Stride of the filter. Default: 1:
|
||||
dilation (int | tuple): Spacing between kernel elements.
|
||||
Default: 1.
|
||||
padding (str): Support "same" and "corner", "corner" mode
|
||||
would pad zero to bottom right, and "same" mode would
|
||||
pad zero around input. Default: "corner".
|
||||
Example:
|
||||
>>> kernel_size = 16
|
||||
>>> stride = 16
|
||||
>>> dilation = 1
|
||||
>>> input = torch.rand(1, 1, 15, 17)
|
||||
>>> adap_pad = AdaptivePadding(
|
||||
>>> kernel_size=kernel_size,
|
||||
>>> stride=stride,
|
||||
>>> dilation=dilation,
|
||||
>>> padding="corner")
|
||||
>>> out = adap_pad(input)
|
||||
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
|
||||
>>> input = torch.rand(1, 1, 16, 17)
|
||||
>>> out = adap_pad(input)
|
||||
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
|
||||
|
||||
super(AdaptivePadding, self).__init__()
|
||||
|
||||
assert padding in ('same', 'corner')
|
||||
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
stride = to_2tuple(stride)
|
||||
dilation = to_2tuple(dilation)
|
||||
|
||||
self.padding = padding
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
def get_pad_shape(self, input_shape):
|
||||
input_h, input_w = input_shape
|
||||
kernel_h, kernel_w = self.kernel_size
|
||||
stride_h, stride_w = self.stride
|
||||
output_h = math.ceil(input_h / stride_h)
|
||||
output_w = math.ceil(input_w / stride_w)
|
||||
pad_h = max((output_h - 1) * stride_h +
|
||||
(kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
|
||||
pad_w = max((output_w - 1) * stride_w +
|
||||
(kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
|
||||
return pad_h, pad_w
|
||||
|
||||
def forward(self, x):
|
||||
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
if self.padding == 'corner':
|
||||
x = F.pad(x, [0, pad_w, 0, pad_h])
|
||||
elif self.padding == 'same':
|
||||
x = F.pad(x, [
|
||||
pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
|
||||
pad_h - pad_h // 2
|
||||
])
|
||||
return x
|
||||
|
||||
|
||||
# Modified from Pytorch-Image-Models
|
||||
class PatchEmbed(BaseModule):
|
||||
"""Image to Patch Embedding V2.
|
||||
"""Image to Patch Embedding.
|
||||
|
||||
We use a conv layer to implement PatchEmbed.
|
||||
|
||||
Args:
|
||||
in_channels (int): The num of input channels. Default: 3
|
||||
embed_dims (int): The dimensions of embedding. Default: 768
|
||||
conv_type (dict, optional): The config dict for conv layers type
|
||||
selection. Default: None.
|
||||
conv_type (str): The config dict for embedding
|
||||
conv layer type selection. Default: "Conv2d".
|
||||
kernel_size (int): The kernel_size of embedding conv. Default: 16.
|
||||
stride (int): The slide stride of embedding conv.
|
||||
Default: None (Default to be equal with kernel_size).
|
||||
padding (int): The padding length of embedding conv. Default: 0.
|
||||
stride (int, optional): The slide stride of embedding conv.
|
||||
Default: None (Would be set as `kernel_size`).
|
||||
padding (int | tuple | string ): The padding length of
|
||||
embedding conv. When it is a string, it means the mode
|
||||
of adaptive padding, support "same" and "corner" now.
|
||||
Default: "corner".
|
||||
dilation (int): The dilation rate of embedding conv. Default: 1.
|
||||
pad_to_patch_size (bool, optional): Whether to pad feature map shape
|
||||
to multiple patch size. Default: True.
|
||||
bias (bool): Bias of embed conv. Default: True.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Default: None.
|
||||
input_size (int | tuple | None): The size of input, which will be
|
||||
used to calculate the out size. Only work when `dynamic_size`
|
||||
is False. Default: None.
|
||||
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
@ -30,39 +111,37 @@ class PatchEmbed(BaseModule):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
conv_type=None,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=16,
|
||||
stride=16,
|
||||
padding=0,
|
||||
stride=None,
|
||||
padding='corner',
|
||||
dilation=1,
|
||||
pad_to_patch_size=True,
|
||||
bias=True,
|
||||
norm_cfg=None,
|
||||
input_size=None,
|
||||
init_cfg=None):
|
||||
super(PatchEmbed, self).__init__()
|
||||
super(PatchEmbed, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.init_cfg = init_cfg
|
||||
|
||||
if stride is None:
|
||||
stride = kernel_size
|
||||
|
||||
self.pad_to_patch_size = pad_to_patch_size
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
stride = to_2tuple(stride)
|
||||
dilation = to_2tuple(dilation)
|
||||
|
||||
# The default setting of patch size is equal to kernel size.
|
||||
patch_size = kernel_size
|
||||
if isinstance(patch_size, int):
|
||||
patch_size = to_2tuple(patch_size)
|
||||
elif isinstance(patch_size, tuple):
|
||||
if len(patch_size) == 1:
|
||||
patch_size = to_2tuple(patch_size[0])
|
||||
assert len(patch_size) == 2, \
|
||||
f'The size of patch should have length 1 or 2, ' \
|
||||
f'but got {len(patch_size)}'
|
||||
if isinstance(padding, str):
|
||||
self.adap_padding = AdaptivePadding(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
# disable the padding of conv
|
||||
padding = 0
|
||||
else:
|
||||
self.adap_padding = None
|
||||
padding = to_2tuple(padding)
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
# Use conv layer to embed
|
||||
conv_type = conv_type or 'Conv2d'
|
||||
self.projection = build_conv_layer(
|
||||
dict(type=conv_type),
|
||||
in_channels=in_channels,
|
||||
@ -70,31 +149,182 @@ class PatchEmbed(BaseModule):
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation)
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
if norm_cfg is not None:
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
H, W = x.shape[2], x.shape[3]
|
||||
if input_size:
|
||||
input_size = to_2tuple(input_size)
|
||||
# `init_out_size` would be used outside to
|
||||
# calculate the num_patches
|
||||
# when `use_abs_pos_embed` outside
|
||||
self.init_input_size = input_size
|
||||
if self.adap_padding:
|
||||
pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
|
||||
input_h, input_w = input_size
|
||||
input_h = input_h + pad_h
|
||||
input_w = input_w + pad_w
|
||||
input_size = (input_h, input_w)
|
||||
|
||||
# TODO: Process overlapping op
|
||||
if self.pad_to_patch_size:
|
||||
# Modify H, W to multiple of patch size.
|
||||
if H % self.patch_size[0] != 0:
|
||||
x = F.pad(
|
||||
x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||
if W % self.patch_size[1] != 0:
|
||||
x = F.pad(
|
||||
x, (0, self.patch_size[1] - W % self.patch_size[1], 0, 0))
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||
h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
|
||||
(kernel_size[0] - 1) - 1) // stride[0] + 1
|
||||
w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
|
||||
(kernel_size[1] - 1) - 1) // stride[1] + 1
|
||||
self.init_out_size = (h_out, w_out)
|
||||
else:
|
||||
self.init_input_size = None
|
||||
self.init_out_size = None
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
|
||||
|
||||
Returns:
|
||||
tuple: Contains merged results and its spatial shape.
|
||||
|
||||
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
|
||||
- out_size (tuple[int]): Spatial shape of x, arrange as
|
||||
(out_h, out_w).
|
||||
"""
|
||||
|
||||
if self.adap_padding:
|
||||
x = self.adap_padding(x)
|
||||
|
||||
x = self.projection(x)
|
||||
self.DH, self.DW = x.shape[2], x.shape[3]
|
||||
out_size = (x.shape[2], x.shape[3])
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x, out_size
|
||||
|
||||
return x
|
||||
|
||||
class PatchMerging(BaseModule):
|
||||
"""Merge patch feature map.
|
||||
|
||||
This layer groups feature map by kernel_size, and applies norm and linear
|
||||
layers to the grouped feature map. Our implementation uses `nn.Unfold` to
|
||||
merge patch, which is about 25% faster than original implementation.
|
||||
Instead, we need to modify pretrained models for compatibility.
|
||||
|
||||
Args:
|
||||
in_channels (int): The num of input channels.
|
||||
out_channels (int): The num of output channels.
|
||||
kernel_size (int | tuple, optional): the kernel size in the unfold
|
||||
layer. Defaults to 2.
|
||||
stride (int | tuple, optional): the stride of the sliding blocks in the
|
||||
unfold layer. Default: None. (Would be set as `kernel_size`)
|
||||
padding (int | tuple | string ): The padding length of
|
||||
embedding conv. When it is a string, it means the mode
|
||||
of adaptive padding, support "same" and "corner" now.
|
||||
Default: "corner".
|
||||
dilation (int | tuple, optional): dilation parameter in the unfold
|
||||
layer. Default: 1.
|
||||
bias (bool, optional): Whether to add bias in linear layer or not.
|
||||
Defaults: False.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=2,
|
||||
stride=None,
|
||||
padding='corner',
|
||||
dilation=1,
|
||||
bias=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
if stride:
|
||||
stride = stride
|
||||
else:
|
||||
stride = kernel_size
|
||||
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
stride = to_2tuple(stride)
|
||||
dilation = to_2tuple(dilation)
|
||||
|
||||
if isinstance(padding, str):
|
||||
self.adap_padding = AdaptivePadding(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
# disable the padding of unfold
|
||||
padding = 0
|
||||
else:
|
||||
self.adap_padding = None
|
||||
|
||||
padding = to_2tuple(padding)
|
||||
self.sampler = nn.Unfold(
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
stride=stride)
|
||||
|
||||
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
|
||||
|
||||
if norm_cfg is not None:
|
||||
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
|
||||
|
||||
def forward(self, x, input_size):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Has shape (B, H*W, C_in).
|
||||
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
tuple: Contains merged results and its spatial shape.
|
||||
|
||||
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
|
||||
- out_size (tuple[int]): Spatial shape of x, arrange as
|
||||
(Merged_H, Merged_W).
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
assert isinstance(input_size, Sequence), f'Expect ' \
|
||||
f'input_size is ' \
|
||||
f'`Sequence` ' \
|
||||
f'but get {input_size}'
|
||||
|
||||
H, W = input_size
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
|
||||
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
|
||||
# Use nn.Unfold to merge patch. About 25% faster than original method,
|
||||
# but need to modify pretrained model for compatibility
|
||||
|
||||
if self.adap_padding:
|
||||
x = self.adap_padding(x)
|
||||
H, W = x.shape[-2:]
|
||||
|
||||
x = self.sampler(x)
|
||||
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
|
||||
|
||||
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
|
||||
(self.sampler.kernel_size[0] - 1) -
|
||||
1) // self.sampler.stride[0] + 1
|
||||
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
|
||||
(self.sampler.kernel_size[1] - 1) -
|
||||
1) // self.sampler.stride[1] + 1
|
||||
|
||||
output_size = (out_h, out_w)
|
||||
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
|
||||
x = self.norm(x) if self.norm else x
|
||||
x = self.reduction(x)
|
||||
return x, output_size
|
||||
|
||||
@ -7,10 +7,6 @@ from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN
|
||||
|
||||
|
||||
def test_mit():
|
||||
with pytest.raises(AssertionError):
|
||||
# It's only support official style and mmcls style now.
|
||||
MixVisionTransformer(pretrain_style='timm')
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# Pretrained represents pretrain url and must be str or None.
|
||||
MixVisionTransformer(pretrained=123)
|
||||
|
||||
@ -1,8 +1,26 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import SwinTransformer
|
||||
from mmseg.models.backbones.swin import SwinBlock, SwinTransformer
|
||||
|
||||
|
||||
def test_swin_block():
|
||||
# test SwinBlock structure and forward
|
||||
block = SwinBlock(embed_dims=64, num_heads=4, feedforward_channels=256)
|
||||
assert block.ffn.embed_dims == 64
|
||||
assert block.attn.w_msa.num_heads == 4
|
||||
assert block.ffn.feedforward_channels == 256
|
||||
x = torch.randn(1, 56 * 56, 64)
|
||||
x_out = block(x, (56, 56))
|
||||
assert x_out.shape == torch.Size([1, 56 * 56, 64])
|
||||
|
||||
# Test BasicBlock with checkpoint forward
|
||||
block = SwinBlock(
|
||||
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 56 * 56, 64)
|
||||
x_out = block(x, (56, 56))
|
||||
assert x_out.shape == torch.Size([1, 56 * 56, 64])
|
||||
|
||||
|
||||
def test_swin_transformer():
|
||||
@ -10,12 +28,16 @@ def test_swin_transformer():
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# Pretrained arg must be str or None.
|
||||
model = SwinTransformer(pretrained=123)
|
||||
SwinTransformer(pretrained=123)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Because swin use non-overlapping patch embed, so the stride of patch
|
||||
# Because swin uses non-overlapping patch embed, so the stride of patch
|
||||
# embed must be equal to patch size.
|
||||
model = SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)
|
||||
SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)
|
||||
|
||||
# test pretrained image size
|
||||
with pytest.raises(AssertionError):
|
||||
SwinTransformer(pretrain_img_size=(224, 224, 224))
|
||||
|
||||
# Test absolute position embedding
|
||||
temp = torch.randn((1, 3, 224, 224))
|
||||
@ -27,12 +49,6 @@ def test_swin_transformer():
|
||||
model = SwinTransformer(patch_norm=False)
|
||||
model(temp)
|
||||
|
||||
# Test pretrain img size
|
||||
model = SwinTransformer(pretrain_img_size=(224, ))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
model = SwinTransformer(pretrain_img_size=(224, 224, 224))
|
||||
|
||||
# Test normal inference
|
||||
temp = torch.randn((1, 3, 512, 512))
|
||||
model = SwinTransformer()
|
||||
@ -42,7 +58,7 @@ def test_swin_transformer():
|
||||
assert outs[2].shape == (1, 384, 32, 32)
|
||||
assert outs[3].shape == (1, 768, 16, 16)
|
||||
|
||||
# Test abnormal inference
|
||||
# Test abnormal inference size
|
||||
temp = torch.randn((1, 3, 511, 511))
|
||||
model = SwinTransformer()
|
||||
outs = model(temp)
|
||||
@ -51,7 +67,7 @@ def test_swin_transformer():
|
||||
assert outs[2].shape == (1, 384, 32, 32)
|
||||
assert outs[3].shape == (1, 768, 16, 16)
|
||||
|
||||
# Test abnormal inference
|
||||
# Test abnormal inference size
|
||||
temp = torch.randn((1, 3, 112, 137))
|
||||
model = SwinTransformer()
|
||||
outs = model(temp)
|
||||
@ -59,3 +75,25 @@ def test_swin_transformer():
|
||||
assert outs[1].shape == (1, 192, 14, 18)
|
||||
assert outs[2].shape == (1, 384, 7, 9)
|
||||
assert outs[3].shape == (1, 768, 4, 5)
|
||||
|
||||
# Test frozen
|
||||
model = SwinTransformer(frozen_stages=4)
|
||||
model.train()
|
||||
for p in model.parameters():
|
||||
assert not p.requires_grad
|
||||
|
||||
# Test absolute position embedding frozen
|
||||
model = SwinTransformer(frozen_stages=4, use_abs_pos_embed=True)
|
||||
model.train()
|
||||
for p in model.parameters():
|
||||
assert not p.requires_grad
|
||||
|
||||
# Test Swin with checkpoint forward
|
||||
temp = torch.randn((1, 3, 224, 224))
|
||||
model = SwinTransformer(with_cp=True)
|
||||
for m in model.modules():
|
||||
if isinstance(m, SwinBlock):
|
||||
assert m.with_cp
|
||||
model.init_weights()
|
||||
model.train()
|
||||
model(temp)
|
||||
|
||||
@ -25,12 +25,6 @@ def test_vit_backbone():
|
||||
x = torch.randn(1, 196)
|
||||
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
# forward inputs must be [N, C, H, W]
|
||||
x = torch.randn(3, 30, 30)
|
||||
model = VisionTransformer()
|
||||
model(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# The length of img_size tuple must be lower than 3.
|
||||
VisionTransformer(img_size=(224, 224, 224))
|
||||
|
||||
0
tests/test_models/test_utils/__init__.py
Normal file
0
tests/test_models/test_utils/__init__.py
Normal file
461
tests/test_models/test_utils/test_embed.py
Normal file
461
tests/test_models/test_utils/test_embed.py
Normal file
@ -0,0 +1,461 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.utils.embed import AdaptivePadding, PatchEmbed, PatchMerging
|
||||
|
||||
|
||||
def test_adaptive_padding():
|
||||
|
||||
for padding in ('same', 'corner'):
|
||||
kernel_size = 16
|
||||
stride = 16
|
||||
dilation = 1
|
||||
input = torch.rand(1, 1, 15, 17)
|
||||
adap_pool = AdaptivePadding(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
out = adap_pool(input)
|
||||
# padding to divisible by 16
|
||||
assert (out.shape[2], out.shape[3]) == (16, 32)
|
||||
input = torch.rand(1, 1, 16, 17)
|
||||
out = adap_pool(input)
|
||||
# padding to divisible by 16
|
||||
assert (out.shape[2], out.shape[3]) == (16, 32)
|
||||
|
||||
kernel_size = (2, 2)
|
||||
stride = (2, 2)
|
||||
dilation = (1, 1)
|
||||
|
||||
adap_pad = AdaptivePadding(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
input = torch.rand(1, 1, 11, 13)
|
||||
out = adap_pad(input)
|
||||
# padding to divisible by 2
|
||||
assert (out.shape[2], out.shape[3]) == (12, 14)
|
||||
|
||||
kernel_size = (2, 2)
|
||||
stride = (10, 10)
|
||||
dilation = (1, 1)
|
||||
|
||||
adap_pad = AdaptivePadding(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
input = torch.rand(1, 1, 10, 13)
|
||||
out = adap_pad(input)
|
||||
# no padding
|
||||
assert (out.shape[2], out.shape[3]) == (10, 13)
|
||||
|
||||
kernel_size = (11, 11)
|
||||
adap_pad = AdaptivePadding(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
input = torch.rand(1, 1, 11, 13)
|
||||
out = adap_pad(input)
|
||||
# all padding
|
||||
assert (out.shape[2], out.shape[3]) == (21, 21)
|
||||
|
||||
# test padding as kernel is (7,9)
|
||||
input = torch.rand(1, 1, 11, 13)
|
||||
stride = (3, 4)
|
||||
kernel_size = (4, 5)
|
||||
dilation = (2, 2)
|
||||
# actually (7, 9)
|
||||
adap_pad = AdaptivePadding(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
dilation_out = adap_pad(input)
|
||||
assert (dilation_out.shape[2], dilation_out.shape[3]) == (16, 21)
|
||||
kernel_size = (7, 9)
|
||||
dilation = (1, 1)
|
||||
adap_pad = AdaptivePadding(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
kernel79_out = adap_pad(input)
|
||||
assert (kernel79_out.shape[2], kernel79_out.shape[3]) == (16, 21)
|
||||
assert kernel79_out.shape == dilation_out.shape
|
||||
|
||||
# assert only support "same" "corner"
|
||||
with pytest.raises(AssertionError):
|
||||
AdaptivePadding(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=1)
|
||||
|
||||
|
||||
def test_patch_embed():
|
||||
B = 2
|
||||
H = 3
|
||||
W = 4
|
||||
C = 3
|
||||
embed_dims = 10
|
||||
kernel_size = 3
|
||||
stride = 1
|
||||
dummy_input = torch.rand(B, C, H, W)
|
||||
patch_merge_1 = PatchEmbed(
|
||||
in_channels=C,
|
||||
embed_dims=embed_dims,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
norm_cfg=None)
|
||||
|
||||
x1, shape = patch_merge_1(dummy_input)
|
||||
# test out shape
|
||||
assert x1.shape == (2, 2, 10)
|
||||
# test outsize is correct
|
||||
assert shape == (1, 2)
|
||||
# test L = out_h * out_w
|
||||
assert shape[0] * shape[1] == x1.shape[1]
|
||||
|
||||
B = 2
|
||||
H = 10
|
||||
W = 10
|
||||
C = 3
|
||||
embed_dims = 10
|
||||
kernel_size = 5
|
||||
stride = 2
|
||||
dummy_input = torch.rand(B, C, H, W)
|
||||
# test dilation
|
||||
patch_merge_2 = PatchEmbed(
|
||||
in_channels=C,
|
||||
embed_dims=embed_dims,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=2,
|
||||
norm_cfg=None,
|
||||
)
|
||||
|
||||
x2, shape = patch_merge_2(dummy_input)
|
||||
# test out shape
|
||||
assert x2.shape == (2, 1, 10)
|
||||
# test outsize is correct
|
||||
assert shape == (1, 1)
|
||||
# test L = out_h * out_w
|
||||
assert shape[0] * shape[1] == x2.shape[1]
|
||||
|
||||
stride = 2
|
||||
input_size = (10, 10)
|
||||
|
||||
dummy_input = torch.rand(B, C, H, W)
|
||||
# test stride and norm
|
||||
patch_merge_3 = PatchEmbed(
|
||||
in_channels=C,
|
||||
embed_dims=embed_dims,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=2,
|
||||
norm_cfg=dict(type='LN'),
|
||||
input_size=input_size)
|
||||
|
||||
x3, shape = patch_merge_3(dummy_input)
|
||||
# test out shape
|
||||
assert x3.shape == (2, 1, 10)
|
||||
# test outsize is correct
|
||||
assert shape == (1, 1)
|
||||
# test L = out_h * out_w
|
||||
assert shape[0] * shape[1] == x3.shape[1]
|
||||
|
||||
# test thte init_out_size with nn.Unfold
|
||||
assert patch_merge_3.init_out_size[1] == (input_size[0] - 2 * 4 -
|
||||
1) // 2 + 1
|
||||
assert patch_merge_3.init_out_size[0] == (input_size[0] - 2 * 4 -
|
||||
1) // 2 + 1
|
||||
H = 11
|
||||
W = 12
|
||||
input_size = (H, W)
|
||||
dummy_input = torch.rand(B, C, H, W)
|
||||
# test stride and norm
|
||||
patch_merge_3 = PatchEmbed(
|
||||
in_channels=C,
|
||||
embed_dims=embed_dims,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=2,
|
||||
norm_cfg=dict(type='LN'),
|
||||
input_size=input_size)
|
||||
|
||||
_, shape = patch_merge_3(dummy_input)
|
||||
# when input_size equal to real input
|
||||
# the out_size shoule be equal to `init_out_size`
|
||||
assert shape == patch_merge_3.init_out_size
|
||||
|
||||
input_size = (H, W)
|
||||
dummy_input = torch.rand(B, C, H, W)
|
||||
# test stride and norm
|
||||
patch_merge_3 = PatchEmbed(
|
||||
in_channels=C,
|
||||
embed_dims=embed_dims,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=2,
|
||||
norm_cfg=dict(type='LN'),
|
||||
input_size=input_size)
|
||||
|
||||
_, shape = patch_merge_3(dummy_input)
|
||||
# when input_size equal to real input
|
||||
# the out_size shoule be equal to `init_out_size`
|
||||
assert shape == patch_merge_3.init_out_size
|
||||
|
||||
# test adap padding
|
||||
for padding in ('same', 'corner'):
|
||||
in_c = 2
|
||||
embed_dims = 3
|
||||
B = 2
|
||||
|
||||
# test stride is 1
|
||||
input_size = (5, 5)
|
||||
kernel_size = (5, 5)
|
||||
stride = (1, 1)
|
||||
dilation = 1
|
||||
bias = False
|
||||
|
||||
x = torch.rand(B, in_c, *input_size)
|
||||
patch_embed = PatchEmbed(
|
||||
in_channels=in_c,
|
||||
embed_dims=embed_dims,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
x_out, out_size = patch_embed(x)
|
||||
assert x_out.size() == (B, 25, 3)
|
||||
assert out_size == (5, 5)
|
||||
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||
|
||||
# test kernel_size == stride
|
||||
input_size = (5, 5)
|
||||
kernel_size = (5, 5)
|
||||
stride = (5, 5)
|
||||
dilation = 1
|
||||
bias = False
|
||||
|
||||
x = torch.rand(B, in_c, *input_size)
|
||||
patch_embed = PatchEmbed(
|
||||
in_channels=in_c,
|
||||
embed_dims=embed_dims,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
x_out, out_size = patch_embed(x)
|
||||
assert x_out.size() == (B, 1, 3)
|
||||
assert out_size == (1, 1)
|
||||
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||
|
||||
# test kernel_size == stride
|
||||
input_size = (6, 5)
|
||||
kernel_size = (5, 5)
|
||||
stride = (5, 5)
|
||||
dilation = 1
|
||||
bias = False
|
||||
|
||||
x = torch.rand(B, in_c, *input_size)
|
||||
patch_embed = PatchEmbed(
|
||||
in_channels=in_c,
|
||||
embed_dims=embed_dims,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
x_out, out_size = patch_embed(x)
|
||||
assert x_out.size() == (B, 2, 3)
|
||||
assert out_size == (2, 1)
|
||||
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||
|
||||
# test different kernel_size with diffrent stride
|
||||
input_size = (6, 5)
|
||||
kernel_size = (6, 2)
|
||||
stride = (6, 2)
|
||||
dilation = 1
|
||||
bias = False
|
||||
|
||||
x = torch.rand(B, in_c, *input_size)
|
||||
patch_embed = PatchEmbed(
|
||||
in_channels=in_c,
|
||||
embed_dims=embed_dims,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
x_out, out_size = patch_embed(x)
|
||||
assert x_out.size() == (B, 3, 3)
|
||||
assert out_size == (1, 3)
|
||||
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||
|
||||
|
||||
def test_patch_merging():
|
||||
|
||||
# Test the model with int padding
|
||||
in_c = 3
|
||||
out_c = 4
|
||||
kernel_size = 3
|
||||
stride = 3
|
||||
padding = 1
|
||||
dilation = 1
|
||||
bias = False
|
||||
# test the case `pad_to_stride` is False
|
||||
patch_merge = PatchMerging(
|
||||
in_channels=in_c,
|
||||
out_channels=out_c,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
B, L, C = 1, 100, 3
|
||||
input_size = (10, 10)
|
||||
x = torch.rand(B, L, C)
|
||||
x_out, out_size = patch_merge(x, input_size)
|
||||
assert x_out.size() == (1, 16, 4)
|
||||
assert out_size == (4, 4)
|
||||
# assert out size is consistent with real output
|
||||
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||
in_c = 4
|
||||
out_c = 5
|
||||
kernel_size = 6
|
||||
stride = 3
|
||||
padding = 2
|
||||
dilation = 2
|
||||
bias = False
|
||||
patch_merge = PatchMerging(
|
||||
in_channels=in_c,
|
||||
out_channels=out_c,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
B, L, C = 1, 100, 4
|
||||
input_size = (10, 10)
|
||||
x = torch.rand(B, L, C)
|
||||
x_out, out_size = patch_merge(x, input_size)
|
||||
assert x_out.size() == (1, 4, 5)
|
||||
assert out_size == (2, 2)
|
||||
# assert out size is consistent with real output
|
||||
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||
|
||||
# Test with adaptive padding
|
||||
for padding in ('same', 'corner'):
|
||||
in_c = 2
|
||||
out_c = 3
|
||||
B = 2
|
||||
|
||||
# test stride is 1
|
||||
input_size = (5, 5)
|
||||
kernel_size = (5, 5)
|
||||
stride = (1, 1)
|
||||
dilation = 1
|
||||
bias = False
|
||||
L = input_size[0] * input_size[1]
|
||||
|
||||
x = torch.rand(B, L, in_c)
|
||||
patch_merge = PatchMerging(
|
||||
in_channels=in_c,
|
||||
out_channels=out_c,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
x_out, out_size = patch_merge(x, input_size)
|
||||
assert x_out.size() == (B, 25, 3)
|
||||
assert out_size == (5, 5)
|
||||
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||
|
||||
# test kernel_size == stride
|
||||
input_size = (5, 5)
|
||||
kernel_size = (5, 5)
|
||||
stride = (5, 5)
|
||||
dilation = 1
|
||||
bias = False
|
||||
L = input_size[0] * input_size[1]
|
||||
|
||||
x = torch.rand(B, L, in_c)
|
||||
patch_merge = PatchMerging(
|
||||
in_channels=in_c,
|
||||
out_channels=out_c,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
x_out, out_size = patch_merge(x, input_size)
|
||||
assert x_out.size() == (B, 1, 3)
|
||||
assert out_size == (1, 1)
|
||||
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||
|
||||
# test kernel_size == stride
|
||||
input_size = (6, 5)
|
||||
kernel_size = (5, 5)
|
||||
stride = (5, 5)
|
||||
dilation = 1
|
||||
bias = False
|
||||
L = input_size[0] * input_size[1]
|
||||
|
||||
x = torch.rand(B, L, in_c)
|
||||
patch_merge = PatchMerging(
|
||||
in_channels=in_c,
|
||||
out_channels=out_c,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
x_out, out_size = patch_merge(x, input_size)
|
||||
assert x_out.size() == (B, 2, 3)
|
||||
assert out_size == (2, 1)
|
||||
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||
|
||||
# test different kernel_size with diffrent stride
|
||||
input_size = (6, 5)
|
||||
kernel_size = (6, 2)
|
||||
stride = (6, 2)
|
||||
dilation = 1
|
||||
bias = False
|
||||
L = input_size[0] * input_size[1]
|
||||
|
||||
x = torch.rand(B, L, in_c)
|
||||
patch_merge = PatchMerging(
|
||||
in_channels=in_c,
|
||||
out_channels=out_c,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
x_out, out_size = patch_merge(x, input_size)
|
||||
assert x_out.size() == (B, 3, 3)
|
||||
assert out_size == (1, 3)
|
||||
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||
Loading…
x
Reference in New Issue
Block a user