fix: pytorch 1.10 generates ONNX with ir_version==7, which should be 6

This commit is contained in:
chingning.chen 2022-03-25 16:03:14 +08:00
parent 5b99260c9b
commit 0ec0b33ed7

View File

@ -124,6 +124,10 @@ def pytorch2onnx(model,
model.forward = origin_forward model.forward = origin_forward
# NOTE: optimizing onnx for kneron inference # NOTE: optimizing onnx for kneron inference
m = onnx.load(output_file) m = onnx.load(output_file)
# NOTE: PyTorch 1.10.x exports onnx ir_version == 7 for opset 11,
# but should be ir_version == 6
if opset_version == 11:
m.ir_version = 6
m = torch_exported_onnx_flow(m, disable_fuse_bn=False) m = torch_exported_onnx_flow(m, disable_fuse_bn=False)
onnx.save(m, output_file) onnx.save(m, output_file)
@ -143,7 +147,9 @@ def pytorch2onnx(model,
] ]
net_feed_input = list(set(input_all) - set(input_initializer)) net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1) assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(output_file) sess = rt.InferenceSession(
output_file, providers=['CPUExecutionProvider']
)
onnx_result = sess.run( onnx_result = sess.run(
None, {net_feed_input[0]: img.detach().numpy()})[0] None, {net_feed_input[0]: img.detach().numpy()})[0]
# show segmentation results # show segmentation results