2026-03-11 16:13:59 +08:00

212 lines
6.6 KiB
Python

# ref http://192.168.200.1:8088/jiyuan/converter_scripts.git
import sys
import onnx
import onnx.utils
from tools import other, helper, replacing
"""
Change onnx model from version 1.4 to version 1.6.
"""
def replace_all_attribute_to_const_node_in_pad_node(g):
node_to_remove = []
node_to_extend = []
for node in g.node:
if node.op_type != "Pad":
continue
pad_loc_node = None # must have
pad_mode = "constant"
pad_value_node = helper.list_to_constant(
node.name + "_pad_value", [], [0.0]
) # need scalar
for att in node.attribute:
if att.name == "mode":
pad_mode = helper.get_var_attribute_by_name(
node, "mode", "string"
)
if att.name == "pads":
pad_loc_node = helper.list_to_constant(
node.name + "_pad_loc", [len(att.ints)], att.ints
)
if att.name == "value":
pad_value_node = helper.list_to_constant(
node.name + "_pad_value", [], [att.f]
)
new_node = onnx.helper.make_node(
"Pad",
[node.input[0], pad_loc_node.name, pad_value_node.name],
[node.output[0]],
name=node.output[0],
mode=pad_mode,
)
node_to_remove.append(node)
node_to_extend.append(new_node)
node_to_extend.append(pad_loc_node)
node_to_extend.append(pad_value_node)
for node in node_to_remove:
g.node.remove(node)
for node in node_to_extend:
g.node.extend([node])
def upsampling_to_resize(g):
for node in g.node:
if node.op_type != "Upsample":
continue
upsampling_mode = helper.get_var_attribute_by_name(
node, "mode", "string"
)
scale_value_node = helper.find_node_by_output_name(g, node.input[1])
if scale_value_node.op_type != "Constant":
raise TypeError(
'seems there is a dynamic "scales" param in Upsampling node: '
+ node.name
+ " , you might need to do constant folding first"
)
roi_node = helper.list_to_constant(node.name + "_roi_value", [0], [])
new_node = onnx.helper.make_node(
"Resize",
[node.input[0], roi_node.name, scale_value_node.name],
[node.output[0]],
name=node.output[0],
mode=upsampling_mode,
coordinate_transformation_mode="asymmetric",
)
g.node.remove(node)
g.node.extend([new_node])
g.node.extend([roi_node])
def replace_all_attribute_to_const_node_in_slice_node(g):
for node in g.node:
if node.op_type != "Slice":
continue
axes_const_node = None
ends_const_node = None
starts_const_node = None
steps_const_node = None
for att in node.attribute:
if att.name == "axes":
axes_const_node = helper.list_to_constant(
node.name + "_axes_value", [len(att.ints)], att.ints
)
if att.name == "ends":
ends_const_node = helper.list_to_constant(
node.name + "_ends_value", [len(att.ints)], att.ints
)
if att.name == "starts":
starts_const_node = helper.list_to_constant(
node.name + "_starts_value", [len(att.ints)], att.ints
)
if att.name == "steps":
steps_const_node = helper.list_to_constant(
node.name + "_steps_value", [len(att.ints)], att.ints
)
# pop out from back
attr_len = len(node.attribute)
for i in range(attr_len):
node.attribute.remove(node.attribute[attr_len - 1 - i])
# according the spec, we need to add node in specific order
if starts_const_node is not None:
g.node.extend([starts_const_node])
node.input.extend([starts_const_node.name])
if ends_const_node is not None:
g.node.extend([ends_const_node])
node.input.extend([ends_const_node.name])
if axes_const_node is not None:
g.node.extend([axes_const_node])
node.input.extend([axes_const_node.name])
if steps_const_node is not None:
g.node.extend([steps_const_node])
node.input.extend([steps_const_node.name])
def replace_min_max_attribute_to_const_node_in_clip_node(g):
for node in g.node:
if node.op_type != "Clip":
continue
max_const_node = None
min_const_node = None
for att in node.attribute:
if att.name == "max":
max_const_node = helper.list_to_constant(
node.name + "_max_value", [], [att.f]
)
if att.name == "min":
min_const_node = helper.list_to_constant(
node.name + "_min_value", [], [att.f]
)
# pop out from back
node.attribute.remove(node.attribute[1])
node.attribute.remove(node.attribute[0])
# according the spec, we need to add node in specific order
g.node.extend([min_const_node])
g.node.extend([max_const_node])
node.input.extend([min_const_node.name])
node.input.extend([max_const_node.name])
def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto:
"""Update ir_version from 4 to 6 and update opset from 9 to 11.
Args:
model (onnx.ModelProto): input onnx model.
Returns:
onnx.ModelProto: updated onnx model.
"""
graph = model.graph
if model.opset_import[0].version == 11:
print("(Stop) the input model is already opset 11, no need to upgrade")
exit(1)
# deal with empty node name issue
other.add_name_to_node(graph)
# simplify the node param type from initializer to constant
replacing.replace_initializer_with_Constant(graph)
# Modify the nodes.
replace_min_max_attribute_to_const_node_in_clip_node(graph)
replace_all_attribute_to_const_node_in_slice_node(graph)
replace_all_attribute_to_const_node_in_pad_node(graph)
upsampling_to_resize(graph)
other.topological_sort(graph)
# Change model properties.
model.ir_version = 6
model.opset_import[0].version = 11
model = other.polish_model(model)
return model
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage:{} file_in file_out".format(sys.argv[0]))
exit(1)
model = onnx.load(sys.argv[1])
model = onnx1_4to1_6(model)
onnx.save(model, sys.argv[2])