feat: onnx auto bn for kneron quantization

This commit is contained in:
chingning.chen 2022-03-25 17:03:14 +08:00
parent 0ec0b33ed7
commit b6a868c3e2

View File

@ -2,6 +2,7 @@
# Original: tools/pytorch2onnx.py, modified by Kneron # Original: tools/pytorch2onnx.py, modified by Kneron
import argparse import argparse
import os
import onnx import onnx
import mmcv import mmcv
import numpy as np import numpy as np
@ -19,6 +20,7 @@ from mmseg.apis.inference import LoadImage
from mmseg.datasets.pipelines import Compose from mmseg.datasets.pipelines import Compose
from mmseg.models import build_segmentor from mmseg.models import build_segmentor
from optimizer_scripts.tools import other
from optimizer_scripts.pytorch_exported_onnx_preprocess import ( from optimizer_scripts.pytorch_exported_onnx_preprocess import (
torch_exported_onnx_flow, torch_exported_onnx_flow,
) )
@ -26,6 +28,19 @@ from optimizer_scripts.pytorch_exported_onnx_preprocess import (
torch.manual_seed(3) torch.manual_seed(3)
def _parse_normalize_cfg(test_pipeline):
transforms = None
for pipeline in test_pipeline:
if 'transforms' in pipeline:
transforms = pipeline['transforms']
break
assert transforms is not None, 'Failed to find `transforms`'
norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize']
assert len(norm_config_li) == 1, '`norm_config` should only have one'
norm_config = norm_config_li[0]
return norm_config
def _convert_batchnorm(module): def _convert_batchnorm(module):
module_output = module module_output = module
if isinstance(module, torch.nn.SyncBatchNorm): if isinstance(module, torch.nn.SyncBatchNorm):
@ -80,6 +95,7 @@ def _prepare_input_img(img_path,
def pytorch2onnx(model, def pytorch2onnx(model,
img, img,
norm_cfg=None,
opset_version=11, opset_version=11,
show=False, show=False,
output_file='tmp.onnx', output_file='tmp.onnx',
@ -130,6 +146,7 @@ def pytorch2onnx(model,
m.ir_version = 6 m.ir_version = 6
m = torch_exported_onnx_flow(m, disable_fuse_bn=False) m = torch_exported_onnx_flow(m, disable_fuse_bn=False)
onnx.save(m, output_file) onnx.save(m, output_file)
print(f'{output_file} optimized by KNERON successfully.')
if verify: if verify:
onnx_model = onnx.load(output_file) onnx_model = onnx.load(output_file)
@ -190,6 +207,40 @@ def pytorch2onnx(model,
err_msg='The outputs are different between Pytorch and ONNX') err_msg='The outputs are different between Pytorch and ONNX')
print('The outputs are same between Pytorch and ONNX') print('The outputs are same between Pytorch and ONNX')
if norm_cfg is not None:
mean = norm_cfg['mean']
std = norm_cfg['std']
# TODO: figure out should we add the lines below?
'''
if all(_ == 128. for _ in mean) and all(_ == 256. for _ in std):
print("normalization config perfectly matches kneron "
"quantization requirement; not prepending batch norm.")
return
print("normalization config does not match kneron "
"quantization requirement; prepending batch norm...")
'''
i_n = m.graph.input[0]
if (
i_n.type.tensor_type.shape.dim[1].dim_value != len(mean)
or i_n.type.tensor_type.shape.dim[1].dim_value != len(std)
):
raise ValueError(
f"--pixel-bias-value ({mean}) and --pixel-scale-value "
f"({std}) should be same as input dimension: "
f"{i_n.type.tensor_type.shape.dim[1].dim_value}"
)
norm_bn_bias = [-1 * cm / cs + 128. / cs for cm, cs in zip(mean, std)]
norm_bn_scale = [1 / cs for cs in std]
other.add_bias_scale_bn_after(
m.graph, i_n.name, norm_bn_bias, norm_bn_scale
)
m = other.polish_model(m)
bn_outf = os.path.splitext(output_file)[0] + "_kneron_optimized.onnx"
onnx.save(m, bn_outf)
print(f"ONNX for quantization saved to {bn_outf}")
return
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX') parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
@ -276,10 +327,12 @@ if __name__ == '__main__':
else: else:
img = _demo_mm_inputs(input_shape) img = _demo_mm_inputs(input_shape)
norm_cfg = _parse_normalize_cfg(cfg.test_pipeline)
# convert model to onnx file # convert model to onnx file
pytorch2onnx( pytorch2onnx(
segmentor, segmentor,
img, img,
norm_cfg=norm_cfg,
opset_version=args.opset_version, opset_version=args.opset_version,
show=args.show, show=args.show,
output_file=args.output_file, output_file=args.output_file,