chore: normalization-in-onnx as an argument in pytorch2onnx_kneron
This commit is contained in:
parent
b6a868c3e2
commit
caaa31c231
@ -208,17 +208,9 @@ def pytorch2onnx(model,
|
|||||||
print('The outputs are same between Pytorch and ONNX')
|
print('The outputs are same between Pytorch and ONNX')
|
||||||
|
|
||||||
if norm_cfg is not None:
|
if norm_cfg is not None:
|
||||||
|
print("Prepending BatchNorm layer to ONNX as data normalization...")
|
||||||
mean = norm_cfg['mean']
|
mean = norm_cfg['mean']
|
||||||
std = norm_cfg['std']
|
std = norm_cfg['std']
|
||||||
# TODO: figure out should we add the lines below?
|
|
||||||
'''
|
|
||||||
if all(_ == 128. for _ in mean) and all(_ == 256. for _ in std):
|
|
||||||
print("normalization config perfectly matches kneron "
|
|
||||||
"quantization requirement; not prepending batch norm.")
|
|
||||||
return
|
|
||||||
print("normalization config does not match kneron "
|
|
||||||
"quantization requirement; prepending batch norm...")
|
|
||||||
'''
|
|
||||||
i_n = m.graph.input[0]
|
i_n = m.graph.input[0]
|
||||||
if (
|
if (
|
||||||
i_n.type.tensor_type.shape.dim[1].dim_value != len(mean)
|
i_n.type.tensor_type.shape.dim[1].dim_value != len(mean)
|
||||||
@ -272,6 +264,13 @@ def parse_args():
|
|||||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||||
'Note that the quotation marks are necessary and that no white space '
|
'Note that the quotation marks are necessary and that no white space '
|
||||||
'is allowed.')
|
'is allowed.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--normalization-in-onnx',
|
||||||
|
action='store_true',
|
||||||
|
help='Prepend BatchNorm layer to onnx model as a role of data '
|
||||||
|
'normalization according to the mean and std value in the given'
|
||||||
|
'cfg file.'
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -327,7 +326,10 @@ if __name__ == '__main__':
|
|||||||
else:
|
else:
|
||||||
img = _demo_mm_inputs(input_shape)
|
img = _demo_mm_inputs(input_shape)
|
||||||
|
|
||||||
|
if args.normalization_in_onnx:
|
||||||
norm_cfg = _parse_normalize_cfg(cfg.test_pipeline)
|
norm_cfg = _parse_normalize_cfg(cfg.test_pipeline)
|
||||||
|
else:
|
||||||
|
norm_cfg = None
|
||||||
# convert model to onnx file
|
# convert model to onnx file
|
||||||
pytorch2onnx(
|
pytorch2onnx(
|
||||||
segmentor,
|
segmentor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user