115 lines
5.2 KiB
Python

from collections import defaultdict, Counter
import numpy as np
def check_dim(narray):
"""
This function checks list dim, and return CLASS format.
: parameter narray (list): list of prediction or GT from CLASS.
"""
#CLASS3 list dim = 3
if isinstance(narray[0], list) and len(narray[0]) > 0 and isinstance(narray[0][0], list):
narray = narray[0]
#CLASS2 list dim = 2
elif isinstance(narray[0], list):
narray = [narray[0]]
# CLASS1 list dim = 1
elif isinstance(narray, list) == 1:
narray = [narray]
else:
raise ValueError("[ERROR] Inference result does not belong to CLASS1, CLASS2, or CLASS3!")
return np.asarray(narray)
def threshold_classification(thresholds_range_list, gt_dict, ct_dict, stat_classes, combine_mapping=[], overall_top1_accuracy=True):
"""
This function uses different threshold to identify the best accuracy.
: parameter thresholds_range_list (Counter): threshold, should be from 0 to 1.
: parameter stat_classes (list): GT total classes and their counts (selected).
: parameter gt_dict (Counter): GT images and CLASS.
: parameter ct_dict (Counter): inference images and CLASS.
: parameter combine_mapping (Counter): combin mapping.
"""
#initiation
print(stat_classes)
# (TP, TN, FP, FN)
correct_counter_all_classes = [0 for _ in range(len(thresholds_range_list))]
TPR_thres_class = [[0 for _ in range(len(stat_classes))] for _ in range(len(thresholds_range_list))]
FPR_thres_class = [[0 for _ in range(len(stat_classes))] for _ in range(len(thresholds_range_list))]
precision_thres_class = [[0 for _ in range(len(stat_classes))] for _ in range(len(thresholds_range_list))]
# threshold range
for i, threshold in enumerate(thresholds_range_list):
# count under different threshold
TP_count = defaultdict(int)
TN_count = defaultdict(int)
FP_count = defaultdict(int)
FN_count = defaultdict(int)
class_id_fp_error_count = defaultdict(set)
class_id_fn_error_count = defaultdict(set)
# iterate classes
for idx, class_id in enumerate(stat_classes):
#iterate images
for k, v in gt_dict.items():
v = check_dim(v)
#gt process
if len(v[0]) == 0:
gt_classes = Counter([0])
else:
#If combine_classes, only consider unique
if not combine_mapping:
# only consider the target class
gt_classes = Counter([int(x[1]) for x in v if int(x[1])==class_id])
else:
gt_classes = Counter(set([combine_mapping[int(x[1])] for x in v if combine_mapping[int(x[1])]==class_id]))
#ct process
if ct_dict[k] and len(ct_dict[k]) != 0:
cv = ct_dict[k]
cv = check_dim(cv)
ct_classes = Counter([int(x[1]) for x in cv if x[0] >= threshold and x[1]==class_id])
#add up correct pictures
if ct_classes == gt_classes:
if gt_classes and ct_classes:
TP_count[class_id] += 1
correct_counter_all_classes[i] += 1
else:
TN_count[class_id] += 1
if not overall_top1_accuracy:
correct_counter_all_classes[i] += 1
#FP
elif not gt_classes and ct_classes:
FP_count[class_id] += 1
class_id_fp_error_count[class_id].add(k)
#FN
elif gt_classes and not ct_classes:
FN_count[class_id] += 1
class_id_fn_error_count[class_id].add(k)
else:
print(gt_classes)
print(ct_classes)
raise ValueError("Weird Case")
corner_case = TP_count[class_id]+FN_count[class_id] == 0 or FP_count[class_id]+TN_count[class_id] == 0
# exception for the corner case TP + FN = 0 or FP + TN = 0 which leads to division by zero.
if corner_case:
print(f"Corner case. TP + FN = {TP_count[class_id]+FN_count[class_id]} or FP + TN = {FP_count[class_id]+TN_count[class_id]}")
TPR_thres_class[i][idx] = None
FPR_thres_class[i][idx] = None
else:
# TPR = recall
TPR = TP_count[class_id]/(TP_count[class_id]+FN_count[class_id])
FPR = FP_count[class_id]/(FP_count[class_id]+TN_count[class_id])
TPR_thres_class[i][idx] = TPR
FPR_thres_class[i][idx] = FPR
precision = TP_count[class_id]/(TP_count[class_id]+FP_count[class_id]) if (TP_count[class_id]+FP_count[class_id]) else None
precision_thres_class[i][idx] = precision
return correct_counter_all_classes, TP_count, TN_count, FP_count, FN_count, class_id_fp_error_count, class_id_fn_error_count