77 lines
2.0 KiB
Python

import argparse
import os
import os.path as osp
import sys
import mmcv
import torch
from mmcv import Config
from mmcv.runner import load_checkpoint
from models import build_posenet
import onnxruntime
import onnx
def parse_args():
parser = argparse.ArgumentParser(description='mmpose export onnx model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='path to checkpoint file')
args = parser.parse_args()
return args
def main():
# export to onnx
args = parse_args()
cfg = Config.fromfile(args.config)
cfg.load_from = args.checkpoint
args.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
# build the model and load checkpoint
model = build_posenet(cfg.model)
load_checkpoint(model, args.checkpoint, map_location='cpu')
[input_w, input_h] = cfg.data_cfg['image_size']
model_inputs = torch.rand(1,3,input_h,input_w)
torch_inputs = model_inputs
input_names = []
input_shapes = {}
opset_version = 11
model_name = osp.join(args.work_dir, osp.splitext(osp.basename(args.config))[0]+'.onnx')
with torch.no_grad():
for index, torch_input in enumerate(torch_inputs):
name = "i" + str(index)
input_names.append(name)
input_shapes[name] = torch_input.shape
torch.onnx.export(model, torch_inputs, model_name, input_names=input_names, keep_initializers_as_inputs=True,opset_version=opset_version)
onnx_model = onnx.load_model(model_name)
print(onnx_model.ir_version)
onnx.checker.check_model(onnx_model)
sess = onnxruntime.InferenceSession(model_name)
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
onnxrt_out = sess.run(None, {input_name: torch_inputs.cpu().numpy()})[0]
print(onnxrt_out.shape)
sys.exit(0)
if __name__ == '__main__':
main()