185 lines
6.3 KiB
Python
185 lines
6.3 KiB
Python
# ref http://192.168.200.1:8088/jiyuan/converter_scripts.git
|
|
|
|
import sys
|
|
import onnx
|
|
import onnx.utils
|
|
import numpy as np
|
|
from onnx import numpy_helper
|
|
from tools import other, helper, replacing
|
|
|
|
"""
|
|
Change onnx model from version 1.4 to version 1.6.
|
|
"""
|
|
|
|
def replace_all_attribute_to_const_node_in_pad_node(g):
|
|
node_to_remove = []
|
|
node_to_extend = []
|
|
for node in g.node:
|
|
if node.op_type != 'Pad':
|
|
continue
|
|
|
|
pad_loc_node = None # must have
|
|
pad_mode = 'constant'
|
|
pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [0.0]) # need scalar
|
|
for att in node.attribute:
|
|
if att.name == 'mode':
|
|
pad_mode = helper.get_var_attribute_by_name(node, 'mode', 'string')
|
|
if att.name == 'pads':
|
|
pad_loc_node = helper.list_to_constant(node.name+'_pad_loc', [len(att.ints)], att.ints)
|
|
if att.name == 'value':
|
|
pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [att.f])
|
|
|
|
new_node = onnx.helper.make_node(
|
|
"Pad",
|
|
[node.input[0], pad_loc_node.name, pad_value_node.name],
|
|
[node.output[0]],
|
|
name=node.output[0],
|
|
mode=pad_mode,
|
|
)
|
|
node_to_remove.append(node)
|
|
node_to_extend.append(new_node)
|
|
node_to_extend.append(pad_loc_node)
|
|
node_to_extend.append(pad_value_node)
|
|
|
|
for node in node_to_remove:
|
|
g.node.remove(node)
|
|
for node in node_to_extend:
|
|
g.node.extend([node])
|
|
|
|
|
|
def upsampling_to_resize(g):
|
|
for node in g.node:
|
|
if node.op_type != 'Upsample':
|
|
continue
|
|
upsampling_mode = helper.get_var_attribute_by_name(node, 'mode', 'string')
|
|
|
|
scale_value_node = helper.find_node_by_output_name(g, node.input[1])
|
|
if scale_value_node.op_type != "Constant":
|
|
raise TypeError('seems there is a dynamic "scales" param in Upsampling node: ' + node.name + ' , you might need to do constant folding first')
|
|
|
|
roi_node = helper.list_to_constant(node.name+'_roi_value', [0], [])
|
|
|
|
new_node = onnx.helper.make_node(
|
|
"Resize",
|
|
[node.input[0], roi_node.name, scale_value_node.name],
|
|
[node.output[0]],
|
|
name=node.output[0],
|
|
mode=upsampling_mode,
|
|
coordinate_transformation_mode = 'asymmetric'
|
|
)
|
|
|
|
g.node.remove(node)
|
|
g.node.extend([new_node])
|
|
g.node.extend([roi_node])
|
|
|
|
|
|
def replace_all_attribute_to_const_node_in_slice_node(g):
|
|
for node in g.node:
|
|
if node.op_type != 'Slice':
|
|
continue
|
|
|
|
axes_const_node = None
|
|
ends_const_node = None
|
|
starts_const_node = None
|
|
steps_const_node = None
|
|
for att in node.attribute:
|
|
if att.name == 'axes':
|
|
axes_const_node = helper.list_to_constant(node.name+'_axes_value', [len(att.ints)], att.ints)
|
|
|
|
if att.name == 'ends':
|
|
ends_const_node = helper.list_to_constant(node.name+'_ends_value', [len(att.ints)], att.ints)
|
|
|
|
if att.name == 'starts':
|
|
starts_const_node = helper.list_to_constant(node.name+'_starts_value', [len(att.ints)], att.ints)
|
|
|
|
if att.name == 'steps':
|
|
steps_const_node = helper.list_to_constant(node.name+'_steps_value',[ len(att.ints)], att.ints)
|
|
|
|
## pop out from back
|
|
attr_len = len(node.attribute)
|
|
for i in range(attr_len):
|
|
node.attribute.remove(node.attribute[ attr_len -1 - i ])
|
|
|
|
## according the spec, we need to add node in specific order
|
|
if starts_const_node != None:
|
|
g.node.extend([starts_const_node])
|
|
node.input.extend([starts_const_node.name])
|
|
if ends_const_node != None:
|
|
g.node.extend([ends_const_node])
|
|
node.input.extend([ends_const_node.name])
|
|
if axes_const_node != None:
|
|
g.node.extend([axes_const_node])
|
|
node.input.extend([axes_const_node.name])
|
|
if steps_const_node != None:
|
|
g.node.extend([steps_const_node])
|
|
node.input.extend([steps_const_node.name])
|
|
|
|
|
|
def replace_min_max_attribute_to_const_node_in_clip_node(g):
|
|
for node in g.node:
|
|
if node.op_type != 'Clip':
|
|
continue
|
|
|
|
max_const_node = None
|
|
min_const_node = None
|
|
for att in node.attribute:
|
|
if att.name == 'max':
|
|
max_const_node = helper.list_to_constant(node.name+'_max_value', [], [att.f])
|
|
|
|
if att.name == 'min':
|
|
min_const_node = helper.list_to_constant(node.name+'_min_value', [], [att.f])
|
|
|
|
## pop out from back
|
|
node.attribute.remove(node.attribute[1])
|
|
node.attribute.remove(node.attribute[0])
|
|
|
|
## according the spec, we need to add node in specific order
|
|
g.node.extend([min_const_node])
|
|
g.node.extend([max_const_node])
|
|
node.input.extend([min_const_node.name])
|
|
node.input.extend([max_const_node.name])
|
|
|
|
def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto:
|
|
"""Update ir_version from 4 to 6 and update opset from 9 to 11.
|
|
|
|
Args:
|
|
model (onnx.ModelProto): input onnx model.
|
|
|
|
Returns:
|
|
onnx.ModelProto: updated onnx model.
|
|
"""
|
|
graph = model.graph
|
|
|
|
if model.opset_import[0].version == 11:
|
|
print("(Stop) the input model is already opset 11, no need to upgrade")
|
|
exit(1)
|
|
|
|
# deal with empty node name issue
|
|
other.add_name_to_node(graph)
|
|
# simplify the node param type from initializer to constant
|
|
replacing.replace_initializer_with_Constant(graph)
|
|
|
|
# Modify the nodes.
|
|
replace_min_max_attribute_to_const_node_in_clip_node(graph)
|
|
replace_all_attribute_to_const_node_in_slice_node(graph)
|
|
replace_all_attribute_to_const_node_in_pad_node(graph)
|
|
upsampling_to_resize(graph)
|
|
other.topological_sort(graph)
|
|
|
|
# Change model properties.
|
|
model.ir_version = 6
|
|
model.opset_import[0].version = 11
|
|
|
|
model = other.polish_model(model)
|
|
return model
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) != 3:
|
|
print("Usage:{} file_in file_out".format(sys.argv[0]))
|
|
exit(1)
|
|
|
|
model = onnx.load(sys.argv[1])
|
|
model = onnx1_4to1_6(model)
|
|
|
|
onnx.save(model, sys.argv[2])
|