diff --git a/mmseg/core/evaluation/mean_iou.py b/mmseg/core/evaluation/mean_iou.py index f0b4234..301cfd0 100644 --- a/mmseg/core/evaluation/mean_iou.py +++ b/mmseg/core/evaluation/mean_iou.py @@ -34,7 +34,7 @@ def intersect_and_union(pred_label, label, num_classes, ignore_index): return area_intersect, area_union, area_pred_label, area_label -def mean_iou(results, gt_seg_maps, num_classes, ignore_index): +def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None): """Calculate Intersection and Union (IoU) Args: @@ -42,6 +42,8 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index): gt_seg_maps (list[ndarray]): list of ground truth segmentation maps num_classes (int): Number of categories ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. Returns: float: Overall accuracy on all images. @@ -66,5 +68,7 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index): all_acc = total_area_intersect.sum() / total_area_label.sum() acc = total_area_intersect / total_area_label iou = total_area_intersect / total_area_union - + if nan_to_num is not None: + return all_acc, np.nan_to_num(acc, nan=nan_to_num), \ + np.nan_to_num(iou, nan=nan_to_num) return all_acc, acc, iou diff --git a/tests/test_mean_iou.py b/tests/test_mean_iou.py index 48a3df8..74a2b78 100644 --- a/tests/test_mean_iou.py +++ b/tests/test_mean_iou.py @@ -54,3 +54,10 @@ def test_mean_iou(): assert all_acc == all_acc_l assert np.allclose(acc, acc_l) assert np.allclose(iou, iou_l) + + results = np.random.randint(0, 5, size=pred_size) + label = np.random.randint(0, 4, size=pred_size) + all_acc, acc, iou = mean_iou( + results, label, num_classes, ignore_index=255, nan_to_num=-1) + assert acc[-1] == -1 + assert iou[-1] == -1