diff --git a/mmseg/models/losses/accuracy.py b/mmseg/models/losses/accuracy.py index 7cd15e2..28d55c4 100644 --- a/mmseg/models/losses/accuracy.py +++ b/mmseg/models/losses/accuracy.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch import torch.nn as nn @@ -46,10 +47,13 @@ def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): correct = correct & (pred_value > thresh).t() correct = correct[:, target != ignore_index] res = [] + eps = torch.finfo(torch.float32).eps for k in topk: - correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) - res.append( - correct_k.mul_(100.0 / target[target != ignore_index].numel())) + # Avoid causing ZeroDivisionError when all pixels + # of an image are ignored + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps + total_num = target[target != ignore_index].numel() + eps + res.append(correct_k.mul_(100.0 / total_num)) return res[0] if return_single else res diff --git a/tests/test_models/test_losses/test_utils.py b/tests/test_models/test_losses/test_utils.py index ac5c666..ab9927f 100644 --- a/tests/test_models/test_losses/test_utils.py +++ b/tests/test_models/test_losses/test_utils.py @@ -56,50 +56,56 @@ def test_accuracy(): true_label = torch.Tensor([2, 3, 0, 1, 2]).long() accuracy = Accuracy(topk=1, ignore_index=None) acc = accuracy(pred, true_label) - assert acc.item() == 100 + assert torch.allclose(acc, torch.tensor(100.0)) # test for ignore_index with a wrong prediction of that index true_label = torch.Tensor([2, 3, 1, 1, 2]).long() accuracy = Accuracy(topk=1, ignore_index=1) acc = accuracy(pred, true_label) - assert acc.item() == 100 + assert torch.allclose(acc, torch.tensor(100.0)) # test for ignore_index 1 with a wrong prediction of other index true_label = torch.Tensor([2, 0, 0, 1, 2]).long() accuracy = Accuracy(topk=1, ignore_index=1) acc = accuracy(pred, true_label) - assert acc.item() == 75 + assert torch.allclose(acc, torch.tensor(75.0)) # test for ignore_index 4 with a wrong prediction of other index true_label = torch.Tensor([2, 0, 0, 1, 2]).long() accuracy = Accuracy(topk=1, ignore_index=4) acc = accuracy(pred, true_label) - assert acc.item() == 80 + assert torch.allclose(acc, torch.tensor(80.0)) + + # test for ignoring all the pixels + true_label = torch.Tensor([2, 2, 2, 2, 2]).long() + accuracy = Accuracy(topk=1, ignore_index=2) + acc = accuracy(pred, true_label) + assert torch.allclose(acc, torch.tensor(100.0)) # test for top1 true_label = torch.Tensor([2, 3, 0, 1, 2]).long() accuracy = Accuracy(topk=1) acc = accuracy(pred, true_label) - assert acc.item() == 100 + assert torch.allclose(acc, torch.tensor(100.0)) # test for top1 with score thresh=0.8 true_label = torch.Tensor([2, 3, 0, 1, 2]).long() accuracy = Accuracy(topk=1, thresh=0.8) acc = accuracy(pred, true_label) - assert acc.item() == 40 + assert torch.allclose(acc, torch.tensor(40.0)) # test for top2 accuracy = Accuracy(topk=2) label = torch.Tensor([3, 2, 0, 0, 2]).long() acc = accuracy(pred, label) - assert acc.item() == 100 + assert torch.allclose(acc, torch.tensor(100.0)) # test for both top1 and top2 accuracy = Accuracy(topk=(1, 2)) true_label = torch.Tensor([2, 3, 0, 1, 2]).long() acc = accuracy(pred, true_label) for a in acc: - assert a.item() == 100 + assert torch.allclose(a, torch.tensor(100.0)) # topk is larger than pred class number with pytest.raises(AssertionError):