From b6a868c3e299149a3501ec8f9f5bcb2b01c582bb Mon Sep 17 00:00:00 2001 From: "chingning.chen" Date: Fri, 25 Mar 2022 17:03:14 +0800 Subject: [PATCH] feat: onnx auto bn for kneron quantization --- tools/pytorch2onnx_kneron.py | 53 ++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tools/pytorch2onnx_kneron.py b/tools/pytorch2onnx_kneron.py index 297aea9..c7ffc55 100644 --- a/tools/pytorch2onnx_kneron.py +++ b/tools/pytorch2onnx_kneron.py @@ -2,6 +2,7 @@ # Original: tools/pytorch2onnx.py, modified by Kneron import argparse +import os import onnx import mmcv import numpy as np @@ -19,6 +20,7 @@ from mmseg.apis.inference import LoadImage from mmseg.datasets.pipelines import Compose from mmseg.models import build_segmentor +from optimizer_scripts.tools import other from optimizer_scripts.pytorch_exported_onnx_preprocess import ( torch_exported_onnx_flow, ) @@ -26,6 +28,19 @@ from optimizer_scripts.pytorch_exported_onnx_preprocess import ( torch.manual_seed(3) +def _parse_normalize_cfg(test_pipeline): + transforms = None + for pipeline in test_pipeline: + if 'transforms' in pipeline: + transforms = pipeline['transforms'] + break + assert transforms is not None, 'Failed to find `transforms`' + norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize'] + assert len(norm_config_li) == 1, '`norm_config` should only have one' + norm_config = norm_config_li[0] + return norm_config + + def _convert_batchnorm(module): module_output = module if isinstance(module, torch.nn.SyncBatchNorm): @@ -80,6 +95,7 @@ def _prepare_input_img(img_path, def pytorch2onnx(model, img, + norm_cfg=None, opset_version=11, show=False, output_file='tmp.onnx', @@ -130,6 +146,7 @@ def pytorch2onnx(model, m.ir_version = 6 m = torch_exported_onnx_flow(m, disable_fuse_bn=False) onnx.save(m, output_file) + print(f'{output_file} optimized by KNERON successfully.') if verify: onnx_model = onnx.load(output_file) @@ -190,6 +207,40 @@ def pytorch2onnx(model, err_msg='The outputs are different between Pytorch and ONNX') print('The outputs are same between Pytorch and ONNX') + if norm_cfg is not None: + mean = norm_cfg['mean'] + std = norm_cfg['std'] + # TODO: figure out should we add the lines below? + ''' + if all(_ == 128. for _ in mean) and all(_ == 256. for _ in std): + print("normalization config perfectly matches kneron " + "quantization requirement; not prepending batch norm.") + return + print("normalization config does not match kneron " + "quantization requirement; prepending batch norm...") + ''' + i_n = m.graph.input[0] + if ( + i_n.type.tensor_type.shape.dim[1].dim_value != len(mean) + or i_n.type.tensor_type.shape.dim[1].dim_value != len(std) + ): + raise ValueError( + f"--pixel-bias-value ({mean}) and --pixel-scale-value " + f"({std}) should be same as input dimension: " + f"{i_n.type.tensor_type.shape.dim[1].dim_value}" + ) + norm_bn_bias = [-1 * cm / cs + 128. / cs for cm, cs in zip(mean, std)] + norm_bn_scale = [1 / cs for cs in std] + other.add_bias_scale_bn_after( + m.graph, i_n.name, norm_bn_bias, norm_bn_scale + ) + m = other.polish_model(m) + bn_outf = os.path.splitext(output_file)[0] + "_kneron_optimized.onnx" + onnx.save(m, bn_outf) + print(f"ONNX for quantization saved to {bn_outf}") + + return + def parse_args(): parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX') @@ -276,10 +327,12 @@ if __name__ == '__main__': else: img = _demo_mm_inputs(input_shape) + norm_cfg = _parse_normalize_cfg(cfg.test_pipeline) # convert model to onnx file pytorch2onnx( segmentor, img, + norm_cfg=norm_cfg, opset_version=args.opset_version, show=args.show, output_file=args.output_file,