137 lines
5.4 KiB
Python

import argparse
import os, json, yaml
from utils.public_field import *
from utils.cocoref import COCOv2, COCOevalv2
class ObjectDetection:
def __init__(self, inference_result, GT_json_path, mapping):
# two json paths
self.inference_result = inference_result
self.GT_json_path = GT_json_path
self.mapping = mapping
def MongoDB2COCO(self, db_path, mapping, GT=False):
"""
This code transforms db format to coco format
:param db_data: string, data path (export from MongoDB)
:param mapping:
"""
#open data_path
db_data = json.load(open(db_path))
# set GT to COCO anno file format
gt_coco = {}
gt_coco["annotations"] = []
gt_coco["images"] = []
gt_coco["categories"] = []
# images
count = 0
for d in db_data:
image2info = {}
image2info["file_name"] = os.path.basename(d[DbKeys.IMAGE_PATH["key"]])
# get image ID
image2info["id"] = int(os.path.splitext(os.path.basename(d[DbKeys.IMAGE_PATH["key"]]))[0])
gt_coco["images"].append(image2info)
# BBOX
if GT:
#image2info["width"], image2info["height"] = d[DbKeys.IMAGE_SIZE["key"]]
for item in d[DbKeys.BBOX["key"]]:
#print(d[DbKeys.IMAGE_PATH["key"]])
#print(item)
item_dict = {}
item_dict["bbox"] = item[:4]
item_dict["image_id"] = int(os.path.splitext(os.path.basename(d[DbKeys.IMAGE_PATH["key"]]))[0])
item_dict["category_id"] = item[4]
item_dict["id"] = count
item_dict["iscrowd"] = 0
item_dict["area"] = item[2]*item[3]
gt_coco["annotations"].append(item_dict)
count += 1
else:
for item in d[DbKeys.BBOX["key"]]:
item_dict = {}
item_dict["bbox"] = item[:4]
item_dict["image_id"] = int(os.path.splitext(os.path.basename(d[DbKeys.IMAGE_PATH["key"]]))[0])
item_dict["score"] = item[4]
item_dict["category_id"] = item[5]
gt_coco["annotations"].append(item_dict)
# categories
for idx, name in mapping.items():
id2cat = {}
id2cat["id"] = idx
id2cat["name"] = name
gt_coco["categories"].append(id2cat)
return gt_coco
def evaluateOD(self, areaRng_type_table={}, subclass = [], saved_path="evaluation.txt"):
# read
coco_gt = self.MongoDB2COCO(self.GT_json_path, mapping=self.mapping, GT=True)
coco_dt = self.MongoDB2COCO(self.inference_result, mapping=self.mapping, GT=False)
# COCO load
cocoGt = COCOv2(coco_gt) # initialize COCO ground truth api
cocoDt = cocoGt.loadRes(coco_dt) # initialize COCO pred api
cocoEval = COCOevalv2(cocoGt, cocoDt, 'bbox', areaRng_type_table)
# should be one on one
assert len(areaRng_type_table) == len(subclass), print("areaRng should be same length as subclass")
with open(saved_path, "w") as fw:
# class-wise
for class_i in range(len(self.mapping)):
if class_i not in subclass: continue
# if class_i not in [1,17,18]: continue
print('===== Class ID: {} Name: {} ====='.format(class_i, self.mapping[class_i]))
fw.write('===== Class ID: {} Name: {} ====='.format(class_i, self.mapping[class_i])+"\n")
cocoEval.params.catIds = [class_i] #person id : 1
cocoEval.evaluate()
cocoEval.accumulate()
if areaRng_type_table:
cocoEval.summarize(areaRng_type=areaRng_type_table[class_i], saved_path=fw)
else:
cocoEval.summarize(saved_path=fw)
# all class
if subclass:
cocoEval.params.catIds = subclass
else:
cocoEval.params.catIds = list(range(len(self.mapping)))
print('===== All Classes =====')
fw.write('===== All Classes ====='+"\n")
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize(saved_path=fw)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--yaml', type=str, default='yaml/object_detection.yaml', help='yaml for parameters')
parser.add_argument('--output', type=str, default="output_objectDetection.txt", help='output text file')
opt = parser.parse_args()
with open(opt.yaml, "r") as f:
params_dict = yaml.load(f, Loader=yaml.FullLoader)
params_dict["areaRng_type_table"] = {int(k):v for k, v in params_dict["areaRng_type_table"].items()}
print(params_dict)
# GT
GT_json_path = params_dict["GT_json_path"]
# inference
inference_result = params_dict["inference_result"]
# bbox areas
areaRng_type_table = params_dict["areaRng_type_table"]
# mapping
mapping = detection_map[params_dict["mapping"]]
# subclasses
subclass = params_dict["subclass"]
# evaluation
evaluator = ObjectDetection(inference_result, GT_json_path, mapping)
evaluator.evaluateOD(areaRng_type_table = areaRng_type_table, subclass = subclass, saved_path=opt.output)