feat: onnx auto bn for kneron quantization
This commit is contained in:
parent
0ec0b33ed7
commit
b6a868c3e2
@ -2,6 +2,7 @@
|
||||
# Original: tools/pytorch2onnx.py, modified by Kneron
|
||||
import argparse
|
||||
|
||||
import os
|
||||
import onnx
|
||||
import mmcv
|
||||
import numpy as np
|
||||
@ -19,6 +20,7 @@ from mmseg.apis.inference import LoadImage
|
||||
from mmseg.datasets.pipelines import Compose
|
||||
from mmseg.models import build_segmentor
|
||||
|
||||
from optimizer_scripts.tools import other
|
||||
from optimizer_scripts.pytorch_exported_onnx_preprocess import (
|
||||
torch_exported_onnx_flow,
|
||||
)
|
||||
@ -26,6 +28,19 @@ from optimizer_scripts.pytorch_exported_onnx_preprocess import (
|
||||
torch.manual_seed(3)
|
||||
|
||||
|
||||
def _parse_normalize_cfg(test_pipeline):
|
||||
transforms = None
|
||||
for pipeline in test_pipeline:
|
||||
if 'transforms' in pipeline:
|
||||
transforms = pipeline['transforms']
|
||||
break
|
||||
assert transforms is not None, 'Failed to find `transforms`'
|
||||
norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize']
|
||||
assert len(norm_config_li) == 1, '`norm_config` should only have one'
|
||||
norm_config = norm_config_li[0]
|
||||
return norm_config
|
||||
|
||||
|
||||
def _convert_batchnorm(module):
|
||||
module_output = module
|
||||
if isinstance(module, torch.nn.SyncBatchNorm):
|
||||
@ -80,6 +95,7 @@ def _prepare_input_img(img_path,
|
||||
|
||||
def pytorch2onnx(model,
|
||||
img,
|
||||
norm_cfg=None,
|
||||
opset_version=11,
|
||||
show=False,
|
||||
output_file='tmp.onnx',
|
||||
@ -130,6 +146,7 @@ def pytorch2onnx(model,
|
||||
m.ir_version = 6
|
||||
m = torch_exported_onnx_flow(m, disable_fuse_bn=False)
|
||||
onnx.save(m, output_file)
|
||||
print(f'{output_file} optimized by KNERON successfully.')
|
||||
|
||||
if verify:
|
||||
onnx_model = onnx.load(output_file)
|
||||
@ -190,6 +207,40 @@ def pytorch2onnx(model,
|
||||
err_msg='The outputs are different between Pytorch and ONNX')
|
||||
print('The outputs are same between Pytorch and ONNX')
|
||||
|
||||
if norm_cfg is not None:
|
||||
mean = norm_cfg['mean']
|
||||
std = norm_cfg['std']
|
||||
# TODO: figure out should we add the lines below?
|
||||
'''
|
||||
if all(_ == 128. for _ in mean) and all(_ == 256. for _ in std):
|
||||
print("normalization config perfectly matches kneron "
|
||||
"quantization requirement; not prepending batch norm.")
|
||||
return
|
||||
print("normalization config does not match kneron "
|
||||
"quantization requirement; prepending batch norm...")
|
||||
'''
|
||||
i_n = m.graph.input[0]
|
||||
if (
|
||||
i_n.type.tensor_type.shape.dim[1].dim_value != len(mean)
|
||||
or i_n.type.tensor_type.shape.dim[1].dim_value != len(std)
|
||||
):
|
||||
raise ValueError(
|
||||
f"--pixel-bias-value ({mean}) and --pixel-scale-value "
|
||||
f"({std}) should be same as input dimension: "
|
||||
f"{i_n.type.tensor_type.shape.dim[1].dim_value}"
|
||||
)
|
||||
norm_bn_bias = [-1 * cm / cs + 128. / cs for cm, cs in zip(mean, std)]
|
||||
norm_bn_scale = [1 / cs for cs in std]
|
||||
other.add_bias_scale_bn_after(
|
||||
m.graph, i_n.name, norm_bn_bias, norm_bn_scale
|
||||
)
|
||||
m = other.polish_model(m)
|
||||
bn_outf = os.path.splitext(output_file)[0] + "_kneron_optimized.onnx"
|
||||
onnx.save(m, bn_outf)
|
||||
print(f"ONNX for quantization saved to {bn_outf}")
|
||||
|
||||
return
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
|
||||
@ -276,10 +327,12 @@ if __name__ == '__main__':
|
||||
else:
|
||||
img = _demo_mm_inputs(input_shape)
|
||||
|
||||
norm_cfg = _parse_normalize_cfg(cfg.test_pipeline)
|
||||
# convert model to onnx file
|
||||
pytorch2onnx(
|
||||
segmentor,
|
||||
img,
|
||||
norm_cfg=norm_cfg,
|
||||
opset_version=args.opset_version,
|
||||
show=args.show,
|
||||
output_file=args.output_file,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user