diff --git a/tools/pytorch2onnx_kneron.py b/tools/pytorch2onnx_kneron.py index c7ffc55..4d89ad8 100644 --- a/tools/pytorch2onnx_kneron.py +++ b/tools/pytorch2onnx_kneron.py @@ -208,17 +208,9 @@ def pytorch2onnx(model, print('The outputs are same between Pytorch and ONNX') if norm_cfg is not None: + print("Prepending BatchNorm layer to ONNX as data normalization...") mean = norm_cfg['mean'] std = norm_cfg['std'] - # TODO: figure out should we add the lines below? - ''' - if all(_ == 128. for _ in mean) and all(_ == 256. for _ in std): - print("normalization config perfectly matches kneron " - "quantization requirement; not prepending batch norm.") - return - print("normalization config does not match kneron " - "quantization requirement; prepending batch norm...") - ''' i_n = m.graph.input[0] if ( i_n.type.tensor_type.shape.dim[1].dim_value != len(mean) @@ -272,6 +264,13 @@ def parse_args(): 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') + parser.add_argument( + '--normalization-in-onnx', + action='store_true', + help='Prepend BatchNorm layer to onnx model as a role of data ' + 'normalization according to the mean and std value in the given' + 'cfg file.' + ) args = parser.parse_args() return args @@ -327,7 +326,10 @@ if __name__ == '__main__': else: img = _demo_mm_inputs(input_shape) - norm_cfg = _parse_normalize_cfg(cfg.test_pipeline) + if args.normalization_in_onnx: + norm_cfg = _parse_normalize_cfg(cfg.test_pipeline) + else: + norm_cfg = None # convert model to onnx file pytorch2onnx( segmentor,