117 lines
3.7 KiB
Python
117 lines
3.7 KiB
Python
import argparse
|
|
import re, json, yaml
|
|
import numpy as np
|
|
from utils.public_field import *
|
|
import motmetrics as mm
|
|
|
|
class OCR:
|
|
|
|
def __init__(self, inference_result, GT_json_path):
|
|
# two json paths
|
|
self.inference_result = inference_result
|
|
self.GT_json_path = GT_json_path
|
|
|
|
def MongoDB2Anno(self, db_path, projection_key, GT=False):
|
|
"""
|
|
This code transforms db format to coco format
|
|
|
|
:param db_data: string, data path (export from MongoDB)
|
|
:param mapping:
|
|
"""
|
|
#open data_path
|
|
db_data = json.load(open(db_path))
|
|
|
|
#projection_key = DbKeys.LMK_COCO_BODY_17TS["key"]
|
|
# images
|
|
db_dict = {}
|
|
for d in db_data:
|
|
if d[projection_key]:
|
|
k = d[DbKeys.IMAGE_PATH["key"]]
|
|
db_dict[k] = d[projection_key]
|
|
if GT:
|
|
print('GT ',end='')
|
|
|
|
else:
|
|
print('Preds ',end='')
|
|
print(f" OCR size: {len(db_data)}")
|
|
return db_dict
|
|
|
|
def evaluateOCR(self, projection_key = DbKeys.LP_OCR["key"], saved_path="evaluation.txt"):
|
|
|
|
"""**OCR Evaluation**
|
|
|
|
Description:
|
|
This function is designed specifically for OCR model.
|
|
|
|
"""
|
|
# read
|
|
gt_dict = self.MongoDB2Anno(self.GT_json_path, projection_key, GT=True)
|
|
ct_dict = self.MongoDB2Anno(self.inference_result, projection_key, GT=False)
|
|
|
|
gt_data_list_dict = gt_dict
|
|
ocr_test_result_dict = ct_dict
|
|
|
|
gt_total_images = len(ocr_test_result_dict)
|
|
empty_images = []
|
|
correct_count = 0
|
|
incorrect_list = []
|
|
|
|
for key in ocr_test_result_dict:
|
|
if key in gt_data_list_dict and ocr_test_result_dict[key] != []:
|
|
value = ocr_test_result_dict[key][0]
|
|
|
|
if not value or len(value) == 0:
|
|
empty_images.append(key)
|
|
continue
|
|
|
|
value = "".join(value.split('-')).upper()
|
|
gt_value = gt_data_list_dict[key][0].replace("-", "").upper()
|
|
|
|
|
|
if gt_value == value:
|
|
correct_count += 1
|
|
else:
|
|
incorrect_list.append(key)
|
|
else:
|
|
empty_images.append(key)
|
|
|
|
with open(saved_path, "w") as fw:
|
|
|
|
print("total images count: ", gt_total_images, file=fw)
|
|
|
|
print("invalid image count: ", len(empty_images), file=fw)
|
|
|
|
print("correctness count: ", correct_count, file=fw)
|
|
|
|
print("correctness rate: ", round(correct_count/(gt_total_images - len(empty_images)),4), file=fw)
|
|
|
|
print("incorrectness count: ", len(incorrect_list), file=fw)
|
|
|
|
print("incorrectness rate: ", round(len(incorrect_list)/(gt_total_images - len(empty_images)),4), file=fw)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--yaml', type=str, default='yaml/ocr.yaml', help='yaml for parameters')
|
|
parser.add_argument('--output', type=str, default="output_ocr.txt", help='output text file')
|
|
opt = parser.parse_args()
|
|
|
|
with open(opt.yaml, "r") as f:
|
|
params_dict = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
print(params_dict)
|
|
|
|
# GT
|
|
GT_json_path = params_dict["GT_json_path"]
|
|
# inference
|
|
inference_result = params_dict["inference_result"]
|
|
# evaluation
|
|
evaluator = OCR(inference_result, GT_json_path)
|
|
# please find projection key in public_field.py
|
|
evaluator.evaluateOCR(saved_path = opt.output)
|
|
|
|
|
|
|
|
|