import argparse import os, json, yaml from utils.public_field import * from utils.utils import * class Classification: def __init__(self, inference_result, GT_json_path, mapping): # two json paths self.inference_result = inference_result self.GT_json_path = GT_json_path self.mapping = mapping def MongoDB2Anno(self, db_path, projection_key, GT=False): """ This code transforms db format to coco format :param db_data: string, data path (export from MongoDB) :param mapping: """ #open data_path db_data = json.load(open(db_path)) #projection_key = DbKeys.LMK_COCO_BODY_17TS["key"] # images db_dict = {} for d in db_data: k = d[DbKeys.IMAGE_PATH["key"]] db_dict[k] = d[projection_key] if GT: print('GT ',end='') else: print('Preds ',end='') print(f"classification dict size: {len(db_data)}") return db_dict def evaluateC(self, scan_threshold = True, thresholds_range = [0, 1], detection_mapping = "kneron_detection", stat_classes = [], combine_classes = [], overall_top1_accuracy = True, projection_key = DbKeys.CLASS3["key"], saved_path="evaluation.txt"): """**Classification Evaluation** Description: If your output is class, classification is the right metric. Classification categorize each instance in certain label. PS: order is not considered in this function. Args: :parameter scan_threshold (bool, optional): Scan threshold from 0 to 1 to find the best performance with the highest accuracy. \ Please make sure you don't set threshold in your application before activate this parameter. :parameter thresholds_range (List[float], optional): the threshold to classify two class. Defaults to [0, 1]. :parameter detection_mapping (str, optional): the mapping. Defaults to kneron_detections. :parameter stat_classes (List[int], optional): If you only need subset classes from the selected detection_mapping, \ please specify them here. ONLY WHEN YOU NEED COMBINE CLASSES.Check their number in public_filed.py.Defaults \ to None. :parameter combine_classes (List[List[int]], optional): If this argument is activated, it means that the classes are composed \ of other subclasses. For example, if person(15), vehicle (27) and motor (45) \ in stat_classes, as vehicle is composed of car (7), truck (25), and bus (6), \ and motor is composed of motorbike (14) and bicycle (2). The stat_classes \ should be [15, 27, 45], and combine_classes should be [[], [6, 7, 25], [2, 14]] :parameter overall_top1_accuracy (bool, optional): If this is a multi-label task, you should deactivate this parameter, \ as this will give you a overall top-1 accuracy across all images. Default is True. """ # read gt_dict = self.MongoDB2Anno(self.GT_json_path, projection_key, GT=True) ct_dict = self.MongoDB2Anno(self.inference_result, projection_key, GT=False) # clean GT eliminate_list=[[],[""], "",[[]],[[""]],None,[[]],[[""]]] eliminate_img = [] for k, v in gt_dict.items(): if v in eliminate_list: eliminate_img.append(k) for img in eliminate_img: del gt_dict[img] del ct_dict[img] #Overall class/label count gt_class_id_count = Counter() TP_count = defaultdict(int) TN_count = defaultdict(int) FP_count = defaultdict(int) FP_count = defaultdict(int) no_detection = set() image_with_detection_count = 0 assert isinstance(thresholds_range, list) thresholds_range_list = np.arange(float(thresholds_range[0]), float(thresholds_range[1]), 0.1) correct_counter_all_classes = [0 for _ in range(len(thresholds_range_list))] # combine classes combine_mapping = {} if stat_classes and combine_classes: assert len(stat_classes) == len(combine_classes), print("[ERROR] Number of subclass: {} should be the same as combine class: {}!!!".format(len(stat_classes), len(combine_classes))) for i in range(len(stat_classes)): for subclass in combine_classes[i]: combine_mapping[subclass] = stat_classes[i] for k, v in gt_dict.items(): # print(k, v) # /mnt/testdata/RELEASE/EYE_LID_YAWNING/blink_detect/closesye/lS2itZhX.jpg [1, 34] [1, 34] gtv_np = np.asarray(v) #Check dim for error if ct_dict[k] and len(ct_dict[k]) != 0: cv = ct_dict[k] if isinstance(cv[0], list) and len(cv[0]) > 0 and isinstance(cv[0][0], list): #CLASS3 assert gtv_np.ndim == 3, print("[ERROR] Inference dimension is different from GT!!! GT dim: {}, \nimage: {}".format(gtv_np.ndim, k)) elif isinstance(v[0], list): #CLASS2 assert gtv_np.ndim == 2, print("[ERROR] Inference dimension is different from GT!!! GT dim: {}, \nimage: {}".format(gtv_np.ndim, k)) else: #CLASS1 assert gtv_np.ndim == 1, print("[ERROR] Inference dimension is different from GT!!! GT dim: {}, \nimage: {}".format(gtv_np.ndim, k)) else: no_detection.add(k) continue #check dim v = check_dim(v) # count for total classes number in gt if len(v[0]) == 0: gt_classes = [] else: #If combine_classes, only consider unique if not combine_mapping: gt_classes = Counter([int(x[1]) for x in v]) else: gt_classes = Counter(set([combine_mapping[int(x[1])] for x in v])) gt_class_id_count += gt_classes # check if this image's result is not None or [], gt_class_id_count only collects valid result. image_with_detection_count += 1 if not stat_classes: stat_classes = gt_class_id_count.keys() # Get accuracy from different thresholds with open(saved_path, "w") as fw: if scan_threshold: correct_counter_all_classes, _, _, _, _, _, _ = threshold_classification(thresholds_range_list, gt_dict, ct_dict, stat_classes, combine_mapping, overall_top1_accuracy) max_correct_count = max(correct_counter_all_classes) # Check largest thres for multiple thres has max correct num. max_correct_count_index_list=[i for i, n in enumerate(correct_counter_all_classes) if n == max_correct_count] best_threshold = round(thresholds_range_list[max(max_correct_count_index_list)], 2) #Use the best threshold to get result _, TP_count, TN_count, FP_count, FN_count, class_id_fp_error_count, class_id_fn_error_count = threshold_classification([best_threshold], gt_dict, ct_dict, stat_classes, combine_mapping, overall_top1_accuracy) if overall_top1_accuracy: print("Max correct count (total TP) @ threshold is: ", str(max_correct_count) + "@" + str(best_threshold)) print({"=== Use optimized thresh on current dataset ===":''}, file=fw) print({"Max correct count (total TP) @ threshold is" : str(max_correct_count) + "@" + str(best_threshold)}, file=fw) else: print("Max correct count (total TP+TN) @ threshold is: ", str(max_correct_count) + "@" + str(best_threshold)) print({"=== Use optimized thresh on current dataset ===":''}, file=fw) print({"Max correct count (total TP+TN) @ threshold is" : str(max_correct_count) + "@" + str(best_threshold)}, file=fw) else: # Do not find thresholds. Set thresholds_range_list = [0] correct_counter_all_classes, TP_count, TN_count, FP_count, FN_count, class_id_fp_error_count, class_id_fn_error_count = threshold_classification([0], gt_dict, ct_dict, stat_classes, combine_mapping, overall_top1_accuracy) # set best_threshold = 0.0 max_correct_count = correct_counter_all_classes[0] best_threshold = 0.0 #_, TP_count, TN_count, FP_count, FN_count, class_id_fp_error_count, class_id_fn_error_count = threshold_classification([best_threshold], gt_dict, ct_dict, stat_classes, combine_mapping, overall_top1_accuracy) if overall_top1_accuracy: print({"=== Use original threshold ===":''}, file=fw) print("Correct count (total TP) is: ", str(max_correct_count)) print({"Correct count (total TP) " : str(max_correct_count)}, file=fw) else: print({"=== Use original threshold ===":''}, file=fw) print("Correct count (total TP+TN) is: ", str(max_correct_count)) print({"Correct count (total TP+TN) " : str(max_correct_count)}, file=fw) print("gt total images: ",len(gt_dict.keys())) print({"gt total images" : len(gt_dict.keys())}, file=fw) print("total image with detection count: ", image_with_detection_count) print({"total image with detection count" : image_with_detection_count}, file=fw) print("image without detection count: ", len(no_detection)) print({"image without detection count" : len(no_detection)}, file=fw) precision_list = [0]*len(stat_classes) recall_list = [0]*len(stat_classes) for idx, k in enumerate(stat_classes): TP, TN, FP, FN = TP_count[k], TN_count[k], FP_count[k], FN_count[k] accuracy = round((TP+TN)/(TP+TN+FP+FN), 3) if (TP+TN+FP+FN) != 0 else 0 precision = round((TP)/(TP+FP), 3) if (TP+FP) != 0 else 0 recall = round((TP)/(TP+FN), 3) if (TP+FN) != 0 else 0 f1 = round((2*precision*recall)/(precision+recall), 3) if (precision+recall) != 0 else 0 precision_list[idx] = precision recall_list[idx] = recall #TP print(k, detection_map[detection_mapping][k], " TP #: ", TP) print({str(k) + " " + str(detection_map[detection_mapping][k]) + " TP #: ": TP}, file=fw) #TN print(k, detection_map[detection_mapping][k], " TN #: ", TN) print({str(k) + " " + str(detection_map[detection_mapping][k]) + " TN #: ": TN}, file=fw) #FP print(k, detection_map[detection_mapping][k], " FP #: ", FP) print({str(k) + " " + str(detection_map[detection_mapping][k]) + " FP #: ": FP}, file=fw) #FN print(k, detection_map[detection_mapping][k], " FN #: ", FN) print({str(k) + " " + str(detection_map[detection_mapping][k]) + " FN #: ": FN}, file=fw) #accuracy print(k, detection_map[detection_mapping][k], " top1-accuracy: ", accuracy) print({str(k) + " " + str(detection_map[detection_mapping][k]) + " accuracy: ": str(accuracy)}, file=fw) #self.stat_summary.update({str(k) + " " + str(detection_map[detection_mapping][k]) + " accuracy: ": str(accuracy)}) #precision print(k, detection_map[detection_mapping][k], " precision: ", precision) print({str(k) + " " + str(detection_map[detection_mapping][k]) + " precision: ": precision}, file=fw) #self.stat_summary.update({str(k) + " " + str(detection_map[detection_mapping][k]) + " precision: ": precision}) #recall print(k, detection_map[detection_mapping][k], " recall: ", recall) print({str(k) + " " + str(detection_map[detection_mapping][k]) + " recall: ": recall}, file=fw) #self.stat_summary.update({str(k) + " " + str(detection_map[detection_mapping][k]) + " recall: ": recall}) print(k, detection_map[detection_mapping][k], " F1-score: ", f1) print({str(k) + " " + str(detection_map[detection_mapping][k]) + " F1-score: ": f1}, file=fw) #self.stat_summary.update({str(k) + " " + str(detection_map[detection_mapping][k]) + " recall: ": recall}) #precision mean_precision = sum(precision_list)/len(precision_list) print("mean precision: ", mean_precision) print({"mean precision: ": mean_precision}, file=fw) #recall mean_recall = sum(recall_list)/len(recall_list) print("mean recall: ", mean_recall) print({"mean recall: ": mean_recall}, file=fw) #F1 mean_f1 = (2*mean_recall*mean_precision)/(mean_precision+mean_recall) print("mean F1-score: ", mean_f1) print({"mean F1-score: ": mean_f1}, file=fw) if overall_top1_accuracy: TP_sum = sum(TP_count.values()) total_instances = sum(TP_count.values()) + sum(FN_count.values()) top1_accuracy = TP_sum/total_instances print({"Overall top-1 accuracy : " : round(top1_accuracy,6)}, file=fw) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--yaml', type=str, default='yaml/classification.yaml', help='yaml for parameters') parser.add_argument('--output', type=str, default="output_classification.txt", help='output text file') opt = parser.parse_args() with open(opt.yaml, "r") as f: params_dict = yaml.load(f, Loader=yaml.FullLoader) # GT GT_json_path = params_dict["GT_json_path"] # inference inference_result = params_dict["inference_result"] # mapping mapping = detection_map[params_dict["mapping"]] # subclasses subclass = params_dict["subclass"] # class format class_format = params_dict["class_format"] # top1 acc top1_acc = params_dict["overall_top1_accuracy"] # scan scan_thres = params_dict["scan_threshold"] # evaluation evaluator = Classification(inference_result, GT_json_path, mapping) evaluator.evaluateC(scan_threshold=scan_thres, stat_classes = subclass, overall_top1_accuracy=top1_acc, projection_key=class_format, saved_path=opt.output)