fix: pytorch 1.10 generates ONNX with ir_version==7, which should be 6
This commit is contained in:
parent
5b99260c9b
commit
0ec0b33ed7
@ -124,6 +124,10 @@ def pytorch2onnx(model,
|
||||
model.forward = origin_forward
|
||||
# NOTE: optimizing onnx for kneron inference
|
||||
m = onnx.load(output_file)
|
||||
# NOTE: PyTorch 1.10.x exports onnx ir_version == 7 for opset 11,
|
||||
# but should be ir_version == 6
|
||||
if opset_version == 11:
|
||||
m.ir_version = 6
|
||||
m = torch_exported_onnx_flow(m, disable_fuse_bn=False)
|
||||
onnx.save(m, output_file)
|
||||
|
||||
@ -143,7 +147,9 @@ def pytorch2onnx(model,
|
||||
]
|
||||
net_feed_input = list(set(input_all) - set(input_initializer))
|
||||
assert (len(net_feed_input) == 1)
|
||||
sess = rt.InferenceSession(output_file)
|
||||
sess = rt.InferenceSession(
|
||||
output_file, providers=['CPUExecutionProvider']
|
||||
)
|
||||
onnx_result = sess.run(
|
||||
None, {net_feed_input[0]: img.detach().numpy()})[0]
|
||||
# show segmentation results
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user