255 lines
8.3 KiB
Python
255 lines
8.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import os
|
|
import os.path as osp
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import onnx
|
|
import torch
|
|
from mmcv import Config
|
|
from mmcv.tensorrt import is_tensorrt_plugin_loaded, onnx2trt, save_trt_engine
|
|
|
|
from mmdet.core.export import preprocess_example_input
|
|
from mmdet.core.export.model_wrappers import (ONNXRuntimeDetector,
|
|
TensorRTDetector)
|
|
from mmdet.datasets import DATASETS
|
|
|
|
|
|
def get_GiB(x: int):
|
|
"""return x GiB."""
|
|
return x * (1 << 30)
|
|
|
|
|
|
def onnx2tensorrt(onnx_file,
|
|
trt_file,
|
|
input_config,
|
|
verify=False,
|
|
show=False,
|
|
workspace_size=1,
|
|
verbose=False):
|
|
import tensorrt as trt
|
|
onnx_model = onnx.load(onnx_file)
|
|
max_shape = input_config['max_shape']
|
|
min_shape = input_config['min_shape']
|
|
opt_shape = input_config['opt_shape']
|
|
fp16_mode = False
|
|
# create trt engine and wrapper
|
|
opt_shape_dict = {'input': [min_shape, opt_shape, max_shape]}
|
|
max_workspace_size = get_GiB(workspace_size)
|
|
trt_engine = onnx2trt(
|
|
onnx_model,
|
|
opt_shape_dict,
|
|
log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
|
|
fp16_mode=fp16_mode,
|
|
max_workspace_size=max_workspace_size)
|
|
save_dir, _ = osp.split(trt_file)
|
|
if save_dir:
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
save_trt_engine(trt_engine, trt_file)
|
|
print(f'Successfully created TensorRT engine: {trt_file}')
|
|
|
|
if verify:
|
|
# prepare input
|
|
one_img, one_meta = preprocess_example_input(input_config)
|
|
img_list, img_meta_list = [one_img], [[one_meta]]
|
|
img_list = [_.cuda().contiguous() for _ in img_list]
|
|
|
|
# wrap ONNX and TensorRT model
|
|
onnx_model = ONNXRuntimeDetector(onnx_file, CLASSES, device_id=0)
|
|
trt_model = TensorRTDetector(trt_file, CLASSES, device_id=0)
|
|
|
|
# inference with wrapped model
|
|
with torch.no_grad():
|
|
onnx_results = onnx_model(
|
|
img_list, img_metas=img_meta_list, return_loss=False)[0]
|
|
trt_results = trt_model(
|
|
img_list, img_metas=img_meta_list, return_loss=False)[0]
|
|
|
|
if show:
|
|
out_file_ort, out_file_trt = None, None
|
|
else:
|
|
out_file_ort, out_file_trt = 'show-ort.png', 'show-trt.png'
|
|
show_img = one_meta['show_img']
|
|
score_thr = 0.3
|
|
onnx_model.show_result(
|
|
show_img,
|
|
onnx_results,
|
|
score_thr=score_thr,
|
|
show=True,
|
|
win_name='ONNXRuntime',
|
|
out_file=out_file_ort)
|
|
trt_model.show_result(
|
|
show_img,
|
|
trt_results,
|
|
score_thr=score_thr,
|
|
show=True,
|
|
win_name='TensorRT',
|
|
out_file=out_file_trt)
|
|
with_mask = trt_model.with_masks
|
|
# compare a part of result
|
|
if with_mask:
|
|
compare_pairs = list(zip(onnx_results, trt_results))
|
|
else:
|
|
compare_pairs = [(onnx_results, trt_results)]
|
|
err_msg = 'The numerical values are different between Pytorch' + \
|
|
' and ONNX, but it does not necessarily mean the' + \
|
|
' exported ONNX model is problematic.'
|
|
# check the numerical value
|
|
for onnx_res, pytorch_res in compare_pairs:
|
|
for o_res, p_res in zip(onnx_res, pytorch_res):
|
|
np.testing.assert_allclose(
|
|
o_res, p_res, rtol=1e-03, atol=1e-05, err_msg=err_msg)
|
|
print('The numerical values are the same between Pytorch and ONNX')
|
|
|
|
|
|
def parse_normalize_cfg(test_pipeline):
|
|
transforms = None
|
|
for pipeline in test_pipeline:
|
|
if 'transforms' in pipeline:
|
|
transforms = pipeline['transforms']
|
|
break
|
|
assert transforms is not None, 'Failed to find `transforms`'
|
|
norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize']
|
|
assert len(norm_config_li) == 1, '`norm_config` should only have one'
|
|
norm_config = norm_config_li[0]
|
|
return norm_config
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description='Convert MMDetection models from ONNX to TensorRT')
|
|
parser.add_argument('config', help='test config file path')
|
|
parser.add_argument('model', help='Filename of input ONNX model')
|
|
parser.add_argument(
|
|
'--trt-file',
|
|
type=str,
|
|
default='tmp.trt',
|
|
help='Filename of output TensorRT engine')
|
|
parser.add_argument(
|
|
'--input-img', type=str, default='', help='Image for test')
|
|
parser.add_argument(
|
|
'--show', action='store_true', help='Whether to show output results')
|
|
parser.add_argument(
|
|
'--dataset',
|
|
type=str,
|
|
default='coco',
|
|
help='Dataset name. This argument is deprecated and will be \
|
|
removed in future releases.')
|
|
parser.add_argument(
|
|
'--verify',
|
|
action='store_true',
|
|
help='Verify the outputs of ONNXRuntime and TensorRT')
|
|
parser.add_argument(
|
|
'--verbose',
|
|
action='store_true',
|
|
help='Whether to verbose logging messages while creating \
|
|
TensorRT engine. Defaults to False.')
|
|
parser.add_argument(
|
|
'--to-rgb',
|
|
action='store_false',
|
|
help='Feed model with RGB or BGR image. Default is RGB. This \
|
|
argument is deprecated and will be removed in future releases.')
|
|
parser.add_argument(
|
|
'--shape',
|
|
type=int,
|
|
nargs='+',
|
|
default=[400, 600],
|
|
help='Input size of the model')
|
|
parser.add_argument(
|
|
'--mean',
|
|
type=float,
|
|
nargs='+',
|
|
default=[123.675, 116.28, 103.53],
|
|
help='Mean value used for preprocess input data. This argument \
|
|
is deprecated and will be removed in future releases.')
|
|
parser.add_argument(
|
|
'--std',
|
|
type=float,
|
|
nargs='+',
|
|
default=[58.395, 57.12, 57.375],
|
|
help='Variance value used for preprocess input data. \
|
|
This argument is deprecated and will be removed in future releases.')
|
|
parser.add_argument(
|
|
'--min-shape',
|
|
type=int,
|
|
nargs='+',
|
|
default=None,
|
|
help='Minimum input size of the model in TensorRT')
|
|
parser.add_argument(
|
|
'--max-shape',
|
|
type=int,
|
|
nargs='+',
|
|
default=None,
|
|
help='Maximum input size of the model in TensorRT')
|
|
parser.add_argument(
|
|
'--workspace-size',
|
|
type=int,
|
|
default=1,
|
|
help='Max workspace size in GiB')
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
|
|
args = parse_args()
|
|
warnings.warn(
|
|
'Arguments like `--to-rgb`, `--mean`, `--std`, `--dataset` would be \
|
|
parsed directly from config file and are deprecated and will be \
|
|
removed in future releases.')
|
|
if not args.input_img:
|
|
args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.jpg')
|
|
|
|
cfg = Config.fromfile(args.config)
|
|
|
|
def parse_shape(shape):
|
|
if len(shape) == 1:
|
|
shape = (1, 3, shape[0], shape[0])
|
|
elif len(args.shape) == 2:
|
|
shape = (1, 3) + tuple(shape)
|
|
else:
|
|
raise ValueError('invalid input shape')
|
|
return shape
|
|
|
|
if args.shape:
|
|
input_shape = parse_shape(args.shape)
|
|
else:
|
|
img_scale = cfg.test_pipeline[1]['img_scale']
|
|
input_shape = (1, 3, img_scale[1], img_scale[0])
|
|
|
|
if not args.max_shape:
|
|
max_shape = input_shape
|
|
else:
|
|
max_shape = parse_shape(args.max_shape)
|
|
|
|
if not args.min_shape:
|
|
min_shape = input_shape
|
|
else:
|
|
min_shape = parse_shape(args.min_shape)
|
|
|
|
dataset = DATASETS.get(cfg.data.test['type'])
|
|
assert (dataset is not None)
|
|
CLASSES = dataset.CLASSES
|
|
normalize_cfg = parse_normalize_cfg(cfg.test_pipeline)
|
|
|
|
input_config = {
|
|
'min_shape': min_shape,
|
|
'opt_shape': input_shape,
|
|
'max_shape': max_shape,
|
|
'input_shape': input_shape,
|
|
'input_path': args.input_img,
|
|
'normalize_cfg': normalize_cfg
|
|
}
|
|
# Create TensorRT engine
|
|
onnx2tensorrt(
|
|
args.model,
|
|
args.trt_file,
|
|
input_config,
|
|
verify=args.verify,
|
|
show=args.show,
|
|
workspace_size=args.workspace_size,
|
|
verbose=args.verbose)
|