83 lines
3.8 KiB
Python

import sys
import os
import argparse
import yaml
import json
from tqdm import tqdm
current_path=os.getcwd()
sys.path.append(current_path+'/prepostprocess')
sys.path.append(current_path+'/prepostprocess/kneron_preprocessing')
sys.path.append(current_path+'/kneron_globalconstant')
sys.path.append(current_path+'/kneron_globalconstant/base')
sys.path.append(current_path+'/kneron_globalconstant/kneron_utils')
from prepostprocess import kneron_preprocessing as kp
from yolov5.yolov5_runner import Yolov5Runner
from function.function_runner import FunctionRunner
from rsn_affine.rsn_affine_runner import RsnAffineRunner
from lite_hrnet.lite_hrnet_runner import LiteHrnetRunner
def inference(dataset_path, yolov5_params, rsn_affine_params, lite_hrnet_params):
yolov5_runner = Yolov5Runner(model_path=yolov5_params['model_path'], model_id=yolov5_params['model_id'], yaml_path=yolov5_params['yaml_path'], grid20_path=yolov5_params['grid20_path'], grid40_path=yolov5_params['grid40_path'], grid80_path=yolov5_params['grid80_path'], num_classes=yolov5_params['num_classes'], input_shape=yolov5_params['input_shape'], conf_thres=yolov5_params['conf_thres'], iou_thres=yolov5_params['iou_thres'], top_k_num=yolov5_params['top_k_num'], detection_type=yolov5_params['detection_type'], vanish_point=yolov5_params['vanish_point'], label_mapping=yolov5_params['label_mapping'], class_name=yolov5_params['class_name'], anchors=yolov5_params['anchors'])
function_runner_2 = FunctionRunner(type=60, thresh_head_iou=0.8, thresh_fbox_iou=0.9, thresh_person_score=0.3)
rsn_affine_runner = RsnAffineRunner(image_size=rsn_affine_params['image_size'], scale_ext=rsn_affine_params['scale_ext'])
lite_hrnet_runner = LiteHrnetRunner(model_path=lite_hrnet_params['model_path'])
function_runner = FunctionRunner(type=48)
img_list = os.listdir(dataset_path)
results = []
for img_name in tqdm(img_list):
if img_name.split('.')[-1] not in ['png', 'jpg']:
continue
img_path = os.path.join(dataset_path, img_name)
out_0_0 = yolov5_runner.run(img_path)
out_1_0, out_1_1 = function_runner_2.run(img_path, out_0_0)
out_2_0, out_2_1 = rsn_affine_runner.run(img_path, out_1_0)
out_3_0 = lite_hrnet_runner.run(img_path, [out_2_0,out_2_1])
out_4_0 = function_runner.run(img_path, out_3_0)
results.append({'img_path': img_path, 'lmk_coco_body_17pts': out_4_0})
return results
def parse_args(args):
"""
Parse the arguments.
"""
parser = argparse.ArgumentParser(description='Simple inference script for inference an object detection network.')
parser.add_argument('--img-path', type=str, help='Path to the image dataset directory.')
parser.add_argument('--yolov5_params', type=str, help='Path to the yolov5 init params file.')
parser.add_argument('--rsn_affine_params', type=str, help='Path to the rsn_affine init params file.')
parser.add_argument('--lite_hrnet_params', type=str, help='Path to the lite-hrnet init params file.')
parser.add_argument('--save-path', type=str, help='Path to save output in json.')
print(vars(parser.parse_args(args)))
return parser.parse_args(args)
def main(args = None):
# parse arguments
if args is None:
args = sys.argv[1:]
args = parse_args(args)
with open(args.yolov5_params) as f:
yolov5_params = yaml.load(f, Loader=yaml.FullLoader)
with open(args.rsn_affine_params) as f:
rsn_affine_params = yaml.load(f, Loader=yaml.FullLoader)
with open(args.lite_hrnet_params) as f:
lite_hrnet_params = yaml.load(f, Loader=yaml.FullLoader)
preds = inference(args.img_path, yolov5_params, rsn_affine_params, lite_hrnet_params)
with open(args.save_path, 'w') as fp:
json.dump(preds, fp)
if __name__ == '__main__':
main()