212 lines
6.6 KiB
Python
212 lines
6.6 KiB
Python
# ref http://192.168.200.1:8088/jiyuan/converter_scripts.git
|
|
|
|
import sys
|
|
import onnx
|
|
import onnx.utils
|
|
from tools import other, helper, replacing
|
|
|
|
"""
|
|
Change onnx model from version 1.4 to version 1.6.
|
|
"""
|
|
|
|
|
|
def replace_all_attribute_to_const_node_in_pad_node(g):
|
|
node_to_remove = []
|
|
node_to_extend = []
|
|
for node in g.node:
|
|
if node.op_type != "Pad":
|
|
continue
|
|
|
|
pad_loc_node = None # must have
|
|
pad_mode = "constant"
|
|
pad_value_node = helper.list_to_constant(
|
|
node.name + "_pad_value", [], [0.0]
|
|
) # need scalar
|
|
for att in node.attribute:
|
|
if att.name == "mode":
|
|
pad_mode = helper.get_var_attribute_by_name(
|
|
node, "mode", "string"
|
|
)
|
|
if att.name == "pads":
|
|
pad_loc_node = helper.list_to_constant(
|
|
node.name + "_pad_loc", [len(att.ints)], att.ints
|
|
)
|
|
if att.name == "value":
|
|
pad_value_node = helper.list_to_constant(
|
|
node.name + "_pad_value", [], [att.f]
|
|
)
|
|
|
|
new_node = onnx.helper.make_node(
|
|
"Pad",
|
|
[node.input[0], pad_loc_node.name, pad_value_node.name],
|
|
[node.output[0]],
|
|
name=node.output[0],
|
|
mode=pad_mode,
|
|
)
|
|
node_to_remove.append(node)
|
|
node_to_extend.append(new_node)
|
|
node_to_extend.append(pad_loc_node)
|
|
node_to_extend.append(pad_value_node)
|
|
|
|
for node in node_to_remove:
|
|
g.node.remove(node)
|
|
for node in node_to_extend:
|
|
g.node.extend([node])
|
|
|
|
|
|
def upsampling_to_resize(g):
|
|
for node in g.node:
|
|
if node.op_type != "Upsample":
|
|
continue
|
|
upsampling_mode = helper.get_var_attribute_by_name(
|
|
node, "mode", "string"
|
|
)
|
|
|
|
scale_value_node = helper.find_node_by_output_name(g, node.input[1])
|
|
if scale_value_node.op_type != "Constant":
|
|
raise TypeError(
|
|
'seems there is a dynamic "scales" param in Upsampling node: '
|
|
+ node.name
|
|
+ " , you might need to do constant folding first"
|
|
)
|
|
|
|
roi_node = helper.list_to_constant(node.name + "_roi_value", [0], [])
|
|
|
|
new_node = onnx.helper.make_node(
|
|
"Resize",
|
|
[node.input[0], roi_node.name, scale_value_node.name],
|
|
[node.output[0]],
|
|
name=node.output[0],
|
|
mode=upsampling_mode,
|
|
coordinate_transformation_mode="asymmetric",
|
|
)
|
|
|
|
g.node.remove(node)
|
|
g.node.extend([new_node])
|
|
g.node.extend([roi_node])
|
|
|
|
|
|
def replace_all_attribute_to_const_node_in_slice_node(g):
|
|
for node in g.node:
|
|
if node.op_type != "Slice":
|
|
continue
|
|
|
|
axes_const_node = None
|
|
ends_const_node = None
|
|
starts_const_node = None
|
|
steps_const_node = None
|
|
for att in node.attribute:
|
|
if att.name == "axes":
|
|
axes_const_node = helper.list_to_constant(
|
|
node.name + "_axes_value", [len(att.ints)], att.ints
|
|
)
|
|
|
|
if att.name == "ends":
|
|
ends_const_node = helper.list_to_constant(
|
|
node.name + "_ends_value", [len(att.ints)], att.ints
|
|
)
|
|
|
|
if att.name == "starts":
|
|
starts_const_node = helper.list_to_constant(
|
|
node.name + "_starts_value", [len(att.ints)], att.ints
|
|
)
|
|
|
|
if att.name == "steps":
|
|
steps_const_node = helper.list_to_constant(
|
|
node.name + "_steps_value", [len(att.ints)], att.ints
|
|
)
|
|
|
|
# pop out from back
|
|
attr_len = len(node.attribute)
|
|
for i in range(attr_len):
|
|
node.attribute.remove(node.attribute[attr_len - 1 - i])
|
|
|
|
# according the spec, we need to add node in specific order
|
|
if starts_const_node is not None:
|
|
g.node.extend([starts_const_node])
|
|
node.input.extend([starts_const_node.name])
|
|
if ends_const_node is not None:
|
|
g.node.extend([ends_const_node])
|
|
node.input.extend([ends_const_node.name])
|
|
if axes_const_node is not None:
|
|
g.node.extend([axes_const_node])
|
|
node.input.extend([axes_const_node.name])
|
|
if steps_const_node is not None:
|
|
g.node.extend([steps_const_node])
|
|
node.input.extend([steps_const_node.name])
|
|
|
|
|
|
def replace_min_max_attribute_to_const_node_in_clip_node(g):
|
|
for node in g.node:
|
|
if node.op_type != "Clip":
|
|
continue
|
|
|
|
max_const_node = None
|
|
min_const_node = None
|
|
for att in node.attribute:
|
|
if att.name == "max":
|
|
max_const_node = helper.list_to_constant(
|
|
node.name + "_max_value", [], [att.f]
|
|
)
|
|
|
|
if att.name == "min":
|
|
min_const_node = helper.list_to_constant(
|
|
node.name + "_min_value", [], [att.f]
|
|
)
|
|
|
|
# pop out from back
|
|
node.attribute.remove(node.attribute[1])
|
|
node.attribute.remove(node.attribute[0])
|
|
|
|
# according the spec, we need to add node in specific order
|
|
g.node.extend([min_const_node])
|
|
g.node.extend([max_const_node])
|
|
node.input.extend([min_const_node.name])
|
|
node.input.extend([max_const_node.name])
|
|
|
|
|
|
def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto:
|
|
"""Update ir_version from 4 to 6 and update opset from 9 to 11.
|
|
|
|
Args:
|
|
model (onnx.ModelProto): input onnx model.
|
|
|
|
Returns:
|
|
onnx.ModelProto: updated onnx model.
|
|
"""
|
|
graph = model.graph
|
|
|
|
if model.opset_import[0].version == 11:
|
|
print("(Stop) the input model is already opset 11, no need to upgrade")
|
|
exit(1)
|
|
|
|
# deal with empty node name issue
|
|
other.add_name_to_node(graph)
|
|
# simplify the node param type from initializer to constant
|
|
replacing.replace_initializer_with_Constant(graph)
|
|
|
|
# Modify the nodes.
|
|
replace_min_max_attribute_to_const_node_in_clip_node(graph)
|
|
replace_all_attribute_to_const_node_in_slice_node(graph)
|
|
replace_all_attribute_to_const_node_in_pad_node(graph)
|
|
upsampling_to_resize(graph)
|
|
other.topological_sort(graph)
|
|
|
|
# Change model properties.
|
|
model.ir_version = 6
|
|
model.opset_import[0].version = 11
|
|
|
|
model = other.polish_model(model)
|
|
return model
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) != 3:
|
|
print("Usage:{} file_in file_out".format(sys.argv[0]))
|
|
exit(1)
|
|
|
|
model = onnx.load(sys.argv[1])
|
|
model = onnx1_4to1_6(model)
|
|
|
|
onnx.save(model, sys.argv[2])
|