115 lines
5.2 KiB
Python
115 lines
5.2 KiB
Python
from collections import defaultdict, Counter
|
|
import numpy as np
|
|
|
|
def check_dim(narray):
|
|
"""
|
|
|
|
This function checks list dim, and return CLASS format.
|
|
|
|
: parameter narray (list): list of prediction or GT from CLASS.
|
|
|
|
"""
|
|
#CLASS3 list dim = 3
|
|
if isinstance(narray[0], list) and len(narray[0]) > 0 and isinstance(narray[0][0], list):
|
|
narray = narray[0]
|
|
#CLASS2 list dim = 2
|
|
elif isinstance(narray[0], list):
|
|
narray = [narray[0]]
|
|
# CLASS1 list dim = 1
|
|
elif isinstance(narray, list) == 1:
|
|
narray = [narray]
|
|
else:
|
|
raise ValueError("[ERROR] Inference result does not belong to CLASS1, CLASS2, or CLASS3!")
|
|
|
|
return np.asarray(narray)
|
|
|
|
def threshold_classification(thresholds_range_list, gt_dict, ct_dict, stat_classes, combine_mapping=[], overall_top1_accuracy=True):
|
|
"""
|
|
|
|
This function uses different threshold to identify the best accuracy.
|
|
|
|
: parameter thresholds_range_list (Counter): threshold, should be from 0 to 1.
|
|
: parameter stat_classes (list): GT total classes and their counts (selected).
|
|
: parameter gt_dict (Counter): GT images and CLASS.
|
|
: parameter ct_dict (Counter): inference images and CLASS.
|
|
: parameter combine_mapping (Counter): combin mapping.
|
|
|
|
"""
|
|
#initiation
|
|
print(stat_classes)
|
|
# (TP, TN, FP, FN)
|
|
correct_counter_all_classes = [0 for _ in range(len(thresholds_range_list))]
|
|
TPR_thres_class = [[0 for _ in range(len(stat_classes))] for _ in range(len(thresholds_range_list))]
|
|
FPR_thres_class = [[0 for _ in range(len(stat_classes))] for _ in range(len(thresholds_range_list))]
|
|
precision_thres_class = [[0 for _ in range(len(stat_classes))] for _ in range(len(thresholds_range_list))]
|
|
|
|
# threshold range
|
|
for i, threshold in enumerate(thresholds_range_list):
|
|
# count under different threshold
|
|
TP_count = defaultdict(int)
|
|
TN_count = defaultdict(int)
|
|
FP_count = defaultdict(int)
|
|
FN_count = defaultdict(int)
|
|
class_id_fp_error_count = defaultdict(set)
|
|
class_id_fn_error_count = defaultdict(set)
|
|
# iterate classes
|
|
for idx, class_id in enumerate(stat_classes):
|
|
#iterate images
|
|
for k, v in gt_dict.items():
|
|
v = check_dim(v)
|
|
#gt process
|
|
if len(v[0]) == 0:
|
|
gt_classes = Counter([0])
|
|
else:
|
|
#If combine_classes, only consider unique
|
|
if not combine_mapping:
|
|
# only consider the target class
|
|
gt_classes = Counter([int(x[1]) for x in v if int(x[1])==class_id])
|
|
else:
|
|
gt_classes = Counter(set([combine_mapping[int(x[1])] for x in v if combine_mapping[int(x[1])]==class_id]))
|
|
#ct process
|
|
if ct_dict[k] and len(ct_dict[k]) != 0:
|
|
cv = ct_dict[k]
|
|
cv = check_dim(cv)
|
|
ct_classes = Counter([int(x[1]) for x in cv if x[0] >= threshold and x[1]==class_id])
|
|
|
|
#add up correct pictures
|
|
if ct_classes == gt_classes:
|
|
if gt_classes and ct_classes:
|
|
TP_count[class_id] += 1
|
|
correct_counter_all_classes[i] += 1
|
|
else:
|
|
TN_count[class_id] += 1
|
|
if not overall_top1_accuracy:
|
|
correct_counter_all_classes[i] += 1
|
|
#FP
|
|
elif not gt_classes and ct_classes:
|
|
FP_count[class_id] += 1
|
|
class_id_fp_error_count[class_id].add(k)
|
|
#FN
|
|
elif gt_classes and not ct_classes:
|
|
FN_count[class_id] += 1
|
|
class_id_fn_error_count[class_id].add(k)
|
|
else:
|
|
print(gt_classes)
|
|
print(ct_classes)
|
|
raise ValueError("Weird Case")
|
|
|
|
|
|
corner_case = TP_count[class_id]+FN_count[class_id] == 0 or FP_count[class_id]+TN_count[class_id] == 0
|
|
# exception for the corner case TP + FN = 0 or FP + TN = 0 which leads to division by zero.
|
|
if corner_case:
|
|
print(f"Corner case. TP + FN = {TP_count[class_id]+FN_count[class_id]} or FP + TN = {FP_count[class_id]+TN_count[class_id]}")
|
|
TPR_thres_class[i][idx] = None
|
|
FPR_thres_class[i][idx] = None
|
|
else:
|
|
# TPR = recall
|
|
TPR = TP_count[class_id]/(TP_count[class_id]+FN_count[class_id])
|
|
FPR = FP_count[class_id]/(FP_count[class_id]+TN_count[class_id])
|
|
TPR_thres_class[i][idx] = TPR
|
|
FPR_thres_class[i][idx] = FPR
|
|
|
|
precision = TP_count[class_id]/(TP_count[class_id]+FP_count[class_id]) if (TP_count[class_id]+FP_count[class_id]) else None
|
|
precision_thres_class[i][idx] = precision
|
|
|
|
return correct_counter_all_classes, TP_count, TN_count, FP_count, FN_count, class_id_fp_error_count, class_id_fn_error_count |