64 lines
1.9 KiB
Python
64 lines
1.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.decode_heads import SETRMLAHead
|
|
from .utils import to_cuda
|
|
|
|
|
|
def test_setr_mla_head(capsys):
|
|
|
|
with pytest.raises(AssertionError):
|
|
# MLA requires input multiple stage feature information.
|
|
SETRMLAHead(in_channels=32, channels=16, num_classes=19, in_index=1)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# multiple in_indexs requires multiple in_channels.
|
|
SETRMLAHead(
|
|
in_channels=32, channels=16, num_classes=19, in_index=(0, 1, 2, 3))
|
|
|
|
with pytest.raises(AssertionError):
|
|
# channels should be len(in_channels) * mla_channels
|
|
SETRMLAHead(
|
|
in_channels=(32, 32, 32, 32),
|
|
channels=32,
|
|
mla_channels=16,
|
|
in_index=(0, 1, 2, 3),
|
|
num_classes=19)
|
|
|
|
# test inference of MLA head
|
|
img_size = (32, 32)
|
|
patch_size = 16
|
|
head = SETRMLAHead(
|
|
in_channels=(32, 32, 32, 32),
|
|
channels=64,
|
|
mla_channels=16,
|
|
in_index=(0, 1, 2, 3),
|
|
num_classes=19,
|
|
norm_cfg=dict(type='BN'))
|
|
|
|
h, w = img_size[0] // patch_size, img_size[1] // patch_size
|
|
# Input square NCHW format feature information
|
|
x = [
|
|
torch.randn(1, 32, h, w),
|
|
torch.randn(1, 32, h, w),
|
|
torch.randn(1, 32, h, w),
|
|
torch.randn(1, 32, h, w)
|
|
]
|
|
if torch.cuda.is_available():
|
|
head, x = to_cuda(head, x)
|
|
out = head(x)
|
|
assert out.shape == (1, head.num_classes, h * 4, w * 4)
|
|
|
|
# Input non-square NCHW format feature information
|
|
x = [
|
|
torch.randn(1, 32, h, w * 2),
|
|
torch.randn(1, 32, h, w * 2),
|
|
torch.randn(1, 32, h, w * 2),
|
|
torch.randn(1, 32, h, w * 2)
|
|
]
|
|
if torch.cuda.is_available():
|
|
head, x = to_cuda(head, x)
|
|
out = head(x)
|
|
assert out.shape == (1, head.num_classes, h * 4, w * 8)
|