76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from mmdet.core.post_processing import mask_matrix_nms
|
|
|
|
|
|
def _create_mask(N, h, w):
|
|
masks = torch.rand((N, h, w)) > 0.5
|
|
labels = torch.rand(N)
|
|
scores = torch.rand(N)
|
|
return masks, labels, scores
|
|
|
|
|
|
def test_nms_input_errors():
|
|
with pytest.raises(AssertionError):
|
|
mask_matrix_nms(
|
|
torch.rand((10, 28, 28)), torch.rand(11), torch.rand(11))
|
|
with pytest.raises(AssertionError):
|
|
masks = torch.rand((10, 28, 28))
|
|
mask_matrix_nms(
|
|
masks,
|
|
torch.rand(11),
|
|
torch.rand(11),
|
|
mask_area=masks.sum((1, 2)).float()[:8])
|
|
with pytest.raises(NotImplementedError):
|
|
mask_matrix_nms(
|
|
torch.rand((10, 28, 28)),
|
|
torch.rand(10),
|
|
torch.rand(10),
|
|
kernel='None')
|
|
# test an empty results
|
|
masks, labels, scores = _create_mask(0, 28, 28)
|
|
score, label, mask, keep_ind = \
|
|
mask_matrix_nms(masks, labels, scores)
|
|
assert len(score) == len(label) == \
|
|
len(mask) == len(keep_ind) == 0
|
|
|
|
# do not use update_thr, nms_pre and max_num
|
|
masks, labels, scores = _create_mask(1000, 28, 28)
|
|
score, label, mask, keep_ind = \
|
|
mask_matrix_nms(masks, labels, scores)
|
|
assert len(score) == len(label) == \
|
|
len(mask) == len(keep_ind) == 1000
|
|
# only use nms_pre
|
|
score, label, mask, keep_ind = \
|
|
mask_matrix_nms(masks, labels, scores, nms_pre=500)
|
|
assert len(score) == len(label) == \
|
|
len(mask) == len(keep_ind) == 500
|
|
# use max_num
|
|
score, label, mask, keep_ind = \
|
|
mask_matrix_nms(masks, labels, scores,
|
|
nms_pre=500, max_num=100)
|
|
assert len(score) == len(label) == \
|
|
len(mask) == len(keep_ind) == 100
|
|
|
|
masks, labels, _ = _create_mask(1, 28, 28)
|
|
scores = torch.Tensor([1.0])
|
|
masks = masks.expand(1000, 28, 28)
|
|
labels = labels.expand(1000)
|
|
scores = scores.expand(1000)
|
|
|
|
# assert scores is decayed and update_thr is worked
|
|
# if with the same mask, label, and all scores = 1
|
|
# the first score will set to 1, others will decay.
|
|
score, label, mask, keep_ind = \
|
|
mask_matrix_nms(masks,
|
|
labels,
|
|
scores,
|
|
nms_pre=500,
|
|
max_num=100,
|
|
kernel='gaussian',
|
|
sigma=2.0,
|
|
filter_thr=0.5)
|
|
assert len(score) == 1
|
|
assert score[0] == 1
|