30 lines
840 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmdet.models.plugins import DropBlock
def test_dropblock():
feat = torch.rand(1, 1, 11, 11)
drop_prob = 1.0
dropblock = DropBlock(drop_prob, block_size=11, warmup_iters=0)
out_feat = dropblock(feat)
assert (out_feat == 0).all() and out_feat.shape == feat.shape
drop_prob = 0.5
dropblock = DropBlock(drop_prob, block_size=5, warmup_iters=0)
out_feat = dropblock(feat)
assert out_feat.shape == feat.shape
# drop_prob must be (0,1]
with pytest.raises(AssertionError):
DropBlock(1.5, 3)
# block_size cannot be an even number
with pytest.raises(AssertionError):
DropBlock(0.5, 2)
# warmup_iters cannot be less than 0
with pytest.raises(AssertionError):
DropBlock(0.5, 3, -1)