import argparse import os import os.path as osp import sys import mmcv import torch from mmcv import Config from mmcv.runner import load_checkpoint from models import build_posenet import onnxruntime import onnx def parse_args(): parser = argparse.ArgumentParser(description='mmpose export onnx model') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='path to checkpoint file') args = parser.parse_args() return args def main(): # export to onnx args = parse_args() cfg = Config.fromfile(args.config) cfg.load_from = args.checkpoint args.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) # build the model and load checkpoint model = build_posenet(cfg.model) load_checkpoint(model, args.checkpoint, map_location='cpu') [input_w, input_h] = cfg.data_cfg['image_size'] model_inputs = torch.rand(1,3,input_h,input_w) torch_inputs = model_inputs input_names = [] input_shapes = {} opset_version = 11 model_name = osp.join(args.work_dir, osp.splitext(osp.basename(args.config))[0]+'.onnx') with torch.no_grad(): for index, torch_input in enumerate(torch_inputs): name = "i" + str(index) input_names.append(name) input_shapes[name] = torch_input.shape torch.onnx.export(model, torch_inputs, model_name, input_names=input_names, keep_initializers_as_inputs=True,opset_version=opset_version) onnx_model = onnx.load_model(model_name) print(onnx_model.ir_version) onnx.checker.check_model(onnx_model) sess = onnxruntime.InferenceSession(model_name) input_name = sess.get_inputs()[0].name label_name = sess.get_outputs()[0].name onnxrt_out = sess.run(None, {input_name: torch_inputs.cpu().numpy()})[0] print(onnxrt_out.shape) sys.exit(0) if __name__ == '__main__': main()