chore: normalization-in-onnx as an argument in pytorch2onnx_kneron

This commit is contained in:
chingning.chen 2022-03-28 16:43:00 +08:00
parent b6a868c3e2
commit caaa31c231

View File

@ -208,17 +208,9 @@ def pytorch2onnx(model,
print('The outputs are same between Pytorch and ONNX') print('The outputs are same between Pytorch and ONNX')
if norm_cfg is not None: if norm_cfg is not None:
print("Prepending BatchNorm layer to ONNX as data normalization...")
mean = norm_cfg['mean'] mean = norm_cfg['mean']
std = norm_cfg['std'] std = norm_cfg['std']
# TODO: figure out should we add the lines below?
'''
if all(_ == 128. for _ in mean) and all(_ == 256. for _ in std):
print("normalization config perfectly matches kneron "
"quantization requirement; not prepending batch norm.")
return
print("normalization config does not match kneron "
"quantization requirement; prepending batch norm...")
'''
i_n = m.graph.input[0] i_n = m.graph.input[0]
if ( if (
i_n.type.tensor_type.shape.dim[1].dim_value != len(mean) i_n.type.tensor_type.shape.dim[1].dim_value != len(mean)
@ -272,6 +264,13 @@ def parse_args():
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space ' 'Note that the quotation marks are necessary and that no white space '
'is allowed.') 'is allowed.')
parser.add_argument(
'--normalization-in-onnx',
action='store_true',
help='Prepend BatchNorm layer to onnx model as a role of data '
'normalization according to the mean and std value in the given'
'cfg file.'
)
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -327,7 +326,10 @@ if __name__ == '__main__':
else: else:
img = _demo_mm_inputs(input_shape) img = _demo_mm_inputs(input_shape)
if args.normalization_in_onnx:
norm_cfg = _parse_normalize_cfg(cfg.test_pipeline) norm_cfg = _parse_normalize_cfg(cfg.test_pipeline)
else:
norm_cfg = None
# convert model to onnx file # convert model to onnx file
pytorch2onnx( pytorch2onnx(
segmentor, segmentor,