import argparse import re, json, yaml import numpy as np from utils.public_field import * import motmetrics as mm class OCR: def __init__(self, inference_result, GT_json_path): # two json paths self.inference_result = inference_result self.GT_json_path = GT_json_path def MongoDB2Anno(self, db_path, projection_key, GT=False): """ This code transforms db format to coco format :param db_data: string, data path (export from MongoDB) :param mapping: """ #open data_path db_data = json.load(open(db_path)) #projection_key = DbKeys.LMK_COCO_BODY_17TS["key"] # images db_dict = {} for d in db_data: if d[projection_key]: k = d[DbKeys.IMAGE_PATH["key"]] db_dict[k] = d[projection_key] if GT: print('GT ',end='') else: print('Preds ',end='') print(f" OCR size: {len(db_data)}") return db_dict def evaluateOCR(self, projection_key = DbKeys.LP_OCR["key"], saved_path="evaluation.txt"): """**OCR Evaluation** Description: This function is designed specifically for OCR model. """ # read gt_dict = self.MongoDB2Anno(self.GT_json_path, projection_key, GT=True) ct_dict = self.MongoDB2Anno(self.inference_result, projection_key, GT=False) gt_data_list_dict = gt_dict ocr_test_result_dict = ct_dict gt_total_images = len(ocr_test_result_dict) empty_images = [] correct_count = 0 incorrect_list = [] for key in ocr_test_result_dict: if key in gt_data_list_dict and ocr_test_result_dict[key] != []: value = ocr_test_result_dict[key][0] if not value or len(value) == 0: empty_images.append(key) continue value = "".join(value.split('-')).upper() gt_value = gt_data_list_dict[key][0].replace("-", "").upper() if gt_value == value: correct_count += 1 else: incorrect_list.append(key) else: empty_images.append(key) with open(saved_path, "w") as fw: print("total images count: ", gt_total_images, file=fw) print("invalid image count: ", len(empty_images), file=fw) print("correctness count: ", correct_count, file=fw) print("correctness rate: ", round(correct_count/(gt_total_images - len(empty_images)),4), file=fw) print("incorrectness count: ", len(incorrect_list), file=fw) print("incorrectness rate: ", round(len(incorrect_list)/(gt_total_images - len(empty_images)),4), file=fw) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--yaml', type=str, default='yaml/ocr.yaml', help='yaml for parameters') parser.add_argument('--output', type=str, default="output_ocr.txt", help='output text file') opt = parser.parse_args() with open(opt.yaml, "r") as f: params_dict = yaml.load(f, Loader=yaml.FullLoader) print(params_dict) # GT GT_json_path = params_dict["GT_json_path"] # inference inference_result = params_dict["inference_result"] # evaluation evaluator = OCR(inference_result, GT_json_path) # please find projection key in public_field.py evaluator.evaluateOCR(saved_path = opt.output)