from collections import defaultdict, Counter import numpy as np def check_dim(narray): """ This function checks list dim, and return CLASS format. : parameter narray (list): list of prediction or GT from CLASS. """ #CLASS3 list dim = 3 if isinstance(narray[0], list) and len(narray[0]) > 0 and isinstance(narray[0][0], list): narray = narray[0] #CLASS2 list dim = 2 elif isinstance(narray[0], list): narray = [narray[0]] # CLASS1 list dim = 1 elif isinstance(narray, list) == 1: narray = [narray] else: raise ValueError("[ERROR] Inference result does not belong to CLASS1, CLASS2, or CLASS3!") return np.asarray(narray) def threshold_classification(thresholds_range_list, gt_dict, ct_dict, stat_classes, combine_mapping=[], overall_top1_accuracy=True): """ This function uses different threshold to identify the best accuracy. : parameter thresholds_range_list (Counter): threshold, should be from 0 to 1. : parameter stat_classes (list): GT total classes and their counts (selected). : parameter gt_dict (Counter): GT images and CLASS. : parameter ct_dict (Counter): inference images and CLASS. : parameter combine_mapping (Counter): combin mapping. """ #initiation print(stat_classes) # (TP, TN, FP, FN) correct_counter_all_classes = [0 for _ in range(len(thresholds_range_list))] TPR_thres_class = [[0 for _ in range(len(stat_classes))] for _ in range(len(thresholds_range_list))] FPR_thres_class = [[0 for _ in range(len(stat_classes))] for _ in range(len(thresholds_range_list))] precision_thres_class = [[0 for _ in range(len(stat_classes))] for _ in range(len(thresholds_range_list))] # threshold range for i, threshold in enumerate(thresholds_range_list): # count under different threshold TP_count = defaultdict(int) TN_count = defaultdict(int) FP_count = defaultdict(int) FN_count = defaultdict(int) class_id_fp_error_count = defaultdict(set) class_id_fn_error_count = defaultdict(set) # iterate classes for idx, class_id in enumerate(stat_classes): #iterate images for k, v in gt_dict.items(): v = check_dim(v) #gt process if len(v[0]) == 0: gt_classes = Counter([0]) else: #If combine_classes, only consider unique if not combine_mapping: # only consider the target class gt_classes = Counter([int(x[1]) for x in v if int(x[1])==class_id]) else: gt_classes = Counter(set([combine_mapping[int(x[1])] for x in v if combine_mapping[int(x[1])]==class_id])) #ct process if ct_dict[k] and len(ct_dict[k]) != 0: cv = ct_dict[k] cv = check_dim(cv) ct_classes = Counter([int(x[1]) for x in cv if x[0] >= threshold and x[1]==class_id]) #add up correct pictures if ct_classes == gt_classes: if gt_classes and ct_classes: TP_count[class_id] += 1 correct_counter_all_classes[i] += 1 else: TN_count[class_id] += 1 if not overall_top1_accuracy: correct_counter_all_classes[i] += 1 #FP elif not gt_classes and ct_classes: FP_count[class_id] += 1 class_id_fp_error_count[class_id].add(k) #FN elif gt_classes and not ct_classes: FN_count[class_id] += 1 class_id_fn_error_count[class_id].add(k) else: print(gt_classes) print(ct_classes) raise ValueError("Weird Case") corner_case = TP_count[class_id]+FN_count[class_id] == 0 or FP_count[class_id]+TN_count[class_id] == 0 # exception for the corner case TP + FN = 0 or FP + TN = 0 which leads to division by zero. if corner_case: print(f"Corner case. TP + FN = {TP_count[class_id]+FN_count[class_id]} or FP + TN = {FP_count[class_id]+TN_count[class_id]}") TPR_thres_class[i][idx] = None FPR_thres_class[i][idx] = None else: # TPR = recall TPR = TP_count[class_id]/(TP_count[class_id]+FN_count[class_id]) FPR = FP_count[class_id]/(FP_count[class_id]+TN_count[class_id]) TPR_thres_class[i][idx] = TPR FPR_thres_class[i][idx] = FPR precision = TP_count[class_id]/(TP_count[class_id]+FP_count[class_id]) if (TP_count[class_id]+FP_count[class_id]) else None precision_thres_class[i][idx] = precision return correct_counter_all_classes, TP_count, TN_count, FP_count, FN_count, class_id_fp_error_count, class_id_fn_error_count