138 lines
4.0 KiB
Python
138 lines
4.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
import os.path as osp
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import onnx
|
|
import onnxruntime as ort
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
ort_custom_op_path = ''
|
|
try:
|
|
from mmcv.ops import get_onnxruntime_op_path
|
|
ort_custom_op_path = get_onnxruntime_op_path()
|
|
except (ImportError, ModuleNotFoundError):
|
|
warnings.warn('If input model has custom op from mmcv, \
|
|
you may have to build mmcv with ONNXRuntime from source.')
|
|
|
|
|
|
class WrapFunction(nn.Module):
|
|
"""Wrap the function to be tested for torch.onnx.export tracking."""
|
|
|
|
def __init__(self, wrapped_function):
|
|
super(WrapFunction, self).__init__()
|
|
self.wrapped_function = wrapped_function
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.wrapped_function(*args, **kwargs)
|
|
|
|
|
|
def ort_validate(model, feats, onnx_io='tmp.onnx'):
|
|
"""Validate the output of the onnxruntime backend is the same as the output
|
|
generated by torch.
|
|
|
|
Args:
|
|
model (nn.Module | function): the function of model or model
|
|
to be verified.
|
|
feats (tuple(list(torch.Tensor)) | list(torch.Tensor) | torch.Tensor):
|
|
the input of model.
|
|
onnx_io (str): the name of onnx output file.
|
|
"""
|
|
# if model is not an instance of nn.Module, then it is a normal
|
|
# function and it should be wrapped.
|
|
if isinstance(model, nn.Module):
|
|
wrap_model = model
|
|
else:
|
|
wrap_model = WrapFunction(model)
|
|
wrap_model.cpu().eval()
|
|
with torch.no_grad():
|
|
torch.onnx.export(
|
|
wrap_model,
|
|
feats,
|
|
onnx_io,
|
|
export_params=True,
|
|
keep_initializers_as_inputs=True,
|
|
do_constant_folding=True,
|
|
verbose=False,
|
|
opset_version=11)
|
|
|
|
if isinstance(feats, tuple):
|
|
ort_feats = []
|
|
for feat in feats:
|
|
ort_feats += feat
|
|
else:
|
|
ort_feats = feats
|
|
# default model name: tmp.onnx
|
|
onnx_outputs = get_ort_model_output(ort_feats)
|
|
|
|
# remove temp file
|
|
if osp.exists(onnx_io):
|
|
os.remove(onnx_io)
|
|
|
|
if isinstance(feats, tuple):
|
|
torch_outputs = convert_result_list(wrap_model.forward(*feats))
|
|
else:
|
|
torch_outputs = convert_result_list(wrap_model.forward(feats))
|
|
torch_outputs = [
|
|
torch_output.detach().numpy() for torch_output in torch_outputs
|
|
]
|
|
|
|
# match torch_outputs and onnx_outputs
|
|
for i in range(len(onnx_outputs)):
|
|
np.testing.assert_allclose(
|
|
torch_outputs[i], onnx_outputs[i], rtol=1e-03, atol=1e-05)
|
|
|
|
|
|
def get_ort_model_output(feat, onnx_io='tmp.onnx'):
|
|
"""Run the model in onnxruntime env.
|
|
|
|
Args:
|
|
feat (list[Tensor]): A list of tensors from torch.rand,
|
|
each is a 4D-tensor.
|
|
|
|
Returns:
|
|
list[np.array]: onnxruntime infer result, each is a np.array
|
|
"""
|
|
|
|
onnx_model = onnx.load(onnx_io)
|
|
onnx.checker.check_model(onnx_model)
|
|
|
|
session_options = ort.SessionOptions()
|
|
# register custom op for onnxruntime
|
|
if osp.exists(ort_custom_op_path):
|
|
session_options.register_custom_ops_library(ort_custom_op_path)
|
|
sess = ort.InferenceSession(onnx_io, session_options)
|
|
if isinstance(feat, torch.Tensor):
|
|
onnx_outputs = sess.run(None,
|
|
{sess.get_inputs()[0].name: feat.numpy()})
|
|
else:
|
|
onnx_outputs = sess.run(None, {
|
|
sess.get_inputs()[i].name: feat[i].numpy()
|
|
for i in range(len(feat))
|
|
})
|
|
return onnx_outputs
|
|
|
|
|
|
def convert_result_list(outputs):
|
|
"""Convert the torch forward outputs containing tuple or list to a list
|
|
only containing torch.Tensor.
|
|
|
|
Args:
|
|
output (list(Tensor) | tuple(list(Tensor) | ...): the outputs
|
|
in torch env, maybe containing nested structures such as list
|
|
or tuple.
|
|
|
|
Returns:
|
|
list(Tensor): a list only containing torch.Tensor
|
|
"""
|
|
# recursive end condition
|
|
if isinstance(outputs, torch.Tensor):
|
|
return [outputs]
|
|
|
|
ret = []
|
|
for sub in outputs:
|
|
ret += convert_result_list(sub)
|
|
return ret
|