# All modification made by Kneron Corp.: Copyright (c) 2022 Kneron Corp. # Copyright (c) OpenMMLab. All rights reserved. import argparse import warnings import os import onnx import mmcv import numpy as np import onnxruntime as rt import torch from mmcv import DictAction from mmcv.onnx import register_extra_symbolics from mmcv.runner import load_checkpoint from torch import nn from mmseg.apis import show_result_pyplot from mmseg.apis.inference import LoadImage from mmseg.datasets.pipelines import Compose from mmseg.models import build_segmentor from optimizer_scripts.tools import other from optimizer_scripts.pytorch_exported_onnx_preprocess import torch_exported_onnx_flow torch.manual_seed(3) def _parse_normalize_cfg(test_pipeline): transforms = None for pipeline in test_pipeline: if 'transforms' in pipeline: transforms = pipeline['transforms'] break assert transforms is not None, 'Failed to find `transforms`' norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize'] assert len(norm_config_li) == 1, '`norm_config` should only have one' return norm_config_li[0] def _convert_batchnorm(module): module_output = module if isinstance(module, torch.nn.SyncBatchNorm): module_output = torch.nn.BatchNorm2d( module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats) if module.affine: module_output.weight.data = module.weight.data.clone().detach() module_output.bias.data = module.bias.data.clone().detach() module_output.weight.requires_grad = module.weight.requires_grad module_output.bias.requires_grad = module.bias.requires_grad module_output.running_mean = module.running_mean module_output.running_var = module.running_var module_output.num_batches_tracked = module.num_batches_tracked for name, child in module.named_children(): module_output.add_module(name, _convert_batchnorm(child)) del module return module_output def _demo_mm_inputs(input_shape): (N, C, H, W) = input_shape rng = np.random.RandomState(0) img = torch.FloatTensor(rng.rand(*input_shape)) return img def _prepare_input_img(img_path, test_pipeline, shape=None): if shape is not None: test_pipeline[1]['img_scale'] = (shape[1], shape[0]) test_pipeline[1]['transforms'][0]['keep_ratio'] = False test_pipeline = [LoadImage()] + test_pipeline[1:] test_pipeline = Compose(test_pipeline) data = dict(img=img_path) data = test_pipeline(data) img = torch.FloatTensor(data['img']).unsqueeze_(0) return img def pytorch2onnx(model, img, norm_cfg=None, opset_version=13, show=False, output_file='tmp.onnx', verify=False): model.cpu().eval() if isinstance(model.decode_head, nn.ModuleList): num_classes = model.decode_head[-1].num_classes else: num_classes = model.decode_head.num_classes model.forward = model.forward_dummy origin_forward = model.forward register_extra_symbolics(opset_version) with torch.no_grad(): torch.onnx.export( model, img, output_file, input_names=['input'], output_names=['output'], export_params=True, keep_initializers_as_inputs=False, verbose=show, opset_version=opset_version, dynamic_axes=None) print(f'Successfully exported ONNX model: {output_file} (opset_version={opset_version})') model.forward = origin_forward # NOTE: optimize onnx m = onnx.load(output_file) if opset_version == 11: m.ir_version = 6 m = torch_exported_onnx_flow(m, disable_fuse_bn=False) onnx.save(m, output_file) print(f'{output_file} optimized by KNERON successfully.') if verify: onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) with torch.no_grad(): pytorch_result = model(img).numpy() input_all = [node.name for node in onnx_model.graph.input] input_initializer = [node.name for node in onnx_model.graph.initializer] net_feed_input = list(set(input_all) - set(input_initializer)) assert len(net_feed_input) == 1 sess = rt.InferenceSession(output_file, providers=['CPUExecutionProvider']) onnx_result = sess.run(None, {net_feed_input[0]: img.detach().numpy()})[0] if show: import cv2 img_show = img[0][:3, ...].permute(1, 2, 0) * 255 img_show = img_show.detach().numpy().astype(np.uint8) ori_shape = img_show.shape[:2] onnx_result_ = onnx_result[0].argmax(0) onnx_result_ = cv2.resize(onnx_result_.astype(np.uint8), (ori_shape[1], ori_shape[0])) show_result_pyplot(model, img_show, (onnx_result_, ), palette=model.PALETTE, block=False, title='ONNXRuntime', opacity=0.5) pytorch_result_ = pytorch_result.squeeze().argmax(0) pytorch_result_ = cv2.resize(pytorch_result_.astype(np.uint8), (ori_shape[1], ori_shape[0])) show_result_pyplot(model, img_show, (pytorch_result_, ), title='PyTorch', palette=model.PALETTE, opacity=0.5) np.testing.assert_allclose( pytorch_result.astype(np.float32) / num_classes, onnx_result.astype(np.float32) / num_classes, rtol=1e-5, atol=1e-5, err_msg='The outputs are different between Pytorch and ONNX') print('The outputs are same between Pytorch and ONNX.') if norm_cfg is not None: print("Prepending BatchNorm layer to ONNX as data normalization...") mean = norm_cfg['mean'] std = norm_cfg['std'] i_n = m.graph.input[0] if (i_n.type.tensor_type.shape.dim[1].dim_value != len(mean) or i_n.type.tensor_type.shape.dim[1].dim_value != len(std)): raise ValueError(f"--pixel-bias-value ({mean}) and --pixel-scale-value ({std}) should match input dimension.") norm_bn_bias = [-1 * cm / cs + 128. / cs for cm, cs in zip(mean, std)] norm_bn_scale = [1 / cs for cs in std] other.add_bias_scale_bn_after(m.graph, i_n.name, norm_bn_bias, norm_bn_scale) m = other.polish_model(m) bn_outf = os.path.splitext(output_file)[0] + "_bn_prepended.onnx" onnx.save(m, bn_outf) print(f"BN-Prepended ONNX saved to {bn_outf}") return def parse_args(): parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX') parser.add_argument('config', help='test config file path') parser.add_argument('--checkpoint', help='checkpoint file', default=None) parser.add_argument('--input-img', type=str, help='Images for input', default=None) parser.add_argument('--show', action='store_true', help='show onnx graph and segmentation results') parser.add_argument('--verify', action='store_true', help='verify the onnx model') parser.add_argument('--output-file', type=str, default='tmp.onnx') parser.add_argument('--opset-version', type=int, default=13) # default opset=13 parser.add_argument('--shape', type=int, nargs='+', default=None, help='input image height and width.') parser.add_argument('--cfg-options', nargs='+', action=DictAction, help='Override config options.') parser.add_argument('--normalization-in-onnx', action='store_true', help='Prepend BN for normalization.') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() if args.opset_version < 11: raise ValueError(f"Only opset_version >=11 is supported (got {args.opset_version}).") cfg = mmcv.Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) cfg.model.pretrained = None test_mode = cfg.model.test_cfg.mode if args.shape is None: if test_mode == 'slide': crop_size = cfg.model.test_cfg['crop_size'] input_shape = (1, 3, crop_size[1], crop_size[0]) else: img_scale = cfg.test_pipeline[1]['img_scale'] input_shape = (1, 3, img_scale[1], img_scale[0]) else: if test_mode == 'slide': warnings.warn("Shape assignment for slide-mode models may cause unexpected results.") if len(args.shape) == 1: input_shape = (1, 3, args.shape[0], args.shape[0]) elif len(args.shape) == 2: input_shape = (1, 3) + tuple(args.shape) else: raise ValueError('Invalid input shape') cfg.model.train_cfg = None segmentor = build_segmentor(cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) segmentor = _convert_batchnorm(segmentor) if args.checkpoint: checkpoint = load_checkpoint(segmentor, args.checkpoint, map_location='cpu') segmentor.CLASSES = checkpoint['meta']['CLASSES'] segmentor.PALETTE = checkpoint['meta']['PALETTE'] if args.input_img is not None: preprocess_shape = (input_shape[2], input_shape[3]) img = _prepare_input_img(args.input_img, cfg.data.test.pipeline, shape=preprocess_shape) else: img = _demo_mm_inputs(input_shape) if args.normalization_in_onnx: norm_cfg = _parse_normalize_cfg(cfg.test_pipeline) else: norm_cfg = None pytorch2onnx( segmentor, img, norm_cfg=norm_cfg, opset_version=args.opset_version, show=args.show, output_file=args.output_file, verify=args.verify, )