# ref http://192.168.200.1:8088/jiyuan/converter_scripts.git import sys import onnx import onnx.utils import numpy as np from onnx import numpy_helper from tools import other, helper, replacing """ Change onnx model from version 1.4 to version 1.6. """ def replace_all_attribute_to_const_node_in_pad_node(g): node_to_remove = [] node_to_extend = [] for node in g.node: if node.op_type != 'Pad': continue pad_loc_node = None # must have pad_mode = 'constant' pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [0.0]) # need scalar for att in node.attribute: if att.name == 'mode': pad_mode = helper.get_var_attribute_by_name(node, 'mode', 'string') if att.name == 'pads': pad_loc_node = helper.list_to_constant(node.name+'_pad_loc', [len(att.ints)], att.ints) if att.name == 'value': pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [att.f]) new_node = onnx.helper.make_node( "Pad", [node.input[0], pad_loc_node.name, pad_value_node.name], [node.output[0]], name=node.output[0], mode=pad_mode, ) node_to_remove.append(node) node_to_extend.append(new_node) node_to_extend.append(pad_loc_node) node_to_extend.append(pad_value_node) for node in node_to_remove: g.node.remove(node) for node in node_to_extend: g.node.extend([node]) def upsampling_to_resize(g): for node in g.node: if node.op_type != 'Upsample': continue upsampling_mode = helper.get_var_attribute_by_name(node, 'mode', 'string') scale_value_node = helper.find_node_by_output_name(g, node.input[1]) if scale_value_node.op_type != "Constant": raise TypeError('seems there is a dynamic "scales" param in Upsampling node: ' + node.name + ' , you might need to do constant folding first') roi_node = helper.list_to_constant(node.name+'_roi_value', [0], []) new_node = onnx.helper.make_node( "Resize", [node.input[0], roi_node.name, scale_value_node.name], [node.output[0]], name=node.output[0], mode=upsampling_mode, coordinate_transformation_mode = 'asymmetric' ) g.node.remove(node) g.node.extend([new_node]) g.node.extend([roi_node]) def replace_all_attribute_to_const_node_in_slice_node(g): for node in g.node: if node.op_type != 'Slice': continue axes_const_node = None ends_const_node = None starts_const_node = None steps_const_node = None for att in node.attribute: if att.name == 'axes': axes_const_node = helper.list_to_constant(node.name+'_axes_value', [len(att.ints)], att.ints) if att.name == 'ends': ends_const_node = helper.list_to_constant(node.name+'_ends_value', [len(att.ints)], att.ints) if att.name == 'starts': starts_const_node = helper.list_to_constant(node.name+'_starts_value', [len(att.ints)], att.ints) if att.name == 'steps': steps_const_node = helper.list_to_constant(node.name+'_steps_value',[ len(att.ints)], att.ints) ## pop out from back attr_len = len(node.attribute) for i in range(attr_len): node.attribute.remove(node.attribute[ attr_len -1 - i ]) ## according the spec, we need to add node in specific order if starts_const_node != None: g.node.extend([starts_const_node]) node.input.extend([starts_const_node.name]) if ends_const_node != None: g.node.extend([ends_const_node]) node.input.extend([ends_const_node.name]) if axes_const_node != None: g.node.extend([axes_const_node]) node.input.extend([axes_const_node.name]) if steps_const_node != None: g.node.extend([steps_const_node]) node.input.extend([steps_const_node.name]) def replace_min_max_attribute_to_const_node_in_clip_node(g): for node in g.node: if node.op_type != 'Clip': continue max_const_node = None min_const_node = None for att in node.attribute: if att.name == 'max': max_const_node = helper.list_to_constant(node.name+'_max_value', [], [att.f]) if att.name == 'min': min_const_node = helper.list_to_constant(node.name+'_min_value', [], [att.f]) ## pop out from back node.attribute.remove(node.attribute[1]) node.attribute.remove(node.attribute[0]) ## according the spec, we need to add node in specific order g.node.extend([min_const_node]) g.node.extend([max_const_node]) node.input.extend([min_const_node.name]) node.input.extend([max_const_node.name]) def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto: """Update ir_version from 4 to 6 and update opset from 9 to 11. Args: model (onnx.ModelProto): input onnx model. Returns: onnx.ModelProto: updated onnx model. """ graph = model.graph if model.opset_import[0].version == 11: print("(Stop) the input model is already opset 11, no need to upgrade") exit(1) # deal with empty node name issue other.add_name_to_node(graph) # simplify the node param type from initializer to constant replacing.replace_initializer_with_Constant(graph) # Modify the nodes. replace_min_max_attribute_to_const_node_in_clip_node(graph) replace_all_attribute_to_const_node_in_slice_node(graph) replace_all_attribute_to_const_node_in_pad_node(graph) upsampling_to_resize(graph) other.topological_sort(graph) # Change model properties. model.ir_version = 6 model.opset_import[0].version = 11 model = other.polish_model(model) return model if __name__ == "__main__": if len(sys.argv) != 3: print("Usage:{} file_in file_out".format(sys.argv[0])) exit(1) model = onnx.load(sys.argv[1]) model = onnx1_4to1_6(model) onnx.save(model, sys.argv[2])