feat: onnx auto bn for kneron quantization
This commit is contained in:
parent
0ec0b33ed7
commit
b6a868c3e2
@ -2,6 +2,7 @@
|
|||||||
# Original: tools/pytorch2onnx.py, modified by Kneron
|
# Original: tools/pytorch2onnx.py, modified by Kneron
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import os
|
||||||
import onnx
|
import onnx
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -19,6 +20,7 @@ from mmseg.apis.inference import LoadImage
|
|||||||
from mmseg.datasets.pipelines import Compose
|
from mmseg.datasets.pipelines import Compose
|
||||||
from mmseg.models import build_segmentor
|
from mmseg.models import build_segmentor
|
||||||
|
|
||||||
|
from optimizer_scripts.tools import other
|
||||||
from optimizer_scripts.pytorch_exported_onnx_preprocess import (
|
from optimizer_scripts.pytorch_exported_onnx_preprocess import (
|
||||||
torch_exported_onnx_flow,
|
torch_exported_onnx_flow,
|
||||||
)
|
)
|
||||||
@ -26,6 +28,19 @@ from optimizer_scripts.pytorch_exported_onnx_preprocess import (
|
|||||||
torch.manual_seed(3)
|
torch.manual_seed(3)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_normalize_cfg(test_pipeline):
|
||||||
|
transforms = None
|
||||||
|
for pipeline in test_pipeline:
|
||||||
|
if 'transforms' in pipeline:
|
||||||
|
transforms = pipeline['transforms']
|
||||||
|
break
|
||||||
|
assert transforms is not None, 'Failed to find `transforms`'
|
||||||
|
norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize']
|
||||||
|
assert len(norm_config_li) == 1, '`norm_config` should only have one'
|
||||||
|
norm_config = norm_config_li[0]
|
||||||
|
return norm_config
|
||||||
|
|
||||||
|
|
||||||
def _convert_batchnorm(module):
|
def _convert_batchnorm(module):
|
||||||
module_output = module
|
module_output = module
|
||||||
if isinstance(module, torch.nn.SyncBatchNorm):
|
if isinstance(module, torch.nn.SyncBatchNorm):
|
||||||
@ -80,6 +95,7 @@ def _prepare_input_img(img_path,
|
|||||||
|
|
||||||
def pytorch2onnx(model,
|
def pytorch2onnx(model,
|
||||||
img,
|
img,
|
||||||
|
norm_cfg=None,
|
||||||
opset_version=11,
|
opset_version=11,
|
||||||
show=False,
|
show=False,
|
||||||
output_file='tmp.onnx',
|
output_file='tmp.onnx',
|
||||||
@ -130,6 +146,7 @@ def pytorch2onnx(model,
|
|||||||
m.ir_version = 6
|
m.ir_version = 6
|
||||||
m = torch_exported_onnx_flow(m, disable_fuse_bn=False)
|
m = torch_exported_onnx_flow(m, disable_fuse_bn=False)
|
||||||
onnx.save(m, output_file)
|
onnx.save(m, output_file)
|
||||||
|
print(f'{output_file} optimized by KNERON successfully.')
|
||||||
|
|
||||||
if verify:
|
if verify:
|
||||||
onnx_model = onnx.load(output_file)
|
onnx_model = onnx.load(output_file)
|
||||||
@ -190,6 +207,40 @@ def pytorch2onnx(model,
|
|||||||
err_msg='The outputs are different between Pytorch and ONNX')
|
err_msg='The outputs are different between Pytorch and ONNX')
|
||||||
print('The outputs are same between Pytorch and ONNX')
|
print('The outputs are same between Pytorch and ONNX')
|
||||||
|
|
||||||
|
if norm_cfg is not None:
|
||||||
|
mean = norm_cfg['mean']
|
||||||
|
std = norm_cfg['std']
|
||||||
|
# TODO: figure out should we add the lines below?
|
||||||
|
'''
|
||||||
|
if all(_ == 128. for _ in mean) and all(_ == 256. for _ in std):
|
||||||
|
print("normalization config perfectly matches kneron "
|
||||||
|
"quantization requirement; not prepending batch norm.")
|
||||||
|
return
|
||||||
|
print("normalization config does not match kneron "
|
||||||
|
"quantization requirement; prepending batch norm...")
|
||||||
|
'''
|
||||||
|
i_n = m.graph.input[0]
|
||||||
|
if (
|
||||||
|
i_n.type.tensor_type.shape.dim[1].dim_value != len(mean)
|
||||||
|
or i_n.type.tensor_type.shape.dim[1].dim_value != len(std)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"--pixel-bias-value ({mean}) and --pixel-scale-value "
|
||||||
|
f"({std}) should be same as input dimension: "
|
||||||
|
f"{i_n.type.tensor_type.shape.dim[1].dim_value}"
|
||||||
|
)
|
||||||
|
norm_bn_bias = [-1 * cm / cs + 128. / cs for cm, cs in zip(mean, std)]
|
||||||
|
norm_bn_scale = [1 / cs for cs in std]
|
||||||
|
other.add_bias_scale_bn_after(
|
||||||
|
m.graph, i_n.name, norm_bn_bias, norm_bn_scale
|
||||||
|
)
|
||||||
|
m = other.polish_model(m)
|
||||||
|
bn_outf = os.path.splitext(output_file)[0] + "_kneron_optimized.onnx"
|
||||||
|
onnx.save(m, bn_outf)
|
||||||
|
print(f"ONNX for quantization saved to {bn_outf}")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
|
parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
|
||||||
@ -276,10 +327,12 @@ if __name__ == '__main__':
|
|||||||
else:
|
else:
|
||||||
img = _demo_mm_inputs(input_shape)
|
img = _demo_mm_inputs(input_shape)
|
||||||
|
|
||||||
|
norm_cfg = _parse_normalize_cfg(cfg.test_pipeline)
|
||||||
# convert model to onnx file
|
# convert model to onnx file
|
||||||
pytorch2onnx(
|
pytorch2onnx(
|
||||||
segmentor,
|
segmentor,
|
||||||
img,
|
img,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
opset_version=args.opset_version,
|
opset_version=args.opset_version,
|
||||||
show=args.show,
|
show=args.show,
|
||||||
output_file=args.output_file,
|
output_file=args.output_file,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user