117 lines
3.7 KiB
Python

import argparse
import re, json, yaml
import numpy as np
from utils.public_field import *
import motmetrics as mm
class OCR:
def __init__(self, inference_result, GT_json_path):
# two json paths
self.inference_result = inference_result
self.GT_json_path = GT_json_path
def MongoDB2Anno(self, db_path, projection_key, GT=False):
"""
This code transforms db format to coco format
:param db_data: string, data path (export from MongoDB)
:param mapping:
"""
#open data_path
db_data = json.load(open(db_path))
#projection_key = DbKeys.LMK_COCO_BODY_17TS["key"]
# images
db_dict = {}
for d in db_data:
if d[projection_key]:
k = d[DbKeys.IMAGE_PATH["key"]]
db_dict[k] = d[projection_key]
if GT:
print('GT ',end='')
else:
print('Preds ',end='')
print(f" OCR size: {len(db_data)}")
return db_dict
def evaluateOCR(self, projection_key = DbKeys.LP_OCR["key"], saved_path="evaluation.txt"):
"""**OCR Evaluation**
Description:
This function is designed specifically for OCR model.
"""
# read
gt_dict = self.MongoDB2Anno(self.GT_json_path, projection_key, GT=True)
ct_dict = self.MongoDB2Anno(self.inference_result, projection_key, GT=False)
gt_data_list_dict = gt_dict
ocr_test_result_dict = ct_dict
gt_total_images = len(ocr_test_result_dict)
empty_images = []
correct_count = 0
incorrect_list = []
for key in ocr_test_result_dict:
if key in gt_data_list_dict and ocr_test_result_dict[key] != []:
value = ocr_test_result_dict[key][0]
if not value or len(value) == 0:
empty_images.append(key)
continue
value = "".join(value.split('-')).upper()
gt_value = gt_data_list_dict[key][0].replace("-", "").upper()
if gt_value == value:
correct_count += 1
else:
incorrect_list.append(key)
else:
empty_images.append(key)
with open(saved_path, "w") as fw:
print("total images count: ", gt_total_images, file=fw)
print("invalid image count: ", len(empty_images), file=fw)
print("correctness count: ", correct_count, file=fw)
print("correctness rate: ", round(correct_count/(gt_total_images - len(empty_images)),4), file=fw)
print("incorrectness count: ", len(incorrect_list), file=fw)
print("incorrectness rate: ", round(len(incorrect_list)/(gt_total_images - len(empty_images)),4), file=fw)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--yaml', type=str, default='yaml/ocr.yaml', help='yaml for parameters')
parser.add_argument('--output', type=str, default="output_ocr.txt", help='output text file')
opt = parser.parse_args()
with open(opt.yaml, "r") as f:
params_dict = yaml.load(f, Loader=yaml.FullLoader)
print(params_dict)
# GT
GT_json_path = params_dict["GT_json_path"]
# inference
inference_result = params_dict["inference_result"]
# evaluation
evaluator = OCR(inference_result, GT_json_path)
# please find projection key in public_field.py
evaluator.evaluateOCR(saved_path = opt.output)