STDC/tools/optimizer_scripts/norm_on_scaled_onnx.py
2022-04-12 14:26:54 +08:00

55 lines
1.5 KiB
Python

import onnx
import sys
import json
from tools import special
if len(sys.argv) != 3:
print("python norm_on_scaled_onnx.py input.onnx input.json")
exit(1)
# Modify onnx
m = onnx.load(sys.argv[1])
special.add_0_5_to_normalized_input(m)
onnx.save(m, sys.argv[1][:-4] + "norm.onnx")
# Change input node
origin_file = open(sys.argv[2], "r")
origin_json = json.load(origin_file)
origin_json["input_node"]["output_datapath_radix"] = [8]
new_json_str = json.dumps(origin_json)
# Modify json
file = open(sys.argv[1][:-4] + "norm.onnx" + ".json", "w")
s = """{{
\"{0}\" :
{{
\"bias_bitwidth\" : 16,
\"{0}_bias\" : [15],
\"{0}_weight\" : [3,3,3],
\"conv_coarse_shift\" : [-4,-4,-4],
\"conv_fine_shift\" : [0,0,0],
\"conv_total_shift\" : [-4,-4,-4],
\"cpu_mode\" : false,
\"delta_input_bitwidth\" : [0],
\"delta_output_bitwidth\" : 8,
\"flag_radix_bias_eq_output\" : true,
\"input_scale\" : [[1.0,1.0,1.0]],
\"output_scale\" : [1.0, 1.0, 1.0],
\"psum_bitwidth\" : 16,
\"weight_bitwidth\" : 8,
\"input_datapath_bitwidth\" : [8],
\"input_datapath_radix\" : [8],
\"working_input_bitwidth\" : 8,
\"working_input_radix\" : [8],
\"working_output_bitwidth\" : 16,
\"working_output_radix\" : 15,
\"output_datapath_bitwidth\" : 8,
\"output_datapath_radix\" : 7
}},\n""".format(
"input_norm"
)
file.write(s + new_json_str[1:])
file.close()
origin_file.close()