[Fix] Fix the bug that when all pixels in an image is ignored, the ac… (#1336)

* [Fix] Fix the bug that when all pixels in an image is ignored, the accuracy calculation raises ZeroDivisionError

* use eps

* all close

* add ignore test

* add eps
This commit is contained in:
Rockey 2022-03-09 13:20:46 +08:00 committed by EricWu
parent 50c80aa6f2
commit 70d7bbd198
2 changed files with 21 additions and 11 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn import torch.nn as nn
@ -46,10 +47,13 @@ def accuracy(pred, target, topk=1, thresh=None, ignore_index=None):
correct = correct & (pred_value > thresh).t() correct = correct & (pred_value > thresh).t()
correct = correct[:, target != ignore_index] correct = correct[:, target != ignore_index]
res = [] res = []
eps = torch.finfo(torch.float32).eps
for k in topk: for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) # Avoid causing ZeroDivisionError when all pixels
res.append( # of an image are ignored
correct_k.mul_(100.0 / target[target != ignore_index].numel())) correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps
total_num = target[target != ignore_index].numel() + eps
res.append(correct_k.mul_(100.0 / total_num))
return res[0] if return_single else res return res[0] if return_single else res

View File

@ -56,50 +56,56 @@ def test_accuracy():
true_label = torch.Tensor([2, 3, 0, 1, 2]).long() true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=None) accuracy = Accuracy(topk=1, ignore_index=None)
acc = accuracy(pred, true_label) acc = accuracy(pred, true_label)
assert acc.item() == 100 assert torch.allclose(acc, torch.tensor(100.0))
# test for ignore_index with a wrong prediction of that index # test for ignore_index with a wrong prediction of that index
true_label = torch.Tensor([2, 3, 1, 1, 2]).long() true_label = torch.Tensor([2, 3, 1, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=1) accuracy = Accuracy(topk=1, ignore_index=1)
acc = accuracy(pred, true_label) acc = accuracy(pred, true_label)
assert acc.item() == 100 assert torch.allclose(acc, torch.tensor(100.0))
# test for ignore_index 1 with a wrong prediction of other index # test for ignore_index 1 with a wrong prediction of other index
true_label = torch.Tensor([2, 0, 0, 1, 2]).long() true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=1) accuracy = Accuracy(topk=1, ignore_index=1)
acc = accuracy(pred, true_label) acc = accuracy(pred, true_label)
assert acc.item() == 75 assert torch.allclose(acc, torch.tensor(75.0))
# test for ignore_index 4 with a wrong prediction of other index # test for ignore_index 4 with a wrong prediction of other index
true_label = torch.Tensor([2, 0, 0, 1, 2]).long() true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=4) accuracy = Accuracy(topk=1, ignore_index=4)
acc = accuracy(pred, true_label) acc = accuracy(pred, true_label)
assert acc.item() == 80 assert torch.allclose(acc, torch.tensor(80.0))
# test for ignoring all the pixels
true_label = torch.Tensor([2, 2, 2, 2, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=2)
acc = accuracy(pred, true_label)
assert torch.allclose(acc, torch.tensor(100.0))
# test for top1 # test for top1
true_label = torch.Tensor([2, 3, 0, 1, 2]).long() true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
accuracy = Accuracy(topk=1) accuracy = Accuracy(topk=1)
acc = accuracy(pred, true_label) acc = accuracy(pred, true_label)
assert acc.item() == 100 assert torch.allclose(acc, torch.tensor(100.0))
# test for top1 with score thresh=0.8 # test for top1 with score thresh=0.8
true_label = torch.Tensor([2, 3, 0, 1, 2]).long() true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, thresh=0.8) accuracy = Accuracy(topk=1, thresh=0.8)
acc = accuracy(pred, true_label) acc = accuracy(pred, true_label)
assert acc.item() == 40 assert torch.allclose(acc, torch.tensor(40.0))
# test for top2 # test for top2
accuracy = Accuracy(topk=2) accuracy = Accuracy(topk=2)
label = torch.Tensor([3, 2, 0, 0, 2]).long() label = torch.Tensor([3, 2, 0, 0, 2]).long()
acc = accuracy(pred, label) acc = accuracy(pred, label)
assert acc.item() == 100 assert torch.allclose(acc, torch.tensor(100.0))
# test for both top1 and top2 # test for both top1 and top2
accuracy = Accuracy(topk=(1, 2)) accuracy = Accuracy(topk=(1, 2))
true_label = torch.Tensor([2, 3, 0, 1, 2]).long() true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
acc = accuracy(pred, true_label) acc = accuracy(pred, true_label)
for a in acc: for a in acc:
assert a.item() == 100 assert torch.allclose(a, torch.tensor(100.0))
# topk is larger than pred class number # topk is larger than pred class number
with pytest.raises(AssertionError): with pytest.raises(AssertionError):