From 6b27600a1a4cfa408caa7b9531045beb15c860cb Mon Sep 17 00:00:00 2001 From: "chingning.chen" Date: Thu, 17 Mar 2022 15:42:01 +0800 Subject: [PATCH] feat: pytorch2onnx_kneron.py --- tools/pytorch2onnx_kneron.py | 274 +++++++++++++++++++++++++++++++++++ 1 file changed, 274 insertions(+) create mode 100644 tools/pytorch2onnx_kneron.py diff --git a/tools/pytorch2onnx_kneron.py b/tools/pytorch2onnx_kneron.py new file mode 100644 index 0000000..be373e3 --- /dev/null +++ b/tools/pytorch2onnx_kneron.py @@ -0,0 +1,274 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Original: tools/pytorch2onnx.py, modified by Kneron +import argparse + +import mmcv +import numpy as np +import onnxruntime as rt +import torch +import torch._C +import torch.serialization +from mmcv import DictAction +from mmcv.onnx import register_extra_symbolics +from mmcv.runner import load_checkpoint +from torch import nn + +from mmseg.apis import show_result_pyplot +from mmseg.apis.inference import LoadImage +from mmseg.datasets.pipelines import Compose +from mmseg.models import build_segmentor + +torch.manual_seed(3) + + +def _convert_batchnorm(module): + module_output = module + if isinstance(module, torch.nn.SyncBatchNorm): + module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + module_output.weight.data = module.weight.data.clone().detach() + module_output.bias.data = module.bias.data.clone().detach() + # keep requires_grad unchanged + module_output.weight.requires_grad = module.weight.requires_grad + module_output.bias.requires_grad = module.bias.requires_grad + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + for name, child in module.named_children(): + module_output.add_module(name, _convert_batchnorm(child)) + del module + return module_output + + +def _demo_mm_inputs(input_shape): + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + num_classes (int): + number of semantic classes + """ + (N, C, H, W) = input_shape + rng = np.random.RandomState(0) + img = torch.FloatTensor(rng.rand(*input_shape)) + return img + + +def _prepare_input_img(img_path, + test_pipeline, + shape=None): + # build the data pipeline + if shape is not None: + test_pipeline[1]['img_scale'] = (shape[1], shape[0]) + test_pipeline[1]['transforms'][0]['keep_ratio'] = False + test_pipeline = [LoadImage()] + test_pipeline[1:] + test_pipeline = Compose(test_pipeline) + # prepare data + data = dict(img=img_path) + data = test_pipeline(data) + img = torch.FloatTensor(data['img']).unsqueeze_(0) + return img + + +def pytorch2onnx(model, + img, + opset_version=11, + show=False, + output_file='tmp.onnx', + verify=False): + """Export Pytorch model to ONNX model and verify the outputs are same + between Pytorch and ONNX. + + Args: + model (nn.Module): Pytorch model we want to export. + img (dict): Input tensor (1xCxHxW) + opset_version (int): The onnx op version. Default: 11. + show (bool): Whether print the computation graph. Default: False. + output_file (string): The path to where we store the output ONNX model. + Default: `tmp.onnx`. + verify (bool): Whether compare the outputs between Pytorch and ONNX. + Default: False. + """ + model.cpu().eval() + + if isinstance(model.decode_head, nn.ModuleList): + num_classes = model.decode_head[-1].num_classes + else: + num_classes = model.decode_head.num_classes + + # replace original forward function + model.forward = model.forward_dummy + origin_forward = model.forward + + register_extra_symbolics(opset_version) + with torch.no_grad(): + torch.onnx.export( + model, img, + output_file, + input_names=['input'], + output_names=['output'], + export_params=True, + keep_initializers_as_inputs=False, + verbose=show, + opset_version=opset_version, + dynamic_axes=None) + print(f'Successfully exported ONNX model: {output_file}') + model.forward = origin_forward + + if verify: + # check by onnx + import onnx + onnx_model = onnx.load(output_file) + onnx.checker.check_model(onnx_model) + + # check the numerical value + # get pytorch output + with torch.no_grad(): + pytorch_result = model(img).numpy() + + # get onnx output + input_all = [node.name for node in onnx_model.graph.input] + input_initializer = [ + node.name for node in onnx_model.graph.initializer + ] + net_feed_input = list(set(input_all) - set(input_initializer)) + assert (len(net_feed_input) == 1) + sess = rt.InferenceSession(output_file) + onnx_result = sess.run( + None, {net_feed_input[0]: img.detach().numpy()})[0] + # show segmentation results + if show: + import cv2 + img = img[0][:3, ...].permute(1, 2, 0) * 255 + img = img.detach().numpy().astype(np.uint8) + ori_shape = img.shape[:2] + + # resize onnx_result to ori_shape + onnx_result_ = onnx_result[0].argmax(0) + onnx_result_ = cv2.resize(onnx_result_.astype(np.uint8), + (ori_shape[1], ori_shape[0])) + show_result_pyplot( + model, + img, (onnx_result_, ), + palette=model.PALETTE, + block=False, + title='ONNXRuntime', + opacity=0.5) + + # resize pytorch_result to ori_shape + pytorch_result_ = pytorch_result.squeeze().argmax(0) + pytorch_result_ = cv2.resize(pytorch_result_.astype(np.uint8), + (ori_shape[1], ori_shape[0])) + show_result_pyplot( + model, + img, (pytorch_result_, ), + title='PyTorch', + palette=model.PALETTE, + opacity=0.5) + # compare results + np.testing.assert_allclose( + pytorch_result.astype(np.float32) / num_classes, + onnx_result.astype(np.float32) / num_classes, + rtol=1e-5, + atol=1e-5, + err_msg='The outputs are different between Pytorch and ONNX') + print('The outputs are same between Pytorch and ONNX') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX') + parser.add_argument('config', help='test config file path') + parser.add_argument('--checkpoint', help='checkpoint file', default=None) + parser.add_argument( + '--input-img', type=str, help='Images for input', default=None) + parser.add_argument( + '--show', + action='store_true', + help='show onnx graph and segmentation results') + parser.add_argument( + '--verify', action='store_true', help='verify the onnx model') + parser.add_argument('--output-file', type=str, default='tmp.onnx') + parser.add_argument('--opset-version', type=int, default=11) + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=None, + help='input image height and width.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='Override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + assert args.opset_version == 11, "kneron_toolchain currently only supports opset 11" + + cfg = mmcv.Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + cfg.model.pretrained = None + + test_mode = cfg.model.test_cfg.mode + + if args.shape is None: + if test_mode == 'slide': + crop_size = cfg.model.test_cfg['crop_size'] + input_shape = (1, 3, crop_size[1], crop_size[0]) + else: + img_scale = cfg.test_pipeline[1]['img_scale'] + input_shape = (1, 3, img_scale[1], img_scale[0]) + elif len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = ( + 1, + 3, + ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + # build the model and load checkpoint + cfg.model.train_cfg = None + segmentor = build_segmentor( + cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) + # convert SyncBN to BN + segmentor = _convert_batchnorm(segmentor) + + if args.checkpoint: + checkpoint = load_checkpoint( + segmentor, args.checkpoint, map_location='cpu') + segmentor.CLASSES = checkpoint['meta']['CLASSES'] + segmentor.PALETTE = checkpoint['meta']['PALETTE'] + + # read input or create dummpy input + if args.input_img is not None: + preprocess_shape = (input_shape[2], input_shape[3]) + img = _prepare_input_img( + args.input_img, + cfg.data.test.pipeline, + shape=preprocess_shape) + else: + img = _demo_mm_inputs(input_shape) + + # convert model to onnx file + pytorch2onnx( + segmentor, + img, + opset_version=args.opset_version, + show=args.show, + output_file=args.output_file, + verify=args.verify, + )