From 85227b46c7b5eea82651e23e55f11cce33275298 Mon Sep 17 00:00:00 2001 From: Jerry Jiarui XU Date: Tue, 28 Sep 2021 17:46:33 -0700 Subject: [PATCH] [Improvement] Refactor Swin-Transformer (#800) * [Improvement] Refactor Swin-Transformer * fixed swin test * update patch emebd, add more tests * fixed test * remove pretrain_style * fixed padding * resolve coments * use mmcv 2tuple * refactor init_cfg Co-authored-by: Junjun2016 --- configs/_base_/models/upernet_swin.py | 3 +- ...512x512_160k_ade20k_pretrain_224x224_1K.py | 3 +- mmseg/models/backbones/mit.py | 12 +- mmseg/models/backbones/swin.py | 298 ++++++----- mmseg/models/backbones/vit.py | 6 +- mmseg/models/utils/embed.py | 324 ++++++++++-- tests/test_models/test_backbones/test_mit.py | 4 - tests/test_models/test_backbones/test_swin.py | 64 ++- tests/test_models/test_backbones/test_vit.py | 6 - tests/test_models/test_utils/__init__.py | 0 tests/test_models/test_utils/test_embed.py | 461 ++++++++++++++++++ 11 files changed, 936 insertions(+), 245 deletions(-) create mode 100644 tests/test_models/test_utils/__init__.py create mode 100644 tests/test_models/test_utils/test_embed.py diff --git a/configs/_base_/models/upernet_swin.py b/configs/_base_/models/upernet_swin.py index 30ee050..71b5162 100644 --- a/configs/_base_/models/upernet_swin.py +++ b/configs/_base_/models/upernet_swin.py @@ -23,8 +23,7 @@ model = dict( drop_path_rate=0.3, use_abs_pos_embed=False, act_cfg=dict(type='GELU'), - norm_cfg=backbone_norm_cfg, - pretrain_style='official'), + norm_cfg=backbone_norm_cfg), decode_head=dict( type='UPerHead', in_channels=[96, 192, 384, 768], diff --git a/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py b/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py index 8dd8404..67eb4df 100644 --- a/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py +++ b/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py @@ -11,8 +11,7 @@ model = dict( window_size=7, use_abs_pos_embed=False, drop_path_rate=0.3, - patch_norm=True, - pretrain_style='official'), + patch_norm=True), decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150), auxiliary_head=dict(in_channels=384, num_classes=150)) diff --git a/mmseg/models/backbones/mit.py b/mmseg/models/backbones/mit.py index ee8bbfa..54d9856 100644 --- a/mmseg/models/backbones/mit.py +++ b/mmseg/models/backbones/mit.py @@ -278,8 +278,6 @@ class MixVisionTransformer(BaseModule): Default: dict(type='LN') act_cfg (dict): The activation config for FFNs. Defalut: dict(type='GELU'). - pretrain_style (str): Choose to use official or mmcls pretrain weights. - Default: official. pretrained (str, optional): model pretrained path. Default: None. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. @@ -302,15 +300,10 @@ class MixVisionTransformer(BaseModule): drop_path_rate=0., act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN', eps=1e-6), - pretrain_style='official', pretrained=None, init_cfg=None): super().__init__() - assert pretrain_style in [ - 'official', 'mmcls' - ], 'we only support official weights or mmcls weights.' - if isinstance(pretrained, str) or pretrained is None: warnings.warn('DeprecationWarning: pretrained is a deprecated, ' 'please use "init_cfg" instead') @@ -330,7 +323,6 @@ class MixVisionTransformer(BaseModule): self.out_indices = out_indices assert max(out_indices) < self.num_stages - self.pretrain_style = pretrain_style self.pretrained = pretrained self.init_cfg = init_cfg @@ -350,7 +342,6 @@ class MixVisionTransformer(BaseModule): kernel_size=patch_sizes[i], stride=strides[i], padding=patch_sizes[i] // 2, - pad_to_patch_size=False, norm_cfg=norm_cfg) layer = ModuleList([ TransformerEncoderLayer( @@ -403,8 +394,7 @@ class MixVisionTransformer(BaseModule): outs = [] for i, layer in enumerate(self.layers): - x, H, W = layer[0](x), layer[0].DH, layer[0].DW - hw_shape = (H, W) + x, hw_shape = layer[0](x) for block in layer[1]: x = block(x, hw_shape) x = layer[2](x) diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py index 7de1883..9133d8c 100644 --- a/mmseg/models/backbones/swin.py +++ b/mmseg/models/backbones/swin.py @@ -1,111 +1,37 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings +from collections import OrderedDict from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F -from mmcv.cnn import build_norm_layer, trunc_normal_init +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init from mmcv.cnn.bricks.transformer import FFN, build_dropout -from mmcv.cnn.utils.weight_init import constant_init -from mmcv.runner import _load_checkpoint -from mmcv.runner.base_module import BaseModule, ModuleList -from torch.nn.modules.linear import Linear -from torch.nn.modules.normalization import LayerNorm -from torch.nn.modules.utils import _pair as to_2tuple +from mmcv.runner import BaseModule, ModuleList, _load_checkpoint +from mmcv.utils import to_2tuple -from mmseg.ops import resize from ...utils import get_root_logger -from ..builder import ATTENTION, BACKBONES -from ..utils import PatchEmbed +from ..builder import BACKBONES +from ..utils.embed import PatchEmbed, PatchMerging -class PatchMerging(BaseModule): - """Merge patch feature map. - - This layer use nn.Unfold to group feature map by kernel_size, and use norm - and linear layer to embed grouped feature map. - - Args: - in_channels (int): The num of input channels. - out_channels (int): The num of output channels. - stride (int | tuple): the stride of the sliding length in the - unfold layer. Defaults: 2. (Default to be equal with kernel_size). - bias (bool, optional): Whether to add bias in linear layer or not. - Defaults: False. - norm_cfg (dict, optional): Config dict for normalization layer. - Defaults: dict(type='LN'). - init_cfg (dict, optional): The extra config for initialization. - Defaults: None. - """ - - def __init__(self, - in_channels, - out_channels, - stride=2, - bias=False, - norm_cfg=dict(type='LN'), - init_cfg=None): - super().__init__(init_cfg) - self.in_channels = in_channels - self.out_channels = out_channels - self.stride = stride - - self.sampler = nn.Unfold( - kernel_size=stride, dilation=1, padding=0, stride=stride) - - sample_dim = stride**2 * in_channels - - if norm_cfg is not None: - self.norm = build_norm_layer(norm_cfg, sample_dim)[1] - else: - self.norm = None - - self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) - - def forward(self, x, hw_shape): - """ - x: x.shape -> [B, H*W, C] - hw_shape: (H, W) - """ - B, L, C = x.shape - H, W = hw_shape - assert L == H * W, 'input feature has wrong size' - - x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W - - # stride is fixed to be equal to kernel_size. - if (H % self.stride != 0) or (W % self.stride != 0): - x = F.pad(x, (0, W % self.stride, 0, H % self.stride)) - - # Use nn.Unfold to merge patch. About 25% faster than original method, - # but need to modify pretrained model for compatibility - x = self.sampler(x) # B, 4*C, H/2*W/2 - x = x.transpose(1, 2) # B, H/2*W/2, 4*C - - x = self.norm(x) if self.norm else x - x = self.reduction(x) - - down_hw_shape = (H + 1) // 2, (W + 1) // 2 - return x, down_hw_shape - - -@ATTENTION.register_module() class WindowMSA(BaseModule): """Window based multi-head self-attention (W-MSA) module with relative position bias. Args: embed_dims (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. Default: True. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. Default: None. attn_drop_rate (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. init_cfg (dict | None, optional): The Config for initialization. Default: None. """ @@ -120,13 +46,12 @@ class WindowMSA(BaseModule): proj_drop_rate=0., init_cfg=None): - super().__init__() + super().__init__(init_cfg=init_cfg) self.embed_dims = embed_dims self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_embed_dims = embed_dims // num_heads self.scale = qk_scale or head_embed_dims**-0.5 - self.init_cfg = init_cfg # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( @@ -161,8 +86,8 @@ class WindowMSA(BaseModule): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[ - 2] # make torchscript happy (cannot use tensor as tuple) + # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = (q @ k.transpose(-2, -1)) @@ -181,9 +106,7 @@ class WindowMSA(BaseModule): attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) + attn = self.softmax(attn) attn = self.attn_drop(attn) @@ -199,9 +122,8 @@ class WindowMSA(BaseModule): return (seq1[:, None] + seq2[None, :]).reshape(1, -1) -@ATTENTION.register_module() class ShiftWindowMSA(BaseModule): - """Shift Window Multihead Self-Attention Module. + """Shifted Window Multihead Self-Attention Module. Args: embed_dims (int): Number of input channels. @@ -234,7 +156,7 @@ class ShiftWindowMSA(BaseModule): proj_drop_rate=0, dropout_layer=dict(type='DropPath', drop_prob=0.), init_cfg=None): - super().__init__(init_cfg) + super().__init__(init_cfg=init_cfg) self.window_size = window_size self.shift_size = shift_size @@ -272,8 +194,7 @@ class ShiftWindowMSA(BaseModule): dims=(1, 2)) # calculate attention mask for SW-MSA - img_mask = torch.zeros((1, H_pad, W_pad, 1), - device=query.device) # 1 H W 1 + img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device) h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) @@ -333,7 +254,6 @@ class ShiftWindowMSA(BaseModule): """ Args: windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size H (int): Height of image W (int): Width of image Returns: @@ -350,7 +270,6 @@ class ShiftWindowMSA(BaseModule): """ Args: x: (B, H, W, C) - window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ @@ -369,18 +288,21 @@ class SwinBlock(BaseModule): embed_dims (int): The feature dimension. num_heads (int): Parallel attention heads. feedforward_channels (int): The hidden dimension for FFNs. - window size (int, optional): The local window scale. Default: 7. - shift (bool): whether to shift window or not. Default False. - qkv_bias (int, optional): enable bias for qkv if True. Default: True. + window_size (int, optional): The local window scale. Default: 7. + shift (bool, optional): whether to shift window or not. Default False. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. Default: None. drop_rate (float, optional): Dropout rate. Default: 0. attn_drop_rate (float, optional): Attention dropout rate. Default: 0. - drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2. + drop_path_rate (float, optional): Stochastic depth rate. Default: 0. act_cfg (dict, optional): The config dict of activation function. Default: dict(type='GELU'). - norm_cfg (dict, optional): The config dict of nomalization. + norm_cfg (dict, optional): The config dict of normalization. Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. init_cfg (dict | list | None, optional): The init config. Default: None. """ @@ -398,11 +320,12 @@ class SwinBlock(BaseModule): drop_path_rate=0., act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), + with_cp=False, init_cfg=None): - super(SwinBlock, self).__init__() + super(SwinBlock, self).__init__(init_cfg=init_cfg) - self.init_cfg = init_cfg + self.with_cp = with_cp self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] self.attn = ShiftWindowMSA( @@ -429,15 +352,24 @@ class SwinBlock(BaseModule): init_cfg=None) def forward(self, x, hw_shape): - identity = x - x = self.norm1(x) - x = self.attn(x, hw_shape) - x = x + identity + def _inner_forward(x): + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) - identity = x - x = self.norm2(x) - x = self.ffn(x, identity=identity) + x = x + identity + + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) return x @@ -450,19 +382,23 @@ class SwinBlockSequence(BaseModule): num_heads (int): Parallel attention heads. feedforward_channels (int): The hidden dimension for FFNs. depth (int): The number of blocks in this stage. - window size (int): The local window scale. Default: 7. - qkv_bias (int): enable bias for qkv if True. Default: True. + window_size (int, optional): The local window scale. Default: 7. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. Default: None. drop_rate (float, optional): Dropout rate. Default: 0. attn_drop_rate (float, optional): Attention dropout rate. Default: 0. - drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2. + drop_path_rate (float | list[float], optional): Stochastic depth + rate. Default: 0. downsample (BaseModule | None, optional): The downsample operation module. Default: None. act_cfg (dict, optional): The config dict of activation function. Default: dict(type='GELU'). - norm_cfg (dict, optional): The config dict of nomalization. + norm_cfg (dict, optional): The config dict of normalization. Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. init_cfg (dict | list | None, optional): The init config. Default: None. """ @@ -481,14 +417,15 @@ class SwinBlockSequence(BaseModule): downsample=None, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), + with_cp=False, init_cfg=None): - super().__init__() + super().__init__(init_cfg=init_cfg) - self.init_cfg = init_cfg - - drop_path_rate = drop_path_rate if isinstance( - drop_path_rate, - list) else [deepcopy(drop_path_rate) for _ in range(depth)] + if isinstance(drop_path_rate, list): + drop_path_rates = drop_path_rate + assert len(drop_path_rates) == depth + else: + drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)] self.blocks = ModuleList() for i in range(depth): @@ -502,9 +439,10 @@ class SwinBlockSequence(BaseModule): qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate[i], + drop_path_rate=drop_path_rates[i], act_cfg=act_cfg, norm_cfg=norm_cfg, + with_cp=with_cp, init_cfg=None) self.blocks.append(block) @@ -538,7 +476,7 @@ class SwinTransformer(BaseModule): embed_dims (int): The feature dimension. Default: 96. patch_size (int | tuple[int]): Patch size. Default: 4. window_size (int): Window size. Default: 7. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. Default: 4. depths (tuple[int]): Depths of each Swin Transformer stage. Default: (2, 2, 6, 2). @@ -564,7 +502,12 @@ class SwinTransformer(BaseModule): Default: dict(type='LN'). norm_cfg (dict): Config dict for normalization layer at output of backone. Defaults: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. pretrained (str, optional): model pretrained path. Default: None. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. init_cfg (dict, optional): The Config for initialization. Defaults to None. """ @@ -589,9 +532,11 @@ class SwinTransformer(BaseModule): use_abs_pos_embed=False, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), + with_cp=False, pretrained=None, + frozen_stages=-1, init_cfg=None): - super(SwinTransformer, self).__init__() + self.frozen_stages = frozen_stages if isinstance(pretrain_img_size, int): pretrain_img_size = to_2tuple(pretrain_img_size) @@ -602,17 +547,22 @@ class SwinTransformer(BaseModule): f'The size of image should have length 1 or 2, ' \ f'but got {len(pretrain_img_size)}' - if isinstance(pretrained, str) or pretrained is None: - warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' 'please use "init_cfg" instead') + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + init_cfg = init_cfg else: raise TypeError('pretrained must be a str or None') + super(SwinTransformer, self).__init__(init_cfg=init_cfg) + num_layers = len(depths) self.out_indices = out_indices self.use_abs_pos_embed = use_abs_pos_embed - self.pretrained = pretrained - self.init_cfg = init_cfg assert strides[0] == patch_size, 'Use non-overlapping patch embed.' @@ -622,7 +572,7 @@ class SwinTransformer(BaseModule): conv_type='Conv2d', kernel_size=patch_size, stride=strides[0], - pad_to_patch_size=True, + padding='corner', norm_cfg=norm_cfg if patch_norm else None, init_cfg=None) @@ -635,11 +585,11 @@ class SwinTransformer(BaseModule): self.drop_after_pos = nn.Dropout(p=drop_rate) - # stochastic depth + # set stochastic depth decay rule total_depth = sum(depths) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, total_depth) - ] # stochastic depth decay rule + ] self.stages = ModuleList() in_channels = embed_dims @@ -664,14 +614,13 @@ class SwinTransformer(BaseModule): qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, - drop_path_rate=dpr[:depths[i]], + drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])], downsample=downsample, act_cfg=act_cfg, norm_cfg=norm_cfg, + with_cp=with_cp, init_cfg=None) self.stages.append(stage) - - dpr = dpr[depths[i]:] if downsample: in_channels = downsample.out_channels @@ -682,29 +631,67 @@ class SwinTransformer(BaseModule): layer_name = f'norm{i}' self.add_module(layer_name, layer) + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + if self.use_abs_pos_embed: + self.absolute_pos_embed.requires_grad = False + self.drop_after_pos.eval() + + for i in range(1, self.frozen_stages + 1): + + if (i - 1) in self.out_indices: + norm_layer = getattr(self, f'norm{i-1}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + m = self.stages[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + def init_weights(self): - if self.pretrained is None: - super().init_weights() + logger = get_root_logger() + if self.init_cfg is None: + logger.warn(f'No pre-trained weights for ' + f'{self.__class__.__name__}, ' + f'training start from scratch') if self.use_abs_pos_embed: trunc_normal_init(self.absolute_pos_embed, std=0.02) for m in self.modules(): - if isinstance(m, Linear): + if isinstance(m, nn.Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: constant_init(m.bias, 0) - elif isinstance(m, LayerNorm): + elif isinstance(m, nn.LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) - elif isinstance(self.pretrained, str): - logger = get_root_logger() + else: + assert 'checkpoint' in self.init_cfg, f'Only support ' \ + f'specify `Pretrained` in ' \ + f'`init_cfg` in ' \ + f'{self.__class__.__name__} ' ckpt = _load_checkpoint( - self.pretrained, logger=logger, map_location='cpu') + self.init_cfg.checkpoint, logger=logger, map_location='cpu') if 'state_dict' in ckpt: - state_dict = ckpt['state_dict'] + _state_dict = ckpt['state_dict'] elif 'model' in ckpt: - state_dict = ckpt['model'] + _state_dict = ckpt['model'] else: - state_dict = ckpt + _state_dict = ckpt + + state_dict = OrderedDict() + for k, v in _state_dict.items(): + if k.startswith('backbone.'): + state_dict[k[9:]] = v # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): @@ -733,25 +720,22 @@ class SwinTransformer(BaseModule): L2, nH2 = table_current.size() if nH1 != nH2: logger.warning(f'Error in loading {table_key}, pass') - else: - if L1 != L2: - S1 = int(L1**0.5) - S2 = int(L2**0.5) - table_pretrained_resized = resize( - table_pretrained.permute(1, 0).reshape( - 1, nH1, S1, S1), - size=(S2, S2), - mode='bicubic') - state_dict[table_key] = table_pretrained_resized.view( - nH2, L2).permute(1, 0).contiguous() + elif L1 != L2: + S1 = int(L1**0.5) + S2 = int(L2**0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view( + nH2, L2).permute(1, 0).contiguous() # load state_dict self.load_state_dict(state_dict, False) def forward(self, x): - x = self.patch_embed(x) + x, hw_shape = self.patch_embed(x) - hw_shape = (self.patch_embed.DH, self.patch_embed.DW) if self.use_abs_pos_embed: x = x + self.absolute_pos_embed x = self.drop_after_pos(x) diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 668d278..5939964 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -205,7 +205,7 @@ class VisionTransformer(BaseModule): conv_type='Conv2d', kernel_size=patch_size, stride=patch_size, - pad_to_patch_size=True, + padding='corner', norm_cfg=norm_cfg if patch_norm else None, init_cfg=None, ) @@ -370,8 +370,8 @@ class VisionTransformer(BaseModule): def forward(self, inputs): B = inputs.shape[0] - x, hw_shape = self.patch_embed(inputs), (self.patch_embed.DH, - self.patch_embed.DW) + x, hw_shape = self.patch_embed(inputs) + # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) diff --git a/mmseg/models/utils/embed.py b/mmseg/models/utils/embed.py index c0cf143..1515675 100644 --- a/mmseg/models/utils/embed.py +++ b/mmseg/models/utils/embed.py @@ -1,28 +1,109 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence + +import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.runner.base_module import BaseModule -from torch.nn.modules.utils import _pair as to_2tuple +from mmcv.utils import to_2tuple + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed) so that input can get fully covered + by filter you specified. It support two modes "same" and "corner". The + "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around + input. The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1. + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + Example: + >>> kernel_size = 16 + >>> stride = 16 + >>> dilation = 1 + >>> input = torch.rand(1, 1, 15, 17) + >>> adap_pad = AdaptivePadding( + >>> kernel_size=kernel_size, + >>> stride=stride, + >>> dilation=dilation, + >>> padding="corner") + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + >>> input = torch.rand(1, 1, 16, 17) + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + """ + + def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): + + super(AdaptivePadding, self).__init__() + + assert padding in ('same', 'corner') + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x): + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == 'corner': + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == 'same': + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2 + ]) + return x -# Modified from Pytorch-Image-Models class PatchEmbed(BaseModule): - """Image to Patch Embedding V2. + """Image to Patch Embedding. We use a conv layer to implement PatchEmbed. + Args: in_channels (int): The num of input channels. Default: 3 embed_dims (int): The dimensions of embedding. Default: 768 - conv_type (dict, optional): The config dict for conv layers type - selection. Default: None. + conv_type (str): The config dict for embedding + conv layer type selection. Default: "Conv2d". kernel_size (int): The kernel_size of embedding conv. Default: 16. - stride (int): The slide stride of embedding conv. - Default: None (Default to be equal with kernel_size). - padding (int): The padding length of embedding conv. Default: 0. + stride (int, optional): The slide stride of embedding conv. + Default: None (Would be set as `kernel_size`). + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". dilation (int): The dilation rate of embedding conv. Default: 1. - pad_to_patch_size (bool, optional): Whether to pad feature map shape - to multiple patch size. Default: True. + bias (bool): Bias of embed conv. Default: True. norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + input_size (int | tuple | None): The size of input, which will be + used to calculate the out size. Only work when `dynamic_size` + is False. Default: None. init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. Default: None. """ @@ -30,39 +111,37 @@ class PatchEmbed(BaseModule): def __init__(self, in_channels=3, embed_dims=768, - conv_type=None, + conv_type='Conv2d', kernel_size=16, - stride=16, - padding=0, + stride=None, + padding='corner', dilation=1, - pad_to_patch_size=True, + bias=True, norm_cfg=None, + input_size=None, init_cfg=None): - super(PatchEmbed, self).__init__() + super(PatchEmbed, self).__init__(init_cfg=init_cfg) self.embed_dims = embed_dims - self.init_cfg = init_cfg - if stride is None: stride = kernel_size - self.pad_to_patch_size = pad_to_patch_size + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) - # The default setting of patch size is equal to kernel size. - patch_size = kernel_size - if isinstance(patch_size, int): - patch_size = to_2tuple(patch_size) - elif isinstance(patch_size, tuple): - if len(patch_size) == 1: - patch_size = to_2tuple(patch_size[0]) - assert len(patch_size) == 2, \ - f'The size of patch should have length 1 or 2, ' \ - f'but got {len(patch_size)}' + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of conv + padding = 0 + else: + self.adap_padding = None + padding = to_2tuple(padding) - self.patch_size = patch_size - - # Use conv layer to embed - conv_type = conv_type or 'Conv2d' self.projection = build_conv_layer( dict(type=conv_type), in_channels=in_channels, @@ -70,31 +149,182 @@ class PatchEmbed(BaseModule): kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation) + dilation=dilation, + bias=bias) if norm_cfg is not None: self.norm = build_norm_layer(norm_cfg, embed_dims)[1] else: self.norm = None - def forward(self, x): - H, W = x.shape[2], x.shape[3] + if input_size: + input_size = to_2tuple(input_size) + # `init_out_size` would be used outside to + # calculate the num_patches + # when `use_abs_pos_embed` outside + self.init_input_size = input_size + if self.adap_padding: + pad_h, pad_w = self.adap_padding.get_pad_shape(input_size) + input_h, input_w = input_size + input_h = input_h + pad_h + input_w = input_w + pad_w + input_size = (input_h, input_w) - # TODO: Process overlapping op - if self.pad_to_patch_size: - # Modify H, W to multiple of patch size. - if H % self.patch_size[0] != 0: - x = F.pad( - x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) - if W % self.patch_size[1] != 0: - x = F.pad( - x, (0, self.patch_size[1] - W % self.patch_size[1], 0, 0)) + # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + h_out = (input_size[0] + 2 * padding[0] - dilation[0] * + (kernel_size[0] - 1) - 1) // stride[0] + 1 + w_out = (input_size[1] + 2 * padding[1] - dilation[1] * + (kernel_size[1] - 1) - 1) // stride[1] + 1 + self.init_out_size = (h_out, w_out) + else: + self.init_input_size = None + self.init_out_size = None + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adap_padding: + x = self.adap_padding(x) x = self.projection(x) - self.DH, self.DW = x.shape[2], x.shape[3] + out_size = (x.shape[2], x.shape[3]) x = x.flatten(2).transpose(1, 2) - if self.norm is not None: x = self.norm(x) + return x, out_size - return x + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding='corner', + dilation=1, + bias=False, + norm_cfg=dict(type='LN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold( + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f'Expect ' \ + f'input_size is ' \ + f'`Sequence` ' \ + f'but get {input_size}' + + H, W = input_size + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + H, W = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * + (self.sampler.kernel_size[0] - 1) - + 1) // self.sampler.stride[0] + 1 + out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * + (self.sampler.kernel_size[1] - 1) - + 1) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size diff --git a/tests/test_models/test_backbones/test_mit.py b/tests/test_models/test_backbones/test_mit.py index 86d98bf..536f2b3 100644 --- a/tests/test_models/test_backbones/test_mit.py +++ b/tests/test_models/test_backbones/test_mit.py @@ -7,10 +7,6 @@ from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN def test_mit(): - with pytest.raises(AssertionError): - # It's only support official style and mmcls style now. - MixVisionTransformer(pretrain_style='timm') - with pytest.raises(TypeError): # Pretrained represents pretrain url and must be str or None. MixVisionTransformer(pretrained=123) diff --git a/tests/test_models/test_backbones/test_swin.py b/tests/test_models/test_backbones/test_swin.py index edb2f83..0529d1e 100644 --- a/tests/test_models/test_backbones/test_swin.py +++ b/tests/test_models/test_backbones/test_swin.py @@ -1,8 +1,26 @@ -# Copyright (c) OpenMMLab. All rights reserved. import pytest import torch -from mmseg.models.backbones import SwinTransformer +from mmseg.models.backbones.swin import SwinBlock, SwinTransformer + + +def test_swin_block(): + # test SwinBlock structure and forward + block = SwinBlock(embed_dims=64, num_heads=4, feedforward_channels=256) + assert block.ffn.embed_dims == 64 + assert block.attn.w_msa.num_heads == 4 + assert block.ffn.feedforward_channels == 256 + x = torch.randn(1, 56 * 56, 64) + x_out = block(x, (56, 56)) + assert x_out.shape == torch.Size([1, 56 * 56, 64]) + + # Test BasicBlock with checkpoint forward + block = SwinBlock( + embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True) + assert block.with_cp + x = torch.randn(1, 56 * 56, 64) + x_out = block(x, (56, 56)) + assert x_out.shape == torch.Size([1, 56 * 56, 64]) def test_swin_transformer(): @@ -10,12 +28,16 @@ def test_swin_transformer(): with pytest.raises(TypeError): # Pretrained arg must be str or None. - model = SwinTransformer(pretrained=123) + SwinTransformer(pretrained=123) with pytest.raises(AssertionError): - # Because swin use non-overlapping patch embed, so the stride of patch + # Because swin uses non-overlapping patch embed, so the stride of patch # embed must be equal to patch size. - model = SwinTransformer(strides=(2, 2, 2, 2), patch_size=4) + SwinTransformer(strides=(2, 2, 2, 2), patch_size=4) + + # test pretrained image size + with pytest.raises(AssertionError): + SwinTransformer(pretrain_img_size=(224, 224, 224)) # Test absolute position embedding temp = torch.randn((1, 3, 224, 224)) @@ -27,12 +49,6 @@ def test_swin_transformer(): model = SwinTransformer(patch_norm=False) model(temp) - # Test pretrain img size - model = SwinTransformer(pretrain_img_size=(224, )) - - with pytest.raises(AssertionError): - model = SwinTransformer(pretrain_img_size=(224, 224, 224)) - # Test normal inference temp = torch.randn((1, 3, 512, 512)) model = SwinTransformer() @@ -42,7 +58,7 @@ def test_swin_transformer(): assert outs[2].shape == (1, 384, 32, 32) assert outs[3].shape == (1, 768, 16, 16) - # Test abnormal inference + # Test abnormal inference size temp = torch.randn((1, 3, 511, 511)) model = SwinTransformer() outs = model(temp) @@ -51,7 +67,7 @@ def test_swin_transformer(): assert outs[2].shape == (1, 384, 32, 32) assert outs[3].shape == (1, 768, 16, 16) - # Test abnormal inference + # Test abnormal inference size temp = torch.randn((1, 3, 112, 137)) model = SwinTransformer() outs = model(temp) @@ -59,3 +75,25 @@ def test_swin_transformer(): assert outs[1].shape == (1, 192, 14, 18) assert outs[2].shape == (1, 384, 7, 9) assert outs[3].shape == (1, 768, 4, 5) + + # Test frozen + model = SwinTransformer(frozen_stages=4) + model.train() + for p in model.parameters(): + assert not p.requires_grad + + # Test absolute position embedding frozen + model = SwinTransformer(frozen_stages=4, use_abs_pos_embed=True) + model.train() + for p in model.parameters(): + assert not p.requires_grad + + # Test Swin with checkpoint forward + temp = torch.randn((1, 3, 224, 224)) + model = SwinTransformer(with_cp=True) + for m in model.modules(): + if isinstance(m, SwinBlock): + assert m.with_cp + model.init_weights() + model.train() + model(temp) diff --git a/tests/test_models/test_backbones/test_vit.py b/tests/test_models/test_backbones/test_vit.py index c9afe07..5dbb51e 100644 --- a/tests/test_models/test_backbones/test_vit.py +++ b/tests/test_models/test_backbones/test_vit.py @@ -25,12 +25,6 @@ def test_vit_backbone(): x = torch.randn(1, 196) VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear') - with pytest.raises(IndexError): - # forward inputs must be [N, C, H, W] - x = torch.randn(3, 30, 30) - model = VisionTransformer() - model(x) - with pytest.raises(AssertionError): # The length of img_size tuple must be lower than 3. VisionTransformer(img_size=(224, 224, 224)) diff --git a/tests/test_models/test_utils/__init__.py b/tests/test_models/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_models/test_utils/test_embed.py b/tests/test_models/test_utils/test_embed.py new file mode 100644 index 0000000..2c6857d --- /dev/null +++ b/tests/test_models/test_utils/test_embed.py @@ -0,0 +1,461 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmseg.models.utils.embed import AdaptivePadding, PatchEmbed, PatchMerging + + +def test_adaptive_padding(): + + for padding in ('same', 'corner'): + kernel_size = 16 + stride = 16 + dilation = 1 + input = torch.rand(1, 1, 15, 17) + adap_pool = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + out = adap_pool(input) + # padding to divisible by 16 + assert (out.shape[2], out.shape[3]) == (16, 32) + input = torch.rand(1, 1, 16, 17) + out = adap_pool(input) + # padding to divisible by 16 + assert (out.shape[2], out.shape[3]) == (16, 32) + + kernel_size = (2, 2) + stride = (2, 2) + dilation = (1, 1) + + adap_pad = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + input = torch.rand(1, 1, 11, 13) + out = adap_pad(input) + # padding to divisible by 2 + assert (out.shape[2], out.shape[3]) == (12, 14) + + kernel_size = (2, 2) + stride = (10, 10) + dilation = (1, 1) + + adap_pad = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + input = torch.rand(1, 1, 10, 13) + out = adap_pad(input) + # no padding + assert (out.shape[2], out.shape[3]) == (10, 13) + + kernel_size = (11, 11) + adap_pad = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + input = torch.rand(1, 1, 11, 13) + out = adap_pad(input) + # all padding + assert (out.shape[2], out.shape[3]) == (21, 21) + + # test padding as kernel is (7,9) + input = torch.rand(1, 1, 11, 13) + stride = (3, 4) + kernel_size = (4, 5) + dilation = (2, 2) + # actually (7, 9) + adap_pad = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + dilation_out = adap_pad(input) + assert (dilation_out.shape[2], dilation_out.shape[3]) == (16, 21) + kernel_size = (7, 9) + dilation = (1, 1) + adap_pad = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + kernel79_out = adap_pad(input) + assert (kernel79_out.shape[2], kernel79_out.shape[3]) == (16, 21) + assert kernel79_out.shape == dilation_out.shape + + # assert only support "same" "corner" + with pytest.raises(AssertionError): + AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=1) + + +def test_patch_embed(): + B = 2 + H = 3 + W = 4 + C = 3 + embed_dims = 10 + kernel_size = 3 + stride = 1 + dummy_input = torch.rand(B, C, H, W) + patch_merge_1 = PatchEmbed( + in_channels=C, + embed_dims=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=1, + norm_cfg=None) + + x1, shape = patch_merge_1(dummy_input) + # test out shape + assert x1.shape == (2, 2, 10) + # test outsize is correct + assert shape == (1, 2) + # test L = out_h * out_w + assert shape[0] * shape[1] == x1.shape[1] + + B = 2 + H = 10 + W = 10 + C = 3 + embed_dims = 10 + kernel_size = 5 + stride = 2 + dummy_input = torch.rand(B, C, H, W) + # test dilation + patch_merge_2 = PatchEmbed( + in_channels=C, + embed_dims=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=2, + norm_cfg=None, + ) + + x2, shape = patch_merge_2(dummy_input) + # test out shape + assert x2.shape == (2, 1, 10) + # test outsize is correct + assert shape == (1, 1) + # test L = out_h * out_w + assert shape[0] * shape[1] == x2.shape[1] + + stride = 2 + input_size = (10, 10) + + dummy_input = torch.rand(B, C, H, W) + # test stride and norm + patch_merge_3 = PatchEmbed( + in_channels=C, + embed_dims=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=2, + norm_cfg=dict(type='LN'), + input_size=input_size) + + x3, shape = patch_merge_3(dummy_input) + # test out shape + assert x3.shape == (2, 1, 10) + # test outsize is correct + assert shape == (1, 1) + # test L = out_h * out_w + assert shape[0] * shape[1] == x3.shape[1] + + # test thte init_out_size with nn.Unfold + assert patch_merge_3.init_out_size[1] == (input_size[0] - 2 * 4 - + 1) // 2 + 1 + assert patch_merge_3.init_out_size[0] == (input_size[0] - 2 * 4 - + 1) // 2 + 1 + H = 11 + W = 12 + input_size = (H, W) + dummy_input = torch.rand(B, C, H, W) + # test stride and norm + patch_merge_3 = PatchEmbed( + in_channels=C, + embed_dims=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=2, + norm_cfg=dict(type='LN'), + input_size=input_size) + + _, shape = patch_merge_3(dummy_input) + # when input_size equal to real input + # the out_size shoule be equal to `init_out_size` + assert shape == patch_merge_3.init_out_size + + input_size = (H, W) + dummy_input = torch.rand(B, C, H, W) + # test stride and norm + patch_merge_3 = PatchEmbed( + in_channels=C, + embed_dims=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=2, + norm_cfg=dict(type='LN'), + input_size=input_size) + + _, shape = patch_merge_3(dummy_input) + # when input_size equal to real input + # the out_size shoule be equal to `init_out_size` + assert shape == patch_merge_3.init_out_size + + # test adap padding + for padding in ('same', 'corner'): + in_c = 2 + embed_dims = 3 + B = 2 + + # test stride is 1 + input_size = (5, 5) + kernel_size = (5, 5) + stride = (1, 1) + dilation = 1 + bias = False + + x = torch.rand(B, in_c, *input_size) + patch_embed = PatchEmbed( + in_channels=in_c, + embed_dims=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + x_out, out_size = patch_embed(x) + assert x_out.size() == (B, 25, 3) + assert out_size == (5, 5) + assert x_out.size(1) == out_size[0] * out_size[1] + + # test kernel_size == stride + input_size = (5, 5) + kernel_size = (5, 5) + stride = (5, 5) + dilation = 1 + bias = False + + x = torch.rand(B, in_c, *input_size) + patch_embed = PatchEmbed( + in_channels=in_c, + embed_dims=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + x_out, out_size = patch_embed(x) + assert x_out.size() == (B, 1, 3) + assert out_size == (1, 1) + assert x_out.size(1) == out_size[0] * out_size[1] + + # test kernel_size == stride + input_size = (6, 5) + kernel_size = (5, 5) + stride = (5, 5) + dilation = 1 + bias = False + + x = torch.rand(B, in_c, *input_size) + patch_embed = PatchEmbed( + in_channels=in_c, + embed_dims=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + x_out, out_size = patch_embed(x) + assert x_out.size() == (B, 2, 3) + assert out_size == (2, 1) + assert x_out.size(1) == out_size[0] * out_size[1] + + # test different kernel_size with diffrent stride + input_size = (6, 5) + kernel_size = (6, 2) + stride = (6, 2) + dilation = 1 + bias = False + + x = torch.rand(B, in_c, *input_size) + patch_embed = PatchEmbed( + in_channels=in_c, + embed_dims=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + x_out, out_size = patch_embed(x) + assert x_out.size() == (B, 3, 3) + assert out_size == (1, 3) + assert x_out.size(1) == out_size[0] * out_size[1] + + +def test_patch_merging(): + + # Test the model with int padding + in_c = 3 + out_c = 4 + kernel_size = 3 + stride = 3 + padding = 1 + dilation = 1 + bias = False + # test the case `pad_to_stride` is False + patch_merge = PatchMerging( + in_channels=in_c, + out_channels=out_c, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + B, L, C = 1, 100, 3 + input_size = (10, 10) + x = torch.rand(B, L, C) + x_out, out_size = patch_merge(x, input_size) + assert x_out.size() == (1, 16, 4) + assert out_size == (4, 4) + # assert out size is consistent with real output + assert x_out.size(1) == out_size[0] * out_size[1] + in_c = 4 + out_c = 5 + kernel_size = 6 + stride = 3 + padding = 2 + dilation = 2 + bias = False + patch_merge = PatchMerging( + in_channels=in_c, + out_channels=out_c, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + B, L, C = 1, 100, 4 + input_size = (10, 10) + x = torch.rand(B, L, C) + x_out, out_size = patch_merge(x, input_size) + assert x_out.size() == (1, 4, 5) + assert out_size == (2, 2) + # assert out size is consistent with real output + assert x_out.size(1) == out_size[0] * out_size[1] + + # Test with adaptive padding + for padding in ('same', 'corner'): + in_c = 2 + out_c = 3 + B = 2 + + # test stride is 1 + input_size = (5, 5) + kernel_size = (5, 5) + stride = (1, 1) + dilation = 1 + bias = False + L = input_size[0] * input_size[1] + + x = torch.rand(B, L, in_c) + patch_merge = PatchMerging( + in_channels=in_c, + out_channels=out_c, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + x_out, out_size = patch_merge(x, input_size) + assert x_out.size() == (B, 25, 3) + assert out_size == (5, 5) + assert x_out.size(1) == out_size[0] * out_size[1] + + # test kernel_size == stride + input_size = (5, 5) + kernel_size = (5, 5) + stride = (5, 5) + dilation = 1 + bias = False + L = input_size[0] * input_size[1] + + x = torch.rand(B, L, in_c) + patch_merge = PatchMerging( + in_channels=in_c, + out_channels=out_c, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + x_out, out_size = patch_merge(x, input_size) + assert x_out.size() == (B, 1, 3) + assert out_size == (1, 1) + assert x_out.size(1) == out_size[0] * out_size[1] + + # test kernel_size == stride + input_size = (6, 5) + kernel_size = (5, 5) + stride = (5, 5) + dilation = 1 + bias = False + L = input_size[0] * input_size[1] + + x = torch.rand(B, L, in_c) + patch_merge = PatchMerging( + in_channels=in_c, + out_channels=out_c, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + x_out, out_size = patch_merge(x, input_size) + assert x_out.size() == (B, 2, 3) + assert out_size == (2, 1) + assert x_out.size(1) == out_size[0] * out_size[1] + + # test different kernel_size with diffrent stride + input_size = (6, 5) + kernel_size = (6, 2) + stride = (6, 2) + dilation = 1 + bias = False + L = input_size[0] * input_size[1] + + x = torch.rand(B, L, in_c) + patch_merge = PatchMerging( + in_channels=in_c, + out_channels=out_c, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + x_out, out_size = patch_merge(x, input_size) + assert x_out.size() == (B, 3, 3) + assert out_size == (1, 3) + assert x_out.size(1) == out_size[0] * out_size[1]