271 lines
15 KiB
Python
271 lines
15 KiB
Python
import argparse
|
|
import os, json, yaml
|
|
from utils.public_field import *
|
|
from utils.utils import *
|
|
|
|
|
|
class Classification:
|
|
|
|
def __init__(self, inference_result, GT_json_path, mapping):
|
|
# two json paths
|
|
self.inference_result = inference_result
|
|
self.GT_json_path = GT_json_path
|
|
self.mapping = mapping
|
|
|
|
def MongoDB2Anno(self, db_path, projection_key, GT=False):
|
|
"""
|
|
This code transforms db format to coco format
|
|
|
|
:param db_data: string, data path (export from MongoDB)
|
|
:param mapping:
|
|
"""
|
|
#open data_path
|
|
db_data = json.load(open(db_path))
|
|
|
|
#projection_key = DbKeys.LMK_COCO_BODY_17TS["key"]
|
|
# images
|
|
db_dict = {}
|
|
for d in db_data:
|
|
k = d[DbKeys.IMAGE_PATH["key"]]
|
|
db_dict[k] = d[projection_key]
|
|
if GT:
|
|
print('GT ',end='')
|
|
|
|
else:
|
|
print('Preds ',end='')
|
|
print(f"classification dict size: {len(db_data)}")
|
|
return db_dict
|
|
|
|
def evaluateC(self, scan_threshold = True, thresholds_range = [0, 1], detection_mapping = "kneron_detection", stat_classes = [], combine_classes = [], overall_top1_accuracy = True, projection_key = DbKeys.CLASS3["key"], saved_path="evaluation.txt"):
|
|
"""**Classification Evaluation**
|
|
|
|
Description:
|
|
If your output is class, classification is the right metric. Classification categorize
|
|
each instance in certain label.
|
|
|
|
PS: order is not considered in this function.
|
|
|
|
Args:
|
|
:parameter scan_threshold (bool, optional): Scan threshold from 0 to 1 to find the best performance with the highest accuracy. \
|
|
Please make sure you don't set threshold in your application before activate this parameter.
|
|
:parameter thresholds_range (List[float], optional): the threshold to classify two class.
|
|
Defaults to [0, 1].
|
|
:parameter detection_mapping (str, optional): the mapping. Defaults to kneron_detections.
|
|
:parameter stat_classes (List[int], optional): If you only need subset classes from the selected detection_mapping, \
|
|
please specify them here. ONLY WHEN YOU NEED COMBINE CLASSES.Check their number in public_filed.py.Defaults \
|
|
to None.
|
|
:parameter combine_classes (List[List[int]], optional): If this argument is activated, it means that the classes are composed \
|
|
of other subclasses. For example, if person(15), vehicle (27) and motor (45) \
|
|
in stat_classes, as vehicle is composed of car (7), truck (25), and bus (6), \
|
|
and motor is composed of motorbike (14) and bicycle (2). The stat_classes \
|
|
should be [15, 27, 45], and combine_classes should be [[], [6, 7, 25], [2, 14]]
|
|
:parameter overall_top1_accuracy (bool, optional): If this is a multi-label task, you should deactivate this parameter, \
|
|
as this will give you a overall top-1 accuracy across all images. Default is True.
|
|
"""
|
|
# read
|
|
gt_dict = self.MongoDB2Anno(self.GT_json_path, projection_key, GT=True)
|
|
ct_dict = self.MongoDB2Anno(self.inference_result, projection_key, GT=False)
|
|
|
|
# clean GT
|
|
eliminate_list=[[],[""], "",[[]],[[""]],None,[[]],[[""]]]
|
|
eliminate_img = []
|
|
for k, v in gt_dict.items():
|
|
if v in eliminate_list:
|
|
eliminate_img.append(k)
|
|
for img in eliminate_img:
|
|
del gt_dict[img]
|
|
del ct_dict[img]
|
|
|
|
#Overall class/label count
|
|
gt_class_id_count = Counter()
|
|
|
|
TP_count = defaultdict(int)
|
|
TN_count = defaultdict(int)
|
|
FP_count = defaultdict(int)
|
|
FP_count = defaultdict(int)
|
|
no_detection = set()
|
|
image_with_detection_count = 0
|
|
|
|
assert isinstance(thresholds_range, list)
|
|
thresholds_range_list = np.arange(float(thresholds_range[0]), float(thresholds_range[1]), 0.1)
|
|
correct_counter_all_classes = [0 for _ in range(len(thresholds_range_list))]
|
|
|
|
# combine classes
|
|
combine_mapping = {}
|
|
if stat_classes and combine_classes:
|
|
assert len(stat_classes) == len(combine_classes), print("[ERROR] Number of subclass: {} should be the same as combine class: {}!!!".format(len(stat_classes), len(combine_classes)))
|
|
for i in range(len(stat_classes)):
|
|
for subclass in combine_classes[i]:
|
|
combine_mapping[subclass] = stat_classes[i]
|
|
|
|
for k, v in gt_dict.items():
|
|
# print(k, v) # /mnt/testdata/RELEASE/EYE_LID_YAWNING/blink_detect/closesye/lS2itZhX.jpg [1, 34] [1, 34]
|
|
gtv_np = np.asarray(v)
|
|
#Check dim for error
|
|
if ct_dict[k] and len(ct_dict[k]) != 0:
|
|
cv = ct_dict[k]
|
|
if isinstance(cv[0], list) and len(cv[0]) > 0 and isinstance(cv[0][0], list):
|
|
#CLASS3
|
|
assert gtv_np.ndim == 3, print("[ERROR] Inference dimension is different from GT!!! GT dim: {}, \nimage: {}".format(gtv_np.ndim, k))
|
|
elif isinstance(v[0], list):
|
|
#CLASS2
|
|
assert gtv_np.ndim == 2, print("[ERROR] Inference dimension is different from GT!!! GT dim: {}, \nimage: {}".format(gtv_np.ndim, k))
|
|
else:
|
|
#CLASS1
|
|
assert gtv_np.ndim == 1, print("[ERROR] Inference dimension is different from GT!!! GT dim: {}, \nimage: {}".format(gtv_np.ndim, k))
|
|
else:
|
|
no_detection.add(k)
|
|
continue
|
|
#check dim
|
|
v = check_dim(v)
|
|
# count for total classes number in gt
|
|
if len(v[0]) == 0:
|
|
gt_classes = []
|
|
else:
|
|
#If combine_classes, only consider unique
|
|
if not combine_mapping:
|
|
gt_classes = Counter([int(x[1]) for x in v])
|
|
else:
|
|
gt_classes = Counter(set([combine_mapping[int(x[1])] for x in v]))
|
|
gt_class_id_count += gt_classes
|
|
# check if this image's result is not None or [], gt_class_id_count only collects valid result.
|
|
image_with_detection_count += 1
|
|
|
|
if not stat_classes:
|
|
stat_classes = gt_class_id_count.keys()
|
|
# Get accuracy from different thresholds
|
|
|
|
with open(saved_path, "w") as fw:
|
|
|
|
if scan_threshold:
|
|
correct_counter_all_classes, _, _, _, _, _, _ = threshold_classification(thresholds_range_list, gt_dict, ct_dict, stat_classes, combine_mapping, overall_top1_accuracy)
|
|
max_correct_count = max(correct_counter_all_classes)
|
|
# Check largest thres for multiple thres has max correct num.
|
|
max_correct_count_index_list=[i for i, n in enumerate(correct_counter_all_classes) if n == max_correct_count]
|
|
best_threshold = round(thresholds_range_list[max(max_correct_count_index_list)], 2)
|
|
|
|
#Use the best threshold to get result
|
|
_, TP_count, TN_count, FP_count, FN_count, class_id_fp_error_count, class_id_fn_error_count = threshold_classification([best_threshold], gt_dict, ct_dict, stat_classes, combine_mapping, overall_top1_accuracy)
|
|
|
|
if overall_top1_accuracy:
|
|
print("Max correct count (total TP) @ threshold is: ", str(max_correct_count) + "@" + str(best_threshold))
|
|
print({"=== Use optimized thresh on current dataset ===":''}, file=fw)
|
|
print({"Max correct count (total TP) @ threshold is" : str(max_correct_count) + "@" + str(best_threshold)}, file=fw)
|
|
else:
|
|
print("Max correct count (total TP+TN) @ threshold is: ", str(max_correct_count) + "@" + str(best_threshold))
|
|
print({"=== Use optimized thresh on current dataset ===":''}, file=fw)
|
|
print({"Max correct count (total TP+TN) @ threshold is" : str(max_correct_count) + "@" + str(best_threshold)}, file=fw)
|
|
else:
|
|
# Do not find thresholds. Set thresholds_range_list = [0]
|
|
correct_counter_all_classes, TP_count, TN_count, FP_count, FN_count, class_id_fp_error_count, class_id_fn_error_count = threshold_classification([0], gt_dict, ct_dict, stat_classes, combine_mapping, overall_top1_accuracy)
|
|
# set best_threshold = 0.0
|
|
max_correct_count = correct_counter_all_classes[0]
|
|
best_threshold = 0.0
|
|
|
|
#_, TP_count, TN_count, FP_count, FN_count, class_id_fp_error_count, class_id_fn_error_count = threshold_classification([best_threshold], gt_dict, ct_dict, stat_classes, combine_mapping, overall_top1_accuracy)
|
|
|
|
if overall_top1_accuracy:
|
|
print({"=== Use original threshold ===":''}, file=fw)
|
|
print("Correct count (total TP) is: ", str(max_correct_count))
|
|
print({"Correct count (total TP) " : str(max_correct_count)}, file=fw)
|
|
else:
|
|
print({"=== Use original threshold ===":''}, file=fw)
|
|
print("Correct count (total TP+TN) is: ", str(max_correct_count))
|
|
print({"Correct count (total TP+TN) " : str(max_correct_count)}, file=fw)
|
|
|
|
print("gt total images: ",len(gt_dict.keys()))
|
|
print({"gt total images" : len(gt_dict.keys())}, file=fw)
|
|
|
|
print("total image with detection count: ", image_with_detection_count)
|
|
print({"total image with detection count" : image_with_detection_count}, file=fw)
|
|
|
|
print("image without detection count: ", len(no_detection))
|
|
print({"image without detection count" : len(no_detection)}, file=fw)
|
|
|
|
precision_list = [0]*len(stat_classes)
|
|
recall_list = [0]*len(stat_classes)
|
|
for idx, k in enumerate(stat_classes):
|
|
TP, TN, FP, FN = TP_count[k], TN_count[k], FP_count[k], FN_count[k]
|
|
accuracy = round((TP+TN)/(TP+TN+FP+FN), 3) if (TP+TN+FP+FN) != 0 else 0
|
|
precision = round((TP)/(TP+FP), 3) if (TP+FP) != 0 else 0
|
|
recall = round((TP)/(TP+FN), 3) if (TP+FN) != 0 else 0
|
|
f1 = round((2*precision*recall)/(precision+recall), 3) if (precision+recall) != 0 else 0
|
|
precision_list[idx] = precision
|
|
recall_list[idx] = recall
|
|
#TP
|
|
print(k, detection_map[detection_mapping][k], " TP #: ", TP)
|
|
print({str(k) + " " + str(detection_map[detection_mapping][k]) + " TP #: ": TP}, file=fw)
|
|
#TN
|
|
print(k, detection_map[detection_mapping][k], " TN #: ", TN)
|
|
print({str(k) + " " + str(detection_map[detection_mapping][k]) + " TN #: ": TN}, file=fw)
|
|
#FP
|
|
print(k, detection_map[detection_mapping][k], " FP #: ", FP)
|
|
print({str(k) + " " + str(detection_map[detection_mapping][k]) + " FP #: ": FP}, file=fw)
|
|
#FN
|
|
print(k, detection_map[detection_mapping][k], " FN #: ", FN)
|
|
print({str(k) + " " + str(detection_map[detection_mapping][k]) + " FN #: ": FN}, file=fw)
|
|
#accuracy
|
|
print(k, detection_map[detection_mapping][k], " top1-accuracy: ", accuracy)
|
|
print({str(k) + " " + str(detection_map[detection_mapping][k]) + " accuracy: ": str(accuracy)}, file=fw)
|
|
#self.stat_summary.update({str(k) + " " + str(detection_map[detection_mapping][k]) + " accuracy: ": str(accuracy)})
|
|
#precision
|
|
print(k, detection_map[detection_mapping][k], " precision: ", precision)
|
|
print({str(k) + " " + str(detection_map[detection_mapping][k]) + " precision: ": precision}, file=fw)
|
|
#self.stat_summary.update({str(k) + " " + str(detection_map[detection_mapping][k]) + " precision: ": precision})
|
|
#recall
|
|
print(k, detection_map[detection_mapping][k], " recall: ", recall)
|
|
print({str(k) + " " + str(detection_map[detection_mapping][k]) + " recall: ": recall}, file=fw)
|
|
#self.stat_summary.update({str(k) + " " + str(detection_map[detection_mapping][k]) + " recall: ": recall})
|
|
print(k, detection_map[detection_mapping][k], " F1-score: ", f1)
|
|
print({str(k) + " " + str(detection_map[detection_mapping][k]) + " F1-score: ": f1}, file=fw)
|
|
#self.stat_summary.update({str(k) + " " + str(detection_map[detection_mapping][k]) + " recall: ": recall})
|
|
|
|
#precision
|
|
mean_precision = sum(precision_list)/len(precision_list)
|
|
print("mean precision: ", mean_precision)
|
|
print({"mean precision: ": mean_precision}, file=fw)
|
|
|
|
#recall
|
|
mean_recall = sum(recall_list)/len(recall_list)
|
|
print("mean recall: ", mean_recall)
|
|
print({"mean recall: ": mean_recall}, file=fw)
|
|
|
|
#F1
|
|
mean_f1 = (2*mean_recall*mean_precision)/(mean_precision+mean_recall)
|
|
print("mean F1-score: ", mean_f1)
|
|
print({"mean F1-score: ": mean_f1}, file=fw)
|
|
|
|
if overall_top1_accuracy:
|
|
TP_sum = sum(TP_count.values())
|
|
total_instances = sum(TP_count.values()) + sum(FN_count.values())
|
|
top1_accuracy = TP_sum/total_instances
|
|
print({"Overall top-1 accuracy : " : round(top1_accuracy,6)}, file=fw)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--yaml', type=str, default='yaml/classification.yaml', help='yaml for parameters')
|
|
parser.add_argument('--output', type=str, default="output_classification.txt", help='output text file')
|
|
opt = parser.parse_args()
|
|
|
|
with open(opt.yaml, "r") as f:
|
|
params_dict = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
# GT
|
|
GT_json_path = params_dict["GT_json_path"]
|
|
# inference
|
|
inference_result = params_dict["inference_result"]
|
|
# mapping
|
|
mapping = detection_map[params_dict["mapping"]]
|
|
# subclasses
|
|
subclass = params_dict["subclass"]
|
|
# class format
|
|
class_format = params_dict["class_format"]
|
|
# top1 acc
|
|
top1_acc = params_dict["overall_top1_accuracy"]
|
|
# scan
|
|
scan_thres = params_dict["scan_threshold"]
|
|
# evaluation
|
|
evaluator = Classification(inference_result, GT_json_path, mapping)
|
|
evaluator.evaluateC(scan_threshold=scan_thres, stat_classes = subclass, overall_top1_accuracy=top1_acc, projection_key=class_format, saved_path=opt.output) |