173 lines
6.0 KiB
Python
173 lines
6.0 KiB
Python
import argparse
|
|
import re, json, yaml
|
|
import numpy as np
|
|
from utils.public_field import *
|
|
import motmetrics as mm
|
|
|
|
class ReID:
|
|
|
|
def __init__(self, inference_result, GT_json_path):
|
|
# two json paths
|
|
self.inference_result = inference_result
|
|
self.GT_json_path = GT_json_path
|
|
|
|
def MongoDB2Anno(self, db_path, projection_key, GT=False):
|
|
"""
|
|
This code transforms db format to coco format
|
|
|
|
:param db_data: string, data path (export from MongoDB)
|
|
:param mapping:
|
|
"""
|
|
#open data_path
|
|
db_data = json.load(open(db_path))
|
|
|
|
#projection_key = DbKeys.LMK_COCO_BODY_17TS["key"]
|
|
# images
|
|
db_dict = {}
|
|
for d in db_data:
|
|
if d[projection_key]:
|
|
k = d[DbKeys.IMAGE_PATH["key"]]
|
|
db_dict[k] = d[projection_key]
|
|
if GT:
|
|
print('GT ',end='')
|
|
|
|
else:
|
|
print('Preds ',end='')
|
|
print(f" ReID size: {len(db_data)}")
|
|
return db_dict
|
|
|
|
def render_summary(self, summary, formatters=None, namemap=None, buf=None):
|
|
"""Render metrics summary to console friendly tabular output.
|
|
Params
|
|
------
|
|
summary : pd.dataframe
|
|
Dataframe containing summaries in rows.
|
|
Kwargs
|
|
------
|
|
buf : StringIO-like, optional
|
|
Buffer to write to
|
|
formatters : dict, optional
|
|
Dicionary defining custom formatters for individual metrics.
|
|
I.e `{'mota': '{:.2%}'.format}`. You can get preset formatters
|
|
from MetricsHost.formatters
|
|
namemap : dict, optional
|
|
Dictionary defining new metric names for display. I.e
|
|
`{'num_false_positives': 'FP'}`.
|
|
Returns
|
|
-------
|
|
string
|
|
Formatted dict
|
|
"""
|
|
|
|
if namemap is not None:
|
|
summary = summary.rename(columns=namemap)
|
|
if formatters is not None:
|
|
formatters = {namemap.get(c, c): f for c, f in formatters.items()}
|
|
|
|
output = summary.to_string(
|
|
buf=buf,
|
|
formatters=formatters,
|
|
)
|
|
print(output)
|
|
|
|
summary.style.format(formatters)
|
|
summary.to_dict()
|
|
|
|
return summary.to_dict("list")
|
|
|
|
def evaluateReID(self, skip_rate: int = 0, max_iou: float = 0.5, projection_key = DbKeys.TRACK_ID["key"], saved_path="evaluation.txt"):
|
|
|
|
"""**Tracking REID Evaluation**
|
|
|
|
Description:
|
|
This function is designed specifically for tracking reid model.
|
|
|
|
Args:
|
|
:parameter skip_rate (int, optional): The number of frames you would like to skip. Defaults to 0.
|
|
|
|
"""
|
|
# read
|
|
gt_dict = self.MongoDB2Anno(self.GT_json_path, projection_key, GT=True)
|
|
ct_dict = self.MongoDB2Anno(self.inference_result, projection_key, GT=False)
|
|
|
|
# maintain a frames order
|
|
frame_order_list = [k for k, v in gt_dict.items() if k in ct_dict]
|
|
#sort by frame number
|
|
def atoi(text):
|
|
return int(text) if text.isdigit() else text
|
|
def natural_keys(text):
|
|
return [ atoi(c) for c in re.split(r'(\d+)', text) ]
|
|
frame_order_list.sort(key=natural_keys)
|
|
# skipping
|
|
frame_order_list = frame_order_list[::skip_rate+1]
|
|
# failure case_analysis
|
|
fn_list = []
|
|
# iterate and save frames info
|
|
acc = mm.MOTAccumulator(auto_id=True)
|
|
for frame_idx, frame in enumerate(frame_order_list):
|
|
gt_bbox_pid = gt_dict[frame]
|
|
ct_bbox_pid = ct_dict[frame]
|
|
# calculate IoU distance
|
|
gt_bboxes = np.array([bbox[:4] for bbox in gt_bbox_pid])
|
|
ct_bboxes = np.array([bbox[:4] for bbox in ct_bbox_pid])
|
|
iou_distance = mm.distances.iou_matrix(gt_bboxes, ct_bboxes, max_iou=max_iou)
|
|
# get pid
|
|
gt_pid = np.array([bbox[-1] for bbox in gt_bbox_pid])
|
|
ct_pid = np.array([bbox[-1] for bbox in ct_bbox_pid])
|
|
# update the mm accumulator
|
|
acc.update(gt_pid, ct_pid, iou_distance)
|
|
# save failure cases
|
|
if any(acc.mot_events.loc[frame_idx]["Type"].isin(["MISS", "FP", "SWITCH"])):
|
|
fn_list.append(frame)
|
|
|
|
#print(acc.mot_events)
|
|
# buil up metrics
|
|
mh = mm.metrics.create()
|
|
summary = mh.compute(
|
|
acc,
|
|
metrics=mm.metrics.motchallenge_metrics,
|
|
name='Tracker')
|
|
strsummary = self.render_summary(
|
|
summary,
|
|
formatters=mh.formatters,
|
|
namemap=mm.io.motchallenge_metric_names
|
|
)
|
|
|
|
with open(saved_path, "w") as fw:
|
|
|
|
for k, v in strsummary.items():
|
|
# list is only one length
|
|
v = v[0]
|
|
if isinstance(v, float):
|
|
v = round(v, 3)
|
|
print("{}: {}".format(k, v), file=fw)
|
|
|
|
# reference
|
|
url = "https://github.com/cheind/py-motmetrics"
|
|
print({"This Evaluation is based on: ": url}, file=fw)
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--yaml', type=str, default='yaml/reid.yaml', help='yaml for parameters')
|
|
parser.add_argument('--output', type=str, default="output_reid.txt", help='output text file')
|
|
opt = parser.parse_args()
|
|
|
|
with open(opt.yaml, "r") as f:
|
|
params_dict = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
print(params_dict)
|
|
|
|
# GT
|
|
GT_json_path = params_dict["GT_json_path"]
|
|
# inference
|
|
inference_result = params_dict["inference_result"]
|
|
# evaluation
|
|
evaluator = ReID(inference_result, GT_json_path)
|
|
# please find projection key in public_field.py
|
|
evaluator.evaluateReID(skip_rate = 0, max_iou = 0.5, saved_path = opt.output)
|
|
|
|
|
|
|
|
|