77 lines
2.0 KiB
Python
77 lines
2.0 KiB
Python
import argparse
|
|
import os
|
|
import os.path as osp
|
|
import sys
|
|
|
|
import mmcv
|
|
import torch
|
|
from mmcv import Config
|
|
from mmcv.runner import load_checkpoint
|
|
|
|
from models import build_posenet
|
|
import onnxruntime
|
|
import onnx
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='mmpose export onnx model')
|
|
parser.add_argument('config', help='test config file path')
|
|
parser.add_argument('checkpoint', help='path to checkpoint file')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
def main():
|
|
# export to onnx
|
|
args = parse_args()
|
|
|
|
cfg = Config.fromfile(args.config)
|
|
|
|
cfg.load_from = args.checkpoint
|
|
|
|
args.work_dir = osp.join('./work_dirs',
|
|
osp.splitext(osp.basename(args.config))[0])
|
|
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
|
|
|
# build the model and load checkpoint
|
|
model = build_posenet(cfg.model)
|
|
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
|
[input_w, input_h] = cfg.data_cfg['image_size']
|
|
model_inputs = torch.rand(1,3,input_h,input_w)
|
|
torch_inputs = model_inputs
|
|
|
|
input_names = []
|
|
input_shapes = {}
|
|
|
|
opset_version = 11
|
|
|
|
model_name = osp.join(args.work_dir, osp.splitext(osp.basename(args.config))[0]+'.onnx')
|
|
|
|
with torch.no_grad():
|
|
|
|
for index, torch_input in enumerate(torch_inputs):
|
|
name = "i" + str(index)
|
|
input_names.append(name)
|
|
input_shapes[name] = torch_input.shape
|
|
|
|
torch.onnx.export(model, torch_inputs, model_name, input_names=input_names, keep_initializers_as_inputs=True,opset_version=opset_version)
|
|
|
|
onnx_model = onnx.load_model(model_name)
|
|
|
|
print(onnx_model.ir_version)
|
|
|
|
onnx.checker.check_model(onnx_model)
|
|
sess = onnxruntime.InferenceSession(model_name)
|
|
|
|
input_name = sess.get_inputs()[0].name
|
|
|
|
label_name = sess.get_outputs()[0].name
|
|
|
|
onnxrt_out = sess.run(None, {input_name: torch_inputs.cpu().numpy()})[0]
|
|
|
|
print(onnxrt_out.shape)
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|