diff --git a/tools/pytorch2onnx_kneron.py b/tools/pytorch2onnx_kneron.py index 8602c32..297aea9 100644 --- a/tools/pytorch2onnx_kneron.py +++ b/tools/pytorch2onnx_kneron.py @@ -124,6 +124,10 @@ def pytorch2onnx(model, model.forward = origin_forward # NOTE: optimizing onnx for kneron inference m = onnx.load(output_file) + # NOTE: PyTorch 1.10.x exports onnx ir_version == 7 for opset 11, + # but should be ir_version == 6 + if opset_version == 11: + m.ir_version = 6 m = torch_exported_onnx_flow(m, disable_fuse_bn=False) onnx.save(m, output_file) @@ -143,7 +147,9 @@ def pytorch2onnx(model, ] net_feed_input = list(set(input_all) - set(input_initializer)) assert (len(net_feed_input) == 1) - sess = rt.InferenceSession(output_file) + sess = rt.InferenceSession( + output_file, providers=['CPUExecutionProvider'] + ) onnx_result = sess.run( None, {net_feed_input[0]: img.detach().numpy()})[0] # show segmentation results