[Improvement] Refactor Swin-Transformer (#800)
* [Improvement] Refactor Swin-Transformer * fixed swin test * update patch emebd, add more tests * fixed test * remove pretrain_style * fixed padding * resolve coments * use mmcv 2tuple * refactor init_cfg Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
This commit is contained in:
parent
ab12009414
commit
85227b46c7
@ -23,8 +23,7 @@ model = dict(
|
|||||||
drop_path_rate=0.3,
|
drop_path_rate=0.3,
|
||||||
use_abs_pos_embed=False,
|
use_abs_pos_embed=False,
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_cfg=backbone_norm_cfg,
|
norm_cfg=backbone_norm_cfg),
|
||||||
pretrain_style='official'),
|
|
||||||
decode_head=dict(
|
decode_head=dict(
|
||||||
type='UPerHead',
|
type='UPerHead',
|
||||||
in_channels=[96, 192, 384, 768],
|
in_channels=[96, 192, 384, 768],
|
||||||
|
|||||||
@ -11,8 +11,7 @@ model = dict(
|
|||||||
window_size=7,
|
window_size=7,
|
||||||
use_abs_pos_embed=False,
|
use_abs_pos_embed=False,
|
||||||
drop_path_rate=0.3,
|
drop_path_rate=0.3,
|
||||||
patch_norm=True,
|
patch_norm=True),
|
||||||
pretrain_style='official'),
|
|
||||||
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
|
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
|
||||||
auxiliary_head=dict(in_channels=384, num_classes=150))
|
auxiliary_head=dict(in_channels=384, num_classes=150))
|
||||||
|
|
||||||
|
|||||||
@ -278,8 +278,6 @@ class MixVisionTransformer(BaseModule):
|
|||||||
Default: dict(type='LN')
|
Default: dict(type='LN')
|
||||||
act_cfg (dict): The activation config for FFNs.
|
act_cfg (dict): The activation config for FFNs.
|
||||||
Defalut: dict(type='GELU').
|
Defalut: dict(type='GELU').
|
||||||
pretrain_style (str): Choose to use official or mmcls pretrain weights.
|
|
||||||
Default: official.
|
|
||||||
pretrained (str, optional): model pretrained path. Default: None.
|
pretrained (str, optional): model pretrained path. Default: None.
|
||||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
Default: None.
|
Default: None.
|
||||||
@ -302,15 +300,10 @@ class MixVisionTransformer(BaseModule):
|
|||||||
drop_path_rate=0.,
|
drop_path_rate=0.,
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_cfg=dict(type='LN', eps=1e-6),
|
norm_cfg=dict(type='LN', eps=1e-6),
|
||||||
pretrain_style='official',
|
|
||||||
pretrained=None,
|
pretrained=None,
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert pretrain_style in [
|
|
||||||
'official', 'mmcls'
|
|
||||||
], 'we only support official weights or mmcls weights.'
|
|
||||||
|
|
||||||
if isinstance(pretrained, str) or pretrained is None:
|
if isinstance(pretrained, str) or pretrained is None:
|
||||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||||
'please use "init_cfg" instead')
|
'please use "init_cfg" instead')
|
||||||
@ -330,7 +323,6 @@ class MixVisionTransformer(BaseModule):
|
|||||||
|
|
||||||
self.out_indices = out_indices
|
self.out_indices = out_indices
|
||||||
assert max(out_indices) < self.num_stages
|
assert max(out_indices) < self.num_stages
|
||||||
self.pretrain_style = pretrain_style
|
|
||||||
self.pretrained = pretrained
|
self.pretrained = pretrained
|
||||||
self.init_cfg = init_cfg
|
self.init_cfg = init_cfg
|
||||||
|
|
||||||
@ -350,7 +342,6 @@ class MixVisionTransformer(BaseModule):
|
|||||||
kernel_size=patch_sizes[i],
|
kernel_size=patch_sizes[i],
|
||||||
stride=strides[i],
|
stride=strides[i],
|
||||||
padding=patch_sizes[i] // 2,
|
padding=patch_sizes[i] // 2,
|
||||||
pad_to_patch_size=False,
|
|
||||||
norm_cfg=norm_cfg)
|
norm_cfg=norm_cfg)
|
||||||
layer = ModuleList([
|
layer = ModuleList([
|
||||||
TransformerEncoderLayer(
|
TransformerEncoderLayer(
|
||||||
@ -403,8 +394,7 @@ class MixVisionTransformer(BaseModule):
|
|||||||
outs = []
|
outs = []
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
x, H, W = layer[0](x), layer[0].DH, layer[0].DW
|
x, hw_shape = layer[0](x)
|
||||||
hw_shape = (H, W)
|
|
||||||
for block in layer[1]:
|
for block in layer[1]:
|
||||||
x = block(x, hw_shape)
|
x = block(x, hw_shape)
|
||||||
x = layer[2](x)
|
x = layer[2](x)
|
||||||
|
|||||||
@ -1,111 +1,37 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import OrderedDict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mmcv.cnn import build_norm_layer, trunc_normal_init
|
import torch.utils.checkpoint as cp
|
||||||
|
from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
|
||||||
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
||||||
from mmcv.cnn.utils.weight_init import constant_init
|
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
||||||
from mmcv.runner import _load_checkpoint
|
from mmcv.utils import to_2tuple
|
||||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
|
||||||
from torch.nn.modules.linear import Linear
|
|
||||||
from torch.nn.modules.normalization import LayerNorm
|
|
||||||
from torch.nn.modules.utils import _pair as to_2tuple
|
|
||||||
|
|
||||||
from mmseg.ops import resize
|
|
||||||
from ...utils import get_root_logger
|
from ...utils import get_root_logger
|
||||||
from ..builder import ATTENTION, BACKBONES
|
from ..builder import BACKBONES
|
||||||
from ..utils import PatchEmbed
|
from ..utils.embed import PatchEmbed, PatchMerging
|
||||||
|
|
||||||
|
|
||||||
class PatchMerging(BaseModule):
|
|
||||||
"""Merge patch feature map.
|
|
||||||
|
|
||||||
This layer use nn.Unfold to group feature map by kernel_size, and use norm
|
|
||||||
and linear layer to embed grouped feature map.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int): The num of input channels.
|
|
||||||
out_channels (int): The num of output channels.
|
|
||||||
stride (int | tuple): the stride of the sliding length in the
|
|
||||||
unfold layer. Defaults: 2. (Default to be equal with kernel_size).
|
|
||||||
bias (bool, optional): Whether to add bias in linear layer or not.
|
|
||||||
Defaults: False.
|
|
||||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
||||||
Defaults: dict(type='LN').
|
|
||||||
init_cfg (dict, optional): The extra config for initialization.
|
|
||||||
Defaults: None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
stride=2,
|
|
||||||
bias=False,
|
|
||||||
norm_cfg=dict(type='LN'),
|
|
||||||
init_cfg=None):
|
|
||||||
super().__init__(init_cfg)
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.stride = stride
|
|
||||||
|
|
||||||
self.sampler = nn.Unfold(
|
|
||||||
kernel_size=stride, dilation=1, padding=0, stride=stride)
|
|
||||||
|
|
||||||
sample_dim = stride**2 * in_channels
|
|
||||||
|
|
||||||
if norm_cfg is not None:
|
|
||||||
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
|
|
||||||
else:
|
|
||||||
self.norm = None
|
|
||||||
|
|
||||||
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x, hw_shape):
|
|
||||||
"""
|
|
||||||
x: x.shape -> [B, H*W, C]
|
|
||||||
hw_shape: (H, W)
|
|
||||||
"""
|
|
||||||
B, L, C = x.shape
|
|
||||||
H, W = hw_shape
|
|
||||||
assert L == H * W, 'input feature has wrong size'
|
|
||||||
|
|
||||||
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
|
|
||||||
|
|
||||||
# stride is fixed to be equal to kernel_size.
|
|
||||||
if (H % self.stride != 0) or (W % self.stride != 0):
|
|
||||||
x = F.pad(x, (0, W % self.stride, 0, H % self.stride))
|
|
||||||
|
|
||||||
# Use nn.Unfold to merge patch. About 25% faster than original method,
|
|
||||||
# but need to modify pretrained model for compatibility
|
|
||||||
x = self.sampler(x) # B, 4*C, H/2*W/2
|
|
||||||
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
|
|
||||||
|
|
||||||
x = self.norm(x) if self.norm else x
|
|
||||||
x = self.reduction(x)
|
|
||||||
|
|
||||||
down_hw_shape = (H + 1) // 2, (W + 1) // 2
|
|
||||||
return x, down_hw_shape
|
|
||||||
|
|
||||||
|
|
||||||
@ATTENTION.register_module()
|
|
||||||
class WindowMSA(BaseModule):
|
class WindowMSA(BaseModule):
|
||||||
"""Window based multi-head self-attention (W-MSA) module with relative
|
"""Window based multi-head self-attention (W-MSA) module with relative
|
||||||
position bias.
|
position bias.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embed_dims (int): Number of input channels.
|
embed_dims (int): Number of input channels.
|
||||||
window_size (tuple[int]): The height and width of the window.
|
|
||||||
num_heads (int): Number of attention heads.
|
num_heads (int): Number of attention heads.
|
||||||
|
window_size (tuple[int]): The height and width of the window.
|
||||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||||
Default: True.
|
Default: True.
|
||||||
qk_scale (float | None, optional): Override default qk scale of
|
qk_scale (float | None, optional): Override default qk scale of
|
||||||
head_dim ** -0.5 if set. Default: None.
|
head_dim ** -0.5 if set. Default: None.
|
||||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||||
Default: 0.0
|
Default: 0.0
|
||||||
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.0
|
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
|
||||||
init_cfg (dict | None, optional): The Config for initialization.
|
init_cfg (dict | None, optional): The Config for initialization.
|
||||||
Default: None.
|
Default: None.
|
||||||
"""
|
"""
|
||||||
@ -120,13 +46,12 @@ class WindowMSA(BaseModule):
|
|||||||
proj_drop_rate=0.,
|
proj_drop_rate=0.,
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__(init_cfg=init_cfg)
|
||||||
self.embed_dims = embed_dims
|
self.embed_dims = embed_dims
|
||||||
self.window_size = window_size # Wh, Ww
|
self.window_size = window_size # Wh, Ww
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
head_embed_dims = embed_dims // num_heads
|
head_embed_dims = embed_dims // num_heads
|
||||||
self.scale = qk_scale or head_embed_dims**-0.5
|
self.scale = qk_scale or head_embed_dims**-0.5
|
||||||
self.init_cfg = init_cfg
|
|
||||||
|
|
||||||
# define a parameter table of relative position bias
|
# define a parameter table of relative position bias
|
||||||
self.relative_position_bias_table = nn.Parameter(
|
self.relative_position_bias_table = nn.Parameter(
|
||||||
@ -161,8 +86,8 @@ class WindowMSA(BaseModule):
|
|||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
q, k, v = qkv[0], qkv[1], qkv[
|
# make torchscript happy (cannot use tensor as tuple)
|
||||||
2] # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||||
|
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = (q @ k.transpose(-2, -1))
|
attn = (q @ k.transpose(-2, -1))
|
||||||
@ -182,8 +107,6 @@ class WindowMSA(BaseModule):
|
|||||||
N) + mask.unsqueeze(1).unsqueeze(0)
|
N) + mask.unsqueeze(1).unsqueeze(0)
|
||||||
attn = attn.view(-1, self.num_heads, N, N)
|
attn = attn.view(-1, self.num_heads, N, N)
|
||||||
attn = self.softmax(attn)
|
attn = self.softmax(attn)
|
||||||
else:
|
|
||||||
attn = self.softmax(attn)
|
|
||||||
|
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
@ -199,9 +122,8 @@ class WindowMSA(BaseModule):
|
|||||||
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
|
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
|
||||||
|
|
||||||
|
|
||||||
@ATTENTION.register_module()
|
|
||||||
class ShiftWindowMSA(BaseModule):
|
class ShiftWindowMSA(BaseModule):
|
||||||
"""Shift Window Multihead Self-Attention Module.
|
"""Shifted Window Multihead Self-Attention Module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embed_dims (int): Number of input channels.
|
embed_dims (int): Number of input channels.
|
||||||
@ -234,7 +156,7 @@ class ShiftWindowMSA(BaseModule):
|
|||||||
proj_drop_rate=0,
|
proj_drop_rate=0,
|
||||||
dropout_layer=dict(type='DropPath', drop_prob=0.),
|
dropout_layer=dict(type='DropPath', drop_prob=0.),
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
super().__init__(init_cfg)
|
super().__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.shift_size = shift_size
|
self.shift_size = shift_size
|
||||||
@ -272,8 +194,7 @@ class ShiftWindowMSA(BaseModule):
|
|||||||
dims=(1, 2))
|
dims=(1, 2))
|
||||||
|
|
||||||
# calculate attention mask for SW-MSA
|
# calculate attention mask for SW-MSA
|
||||||
img_mask = torch.zeros((1, H_pad, W_pad, 1),
|
img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
|
||||||
device=query.device) # 1 H W 1
|
|
||||||
h_slices = (slice(0, -self.window_size),
|
h_slices = (slice(0, -self.window_size),
|
||||||
slice(-self.window_size,
|
slice(-self.window_size,
|
||||||
-self.shift_size), slice(-self.shift_size, None))
|
-self.shift_size), slice(-self.shift_size, None))
|
||||||
@ -333,7 +254,6 @@ class ShiftWindowMSA(BaseModule):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
windows: (num_windows*B, window_size, window_size, C)
|
windows: (num_windows*B, window_size, window_size, C)
|
||||||
window_size (int): Window size
|
|
||||||
H (int): Height of image
|
H (int): Height of image
|
||||||
W (int): Width of image
|
W (int): Width of image
|
||||||
Returns:
|
Returns:
|
||||||
@ -350,7 +270,6 @@ class ShiftWindowMSA(BaseModule):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: (B, H, W, C)
|
x: (B, H, W, C)
|
||||||
window_size (int): window size
|
|
||||||
Returns:
|
Returns:
|
||||||
windows: (num_windows*B, window_size, window_size, C)
|
windows: (num_windows*B, window_size, window_size, C)
|
||||||
"""
|
"""
|
||||||
@ -369,18 +288,21 @@ class SwinBlock(BaseModule):
|
|||||||
embed_dims (int): The feature dimension.
|
embed_dims (int): The feature dimension.
|
||||||
num_heads (int): Parallel attention heads.
|
num_heads (int): Parallel attention heads.
|
||||||
feedforward_channels (int): The hidden dimension for FFNs.
|
feedforward_channels (int): The hidden dimension for FFNs.
|
||||||
window size (int, optional): The local window scale. Default: 7.
|
window_size (int, optional): The local window scale. Default: 7.
|
||||||
shift (bool): whether to shift window or not. Default False.
|
shift (bool, optional): whether to shift window or not. Default False.
|
||||||
qkv_bias (int, optional): enable bias for qkv if True. Default: True.
|
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
|
||||||
qk_scale (float | None, optional): Override default qk scale of
|
qk_scale (float | None, optional): Override default qk scale of
|
||||||
head_dim ** -0.5 if set. Default: None.
|
head_dim ** -0.5 if set. Default: None.
|
||||||
drop_rate (float, optional): Dropout rate. Default: 0.
|
drop_rate (float, optional): Dropout rate. Default: 0.
|
||||||
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
||||||
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
|
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
|
||||||
act_cfg (dict, optional): The config dict of activation function.
|
act_cfg (dict, optional): The config dict of activation function.
|
||||||
Default: dict(type='GELU').
|
Default: dict(type='GELU').
|
||||||
norm_cfg (dict, optional): The config dict of nomalization.
|
norm_cfg (dict, optional): The config dict of normalization.
|
||||||
Default: dict(type='LN').
|
Default: dict(type='LN').
|
||||||
|
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||||
|
will save some memory while slowing down the training speed.
|
||||||
|
Default: False.
|
||||||
init_cfg (dict | list | None, optional): The init config.
|
init_cfg (dict | list | None, optional): The init config.
|
||||||
Default: None.
|
Default: None.
|
||||||
"""
|
"""
|
||||||
@ -398,11 +320,12 @@ class SwinBlock(BaseModule):
|
|||||||
drop_path_rate=0.,
|
drop_path_rate=0.,
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_cfg=dict(type='LN'),
|
norm_cfg=dict(type='LN'),
|
||||||
|
with_cp=False,
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
|
|
||||||
super(SwinBlock, self).__init__()
|
super(SwinBlock, self).__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
self.init_cfg = init_cfg
|
self.with_cp = with_cp
|
||||||
|
|
||||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||||
self.attn = ShiftWindowMSA(
|
self.attn = ShiftWindowMSA(
|
||||||
@ -429,6 +352,8 @@ class SwinBlock(BaseModule):
|
|||||||
init_cfg=None)
|
init_cfg=None)
|
||||||
|
|
||||||
def forward(self, x, hw_shape):
|
def forward(self, x, hw_shape):
|
||||||
|
|
||||||
|
def _inner_forward(x):
|
||||||
identity = x
|
identity = x
|
||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
x = self.attn(x, hw_shape)
|
x = self.attn(x, hw_shape)
|
||||||
@ -441,6 +366,13 @@ class SwinBlock(BaseModule):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
if self.with_cp and x.requires_grad:
|
||||||
|
x = cp.checkpoint(_inner_forward, x)
|
||||||
|
else:
|
||||||
|
x = _inner_forward(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SwinBlockSequence(BaseModule):
|
class SwinBlockSequence(BaseModule):
|
||||||
"""Implements one stage in Swin Transformer.
|
"""Implements one stage in Swin Transformer.
|
||||||
@ -450,19 +382,23 @@ class SwinBlockSequence(BaseModule):
|
|||||||
num_heads (int): Parallel attention heads.
|
num_heads (int): Parallel attention heads.
|
||||||
feedforward_channels (int): The hidden dimension for FFNs.
|
feedforward_channels (int): The hidden dimension for FFNs.
|
||||||
depth (int): The number of blocks in this stage.
|
depth (int): The number of blocks in this stage.
|
||||||
window size (int): The local window scale. Default: 7.
|
window_size (int, optional): The local window scale. Default: 7.
|
||||||
qkv_bias (int): enable bias for qkv if True. Default: True.
|
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
|
||||||
qk_scale (float | None, optional): Override default qk scale of
|
qk_scale (float | None, optional): Override default qk scale of
|
||||||
head_dim ** -0.5 if set. Default: None.
|
head_dim ** -0.5 if set. Default: None.
|
||||||
drop_rate (float, optional): Dropout rate. Default: 0.
|
drop_rate (float, optional): Dropout rate. Default: 0.
|
||||||
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
||||||
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
|
drop_path_rate (float | list[float], optional): Stochastic depth
|
||||||
|
rate. Default: 0.
|
||||||
downsample (BaseModule | None, optional): The downsample operation
|
downsample (BaseModule | None, optional): The downsample operation
|
||||||
module. Default: None.
|
module. Default: None.
|
||||||
act_cfg (dict, optional): The config dict of activation function.
|
act_cfg (dict, optional): The config dict of activation function.
|
||||||
Default: dict(type='GELU').
|
Default: dict(type='GELU').
|
||||||
norm_cfg (dict, optional): The config dict of nomalization.
|
norm_cfg (dict, optional): The config dict of normalization.
|
||||||
Default: dict(type='LN').
|
Default: dict(type='LN').
|
||||||
|
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||||
|
will save some memory while slowing down the training speed.
|
||||||
|
Default: False.
|
||||||
init_cfg (dict | list | None, optional): The init config.
|
init_cfg (dict | list | None, optional): The init config.
|
||||||
Default: None.
|
Default: None.
|
||||||
"""
|
"""
|
||||||
@ -481,14 +417,15 @@ class SwinBlockSequence(BaseModule):
|
|||||||
downsample=None,
|
downsample=None,
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_cfg=dict(type='LN'),
|
norm_cfg=dict(type='LN'),
|
||||||
|
with_cp=False,
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
super().__init__()
|
super().__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
self.init_cfg = init_cfg
|
if isinstance(drop_path_rate, list):
|
||||||
|
drop_path_rates = drop_path_rate
|
||||||
drop_path_rate = drop_path_rate if isinstance(
|
assert len(drop_path_rates) == depth
|
||||||
drop_path_rate,
|
else:
|
||||||
list) else [deepcopy(drop_path_rate) for _ in range(depth)]
|
drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
|
||||||
|
|
||||||
self.blocks = ModuleList()
|
self.blocks = ModuleList()
|
||||||
for i in range(depth):
|
for i in range(depth):
|
||||||
@ -502,9 +439,10 @@ class SwinBlockSequence(BaseModule):
|
|||||||
qk_scale=qk_scale,
|
qk_scale=qk_scale,
|
||||||
drop_rate=drop_rate,
|
drop_rate=drop_rate,
|
||||||
attn_drop_rate=attn_drop_rate,
|
attn_drop_rate=attn_drop_rate,
|
||||||
drop_path_rate=drop_path_rate[i],
|
drop_path_rate=drop_path_rates[i],
|
||||||
act_cfg=act_cfg,
|
act_cfg=act_cfg,
|
||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
|
with_cp=with_cp,
|
||||||
init_cfg=None)
|
init_cfg=None)
|
||||||
self.blocks.append(block)
|
self.blocks.append(block)
|
||||||
|
|
||||||
@ -538,7 +476,7 @@ class SwinTransformer(BaseModule):
|
|||||||
embed_dims (int): The feature dimension. Default: 96.
|
embed_dims (int): The feature dimension. Default: 96.
|
||||||
patch_size (int | tuple[int]): Patch size. Default: 4.
|
patch_size (int | tuple[int]): Patch size. Default: 4.
|
||||||
window_size (int): Window size. Default: 7.
|
window_size (int): Window size. Default: 7.
|
||||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
||||||
Default: 4.
|
Default: 4.
|
||||||
depths (tuple[int]): Depths of each Swin Transformer stage.
|
depths (tuple[int]): Depths of each Swin Transformer stage.
|
||||||
Default: (2, 2, 6, 2).
|
Default: (2, 2, 6, 2).
|
||||||
@ -564,7 +502,12 @@ class SwinTransformer(BaseModule):
|
|||||||
Default: dict(type='LN').
|
Default: dict(type='LN').
|
||||||
norm_cfg (dict): Config dict for normalization layer at
|
norm_cfg (dict): Config dict for normalization layer at
|
||||||
output of backone. Defaults: dict(type='LN').
|
output of backone. Defaults: dict(type='LN').
|
||||||
|
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||||
|
will save some memory while slowing down the training speed.
|
||||||
|
Default: False.
|
||||||
pretrained (str, optional): model pretrained path. Default: None.
|
pretrained (str, optional): model pretrained path. Default: None.
|
||||||
|
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||||
|
-1 means not freezing any parameters.
|
||||||
init_cfg (dict, optional): The Config for initialization.
|
init_cfg (dict, optional): The Config for initialization.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
@ -589,9 +532,11 @@ class SwinTransformer(BaseModule):
|
|||||||
use_abs_pos_embed=False,
|
use_abs_pos_embed=False,
|
||||||
act_cfg=dict(type='GELU'),
|
act_cfg=dict(type='GELU'),
|
||||||
norm_cfg=dict(type='LN'),
|
norm_cfg=dict(type='LN'),
|
||||||
|
with_cp=False,
|
||||||
pretrained=None,
|
pretrained=None,
|
||||||
|
frozen_stages=-1,
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
super(SwinTransformer, self).__init__()
|
self.frozen_stages = frozen_stages
|
||||||
|
|
||||||
if isinstance(pretrain_img_size, int):
|
if isinstance(pretrain_img_size, int):
|
||||||
pretrain_img_size = to_2tuple(pretrain_img_size)
|
pretrain_img_size = to_2tuple(pretrain_img_size)
|
||||||
@ -602,17 +547,22 @@ class SwinTransformer(BaseModule):
|
|||||||
f'The size of image should have length 1 or 2, ' \
|
f'The size of image should have length 1 or 2, ' \
|
||||||
f'but got {len(pretrain_img_size)}'
|
f'but got {len(pretrain_img_size)}'
|
||||||
|
|
||||||
if isinstance(pretrained, str) or pretrained is None:
|
assert not (init_cfg and pretrained), \
|
||||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
'init_cfg and pretrained cannot be specified at the same time'
|
||||||
|
if isinstance(pretrained, str):
|
||||||
|
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||||
'please use "init_cfg" instead')
|
'please use "init_cfg" instead')
|
||||||
|
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||||
|
elif pretrained is None:
|
||||||
|
init_cfg = init_cfg
|
||||||
else:
|
else:
|
||||||
raise TypeError('pretrained must be a str or None')
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
|
super(SwinTransformer, self).__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
num_layers = len(depths)
|
num_layers = len(depths)
|
||||||
self.out_indices = out_indices
|
self.out_indices = out_indices
|
||||||
self.use_abs_pos_embed = use_abs_pos_embed
|
self.use_abs_pos_embed = use_abs_pos_embed
|
||||||
self.pretrained = pretrained
|
|
||||||
self.init_cfg = init_cfg
|
|
||||||
|
|
||||||
assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
|
assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
|
||||||
|
|
||||||
@ -622,7 +572,7 @@ class SwinTransformer(BaseModule):
|
|||||||
conv_type='Conv2d',
|
conv_type='Conv2d',
|
||||||
kernel_size=patch_size,
|
kernel_size=patch_size,
|
||||||
stride=strides[0],
|
stride=strides[0],
|
||||||
pad_to_patch_size=True,
|
padding='corner',
|
||||||
norm_cfg=norm_cfg if patch_norm else None,
|
norm_cfg=norm_cfg if patch_norm else None,
|
||||||
init_cfg=None)
|
init_cfg=None)
|
||||||
|
|
||||||
@ -635,11 +585,11 @@ class SwinTransformer(BaseModule):
|
|||||||
|
|
||||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
# stochastic depth
|
# set stochastic depth decay rule
|
||||||
total_depth = sum(depths)
|
total_depth = sum(depths)
|
||||||
dpr = [
|
dpr = [
|
||||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||||
] # stochastic depth decay rule
|
]
|
||||||
|
|
||||||
self.stages = ModuleList()
|
self.stages = ModuleList()
|
||||||
in_channels = embed_dims
|
in_channels = embed_dims
|
||||||
@ -664,14 +614,13 @@ class SwinTransformer(BaseModule):
|
|||||||
qk_scale=qk_scale,
|
qk_scale=qk_scale,
|
||||||
drop_rate=drop_rate,
|
drop_rate=drop_rate,
|
||||||
attn_drop_rate=attn_drop_rate,
|
attn_drop_rate=attn_drop_rate,
|
||||||
drop_path_rate=dpr[:depths[i]],
|
drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
||||||
downsample=downsample,
|
downsample=downsample,
|
||||||
act_cfg=act_cfg,
|
act_cfg=act_cfg,
|
||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
|
with_cp=with_cp,
|
||||||
init_cfg=None)
|
init_cfg=None)
|
||||||
self.stages.append(stage)
|
self.stages.append(stage)
|
||||||
|
|
||||||
dpr = dpr[depths[i]:]
|
|
||||||
if downsample:
|
if downsample:
|
||||||
in_channels = downsample.out_channels
|
in_channels = downsample.out_channels
|
||||||
|
|
||||||
@ -682,29 +631,67 @@ class SwinTransformer(BaseModule):
|
|||||||
layer_name = f'norm{i}'
|
layer_name = f'norm{i}'
|
||||||
self.add_module(layer_name, layer)
|
self.add_module(layer_name, layer)
|
||||||
|
|
||||||
|
def train(self, mode=True):
|
||||||
|
"""Convert the model into training mode while keep layers freezed."""
|
||||||
|
super(SwinTransformer, self).train(mode)
|
||||||
|
self._freeze_stages()
|
||||||
|
|
||||||
|
def _freeze_stages(self):
|
||||||
|
if self.frozen_stages >= 0:
|
||||||
|
self.patch_embed.eval()
|
||||||
|
for param in self.patch_embed.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
if self.use_abs_pos_embed:
|
||||||
|
self.absolute_pos_embed.requires_grad = False
|
||||||
|
self.drop_after_pos.eval()
|
||||||
|
|
||||||
|
for i in range(1, self.frozen_stages + 1):
|
||||||
|
|
||||||
|
if (i - 1) in self.out_indices:
|
||||||
|
norm_layer = getattr(self, f'norm{i-1}')
|
||||||
|
norm_layer.eval()
|
||||||
|
for param in norm_layer.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
m = self.stages[i - 1]
|
||||||
|
m.eval()
|
||||||
|
for param in m.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
if self.pretrained is None:
|
logger = get_root_logger()
|
||||||
super().init_weights()
|
if self.init_cfg is None:
|
||||||
|
logger.warn(f'No pre-trained weights for '
|
||||||
|
f'{self.__class__.__name__}, '
|
||||||
|
f'training start from scratch')
|
||||||
if self.use_abs_pos_embed:
|
if self.use_abs_pos_embed:
|
||||||
trunc_normal_init(self.absolute_pos_embed, std=0.02)
|
trunc_normal_init(self.absolute_pos_embed, std=0.02)
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, Linear):
|
if isinstance(m, nn.Linear):
|
||||||
trunc_normal_init(m.weight, std=.02)
|
trunc_normal_init(m.weight, std=.02)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
constant_init(m.bias, 0)
|
constant_init(m.bias, 0)
|
||||||
elif isinstance(m, LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
constant_init(m.bias, 0)
|
constant_init(m.bias, 0)
|
||||||
constant_init(m.weight, 1.0)
|
constant_init(m.weight, 1.0)
|
||||||
elif isinstance(self.pretrained, str):
|
|
||||||
logger = get_root_logger()
|
|
||||||
ckpt = _load_checkpoint(
|
|
||||||
self.pretrained, logger=logger, map_location='cpu')
|
|
||||||
if 'state_dict' in ckpt:
|
|
||||||
state_dict = ckpt['state_dict']
|
|
||||||
elif 'model' in ckpt:
|
|
||||||
state_dict = ckpt['model']
|
|
||||||
else:
|
else:
|
||||||
state_dict = ckpt
|
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
||||||
|
f'specify `Pretrained` in ' \
|
||||||
|
f'`init_cfg` in ' \
|
||||||
|
f'{self.__class__.__name__} '
|
||||||
|
ckpt = _load_checkpoint(
|
||||||
|
self.init_cfg.checkpoint, logger=logger, map_location='cpu')
|
||||||
|
if 'state_dict' in ckpt:
|
||||||
|
_state_dict = ckpt['state_dict']
|
||||||
|
elif 'model' in ckpt:
|
||||||
|
_state_dict = ckpt['model']
|
||||||
|
else:
|
||||||
|
_state_dict = ckpt
|
||||||
|
|
||||||
|
state_dict = OrderedDict()
|
||||||
|
for k, v in _state_dict.items():
|
||||||
|
if k.startswith('backbone.'):
|
||||||
|
state_dict[k[9:]] = v
|
||||||
|
|
||||||
# strip prefix of state_dict
|
# strip prefix of state_dict
|
||||||
if list(state_dict.keys())[0].startswith('module.'):
|
if list(state_dict.keys())[0].startswith('module.'):
|
||||||
@ -733,13 +720,11 @@ class SwinTransformer(BaseModule):
|
|||||||
L2, nH2 = table_current.size()
|
L2, nH2 = table_current.size()
|
||||||
if nH1 != nH2:
|
if nH1 != nH2:
|
||||||
logger.warning(f'Error in loading {table_key}, pass')
|
logger.warning(f'Error in loading {table_key}, pass')
|
||||||
else:
|
elif L1 != L2:
|
||||||
if L1 != L2:
|
|
||||||
S1 = int(L1**0.5)
|
S1 = int(L1**0.5)
|
||||||
S2 = int(L2**0.5)
|
S2 = int(L2**0.5)
|
||||||
table_pretrained_resized = resize(
|
table_pretrained_resized = F.interpolate(
|
||||||
table_pretrained.permute(1, 0).reshape(
|
table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
|
||||||
1, nH1, S1, S1),
|
|
||||||
size=(S2, S2),
|
size=(S2, S2),
|
||||||
mode='bicubic')
|
mode='bicubic')
|
||||||
state_dict[table_key] = table_pretrained_resized.view(
|
state_dict[table_key] = table_pretrained_resized.view(
|
||||||
@ -749,9 +734,8 @@ class SwinTransformer(BaseModule):
|
|||||||
self.load_state_dict(state_dict, False)
|
self.load_state_dict(state_dict, False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.patch_embed(x)
|
x, hw_shape = self.patch_embed(x)
|
||||||
|
|
||||||
hw_shape = (self.patch_embed.DH, self.patch_embed.DW)
|
|
||||||
if self.use_abs_pos_embed:
|
if self.use_abs_pos_embed:
|
||||||
x = x + self.absolute_pos_embed
|
x = x + self.absolute_pos_embed
|
||||||
x = self.drop_after_pos(x)
|
x = self.drop_after_pos(x)
|
||||||
|
|||||||
@ -205,7 +205,7 @@ class VisionTransformer(BaseModule):
|
|||||||
conv_type='Conv2d',
|
conv_type='Conv2d',
|
||||||
kernel_size=patch_size,
|
kernel_size=patch_size,
|
||||||
stride=patch_size,
|
stride=patch_size,
|
||||||
pad_to_patch_size=True,
|
padding='corner',
|
||||||
norm_cfg=norm_cfg if patch_norm else None,
|
norm_cfg=norm_cfg if patch_norm else None,
|
||||||
init_cfg=None,
|
init_cfg=None,
|
||||||
)
|
)
|
||||||
@ -370,8 +370,8 @@ class VisionTransformer(BaseModule):
|
|||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
B = inputs.shape[0]
|
B = inputs.shape[0]
|
||||||
|
|
||||||
x, hw_shape = self.patch_embed(inputs), (self.patch_embed.DH,
|
x, hw_shape = self.patch_embed(inputs)
|
||||||
self.patch_embed.DW)
|
|
||||||
# stole cls_tokens impl from Phil Wang, thanks
|
# stole cls_tokens impl from Phil Wang, thanks
|
||||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||||
x = torch.cat((cls_tokens, x), dim=1)
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
|||||||
@ -1,28 +1,109 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import math
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||||
from mmcv.runner.base_module import BaseModule
|
from mmcv.runner.base_module import BaseModule
|
||||||
from torch.nn.modules.utils import _pair as to_2tuple
|
from mmcv.utils import to_2tuple
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptivePadding(nn.Module):
|
||||||
|
"""Applies padding to input (if needed) so that input can get fully covered
|
||||||
|
by filter you specified. It support two modes "same" and "corner". The
|
||||||
|
"same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
|
||||||
|
input. The "corner" mode would pad zero to bottom right.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_size (int | tuple): Size of the kernel:
|
||||||
|
stride (int | tuple): Stride of the filter. Default: 1:
|
||||||
|
dilation (int | tuple): Spacing between kernel elements.
|
||||||
|
Default: 1.
|
||||||
|
padding (str): Support "same" and "corner", "corner" mode
|
||||||
|
would pad zero to bottom right, and "same" mode would
|
||||||
|
pad zero around input. Default: "corner".
|
||||||
|
Example:
|
||||||
|
>>> kernel_size = 16
|
||||||
|
>>> stride = 16
|
||||||
|
>>> dilation = 1
|
||||||
|
>>> input = torch.rand(1, 1, 15, 17)
|
||||||
|
>>> adap_pad = AdaptivePadding(
|
||||||
|
>>> kernel_size=kernel_size,
|
||||||
|
>>> stride=stride,
|
||||||
|
>>> dilation=dilation,
|
||||||
|
>>> padding="corner")
|
||||||
|
>>> out = adap_pad(input)
|
||||||
|
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
|
||||||
|
>>> input = torch.rand(1, 1, 16, 17)
|
||||||
|
>>> out = adap_pad(input)
|
||||||
|
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
|
||||||
|
|
||||||
|
super(AdaptivePadding, self).__init__()
|
||||||
|
|
||||||
|
assert padding in ('same', 'corner')
|
||||||
|
|
||||||
|
kernel_size = to_2tuple(kernel_size)
|
||||||
|
stride = to_2tuple(stride)
|
||||||
|
dilation = to_2tuple(dilation)
|
||||||
|
|
||||||
|
self.padding = padding
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.dilation = dilation
|
||||||
|
|
||||||
|
def get_pad_shape(self, input_shape):
|
||||||
|
input_h, input_w = input_shape
|
||||||
|
kernel_h, kernel_w = self.kernel_size
|
||||||
|
stride_h, stride_w = self.stride
|
||||||
|
output_h = math.ceil(input_h / stride_h)
|
||||||
|
output_w = math.ceil(input_w / stride_w)
|
||||||
|
pad_h = max((output_h - 1) * stride_h +
|
||||||
|
(kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
|
||||||
|
pad_w = max((output_w - 1) * stride_w +
|
||||||
|
(kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
|
||||||
|
return pad_h, pad_w
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
|
||||||
|
if pad_h > 0 or pad_w > 0:
|
||||||
|
if self.padding == 'corner':
|
||||||
|
x = F.pad(x, [0, pad_w, 0, pad_h])
|
||||||
|
elif self.padding == 'same':
|
||||||
|
x = F.pad(x, [
|
||||||
|
pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
|
||||||
|
pad_h - pad_h // 2
|
||||||
|
])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
# Modified from Pytorch-Image-Models
|
|
||||||
class PatchEmbed(BaseModule):
|
class PatchEmbed(BaseModule):
|
||||||
"""Image to Patch Embedding V2.
|
"""Image to Patch Embedding.
|
||||||
|
|
||||||
We use a conv layer to implement PatchEmbed.
|
We use a conv layer to implement PatchEmbed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_channels (int): The num of input channels. Default: 3
|
in_channels (int): The num of input channels. Default: 3
|
||||||
embed_dims (int): The dimensions of embedding. Default: 768
|
embed_dims (int): The dimensions of embedding. Default: 768
|
||||||
conv_type (dict, optional): The config dict for conv layers type
|
conv_type (str): The config dict for embedding
|
||||||
selection. Default: None.
|
conv layer type selection. Default: "Conv2d".
|
||||||
kernel_size (int): The kernel_size of embedding conv. Default: 16.
|
kernel_size (int): The kernel_size of embedding conv. Default: 16.
|
||||||
stride (int): The slide stride of embedding conv.
|
stride (int, optional): The slide stride of embedding conv.
|
||||||
Default: None (Default to be equal with kernel_size).
|
Default: None (Would be set as `kernel_size`).
|
||||||
padding (int): The padding length of embedding conv. Default: 0.
|
padding (int | tuple | string ): The padding length of
|
||||||
|
embedding conv. When it is a string, it means the mode
|
||||||
|
of adaptive padding, support "same" and "corner" now.
|
||||||
|
Default: "corner".
|
||||||
dilation (int): The dilation rate of embedding conv. Default: 1.
|
dilation (int): The dilation rate of embedding conv. Default: 1.
|
||||||
pad_to_patch_size (bool, optional): Whether to pad feature map shape
|
bias (bool): Bias of embed conv. Default: True.
|
||||||
to multiple patch size. Default: True.
|
|
||||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||||
|
Default: None.
|
||||||
|
input_size (int | tuple | None): The size of input, which will be
|
||||||
|
used to calculate the out size. Only work when `dynamic_size`
|
||||||
|
is False. Default: None.
|
||||||
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
|
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
|
||||||
Default: None.
|
Default: None.
|
||||||
"""
|
"""
|
||||||
@ -30,39 +111,37 @@ class PatchEmbed(BaseModule):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels=3,
|
in_channels=3,
|
||||||
embed_dims=768,
|
embed_dims=768,
|
||||||
conv_type=None,
|
conv_type='Conv2d',
|
||||||
kernel_size=16,
|
kernel_size=16,
|
||||||
stride=16,
|
stride=None,
|
||||||
padding=0,
|
padding='corner',
|
||||||
dilation=1,
|
dilation=1,
|
||||||
pad_to_patch_size=True,
|
bias=True,
|
||||||
norm_cfg=None,
|
norm_cfg=None,
|
||||||
|
input_size=None,
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
super(PatchEmbed, self).__init__()
|
super(PatchEmbed, self).__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
self.embed_dims = embed_dims
|
self.embed_dims = embed_dims
|
||||||
self.init_cfg = init_cfg
|
|
||||||
|
|
||||||
if stride is None:
|
if stride is None:
|
||||||
stride = kernel_size
|
stride = kernel_size
|
||||||
|
|
||||||
self.pad_to_patch_size = pad_to_patch_size
|
kernel_size = to_2tuple(kernel_size)
|
||||||
|
stride = to_2tuple(stride)
|
||||||
|
dilation = to_2tuple(dilation)
|
||||||
|
|
||||||
# The default setting of patch size is equal to kernel size.
|
if isinstance(padding, str):
|
||||||
patch_size = kernel_size
|
self.adap_padding = AdaptivePadding(
|
||||||
if isinstance(patch_size, int):
|
kernel_size=kernel_size,
|
||||||
patch_size = to_2tuple(patch_size)
|
stride=stride,
|
||||||
elif isinstance(patch_size, tuple):
|
dilation=dilation,
|
||||||
if len(patch_size) == 1:
|
padding=padding)
|
||||||
patch_size = to_2tuple(patch_size[0])
|
# disable the padding of conv
|
||||||
assert len(patch_size) == 2, \
|
padding = 0
|
||||||
f'The size of patch should have length 1 or 2, ' \
|
else:
|
||||||
f'but got {len(patch_size)}'
|
self.adap_padding = None
|
||||||
|
padding = to_2tuple(padding)
|
||||||
|
|
||||||
self.patch_size = patch_size
|
|
||||||
|
|
||||||
# Use conv layer to embed
|
|
||||||
conv_type = conv_type or 'Conv2d'
|
|
||||||
self.projection = build_conv_layer(
|
self.projection = build_conv_layer(
|
||||||
dict(type=conv_type),
|
dict(type=conv_type),
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -70,31 +149,182 @@ class PatchEmbed(BaseModule):
|
|||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
dilation=dilation)
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
if norm_cfg is not None:
|
if norm_cfg is not None:
|
||||||
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||||
else:
|
else:
|
||||||
self.norm = None
|
self.norm = None
|
||||||
|
|
||||||
def forward(self, x):
|
if input_size:
|
||||||
H, W = x.shape[2], x.shape[3]
|
input_size = to_2tuple(input_size)
|
||||||
|
# `init_out_size` would be used outside to
|
||||||
|
# calculate the num_patches
|
||||||
|
# when `use_abs_pos_embed` outside
|
||||||
|
self.init_input_size = input_size
|
||||||
|
if self.adap_padding:
|
||||||
|
pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
|
||||||
|
input_h, input_w = input_size
|
||||||
|
input_h = input_h + pad_h
|
||||||
|
input_w = input_w + pad_w
|
||||||
|
input_size = (input_h, input_w)
|
||||||
|
|
||||||
# TODO: Process overlapping op
|
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||||
if self.pad_to_patch_size:
|
h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
|
||||||
# Modify H, W to multiple of patch size.
|
(kernel_size[0] - 1) - 1) // stride[0] + 1
|
||||||
if H % self.patch_size[0] != 0:
|
w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
|
||||||
x = F.pad(
|
(kernel_size[1] - 1) - 1) // stride[1] + 1
|
||||||
x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
self.init_out_size = (h_out, w_out)
|
||||||
if W % self.patch_size[1] != 0:
|
else:
|
||||||
x = F.pad(
|
self.init_input_size = None
|
||||||
x, (0, self.patch_size[1] - W % self.patch_size[1], 0, 0))
|
self.init_out_size = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Contains merged results and its spatial shape.
|
||||||
|
|
||||||
|
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
|
||||||
|
- out_size (tuple[int]): Spatial shape of x, arrange as
|
||||||
|
(out_h, out_w).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.adap_padding:
|
||||||
|
x = self.adap_padding(x)
|
||||||
|
|
||||||
x = self.projection(x)
|
x = self.projection(x)
|
||||||
self.DH, self.DW = x.shape[2], x.shape[3]
|
out_size = (x.shape[2], x.shape[3])
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
if self.norm is not None:
|
if self.norm is not None:
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
return x, out_size
|
||||||
|
|
||||||
return x
|
|
||||||
|
class PatchMerging(BaseModule):
|
||||||
|
"""Merge patch feature map.
|
||||||
|
|
||||||
|
This layer groups feature map by kernel_size, and applies norm and linear
|
||||||
|
layers to the grouped feature map. Our implementation uses `nn.Unfold` to
|
||||||
|
merge patch, which is about 25% faster than original implementation.
|
||||||
|
Instead, we need to modify pretrained models for compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The num of input channels.
|
||||||
|
out_channels (int): The num of output channels.
|
||||||
|
kernel_size (int | tuple, optional): the kernel size in the unfold
|
||||||
|
layer. Defaults to 2.
|
||||||
|
stride (int | tuple, optional): the stride of the sliding blocks in the
|
||||||
|
unfold layer. Default: None. (Would be set as `kernel_size`)
|
||||||
|
padding (int | tuple | string ): The padding length of
|
||||||
|
embedding conv. When it is a string, it means the mode
|
||||||
|
of adaptive padding, support "same" and "corner" now.
|
||||||
|
Default: "corner".
|
||||||
|
dilation (int | tuple, optional): dilation parameter in the unfold
|
||||||
|
layer. Default: 1.
|
||||||
|
bias (bool, optional): Whether to add bias in linear layer or not.
|
||||||
|
Defaults: False.
|
||||||
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||||
|
Default: dict(type='LN').
|
||||||
|
init_cfg (dict, optional): The extra config for initialization.
|
||||||
|
Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=2,
|
||||||
|
stride=None,
|
||||||
|
padding='corner',
|
||||||
|
dilation=1,
|
||||||
|
bias=False,
|
||||||
|
norm_cfg=dict(type='LN'),
|
||||||
|
init_cfg=None):
|
||||||
|
super().__init__(init_cfg=init_cfg)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
if stride:
|
||||||
|
stride = stride
|
||||||
|
else:
|
||||||
|
stride = kernel_size
|
||||||
|
|
||||||
|
kernel_size = to_2tuple(kernel_size)
|
||||||
|
stride = to_2tuple(stride)
|
||||||
|
dilation = to_2tuple(dilation)
|
||||||
|
|
||||||
|
if isinstance(padding, str):
|
||||||
|
self.adap_padding = AdaptivePadding(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=padding)
|
||||||
|
# disable the padding of unfold
|
||||||
|
padding = 0
|
||||||
|
else:
|
||||||
|
self.adap_padding = None
|
||||||
|
|
||||||
|
padding = to_2tuple(padding)
|
||||||
|
self.sampler = nn.Unfold(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=padding,
|
||||||
|
stride=stride)
|
||||||
|
|
||||||
|
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
|
||||||
|
|
||||||
|
if norm_cfg is not None:
|
||||||
|
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
|
||||||
|
else:
|
||||||
|
self.norm = None
|
||||||
|
|
||||||
|
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x, input_size):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): Has shape (B, H*W, C_in).
|
||||||
|
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
|
||||||
|
Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Contains merged results and its spatial shape.
|
||||||
|
|
||||||
|
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
|
||||||
|
- out_size (tuple[int]): Spatial shape of x, arrange as
|
||||||
|
(Merged_H, Merged_W).
|
||||||
|
"""
|
||||||
|
B, L, C = x.shape
|
||||||
|
assert isinstance(input_size, Sequence), f'Expect ' \
|
||||||
|
f'input_size is ' \
|
||||||
|
f'`Sequence` ' \
|
||||||
|
f'but get {input_size}'
|
||||||
|
|
||||||
|
H, W = input_size
|
||||||
|
assert L == H * W, 'input feature has wrong size'
|
||||||
|
|
||||||
|
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
|
||||||
|
# Use nn.Unfold to merge patch. About 25% faster than original method,
|
||||||
|
# but need to modify pretrained model for compatibility
|
||||||
|
|
||||||
|
if self.adap_padding:
|
||||||
|
x = self.adap_padding(x)
|
||||||
|
H, W = x.shape[-2:]
|
||||||
|
|
||||||
|
x = self.sampler(x)
|
||||||
|
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
|
||||||
|
|
||||||
|
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
|
||||||
|
(self.sampler.kernel_size[0] - 1) -
|
||||||
|
1) // self.sampler.stride[0] + 1
|
||||||
|
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
|
||||||
|
(self.sampler.kernel_size[1] - 1) -
|
||||||
|
1) // self.sampler.stride[1] + 1
|
||||||
|
|
||||||
|
output_size = (out_h, out_w)
|
||||||
|
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
|
||||||
|
x = self.norm(x) if self.norm else x
|
||||||
|
x = self.reduction(x)
|
||||||
|
return x, output_size
|
||||||
|
|||||||
@ -7,10 +7,6 @@ from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN
|
|||||||
|
|
||||||
|
|
||||||
def test_mit():
|
def test_mit():
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
# It's only support official style and mmcls style now.
|
|
||||||
MixVisionTransformer(pretrain_style='timm')
|
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
# Pretrained represents pretrain url and must be str or None.
|
# Pretrained represents pretrain url and must be str or None.
|
||||||
MixVisionTransformer(pretrained=123)
|
MixVisionTransformer(pretrained=123)
|
||||||
|
|||||||
@ -1,8 +1,26 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmseg.models.backbones import SwinTransformer
|
from mmseg.models.backbones.swin import SwinBlock, SwinTransformer
|
||||||
|
|
||||||
|
|
||||||
|
def test_swin_block():
|
||||||
|
# test SwinBlock structure and forward
|
||||||
|
block = SwinBlock(embed_dims=64, num_heads=4, feedforward_channels=256)
|
||||||
|
assert block.ffn.embed_dims == 64
|
||||||
|
assert block.attn.w_msa.num_heads == 4
|
||||||
|
assert block.ffn.feedforward_channels == 256
|
||||||
|
x = torch.randn(1, 56 * 56, 64)
|
||||||
|
x_out = block(x, (56, 56))
|
||||||
|
assert x_out.shape == torch.Size([1, 56 * 56, 64])
|
||||||
|
|
||||||
|
# Test BasicBlock with checkpoint forward
|
||||||
|
block = SwinBlock(
|
||||||
|
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
|
||||||
|
assert block.with_cp
|
||||||
|
x = torch.randn(1, 56 * 56, 64)
|
||||||
|
x_out = block(x, (56, 56))
|
||||||
|
assert x_out.shape == torch.Size([1, 56 * 56, 64])
|
||||||
|
|
||||||
|
|
||||||
def test_swin_transformer():
|
def test_swin_transformer():
|
||||||
@ -10,12 +28,16 @@ def test_swin_transformer():
|
|||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
# Pretrained arg must be str or None.
|
# Pretrained arg must be str or None.
|
||||||
model = SwinTransformer(pretrained=123)
|
SwinTransformer(pretrained=123)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# Because swin use non-overlapping patch embed, so the stride of patch
|
# Because swin uses non-overlapping patch embed, so the stride of patch
|
||||||
# embed must be equal to patch size.
|
# embed must be equal to patch size.
|
||||||
model = SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)
|
SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)
|
||||||
|
|
||||||
|
# test pretrained image size
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
SwinTransformer(pretrain_img_size=(224, 224, 224))
|
||||||
|
|
||||||
# Test absolute position embedding
|
# Test absolute position embedding
|
||||||
temp = torch.randn((1, 3, 224, 224))
|
temp = torch.randn((1, 3, 224, 224))
|
||||||
@ -27,12 +49,6 @@ def test_swin_transformer():
|
|||||||
model = SwinTransformer(patch_norm=False)
|
model = SwinTransformer(patch_norm=False)
|
||||||
model(temp)
|
model(temp)
|
||||||
|
|
||||||
# Test pretrain img size
|
|
||||||
model = SwinTransformer(pretrain_img_size=(224, ))
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
model = SwinTransformer(pretrain_img_size=(224, 224, 224))
|
|
||||||
|
|
||||||
# Test normal inference
|
# Test normal inference
|
||||||
temp = torch.randn((1, 3, 512, 512))
|
temp = torch.randn((1, 3, 512, 512))
|
||||||
model = SwinTransformer()
|
model = SwinTransformer()
|
||||||
@ -42,7 +58,7 @@ def test_swin_transformer():
|
|||||||
assert outs[2].shape == (1, 384, 32, 32)
|
assert outs[2].shape == (1, 384, 32, 32)
|
||||||
assert outs[3].shape == (1, 768, 16, 16)
|
assert outs[3].shape == (1, 768, 16, 16)
|
||||||
|
|
||||||
# Test abnormal inference
|
# Test abnormal inference size
|
||||||
temp = torch.randn((1, 3, 511, 511))
|
temp = torch.randn((1, 3, 511, 511))
|
||||||
model = SwinTransformer()
|
model = SwinTransformer()
|
||||||
outs = model(temp)
|
outs = model(temp)
|
||||||
@ -51,7 +67,7 @@ def test_swin_transformer():
|
|||||||
assert outs[2].shape == (1, 384, 32, 32)
|
assert outs[2].shape == (1, 384, 32, 32)
|
||||||
assert outs[3].shape == (1, 768, 16, 16)
|
assert outs[3].shape == (1, 768, 16, 16)
|
||||||
|
|
||||||
# Test abnormal inference
|
# Test abnormal inference size
|
||||||
temp = torch.randn((1, 3, 112, 137))
|
temp = torch.randn((1, 3, 112, 137))
|
||||||
model = SwinTransformer()
|
model = SwinTransformer()
|
||||||
outs = model(temp)
|
outs = model(temp)
|
||||||
@ -59,3 +75,25 @@ def test_swin_transformer():
|
|||||||
assert outs[1].shape == (1, 192, 14, 18)
|
assert outs[1].shape == (1, 192, 14, 18)
|
||||||
assert outs[2].shape == (1, 384, 7, 9)
|
assert outs[2].shape == (1, 384, 7, 9)
|
||||||
assert outs[3].shape == (1, 768, 4, 5)
|
assert outs[3].shape == (1, 768, 4, 5)
|
||||||
|
|
||||||
|
# Test frozen
|
||||||
|
model = SwinTransformer(frozen_stages=4)
|
||||||
|
model.train()
|
||||||
|
for p in model.parameters():
|
||||||
|
assert not p.requires_grad
|
||||||
|
|
||||||
|
# Test absolute position embedding frozen
|
||||||
|
model = SwinTransformer(frozen_stages=4, use_abs_pos_embed=True)
|
||||||
|
model.train()
|
||||||
|
for p in model.parameters():
|
||||||
|
assert not p.requires_grad
|
||||||
|
|
||||||
|
# Test Swin with checkpoint forward
|
||||||
|
temp = torch.randn((1, 3, 224, 224))
|
||||||
|
model = SwinTransformer(with_cp=True)
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, SwinBlock):
|
||||||
|
assert m.with_cp
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
model(temp)
|
||||||
|
|||||||
@ -25,12 +25,6 @@ def test_vit_backbone():
|
|||||||
x = torch.randn(1, 196)
|
x = torch.randn(1, 196)
|
||||||
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
|
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
|
||||||
|
|
||||||
with pytest.raises(IndexError):
|
|
||||||
# forward inputs must be [N, C, H, W]
|
|
||||||
x = torch.randn(3, 30, 30)
|
|
||||||
model = VisionTransformer()
|
|
||||||
model(x)
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# The length of img_size tuple must be lower than 3.
|
# The length of img_size tuple must be lower than 3.
|
||||||
VisionTransformer(img_size=(224, 224, 224))
|
VisionTransformer(img_size=(224, 224, 224))
|
||||||
|
|||||||
0
tests/test_models/test_utils/__init__.py
Normal file
0
tests/test_models/test_utils/__init__.py
Normal file
461
tests/test_models/test_utils/test_embed.py
Normal file
461
tests/test_models/test_utils/test_embed.py
Normal file
@ -0,0 +1,461 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmseg.models.utils.embed import AdaptivePadding, PatchEmbed, PatchMerging
|
||||||
|
|
||||||
|
|
||||||
|
def test_adaptive_padding():
|
||||||
|
|
||||||
|
for padding in ('same', 'corner'):
|
||||||
|
kernel_size = 16
|
||||||
|
stride = 16
|
||||||
|
dilation = 1
|
||||||
|
input = torch.rand(1, 1, 15, 17)
|
||||||
|
adap_pool = AdaptivePadding(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=padding)
|
||||||
|
out = adap_pool(input)
|
||||||
|
# padding to divisible by 16
|
||||||
|
assert (out.shape[2], out.shape[3]) == (16, 32)
|
||||||
|
input = torch.rand(1, 1, 16, 17)
|
||||||
|
out = adap_pool(input)
|
||||||
|
# padding to divisible by 16
|
||||||
|
assert (out.shape[2], out.shape[3]) == (16, 32)
|
||||||
|
|
||||||
|
kernel_size = (2, 2)
|
||||||
|
stride = (2, 2)
|
||||||
|
dilation = (1, 1)
|
||||||
|
|
||||||
|
adap_pad = AdaptivePadding(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=padding)
|
||||||
|
input = torch.rand(1, 1, 11, 13)
|
||||||
|
out = adap_pad(input)
|
||||||
|
# padding to divisible by 2
|
||||||
|
assert (out.shape[2], out.shape[3]) == (12, 14)
|
||||||
|
|
||||||
|
kernel_size = (2, 2)
|
||||||
|
stride = (10, 10)
|
||||||
|
dilation = (1, 1)
|
||||||
|
|
||||||
|
adap_pad = AdaptivePadding(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=padding)
|
||||||
|
input = torch.rand(1, 1, 10, 13)
|
||||||
|
out = adap_pad(input)
|
||||||
|
# no padding
|
||||||
|
assert (out.shape[2], out.shape[3]) == (10, 13)
|
||||||
|
|
||||||
|
kernel_size = (11, 11)
|
||||||
|
adap_pad = AdaptivePadding(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=padding)
|
||||||
|
input = torch.rand(1, 1, 11, 13)
|
||||||
|
out = adap_pad(input)
|
||||||
|
# all padding
|
||||||
|
assert (out.shape[2], out.shape[3]) == (21, 21)
|
||||||
|
|
||||||
|
# test padding as kernel is (7,9)
|
||||||
|
input = torch.rand(1, 1, 11, 13)
|
||||||
|
stride = (3, 4)
|
||||||
|
kernel_size = (4, 5)
|
||||||
|
dilation = (2, 2)
|
||||||
|
# actually (7, 9)
|
||||||
|
adap_pad = AdaptivePadding(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=padding)
|
||||||
|
dilation_out = adap_pad(input)
|
||||||
|
assert (dilation_out.shape[2], dilation_out.shape[3]) == (16, 21)
|
||||||
|
kernel_size = (7, 9)
|
||||||
|
dilation = (1, 1)
|
||||||
|
adap_pad = AdaptivePadding(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=padding)
|
||||||
|
kernel79_out = adap_pad(input)
|
||||||
|
assert (kernel79_out.shape[2], kernel79_out.shape[3]) == (16, 21)
|
||||||
|
assert kernel79_out.shape == dilation_out.shape
|
||||||
|
|
||||||
|
# assert only support "same" "corner"
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
AdaptivePadding(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_embed():
|
||||||
|
B = 2
|
||||||
|
H = 3
|
||||||
|
W = 4
|
||||||
|
C = 3
|
||||||
|
embed_dims = 10
|
||||||
|
kernel_size = 3
|
||||||
|
stride = 1
|
||||||
|
dummy_input = torch.rand(B, C, H, W)
|
||||||
|
patch_merge_1 = PatchEmbed(
|
||||||
|
in_channels=C,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
norm_cfg=None)
|
||||||
|
|
||||||
|
x1, shape = patch_merge_1(dummy_input)
|
||||||
|
# test out shape
|
||||||
|
assert x1.shape == (2, 2, 10)
|
||||||
|
# test outsize is correct
|
||||||
|
assert shape == (1, 2)
|
||||||
|
# test L = out_h * out_w
|
||||||
|
assert shape[0] * shape[1] == x1.shape[1]
|
||||||
|
|
||||||
|
B = 2
|
||||||
|
H = 10
|
||||||
|
W = 10
|
||||||
|
C = 3
|
||||||
|
embed_dims = 10
|
||||||
|
kernel_size = 5
|
||||||
|
stride = 2
|
||||||
|
dummy_input = torch.rand(B, C, H, W)
|
||||||
|
# test dilation
|
||||||
|
patch_merge_2 = PatchEmbed(
|
||||||
|
in_channels=C,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=0,
|
||||||
|
dilation=2,
|
||||||
|
norm_cfg=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
x2, shape = patch_merge_2(dummy_input)
|
||||||
|
# test out shape
|
||||||
|
assert x2.shape == (2, 1, 10)
|
||||||
|
# test outsize is correct
|
||||||
|
assert shape == (1, 1)
|
||||||
|
# test L = out_h * out_w
|
||||||
|
assert shape[0] * shape[1] == x2.shape[1]
|
||||||
|
|
||||||
|
stride = 2
|
||||||
|
input_size = (10, 10)
|
||||||
|
|
||||||
|
dummy_input = torch.rand(B, C, H, W)
|
||||||
|
# test stride and norm
|
||||||
|
patch_merge_3 = PatchEmbed(
|
||||||
|
in_channels=C,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=0,
|
||||||
|
dilation=2,
|
||||||
|
norm_cfg=dict(type='LN'),
|
||||||
|
input_size=input_size)
|
||||||
|
|
||||||
|
x3, shape = patch_merge_3(dummy_input)
|
||||||
|
# test out shape
|
||||||
|
assert x3.shape == (2, 1, 10)
|
||||||
|
# test outsize is correct
|
||||||
|
assert shape == (1, 1)
|
||||||
|
# test L = out_h * out_w
|
||||||
|
assert shape[0] * shape[1] == x3.shape[1]
|
||||||
|
|
||||||
|
# test thte init_out_size with nn.Unfold
|
||||||
|
assert patch_merge_3.init_out_size[1] == (input_size[0] - 2 * 4 -
|
||||||
|
1) // 2 + 1
|
||||||
|
assert patch_merge_3.init_out_size[0] == (input_size[0] - 2 * 4 -
|
||||||
|
1) // 2 + 1
|
||||||
|
H = 11
|
||||||
|
W = 12
|
||||||
|
input_size = (H, W)
|
||||||
|
dummy_input = torch.rand(B, C, H, W)
|
||||||
|
# test stride and norm
|
||||||
|
patch_merge_3 = PatchEmbed(
|
||||||
|
in_channels=C,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=0,
|
||||||
|
dilation=2,
|
||||||
|
norm_cfg=dict(type='LN'),
|
||||||
|
input_size=input_size)
|
||||||
|
|
||||||
|
_, shape = patch_merge_3(dummy_input)
|
||||||
|
# when input_size equal to real input
|
||||||
|
# the out_size shoule be equal to `init_out_size`
|
||||||
|
assert shape == patch_merge_3.init_out_size
|
||||||
|
|
||||||
|
input_size = (H, W)
|
||||||
|
dummy_input = torch.rand(B, C, H, W)
|
||||||
|
# test stride and norm
|
||||||
|
patch_merge_3 = PatchEmbed(
|
||||||
|
in_channels=C,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=0,
|
||||||
|
dilation=2,
|
||||||
|
norm_cfg=dict(type='LN'),
|
||||||
|
input_size=input_size)
|
||||||
|
|
||||||
|
_, shape = patch_merge_3(dummy_input)
|
||||||
|
# when input_size equal to real input
|
||||||
|
# the out_size shoule be equal to `init_out_size`
|
||||||
|
assert shape == patch_merge_3.init_out_size
|
||||||
|
|
||||||
|
# test adap padding
|
||||||
|
for padding in ('same', 'corner'):
|
||||||
|
in_c = 2
|
||||||
|
embed_dims = 3
|
||||||
|
B = 2
|
||||||
|
|
||||||
|
# test stride is 1
|
||||||
|
input_size = (5, 5)
|
||||||
|
kernel_size = (5, 5)
|
||||||
|
stride = (1, 1)
|
||||||
|
dilation = 1
|
||||||
|
bias = False
|
||||||
|
|
||||||
|
x = torch.rand(B, in_c, *input_size)
|
||||||
|
patch_embed = PatchEmbed(
|
||||||
|
in_channels=in_c,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
x_out, out_size = patch_embed(x)
|
||||||
|
assert x_out.size() == (B, 25, 3)
|
||||||
|
assert out_size == (5, 5)
|
||||||
|
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||||
|
|
||||||
|
# test kernel_size == stride
|
||||||
|
input_size = (5, 5)
|
||||||
|
kernel_size = (5, 5)
|
||||||
|
stride = (5, 5)
|
||||||
|
dilation = 1
|
||||||
|
bias = False
|
||||||
|
|
||||||
|
x = torch.rand(B, in_c, *input_size)
|
||||||
|
patch_embed = PatchEmbed(
|
||||||
|
in_channels=in_c,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
x_out, out_size = patch_embed(x)
|
||||||
|
assert x_out.size() == (B, 1, 3)
|
||||||
|
assert out_size == (1, 1)
|
||||||
|
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||||
|
|
||||||
|
# test kernel_size == stride
|
||||||
|
input_size = (6, 5)
|
||||||
|
kernel_size = (5, 5)
|
||||||
|
stride = (5, 5)
|
||||||
|
dilation = 1
|
||||||
|
bias = False
|
||||||
|
|
||||||
|
x = torch.rand(B, in_c, *input_size)
|
||||||
|
patch_embed = PatchEmbed(
|
||||||
|
in_channels=in_c,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
x_out, out_size = patch_embed(x)
|
||||||
|
assert x_out.size() == (B, 2, 3)
|
||||||
|
assert out_size == (2, 1)
|
||||||
|
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||||
|
|
||||||
|
# test different kernel_size with diffrent stride
|
||||||
|
input_size = (6, 5)
|
||||||
|
kernel_size = (6, 2)
|
||||||
|
stride = (6, 2)
|
||||||
|
dilation = 1
|
||||||
|
bias = False
|
||||||
|
|
||||||
|
x = torch.rand(B, in_c, *input_size)
|
||||||
|
patch_embed = PatchEmbed(
|
||||||
|
in_channels=in_c,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
x_out, out_size = patch_embed(x)
|
||||||
|
assert x_out.size() == (B, 3, 3)
|
||||||
|
assert out_size == (1, 3)
|
||||||
|
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_merging():
|
||||||
|
|
||||||
|
# Test the model with int padding
|
||||||
|
in_c = 3
|
||||||
|
out_c = 4
|
||||||
|
kernel_size = 3
|
||||||
|
stride = 3
|
||||||
|
padding = 1
|
||||||
|
dilation = 1
|
||||||
|
bias = False
|
||||||
|
# test the case `pad_to_stride` is False
|
||||||
|
patch_merge = PatchMerging(
|
||||||
|
in_channels=in_c,
|
||||||
|
out_channels=out_c,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
B, L, C = 1, 100, 3
|
||||||
|
input_size = (10, 10)
|
||||||
|
x = torch.rand(B, L, C)
|
||||||
|
x_out, out_size = patch_merge(x, input_size)
|
||||||
|
assert x_out.size() == (1, 16, 4)
|
||||||
|
assert out_size == (4, 4)
|
||||||
|
# assert out size is consistent with real output
|
||||||
|
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||||
|
in_c = 4
|
||||||
|
out_c = 5
|
||||||
|
kernel_size = 6
|
||||||
|
stride = 3
|
||||||
|
padding = 2
|
||||||
|
dilation = 2
|
||||||
|
bias = False
|
||||||
|
patch_merge = PatchMerging(
|
||||||
|
in_channels=in_c,
|
||||||
|
out_channels=out_c,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
B, L, C = 1, 100, 4
|
||||||
|
input_size = (10, 10)
|
||||||
|
x = torch.rand(B, L, C)
|
||||||
|
x_out, out_size = patch_merge(x, input_size)
|
||||||
|
assert x_out.size() == (1, 4, 5)
|
||||||
|
assert out_size == (2, 2)
|
||||||
|
# assert out size is consistent with real output
|
||||||
|
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||||
|
|
||||||
|
# Test with adaptive padding
|
||||||
|
for padding in ('same', 'corner'):
|
||||||
|
in_c = 2
|
||||||
|
out_c = 3
|
||||||
|
B = 2
|
||||||
|
|
||||||
|
# test stride is 1
|
||||||
|
input_size = (5, 5)
|
||||||
|
kernel_size = (5, 5)
|
||||||
|
stride = (1, 1)
|
||||||
|
dilation = 1
|
||||||
|
bias = False
|
||||||
|
L = input_size[0] * input_size[1]
|
||||||
|
|
||||||
|
x = torch.rand(B, L, in_c)
|
||||||
|
patch_merge = PatchMerging(
|
||||||
|
in_channels=in_c,
|
||||||
|
out_channels=out_c,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
x_out, out_size = patch_merge(x, input_size)
|
||||||
|
assert x_out.size() == (B, 25, 3)
|
||||||
|
assert out_size == (5, 5)
|
||||||
|
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||||
|
|
||||||
|
# test kernel_size == stride
|
||||||
|
input_size = (5, 5)
|
||||||
|
kernel_size = (5, 5)
|
||||||
|
stride = (5, 5)
|
||||||
|
dilation = 1
|
||||||
|
bias = False
|
||||||
|
L = input_size[0] * input_size[1]
|
||||||
|
|
||||||
|
x = torch.rand(B, L, in_c)
|
||||||
|
patch_merge = PatchMerging(
|
||||||
|
in_channels=in_c,
|
||||||
|
out_channels=out_c,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
x_out, out_size = patch_merge(x, input_size)
|
||||||
|
assert x_out.size() == (B, 1, 3)
|
||||||
|
assert out_size == (1, 1)
|
||||||
|
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||||
|
|
||||||
|
# test kernel_size == stride
|
||||||
|
input_size = (6, 5)
|
||||||
|
kernel_size = (5, 5)
|
||||||
|
stride = (5, 5)
|
||||||
|
dilation = 1
|
||||||
|
bias = False
|
||||||
|
L = input_size[0] * input_size[1]
|
||||||
|
|
||||||
|
x = torch.rand(B, L, in_c)
|
||||||
|
patch_merge = PatchMerging(
|
||||||
|
in_channels=in_c,
|
||||||
|
out_channels=out_c,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
x_out, out_size = patch_merge(x, input_size)
|
||||||
|
assert x_out.size() == (B, 2, 3)
|
||||||
|
assert out_size == (2, 1)
|
||||||
|
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||||
|
|
||||||
|
# test different kernel_size with diffrent stride
|
||||||
|
input_size = (6, 5)
|
||||||
|
kernel_size = (6, 2)
|
||||||
|
stride = (6, 2)
|
||||||
|
dilation = 1
|
||||||
|
bias = False
|
||||||
|
L = input_size[0] * input_size[1]
|
||||||
|
|
||||||
|
x = torch.rand(B, L, in_c)
|
||||||
|
patch_merge = PatchMerging(
|
||||||
|
in_channels=in_c,
|
||||||
|
out_channels=out_c,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
x_out, out_size = patch_merge(x, input_size)
|
||||||
|
assert x_out.size() == (B, 3, 3)
|
||||||
|
assert out_size == (1, 3)
|
||||||
|
assert x_out.size(1) == out_size[0] * out_size[1]
|
||||||
Loading…
x
Reference in New Issue
Block a user