fix: pytorch 1.10 generates ONNX with ir_version==7, which should be 6
This commit is contained in:
parent
5b99260c9b
commit
0ec0b33ed7
@ -124,6 +124,10 @@ def pytorch2onnx(model,
|
|||||||
model.forward = origin_forward
|
model.forward = origin_forward
|
||||||
# NOTE: optimizing onnx for kneron inference
|
# NOTE: optimizing onnx for kneron inference
|
||||||
m = onnx.load(output_file)
|
m = onnx.load(output_file)
|
||||||
|
# NOTE: PyTorch 1.10.x exports onnx ir_version == 7 for opset 11,
|
||||||
|
# but should be ir_version == 6
|
||||||
|
if opset_version == 11:
|
||||||
|
m.ir_version = 6
|
||||||
m = torch_exported_onnx_flow(m, disable_fuse_bn=False)
|
m = torch_exported_onnx_flow(m, disable_fuse_bn=False)
|
||||||
onnx.save(m, output_file)
|
onnx.save(m, output_file)
|
||||||
|
|
||||||
@ -143,7 +147,9 @@ def pytorch2onnx(model,
|
|||||||
]
|
]
|
||||||
net_feed_input = list(set(input_all) - set(input_initializer))
|
net_feed_input = list(set(input_all) - set(input_initializer))
|
||||||
assert (len(net_feed_input) == 1)
|
assert (len(net_feed_input) == 1)
|
||||||
sess = rt.InferenceSession(output_file)
|
sess = rt.InferenceSession(
|
||||||
|
output_file, providers=['CPUExecutionProvider']
|
||||||
|
)
|
||||||
onnx_result = sess.run(
|
onnx_result = sess.run(
|
||||||
None, {net_feed_input[0]: img.detach().numpy()})[0]
|
None, {net_feed_input[0]: img.detach().numpy()})[0]
|
||||||
# show segmentation results
|
# show segmentation results
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user