[Improvement] Refactor Swin-Transformer (#800)

* [Improvement] Refactor Swin-Transformer

* fixed swin test

* update patch emebd, add more tests

* fixed test

* remove pretrain_style

* fixed padding

* resolve coments

* use mmcv 2tuple

* refactor init_cfg

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
This commit is contained in:
Jerry Jiarui XU 2021-09-28 17:46:33 -07:00 committed by GitHub
parent ab12009414
commit 85227b46c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 936 additions and 245 deletions

View File

@ -23,8 +23,7 @@ model = dict(
drop_path_rate=0.3,
use_abs_pos_embed=False,
act_cfg=dict(type='GELU'),
norm_cfg=backbone_norm_cfg,
pretrain_style='official'),
norm_cfg=backbone_norm_cfg),
decode_head=dict(
type='UPerHead',
in_channels=[96, 192, 384, 768],

View File

@ -11,8 +11,7 @@ model = dict(
window_size=7,
use_abs_pos_embed=False,
drop_path_rate=0.3,
patch_norm=True,
pretrain_style='official'),
patch_norm=True),
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
auxiliary_head=dict(in_channels=384, num_classes=150))

View File

@ -278,8 +278,6 @@ class MixVisionTransformer(BaseModule):
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Defalut: dict(type='GELU').
pretrain_style (str): Choose to use official or mmcls pretrain weights.
Default: official.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
@ -302,15 +300,10 @@ class MixVisionTransformer(BaseModule):
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN', eps=1e-6),
pretrain_style='official',
pretrained=None,
init_cfg=None):
super().__init__()
assert pretrain_style in [
'official', 'mmcls'
], 'we only support official weights or mmcls weights.'
if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
@ -330,7 +323,6 @@ class MixVisionTransformer(BaseModule):
self.out_indices = out_indices
assert max(out_indices) < self.num_stages
self.pretrain_style = pretrain_style
self.pretrained = pretrained
self.init_cfg = init_cfg
@ -350,7 +342,6 @@ class MixVisionTransformer(BaseModule):
kernel_size=patch_sizes[i],
stride=strides[i],
padding=patch_sizes[i] // 2,
pad_to_patch_size=False,
norm_cfg=norm_cfg)
layer = ModuleList([
TransformerEncoderLayer(
@ -403,8 +394,7 @@ class MixVisionTransformer(BaseModule):
outs = []
for i, layer in enumerate(self.layers):
x, H, W = layer[0](x), layer[0].DH, layer[0].DW
hw_shape = (H, W)
x, hw_shape = layer[0](x)
for block in layer[1]:
x = block(x, hw_shape)
x = layer[2](x)

View File

@ -1,111 +1,37 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer, trunc_normal_init
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmcv.cnn.utils.weight_init import constant_init
from mmcv.runner import _load_checkpoint
from mmcv.runner.base_module import BaseModule, ModuleList
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from mmcv.utils import to_2tuple
from mmseg.ops import resize
from ...utils import get_root_logger
from ..builder import ATTENTION, BACKBONES
from ..utils import PatchEmbed
from ..builder import BACKBONES
from ..utils.embed import PatchEmbed, PatchMerging
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer use nn.Unfold to group feature map by kernel_size, and use norm
and linear layer to embed grouped feature map.
Args:
in_channels (int): The num of input channels.
out_channels (int): The num of output channels.
stride (int | tuple): the stride of the sliding length in the
unfold layer. Defaults: 2. (Default to be equal with kernel_size).
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Defaults: None.
"""
def __init__(self,
in_channels,
out_channels,
stride=2,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.sampler = nn.Unfold(
kernel_size=stride, dilation=1, padding=0, stride=stride)
sample_dim = stride**2 * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, hw_shape):
"""
x: x.shape -> [B, H*W, C]
hw_shape: (H, W)
"""
B, L, C = x.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
# stride is fixed to be equal to kernel_size.
if (H % self.stride != 0) or (W % self.stride != 0):
x = F.pad(x, (0, W % self.stride, 0, H % self.stride))
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
x = self.sampler(x) # B, 4*C, H/2*W/2
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
down_hw_shape = (H + 1) // 2, (W + 1) // 2
return x, down_hw_shape
@ATTENTION.register_module()
class WindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Args:
embed_dims (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
window_size (tuple[int]): The height and width of the window.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
init_cfg (dict | None, optional): The Config for initialization.
Default: None.
"""
@ -120,13 +46,12 @@ class WindowMSA(BaseModule):
proj_drop_rate=0.,
init_cfg=None):
super().__init__()
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
self.init_cfg = init_cfg
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
@ -161,8 +86,8 @@ class WindowMSA(BaseModule):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
@ -181,9 +106,7 @@ class WindowMSA(BaseModule):
attn = attn.view(B // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
@ -199,9 +122,8 @@ class WindowMSA(BaseModule):
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
@ATTENTION.register_module()
class ShiftWindowMSA(BaseModule):
"""Shift Window Multihead Self-Attention Module.
"""Shifted Window Multihead Self-Attention Module.
Args:
embed_dims (int): Number of input channels.
@ -234,7 +156,7 @@ class ShiftWindowMSA(BaseModule):
proj_drop_rate=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
init_cfg=None):
super().__init__(init_cfg)
super().__init__(init_cfg=init_cfg)
self.window_size = window_size
self.shift_size = shift_size
@ -272,8 +194,7 @@ class ShiftWindowMSA(BaseModule):
dims=(1, 2))
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, H_pad, W_pad, 1),
device=query.device) # 1 H W 1
img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
@ -333,7 +254,6 @@ class ShiftWindowMSA(BaseModule):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
@ -350,7 +270,6 @@ class ShiftWindowMSA(BaseModule):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
@ -369,18 +288,21 @@ class SwinBlock(BaseModule):
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
window size (int, optional): The local window scale. Default: 7.
shift (bool): whether to shift window or not. Default False.
qkv_bias (int, optional): enable bias for qkv if True. Default: True.
window_size (int, optional): The local window scale. Default: 7.
shift (bool, optional): whether to shift window or not. Default False.
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization.
norm_cfg (dict, optional): The config dict of normalization.
Default: dict(type='LN').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
@ -398,11 +320,12 @@ class SwinBlock(BaseModule):
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
with_cp=False,
init_cfg=None):
super(SwinBlock, self).__init__()
super(SwinBlock, self).__init__(init_cfg=init_cfg)
self.init_cfg = init_cfg
self.with_cp = with_cp
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = ShiftWindowMSA(
@ -429,15 +352,24 @@ class SwinBlock(BaseModule):
init_cfg=None)
def forward(self, x, hw_shape):
identity = x
x = self.norm1(x)
x = self.attn(x, hw_shape)
x = x + identity
def _inner_forward(x):
identity = x
x = self.norm1(x)
x = self.attn(x, hw_shape)
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
x = x + identity
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
@ -450,19 +382,23 @@ class SwinBlockSequence(BaseModule):
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
depth (int): The number of blocks in this stage.
window size (int): The local window scale. Default: 7.
qkv_bias (int): enable bias for qkv if True. Default: True.
window_size (int, optional): The local window scale. Default: 7.
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
drop_path_rate (float | list[float], optional): Stochastic depth
rate. Default: 0.
downsample (BaseModule | None, optional): The downsample operation
module. Default: None.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization.
norm_cfg (dict, optional): The config dict of normalization.
Default: dict(type='LN').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
@ -481,14 +417,15 @@ class SwinBlockSequence(BaseModule):
downsample=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
with_cp=False,
init_cfg=None):
super().__init__()
super().__init__(init_cfg=init_cfg)
self.init_cfg = init_cfg
drop_path_rate = drop_path_rate if isinstance(
drop_path_rate,
list) else [deepcopy(drop_path_rate) for _ in range(depth)]
if isinstance(drop_path_rate, list):
drop_path_rates = drop_path_rate
assert len(drop_path_rates) == depth
else:
drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
self.blocks = ModuleList()
for i in range(depth):
@ -502,9 +439,10 @@ class SwinBlockSequence(BaseModule):
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate[i],
drop_path_rate=drop_path_rates[i],
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
init_cfg=None)
self.blocks.append(block)
@ -538,7 +476,7 @@ class SwinTransformer(BaseModule):
embed_dims (int): The feature dimension. Default: 96.
patch_size (int | tuple[int]): Patch size. Default: 4.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
Default: 4.
depths (tuple[int]): Depths of each Swin Transformer stage.
Default: (2, 2, 6, 2).
@ -564,7 +502,12 @@ class SwinTransformer(BaseModule):
Default: dict(type='LN').
norm_cfg (dict): Config dict for normalization layer at
output of backone. Defaults: dict(type='LN').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
pretrained (str, optional): model pretrained path. Default: None.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
@ -589,9 +532,11 @@ class SwinTransformer(BaseModule):
use_abs_pos_embed=False,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
with_cp=False,
pretrained=None,
frozen_stages=-1,
init_cfg=None):
super(SwinTransformer, self).__init__()
self.frozen_stages = frozen_stages
if isinstance(pretrain_img_size, int):
pretrain_img_size = to_2tuple(pretrain_img_size)
@ -602,17 +547,22 @@ class SwinTransformer(BaseModule):
f'The size of image should have length 1 or 2, ' \
f'but got {len(pretrain_img_size)}'
if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be specified at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
init_cfg = init_cfg
else:
raise TypeError('pretrained must be a str or None')
super(SwinTransformer, self).__init__(init_cfg=init_cfg)
num_layers = len(depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
self.pretrained = pretrained
self.init_cfg = init_cfg
assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
@ -622,7 +572,7 @@ class SwinTransformer(BaseModule):
conv_type='Conv2d',
kernel_size=patch_size,
stride=strides[0],
pad_to_patch_size=True,
padding='corner',
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
@ -635,11 +585,11 @@ class SwinTransformer(BaseModule):
self.drop_after_pos = nn.Dropout(p=drop_rate)
# stochastic depth
# set stochastic depth decay rule
total_depth = sum(depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
]
self.stages = ModuleList()
in_channels = embed_dims
@ -664,14 +614,13 @@ class SwinTransformer(BaseModule):
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[:depths[i]],
drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
downsample=downsample,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp,
init_cfg=None)
self.stages.append(stage)
dpr = dpr[depths[i]:]
if downsample:
in_channels = downsample.out_channels
@ -682,29 +631,67 @@ class SwinTransformer(BaseModule):
layer_name = f'norm{i}'
self.add_module(layer_name, layer)
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(SwinTransformer, self).train(mode)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.use_abs_pos_embed:
self.absolute_pos_embed.requires_grad = False
self.drop_after_pos.eval()
for i in range(1, self.frozen_stages + 1):
if (i - 1) in self.out_indices:
norm_layer = getattr(self, f'norm{i-1}')
norm_layer.eval()
for param in norm_layer.parameters():
param.requires_grad = False
m = self.stages[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self):
if self.pretrained is None:
super().init_weights()
logger = get_root_logger()
if self.init_cfg is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
if self.use_abs_pos_embed:
trunc_normal_init(self.absolute_pos_embed, std=0.02)
for m in self.modules():
if isinstance(m, Linear):
if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
elif isinstance(m, LayerNorm):
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
elif isinstance(self.pretrained, str):
logger = get_root_logger()
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
ckpt = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
self.init_cfg.checkpoint, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
state_dict = ckpt['state_dict']
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
state_dict = ckpt['model']
_state_dict = ckpt['model']
else:
state_dict = ckpt
_state_dict = ckpt
state_dict = OrderedDict()
for k, v in _state_dict.items():
if k.startswith('backbone.'):
state_dict[k[9:]] = v
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
@ -733,25 +720,22 @@ class SwinTransformer(BaseModule):
L2, nH2 = table_current.size()
if nH1 != nH2:
logger.warning(f'Error in loading {table_key}, pass')
else:
if L1 != L2:
S1 = int(L1**0.5)
S2 = int(L2**0.5)
table_pretrained_resized = resize(
table_pretrained.permute(1, 0).reshape(
1, nH1, S1, S1),
size=(S2, S2),
mode='bicubic')
state_dict[table_key] = table_pretrained_resized.view(
nH2, L2).permute(1, 0).contiguous()
elif L1 != L2:
S1 = int(L1**0.5)
S2 = int(L2**0.5)
table_pretrained_resized = F.interpolate(
table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
size=(S2, S2),
mode='bicubic')
state_dict[table_key] = table_pretrained_resized.view(
nH2, L2).permute(1, 0).contiguous()
# load state_dict
self.load_state_dict(state_dict, False)
def forward(self, x):
x = self.patch_embed(x)
x, hw_shape = self.patch_embed(x)
hw_shape = (self.patch_embed.DH, self.patch_embed.DW)
if self.use_abs_pos_embed:
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)

View File

@ -205,7 +205,7 @@ class VisionTransformer(BaseModule):
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
pad_to_patch_size=True,
padding='corner',
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None,
)
@ -370,8 +370,8 @@ class VisionTransformer(BaseModule):
def forward(self, inputs):
B = inputs.shape[0]
x, hw_shape = self.patch_embed(inputs), (self.patch_embed.DH,
self.patch_embed.DW)
x, hw_shape = self.patch_embed(inputs)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)

View File

@ -1,28 +1,109 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Sequence
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner.base_module import BaseModule
from torch.nn.modules.utils import _pair as to_2tuple
from mmcv.utils import to_2tuple
class AdaptivePadding(nn.Module):
"""Applies padding to input (if needed) so that input can get fully covered
by filter you specified. It support two modes "same" and "corner". The
"same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
input. The "corner" mode would pad zero to bottom right.
Args:
kernel_size (int | tuple): Size of the kernel:
stride (int | tuple): Stride of the filter. Default: 1:
dilation (int | tuple): Spacing between kernel elements.
Default: 1.
padding (str): Support "same" and "corner", "corner" mode
would pad zero to bottom right, and "same" mode would
pad zero around input. Default: "corner".
Example:
>>> kernel_size = 16
>>> stride = 16
>>> dilation = 1
>>> input = torch.rand(1, 1, 15, 17)
>>> adap_pad = AdaptivePadding(
>>> kernel_size=kernel_size,
>>> stride=stride,
>>> dilation=dilation,
>>> padding="corner")
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
>>> input = torch.rand(1, 1, 16, 17)
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
"""
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
super(AdaptivePadding, self).__init__()
assert padding in ('same', 'corner')
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
self.padding = padding
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
def get_pad_shape(self, input_shape):
input_h, input_w = input_shape
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.stride
output_h = math.ceil(input_h / stride_h)
output_w = math.ceil(input_w / stride_w)
pad_h = max((output_h - 1) * stride_h +
(kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
pad_w = max((output_w - 1) * stride_w +
(kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
return pad_h, pad_w
def forward(self, x):
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
if pad_h > 0 or pad_w > 0:
if self.padding == 'corner':
x = F.pad(x, [0, pad_w, 0, pad_h])
elif self.padding == 'same':
x = F.pad(x, [
pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
pad_h - pad_h // 2
])
return x
# Modified from Pytorch-Image-Models
class PatchEmbed(BaseModule):
"""Image to Patch Embedding V2.
"""Image to Patch Embedding.
We use a conv layer to implement PatchEmbed.
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (dict, optional): The config dict for conv layers type
selection. Default: None.
conv_type (str): The config dict for embedding
conv layer type selection. Default: "Conv2d".
kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int): The slide stride of embedding conv.
Default: None (Default to be equal with kernel_size).
padding (int): The padding length of embedding conv. Default: 0.
stride (int, optional): The slide stride of embedding conv.
Default: None (Would be set as `kernel_size`).
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int): The dilation rate of embedding conv. Default: 1.
pad_to_patch_size (bool, optional): Whether to pad feature map shape
to multiple patch size. Default: True.
bias (bool): Bias of embed conv. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
input_size (int | tuple | None): The size of input, which will be
used to calculate the out size. Only work when `dynamic_size`
is False. Default: None.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
@ -30,39 +111,37 @@ class PatchEmbed(BaseModule):
def __init__(self,
in_channels=3,
embed_dims=768,
conv_type=None,
conv_type='Conv2d',
kernel_size=16,
stride=16,
padding=0,
stride=None,
padding='corner',
dilation=1,
pad_to_patch_size=True,
bias=True,
norm_cfg=None,
input_size=None,
init_cfg=None):
super(PatchEmbed, self).__init__()
super(PatchEmbed, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.init_cfg = init_cfg
if stride is None:
stride = kernel_size
self.pad_to_patch_size = pad_to_patch_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
# The default setting of patch size is equal to kernel size.
patch_size = kernel_size
if isinstance(patch_size, int):
patch_size = to_2tuple(patch_size)
elif isinstance(patch_size, tuple):
if len(patch_size) == 1:
patch_size = to_2tuple(patch_size[0])
assert len(patch_size) == 2, \
f'The size of patch should have length 1 or 2, ' \
f'but got {len(patch_size)}'
if isinstance(padding, str):
self.adap_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of conv
padding = 0
else:
self.adap_padding = None
padding = to_2tuple(padding)
self.patch_size = patch_size
# Use conv layer to embed
conv_type = conv_type or 'Conv2d'
self.projection = build_conv_layer(
dict(type=conv_type),
in_channels=in_channels,
@ -70,31 +149,182 @@ class PatchEmbed(BaseModule):
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
dilation=dilation,
bias=bias)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
def forward(self, x):
H, W = x.shape[2], x.shape[3]
if input_size:
input_size = to_2tuple(input_size)
# `init_out_size` would be used outside to
# calculate the num_patches
# when `use_abs_pos_embed` outside
self.init_input_size = input_size
if self.adap_padding:
pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
input_h, input_w = input_size
input_h = input_h + pad_h
input_w = input_w + pad_w
input_size = (input_h, input_w)
# TODO: Process overlapping op
if self.pad_to_patch_size:
# Modify H, W to multiple of patch size.
if H % self.patch_size[0] != 0:
x = F.pad(
x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
if W % self.patch_size[1] != 0:
x = F.pad(
x, (0, self.patch_size[1] - W % self.patch_size[1], 0, 0))
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
self.init_out_size = (h_out, w_out)
else:
self.init_input_size = None
self.init_out_size = None
def forward(self, x):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if self.adap_padding:
x = self.adap_padding(x)
x = self.projection(x)
self.DH, self.DW = x.shape[2], x.shape[3]
out_size = (x.shape[2], x.shape[3])
x = x.flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x, out_size
return x
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer groups feature map by kernel_size, and applies norm and linear
layers to the grouped feature map. Our implementation uses `nn.Unfold` to
merge patch, which is about 25% faster than original implementation.
Instead, we need to modify pretrained models for compatibility.
Args:
in_channels (int): The num of input channels.
out_channels (int): The num of output channels.
kernel_size (int | tuple, optional): the kernel size in the unfold
layer. Defaults to 2.
stride (int | tuple, optional): the stride of the sliding blocks in the
unfold layer. Default: None. (Would be set as `kernel_size`)
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int | tuple, optional): dilation parameter in the unfold
layer. Default: 1.
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=2,
stride=None,
padding='corner',
dilation=1,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
if stride:
stride = stride
else:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adap_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of unfold
padding = 0
else:
self.adap_padding = None
padding = to_2tuple(padding)
self.sampler = nn.Unfold(
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
stride=stride)
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, input_size):
"""
Args:
x (Tensor): Has shape (B, H*W, C_in).
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
Default: None.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- out_size (tuple[int]): Spatial shape of x, arrange as
(Merged_H, Merged_W).
"""
B, L, C = x.shape
assert isinstance(input_size, Sequence), f'Expect ' \
f'input_size is ' \
f'`Sequence` ' \
f'but get {input_size}'
H, W = input_size
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
if self.adap_padding:
x = self.adap_padding(x)
H, W = x.shape[-2:]
x = self.sampler(x)
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
(self.sampler.kernel_size[0] - 1) -
1) // self.sampler.stride[0] + 1
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
(self.sampler.kernel_size[1] - 1) -
1) // self.sampler.stride[1] + 1
output_size = (out_h, out_w)
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
return x, output_size

View File

@ -7,10 +7,6 @@ from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN
def test_mit():
with pytest.raises(AssertionError):
# It's only support official style and mmcls style now.
MixVisionTransformer(pretrain_style='timm')
with pytest.raises(TypeError):
# Pretrained represents pretrain url and must be str or None.
MixVisionTransformer(pretrained=123)

View File

@ -1,8 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.backbones import SwinTransformer
from mmseg.models.backbones.swin import SwinBlock, SwinTransformer
def test_swin_block():
# test SwinBlock structure and forward
block = SwinBlock(embed_dims=64, num_heads=4, feedforward_channels=256)
assert block.ffn.embed_dims == 64
assert block.attn.w_msa.num_heads == 4
assert block.ffn.feedforward_channels == 256
x = torch.randn(1, 56 * 56, 64)
x_out = block(x, (56, 56))
assert x_out.shape == torch.Size([1, 56 * 56, 64])
# Test BasicBlock with checkpoint forward
block = SwinBlock(
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
assert block.with_cp
x = torch.randn(1, 56 * 56, 64)
x_out = block(x, (56, 56))
assert x_out.shape == torch.Size([1, 56 * 56, 64])
def test_swin_transformer():
@ -10,12 +28,16 @@ def test_swin_transformer():
with pytest.raises(TypeError):
# Pretrained arg must be str or None.
model = SwinTransformer(pretrained=123)
SwinTransformer(pretrained=123)
with pytest.raises(AssertionError):
# Because swin use non-overlapping patch embed, so the stride of patch
# Because swin uses non-overlapping patch embed, so the stride of patch
# embed must be equal to patch size.
model = SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)
SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)
# test pretrained image size
with pytest.raises(AssertionError):
SwinTransformer(pretrain_img_size=(224, 224, 224))
# Test absolute position embedding
temp = torch.randn((1, 3, 224, 224))
@ -27,12 +49,6 @@ def test_swin_transformer():
model = SwinTransformer(patch_norm=False)
model(temp)
# Test pretrain img size
model = SwinTransformer(pretrain_img_size=(224, ))
with pytest.raises(AssertionError):
model = SwinTransformer(pretrain_img_size=(224, 224, 224))
# Test normal inference
temp = torch.randn((1, 3, 512, 512))
model = SwinTransformer()
@ -42,7 +58,7 @@ def test_swin_transformer():
assert outs[2].shape == (1, 384, 32, 32)
assert outs[3].shape == (1, 768, 16, 16)
# Test abnormal inference
# Test abnormal inference size
temp = torch.randn((1, 3, 511, 511))
model = SwinTransformer()
outs = model(temp)
@ -51,7 +67,7 @@ def test_swin_transformer():
assert outs[2].shape == (1, 384, 32, 32)
assert outs[3].shape == (1, 768, 16, 16)
# Test abnormal inference
# Test abnormal inference size
temp = torch.randn((1, 3, 112, 137))
model = SwinTransformer()
outs = model(temp)
@ -59,3 +75,25 @@ def test_swin_transformer():
assert outs[1].shape == (1, 192, 14, 18)
assert outs[2].shape == (1, 384, 7, 9)
assert outs[3].shape == (1, 768, 4, 5)
# Test frozen
model = SwinTransformer(frozen_stages=4)
model.train()
for p in model.parameters():
assert not p.requires_grad
# Test absolute position embedding frozen
model = SwinTransformer(frozen_stages=4, use_abs_pos_embed=True)
model.train()
for p in model.parameters():
assert not p.requires_grad
# Test Swin with checkpoint forward
temp = torch.randn((1, 3, 224, 224))
model = SwinTransformer(with_cp=True)
for m in model.modules():
if isinstance(m, SwinBlock):
assert m.with_cp
model.init_weights()
model.train()
model(temp)

View File

@ -25,12 +25,6 @@ def test_vit_backbone():
x = torch.randn(1, 196)
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
with pytest.raises(IndexError):
# forward inputs must be [N, C, H, W]
x = torch.randn(3, 30, 30)
model = VisionTransformer()
model(x)
with pytest.raises(AssertionError):
# The length of img_size tuple must be lower than 3.
VisionTransformer(img_size=(224, 224, 224))

View File

View File

@ -0,0 +1,461 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.utils.embed import AdaptivePadding, PatchEmbed, PatchMerging
def test_adaptive_padding():
for padding in ('same', 'corner'):
kernel_size = 16
stride = 16
dilation = 1
input = torch.rand(1, 1, 15, 17)
adap_pool = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
out = adap_pool(input)
# padding to divisible by 16
assert (out.shape[2], out.shape[3]) == (16, 32)
input = torch.rand(1, 1, 16, 17)
out = adap_pool(input)
# padding to divisible by 16
assert (out.shape[2], out.shape[3]) == (16, 32)
kernel_size = (2, 2)
stride = (2, 2)
dilation = (1, 1)
adap_pad = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
input = torch.rand(1, 1, 11, 13)
out = adap_pad(input)
# padding to divisible by 2
assert (out.shape[2], out.shape[3]) == (12, 14)
kernel_size = (2, 2)
stride = (10, 10)
dilation = (1, 1)
adap_pad = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
input = torch.rand(1, 1, 10, 13)
out = adap_pad(input)
# no padding
assert (out.shape[2], out.shape[3]) == (10, 13)
kernel_size = (11, 11)
adap_pad = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
input = torch.rand(1, 1, 11, 13)
out = adap_pad(input)
# all padding
assert (out.shape[2], out.shape[3]) == (21, 21)
# test padding as kernel is (7,9)
input = torch.rand(1, 1, 11, 13)
stride = (3, 4)
kernel_size = (4, 5)
dilation = (2, 2)
# actually (7, 9)
adap_pad = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
dilation_out = adap_pad(input)
assert (dilation_out.shape[2], dilation_out.shape[3]) == (16, 21)
kernel_size = (7, 9)
dilation = (1, 1)
adap_pad = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
kernel79_out = adap_pad(input)
assert (kernel79_out.shape[2], kernel79_out.shape[3]) == (16, 21)
assert kernel79_out.shape == dilation_out.shape
# assert only support "same" "corner"
with pytest.raises(AssertionError):
AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=1)
def test_patch_embed():
B = 2
H = 3
W = 4
C = 3
embed_dims = 10
kernel_size = 3
stride = 1
dummy_input = torch.rand(B, C, H, W)
patch_merge_1 = PatchEmbed(
in_channels=C,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=1,
norm_cfg=None)
x1, shape = patch_merge_1(dummy_input)
# test out shape
assert x1.shape == (2, 2, 10)
# test outsize is correct
assert shape == (1, 2)
# test L = out_h * out_w
assert shape[0] * shape[1] == x1.shape[1]
B = 2
H = 10
W = 10
C = 3
embed_dims = 10
kernel_size = 5
stride = 2
dummy_input = torch.rand(B, C, H, W)
# test dilation
patch_merge_2 = PatchEmbed(
in_channels=C,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=2,
norm_cfg=None,
)
x2, shape = patch_merge_2(dummy_input)
# test out shape
assert x2.shape == (2, 1, 10)
# test outsize is correct
assert shape == (1, 1)
# test L = out_h * out_w
assert shape[0] * shape[1] == x2.shape[1]
stride = 2
input_size = (10, 10)
dummy_input = torch.rand(B, C, H, W)
# test stride and norm
patch_merge_3 = PatchEmbed(
in_channels=C,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=2,
norm_cfg=dict(type='LN'),
input_size=input_size)
x3, shape = patch_merge_3(dummy_input)
# test out shape
assert x3.shape == (2, 1, 10)
# test outsize is correct
assert shape == (1, 1)
# test L = out_h * out_w
assert shape[0] * shape[1] == x3.shape[1]
# test thte init_out_size with nn.Unfold
assert patch_merge_3.init_out_size[1] == (input_size[0] - 2 * 4 -
1) // 2 + 1
assert patch_merge_3.init_out_size[0] == (input_size[0] - 2 * 4 -
1) // 2 + 1
H = 11
W = 12
input_size = (H, W)
dummy_input = torch.rand(B, C, H, W)
# test stride and norm
patch_merge_3 = PatchEmbed(
in_channels=C,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=2,
norm_cfg=dict(type='LN'),
input_size=input_size)
_, shape = patch_merge_3(dummy_input)
# when input_size equal to real input
# the out_size shoule be equal to `init_out_size`
assert shape == patch_merge_3.init_out_size
input_size = (H, W)
dummy_input = torch.rand(B, C, H, W)
# test stride and norm
patch_merge_3 = PatchEmbed(
in_channels=C,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=2,
norm_cfg=dict(type='LN'),
input_size=input_size)
_, shape = patch_merge_3(dummy_input)
# when input_size equal to real input
# the out_size shoule be equal to `init_out_size`
assert shape == patch_merge_3.init_out_size
# test adap padding
for padding in ('same', 'corner'):
in_c = 2
embed_dims = 3
B = 2
# test stride is 1
input_size = (5, 5)
kernel_size = (5, 5)
stride = (1, 1)
dilation = 1
bias = False
x = torch.rand(B, in_c, *input_size)
patch_embed = PatchEmbed(
in_channels=in_c,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_embed(x)
assert x_out.size() == (B, 25, 3)
assert out_size == (5, 5)
assert x_out.size(1) == out_size[0] * out_size[1]
# test kernel_size == stride
input_size = (5, 5)
kernel_size = (5, 5)
stride = (5, 5)
dilation = 1
bias = False
x = torch.rand(B, in_c, *input_size)
patch_embed = PatchEmbed(
in_channels=in_c,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_embed(x)
assert x_out.size() == (B, 1, 3)
assert out_size == (1, 1)
assert x_out.size(1) == out_size[0] * out_size[1]
# test kernel_size == stride
input_size = (6, 5)
kernel_size = (5, 5)
stride = (5, 5)
dilation = 1
bias = False
x = torch.rand(B, in_c, *input_size)
patch_embed = PatchEmbed(
in_channels=in_c,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_embed(x)
assert x_out.size() == (B, 2, 3)
assert out_size == (2, 1)
assert x_out.size(1) == out_size[0] * out_size[1]
# test different kernel_size with diffrent stride
input_size = (6, 5)
kernel_size = (6, 2)
stride = (6, 2)
dilation = 1
bias = False
x = torch.rand(B, in_c, *input_size)
patch_embed = PatchEmbed(
in_channels=in_c,
embed_dims=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_embed(x)
assert x_out.size() == (B, 3, 3)
assert out_size == (1, 3)
assert x_out.size(1) == out_size[0] * out_size[1]
def test_patch_merging():
# Test the model with int padding
in_c = 3
out_c = 4
kernel_size = 3
stride = 3
padding = 1
dilation = 1
bias = False
# test the case `pad_to_stride` is False
patch_merge = PatchMerging(
in_channels=in_c,
out_channels=out_c,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
B, L, C = 1, 100, 3
input_size = (10, 10)
x = torch.rand(B, L, C)
x_out, out_size = patch_merge(x, input_size)
assert x_out.size() == (1, 16, 4)
assert out_size == (4, 4)
# assert out size is consistent with real output
assert x_out.size(1) == out_size[0] * out_size[1]
in_c = 4
out_c = 5
kernel_size = 6
stride = 3
padding = 2
dilation = 2
bias = False
patch_merge = PatchMerging(
in_channels=in_c,
out_channels=out_c,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
B, L, C = 1, 100, 4
input_size = (10, 10)
x = torch.rand(B, L, C)
x_out, out_size = patch_merge(x, input_size)
assert x_out.size() == (1, 4, 5)
assert out_size == (2, 2)
# assert out size is consistent with real output
assert x_out.size(1) == out_size[0] * out_size[1]
# Test with adaptive padding
for padding in ('same', 'corner'):
in_c = 2
out_c = 3
B = 2
# test stride is 1
input_size = (5, 5)
kernel_size = (5, 5)
stride = (1, 1)
dilation = 1
bias = False
L = input_size[0] * input_size[1]
x = torch.rand(B, L, in_c)
patch_merge = PatchMerging(
in_channels=in_c,
out_channels=out_c,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_merge(x, input_size)
assert x_out.size() == (B, 25, 3)
assert out_size == (5, 5)
assert x_out.size(1) == out_size[0] * out_size[1]
# test kernel_size == stride
input_size = (5, 5)
kernel_size = (5, 5)
stride = (5, 5)
dilation = 1
bias = False
L = input_size[0] * input_size[1]
x = torch.rand(B, L, in_c)
patch_merge = PatchMerging(
in_channels=in_c,
out_channels=out_c,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_merge(x, input_size)
assert x_out.size() == (B, 1, 3)
assert out_size == (1, 1)
assert x_out.size(1) == out_size[0] * out_size[1]
# test kernel_size == stride
input_size = (6, 5)
kernel_size = (5, 5)
stride = (5, 5)
dilation = 1
bias = False
L = input_size[0] * input_size[1]
x = torch.rand(B, L, in_c)
patch_merge = PatchMerging(
in_channels=in_c,
out_channels=out_c,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_merge(x, input_size)
assert x_out.size() == (B, 2, 3)
assert out_size == (2, 1)
assert x_out.size(1) == out_size[0] * out_size[1]
# test different kernel_size with diffrent stride
input_size = (6, 5)
kernel_size = (6, 2)
stride = (6, 2)
dilation = 1
bias = False
L = input_size[0] * input_size[1]
x = torch.rand(B, L, in_c)
patch_merge = PatchMerging(
in_channels=in_c,
out_channels=out_c,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
x_out, out_size = patch_merge(x, input_size)
assert x_out.size() == (B, 3, 3)
assert out_size == (1, 3)
assert x_out.size(1) == out_size[0] * out_size[1]