chore: normalization-in-onnx as an argument in pytorch2onnx_kneron
This commit is contained in:
parent
b6a868c3e2
commit
caaa31c231
@ -208,17 +208,9 @@ def pytorch2onnx(model,
|
||||
print('The outputs are same between Pytorch and ONNX')
|
||||
|
||||
if norm_cfg is not None:
|
||||
print("Prepending BatchNorm layer to ONNX as data normalization...")
|
||||
mean = norm_cfg['mean']
|
||||
std = norm_cfg['std']
|
||||
# TODO: figure out should we add the lines below?
|
||||
'''
|
||||
if all(_ == 128. for _ in mean) and all(_ == 256. for _ in std):
|
||||
print("normalization config perfectly matches kneron "
|
||||
"quantization requirement; not prepending batch norm.")
|
||||
return
|
||||
print("normalization config does not match kneron "
|
||||
"quantization requirement; prepending batch norm...")
|
||||
'''
|
||||
i_n = m.graph.input[0]
|
||||
if (
|
||||
i_n.type.tensor_type.shape.dim[1].dim_value != len(mean)
|
||||
@ -272,6 +264,13 @@ def parse_args():
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--normalization-in-onnx',
|
||||
action='store_true',
|
||||
help='Prepend BatchNorm layer to onnx model as a role of data '
|
||||
'normalization according to the mean and std value in the given'
|
||||
'cfg file.'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@ -327,7 +326,10 @@ if __name__ == '__main__':
|
||||
else:
|
||||
img = _demo_mm_inputs(input_shape)
|
||||
|
||||
norm_cfg = _parse_normalize_cfg(cfg.test_pipeline)
|
||||
if args.normalization_in_onnx:
|
||||
norm_cfg = _parse_normalize_cfg(cfg.test_pipeline)
|
||||
else:
|
||||
norm_cfg = None
|
||||
# convert model to onnx file
|
||||
pytorch2onnx(
|
||||
segmentor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user