diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 801f99f..3778c4d 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -316,7 +316,9 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): session_options.register_custom_ops_library(ort_custom_op_path) providers = ['CPUExecutionProvider'] provider_options = [{}] - is_cuda_available = ort.get_device() == 'GPU' and torch.cuda.is_available() + is_cuda_available = ( + ort.get_device() == 'GPU' and torch.cuda.is_available() + ) if is_cuda_available: providers.insert(0, 'CUDAExecutionProvider') device_id = device_id or 0 @@ -334,7 +336,7 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor): self.output_name_list = [sess_outputs[0].name] self.cfg = cfg # TODO: necessary? self.test_cfg = cfg.model.test_cfg - self.test_mode = self.test_cfg.mode # NOTE: should be 'whole' or 'slide' + self.test_mode = self.test_cfg.mode # NOTE: either 'whole' or 'slide' self.is_cuda_available = is_cuda_available self.count_mat = None try: diff --git a/setup.py b/setup.py index 4028c18..dc758e2 100755 --- a/setup.py +++ b/setup.py @@ -171,7 +171,8 @@ if __name__ == '__main__': setup( name='mmsegmentation', version=get_version(), - description='Open MMLab Semantic Segmentation Toolbox and Benchmark (Kneron Edition)', + description='Open MMLab Semantic Segmentation Toolbox ' + 'and Benchmark (Kneron Edition)', long_description=readme(), long_description_content_type='text/markdown', author='MMSegmentation Contributors and Kneron', diff --git a/tools/deploy_test_kneron.py b/tools/deploy_test_kneron.py index 84043c6..5be25a5 100644 --- a/tools/deploy_test_kneron.py +++ b/tools/deploy_test_kneron.py @@ -163,9 +163,9 @@ def main(): efficient_test = eval_kwargs.get('efficient_test', False) if efficient_test: warnings.warn( - '``efficient_test=True`` does not have effect in tools/test_kneron.py, ' - 'the evaluation and format results are CPU memory efficient by ' - 'default') + '"efficient_test=True" does not have effect in ' + 'tools/test_kneron.py, the evaluation and format ' + 'results are CPU memory efficient by default') eval_on_format_results = ( args.eval is not None and 'cityscapes' in args.eval) diff --git a/tools/optimizer_scripts/consecutive_conv_opt.py b/tools/optimizer_scripts/consecutive_conv_opt.py index c7d4068..0ed4a28 100644 --- a/tools/optimizer_scripts/consecutive_conv_opt.py +++ b/tools/optimizer_scripts/consecutive_conv_opt.py @@ -5,55 +5,81 @@ import sys from tools.other import topological_sort from tools import helper + def fuse_bias_in_consecutive_1x1_conv(g): for second in g.node: # Find two conv - if second.op_type != 'Conv': + if second.op_type != "Conv": continue first = helper.find_node_by_output_name(g, second.input[0]) - if first is None or first.op_type != 'Conv': + if first is None or first.op_type != "Conv": continue # Check if the first one has only one folloing node - if len(helper.find_following_nodes_by_input_value_name(g, first.output[0])) != 1: + if ( + len( + helper.find_following_nodes_by_input_value_name( + g, first.output[0] + ) + ) + != 1 + ): continue # If first node has no bias, continue if len(first.input) == 2: continue # Check their kernel size - first_kernel_shape = helper.get_list_attribute_by_name(first, 'kernel_shape', 'int') - second_kernel_shape = helper.get_list_attribute_by_name(second, 'kernel_shape', 'int') - prod = first_kernel_shape[0] * first_kernel_shape[1] * second_kernel_shape[0] * second_kernel_shape[1] + first_kernel_shape = helper.get_list_attribute_by_name( + first, "kernel_shape", "int" + ) + second_kernel_shape = helper.get_list_attribute_by_name( + second, "kernel_shape", "int" + ) + prod = ( + first_kernel_shape[0] + * first_kernel_shape[1] + * second_kernel_shape[0] + * second_kernel_shape[1] + ) if prod != 1: continue - print('Found: ', first.name, ' ', second.name) + print("Found: ", first.name, " ", second.name) # Get bias of the nodes first_bias_node = helper.find_node_by_output_name(g, first.input[2]) - second_weight_node = helper.find_node_by_output_name(g, second.input[1]) + second_weight_node = helper.find_node_by_output_name( + g, second.input[1] + ) second_bias_node = helper.find_node_by_output_name(g, second.input[2]) first_bias = helper.constant_to_numpy(first_bias_node) second_weight = helper.constant_to_numpy(second_weight_node) second_bias = helper.constant_to_numpy(second_bias_node) # Calculate the weight for second node first_bias = np.reshape(first_bias, (1, first_bias.size)) - second_weight = np.reshape(second_weight, (second_weight.shape[0], second_weight.shape[1])) + second_weight = np.reshape( + second_weight, (second_weight.shape[0], second_weight.shape[1]) + ) second_weight = np.transpose(second_weight) new_second_bias = second_bias + np.matmul(first_bias, second_weight) new_second_bias = np.reshape(new_second_bias, (new_second_bias.size,)) # Generate new weight - new_first_bias = np.reshape(first_bias, (first_bias.size, )) + new_first_bias = np.reshape(first_bias, (first_bias.size,)) for i in range(new_first_bias.shape[0]): new_first_bias[i] = 0.0 - new_first_bias_node = helper.numpy_to_constant(first_bias_node.output[0], new_first_bias) - new_second_bias_node = helper.numpy_to_constant(second_bias_node.output[0], new_second_bias) + new_first_bias_node = helper.numpy_to_constant( + first_bias_node.output[0], new_first_bias + ) + new_second_bias_node = helper.numpy_to_constant( + second_bias_node.output[0], new_second_bias + ) # Delete old weight and add new weights g.node.remove(first_bias_node) g.node.remove(second_bias_node) g.node.extend([new_first_bias_node, new_second_bias_node]) topological_sort(g) + if __name__ == "__main__": if len(sys.argv) != 3: exit(1) m = onnx.load(sys.argv[1]) fuse_bias_in_consecutive_1x1_conv(m.graph) - onnx.save(m, sys.argv[2]) \ No newline at end of file + onnx.save(m, sys.argv[2]) diff --git a/tools/optimizer_scripts/editor.py b/tools/optimizer_scripts/editor.py index 8ccc6ca..b04183c 100644 --- a/tools/optimizer_scripts/editor.py +++ b/tools/optimizer_scripts/editor.py @@ -1,5 +1,6 @@ import onnx import onnx.utils + try: from onnx import optimizer except ImportError: @@ -9,23 +10,107 @@ import argparse import tools.modhelper as helper import tools.other as other import tools.replacing as replacing + # Main process # Argument parser -parser = argparse.ArgumentParser(description="Edit an ONNX model.\nThe processing sequense is 'delete nodes/values' -> 'add nodes' -> 'change shapes'.\nCutting cannot be done with other operations together") -parser.add_argument('in_file', type=str, help='input ONNX FILE') -parser.add_argument('out_file', type=str, help="ouput ONNX FILE") -parser.add_argument('-c', '--cut', dest='cut_node', type=str, nargs='+', help="remove nodes from the given nodes(inclusive)") -parser.add_argument('--cut-type', dest='cut_type', type=str, nargs='+', help="remove nodes by type from the given nodes(inclusive)") -parser.add_argument('-d', '--delete', dest='delete_node', type=str, nargs='+', help="delete nodes by names and only those nodes") -parser.add_argument('--delete-input', dest='delete_input', type=str, nargs='+', help="delete inputs by names") -parser.add_argument('--delete-output', dest='delete_output', type=str, nargs='+', help="delete outputs by names") -parser.add_argument('-i', '--input', dest='input_change', type=str, nargs='+', help="change input shape (e.g. -i 'input_0 1 3 224 224')") -parser.add_argument('-o', '--output', dest='output_change', type=str, nargs='+', help="change output shape (e.g. -o 'input_0 1 3 224 224')") -parser.add_argument('--add-conv', dest='add_conv', type=str, nargs='+', help='add nop conv using specific input') -parser.add_argument('--add-bn', dest='add_bn', type=str, nargs='+', help='add nop bn using specific input') -parser.add_argument('--rename-output', dest='rename_output', type=str, nargs='+', help='Rename the specific output(e.g. --rename-output old_name new_name)') -parser.add_argument('--pixel-bias-value', dest='pixel_bias_value', type=str, nargs='+', help='(per channel) set pixel value bias bn layer at model front for normalization( e.g. --pixel_bias_value "[104.0, 117.0, 123.0]" )') -parser.add_argument('--pixel-scale-value', dest='pixel_scale_value', type=str, nargs='+', help='(per channel) set pixel value scale bn layer at model front for normalization( e.g. --pixel_scale_value "[0.0078125, 0.0078125, 0.0078125]" )') +parser = argparse.ArgumentParser( + description="Edit an ONNX model.\nThe processing sequense is 'delete " + "nodes/values' -> 'add nodes' -> 'change shapes'.\nCutting " + "cannot be done with other operations together" +) +parser.add_argument("in_file", type=str, help="input ONNX FILE") +parser.add_argument("out_file", type=str, help="ouput ONNX FILE") +parser.add_argument( + "-c", + "--cut", + dest="cut_node", + type=str, + nargs="+", + help="remove nodes from the given nodes(inclusive)", +) +parser.add_argument( + "--cut-type", + dest="cut_type", + type=str, + nargs="+", + help="remove nodes by type from the given nodes(inclusive)", +) +parser.add_argument( + "-d", + "--delete", + dest="delete_node", + type=str, + nargs="+", + help="delete nodes by names and only those nodes", +) +parser.add_argument( + "--delete-input", + dest="delete_input", + type=str, + nargs="+", + help="delete inputs by names", +) +parser.add_argument( + "--delete-output", + dest="delete_output", + type=str, + nargs="+", + help="delete outputs by names", +) +parser.add_argument( + "-i", + "--input", + dest="input_change", + type=str, + nargs="+", + help="change input shape (e.g. -i 'input_0 1 3 224 224')", +) +parser.add_argument( + "-o", + "--output", + dest="output_change", + type=str, + nargs="+", + help="change output shape (e.g. -o 'input_0 1 3 224 224')", +) +parser.add_argument( + "--add-conv", + dest="add_conv", + type=str, + nargs="+", + help="add nop conv using specific input", +) +parser.add_argument( + "--add-bn", + dest="add_bn", + type=str, + nargs="+", + help="add nop bn using specific input", +) +parser.add_argument( + "--rename-output", + dest="rename_output", + type=str, + nargs="+", + help="Rename the specific output(e.g. --rename-output old_name new_name)", +) +parser.add_argument( + "--pixel-bias-value", + dest="pixel_bias_value", + type=str, + nargs="+", + help='(per channel) set pixel value bias bn layer at model front for ' + 'normalization( e.g. --pixel_bias_value "[104.0, 117.0, 123.0]" )', +) +parser.add_argument( + "--pixel-scale-value", + dest="pixel_scale_value", + type=str, + nargs="+", + help='(per channel) set pixel value scale bn layer at model front for ' + 'normalization( e.g. --pixel_scale_value ' + '"[0.0078125, 0.0078125, 0.0078125]" )', +) args = parser.parse_args() @@ -60,23 +145,48 @@ if args.add_bn is not None: if args.pixel_bias_value is not None or args.pixel_scale_value is not None: if len(g.input) > 1: - raise ValueError(" '--pixel-bias-value' and '--pixel-scale-value' only support one input node model currently") - + raise ValueError( + " '--pixel-bias-value' and '--pixel-scale-value' " + "only support one input node model currently" + ) + i_n = g.input[0] pixel_bias_value = [0] * i_n.type.tensor_type.shape.dim[1].dim_value pixel_scale_value = [1] * i_n.type.tensor_type.shape.dim[1].dim_value if args.pixel_bias_value is not None and len(args.pixel_bias_value) == 1: - pixel_bias_value = [float(n) for n in args.pixel_bias_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')] + pixel_bias_value = [ + float(n) + for n in args.pixel_bias_value[0] + .replace("[", "") + .replace("]", "") + .split(",") + ] if args.pixel_scale_value is not None and len(args.pixel_scale_value) == 1: - pixel_scale_value = [float(n) for n in args.pixel_scale_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')] + pixel_scale_value = [ + float(n) + for n in args.pixel_scale_value[0] + .replace("[", "") + .replace("]", "") + .split(",") + ] - - if i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_bias_value) or i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_scale_value): - raise ValueError("--pixel-bias-value (" + str(pixel_bias_value) + ") and --pixel-scale-value (" + str(pixel_scale_value) + ") should be same as input dimension:" + str(i_n.type.tensor_type.shape.dim[1].dim_value) ) - other.add_bias_scale_bn_after(g, i_n.name, pixel_bias_value, pixel_scale_value) + if i_n.type.tensor_type.shape.dim[1].dim_value != len( + pixel_bias_value + ) or i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_scale_value): + raise ValueError( + "--pixel-bias-value (" + + str(pixel_bias_value) + + ") and --pixel-scale-value (" + + str(pixel_scale_value) + + ") should be same as input dimension:" + + str(i_n.type.tensor_type.shape.dim[1].dim_value) + ) + other.add_bias_scale_bn_after( + g, i_n.name, pixel_bias_value, pixel_scale_value + ) # Change input and output shapes as requested if args.input_change is not None: @@ -100,14 +210,21 @@ if args.rename_output: print("Rename output should be paires of names.") else: for i in range(0, len(args.rename_output), 2): - other.rename_output_name(g, args.rename_output[i], args.rename_output[i + 1]) + other.rename_output_name( + g, args.rename_output[i], args.rename_output[i + 1] + ) # Remove useless nodes -if args.delete_node or args.delete_input or args.input_change or args.output_change: +if ( + args.delete_node + or args.delete_input + or args.input_change + or args.output_change +): # If shape changed during the modification, redo shape inference. - while(len(g.value_info) > 0): + while len(g.value_info) > 0: g.value_info.pop() -passes = ['extract_constant_to_initializer'] +passes = ["extract_constant_to_initializer"] m = optimizer.optimize(m, passes) g = m.graph replacing.replace_initializer_with_Constant(g) @@ -115,4 +232,4 @@ other.topological_sort(g) # Polish and output m = other.polish_model(m) other.add_output_to_value_info(m.graph) -onnx.save(m, args.out_file) \ No newline at end of file +onnx.save(m, args.out_file) diff --git a/tools/optimizer_scripts/norm_on_scaled_onnx.py b/tools/optimizer_scripts/norm_on_scaled_onnx.py index f99a866..7d462c2 100644 --- a/tools/optimizer_scripts/norm_on_scaled_onnx.py +++ b/tools/optimizer_scripts/norm_on_scaled_onnx.py @@ -11,42 +11,44 @@ if len(sys.argv) != 3: # Modify onnx m = onnx.load(sys.argv[1]) special.add_0_5_to_normalized_input(m) -onnx.save(m, sys.argv[1][:-4] + 'norm.onnx') +onnx.save(m, sys.argv[1][:-4] + "norm.onnx") # Change input node -origin_file = open(sys.argv[2], 'r') +origin_file = open(sys.argv[2], "r") origin_json = json.load(origin_file) origin_json["input_node"]["output_datapath_radix"] = [8] new_json_str = json.dumps(origin_json) # Modify json -file = open(sys.argv[1][:-4] + 'norm.onnx' + '.json', 'w') +file = open(sys.argv[1][:-4] + "norm.onnx" + ".json", "w") s = """{{ - \"{0}\" : - {{ - \"bias_bitwidth\" : 16, - \"{0}_bias\" : [15], - \"{0}_weight\" : [3,3,3], - \"conv_coarse_shift\" : [-4,-4,-4], - \"conv_fine_shift\" : [0,0,0], - \"conv_total_shift\" : [-4,-4,-4], - \"cpu_mode\" : false, - \"delta_input_bitwidth\" : [0], - \"delta_output_bitwidth\" : 8, - \"flag_radix_bias_eq_output\" : true, - \"input_scale\" : [[1.0,1.0,1.0]], - \"output_scale\" : [1.0, 1.0, 1.0], - \"psum_bitwidth\" : 16, - \"weight_bitwidth\" : 8, - \"input_datapath_bitwidth\" : [8], - \"input_datapath_radix\" : [8], - \"working_input_bitwidth\" : 8, - \"working_input_radix\" : [8], - \"working_output_bitwidth\" : 16, - \"working_output_radix\" : 15, - \"output_datapath_bitwidth\" : 8, - \"output_datapath_radix\" : 7 - }},\n""".format('input_norm') + \"{0}\" : + {{ + \"bias_bitwidth\" : 16, + \"{0}_bias\" : [15], + \"{0}_weight\" : [3,3,3], + \"conv_coarse_shift\" : [-4,-4,-4], + \"conv_fine_shift\" : [0,0,0], + \"conv_total_shift\" : [-4,-4,-4], + \"cpu_mode\" : false, + \"delta_input_bitwidth\" : [0], + \"delta_output_bitwidth\" : 8, + \"flag_radix_bias_eq_output\" : true, + \"input_scale\" : [[1.0,1.0,1.0]], + \"output_scale\" : [1.0, 1.0, 1.0], + \"psum_bitwidth\" : 16, + \"weight_bitwidth\" : 8, + \"input_datapath_bitwidth\" : [8], + \"input_datapath_radix\" : [8], + \"working_input_bitwidth\" : 8, + \"working_input_radix\" : [8], + \"working_output_bitwidth\" : 16, + \"working_output_radix\" : 15, + \"output_datapath_bitwidth\" : 8, + \"output_datapath_radix\" : 7 + }},\n""".format( + "input_norm" +) file.write(s + new_json_str[1:]) file.close() origin_file.close() diff --git a/tools/optimizer_scripts/onnx1_3to1_4.py b/tools/optimizer_scripts/onnx1_3to1_4.py index 64b72b5..6c6613f 100644 --- a/tools/optimizer_scripts/onnx1_3to1_4.py +++ b/tools/optimizer_scripts/onnx1_3to1_4.py @@ -2,33 +2,33 @@ import sys import onnx -import numpy as np -from onnx import numpy_helper from tools import other, helper """ -Change onnx model from version 1.3 to version 1.4. -Modify the BN node by removing the spatial attribute -Modify the Upsample node by removing the 'scales' attribute, and adding a constant node instead. -Model's ir_version and opset_import are updated. +Change onnx model from version 1.3 to version 1.4. +- Modify the BN node by removing the spatial attribute +- Modify the Upsample node by removing the 'scales' attribute, + and adding a constant node instead. +- Model's ir_version and opset_import are updated. """ + def remove_BN_spatial(g): for node in g.node: - if node.op_type != 'BatchNormalization': + if node.op_type != "BatchNormalization": continue for att in node.attribute: - if att.name == 'spatial': + if att.name == "spatial": node.attribute.remove(att) def upsample_attribute_to_const(g): for node in g.node: - if node.op_type != 'Upsample': + if node.op_type != "Upsample": continue scales_exist = False for att in node.attribute: - if att.name == 'scales': + if att.name == "scales": scales_exist = True break if not scales_exist: @@ -36,18 +36,23 @@ def upsample_attribute_to_const(g): shape = [len(att.floats)] node.attribute.remove(att) - new_node = helper.list_to_constant(node.name+'_input', shape, att.floats) + new_node = helper.list_to_constant( + node.name + "_input", shape, att.floats + ) g.node.extend([new_node]) - value_info = onnx.helper.make_tensor_value_info(node.name+'_input', onnx.TensorProto.FLOAT, shape) - node.input.extend([node.name+'_input']) + value_info = onnx.helper.make_tensor_value_info( + node.name + "_input", onnx.TensorProto.FLOAT, shape + ) + node.input.extend([node.name + "_input"]) g.value_info.extend([value_info]) + def relu6_to_clip(g): for node in g.node: - if node.op_type != 'Relu': + if node.op_type != "Relu": continue - max_val = helper.get_var_attribute_by_name(node, 'max', 'float') + max_val = helper.get_var_attribute_by_name(node, "max", "float") if max_val is None: continue new_node = onnx.helper.make_node( @@ -56,11 +61,12 @@ def relu6_to_clip(g): node.output, name=node.name, max=max_val, - min=0.0 + min=0.0, ) g.node.remove(node) g.node.extend([new_node]) + def PRelu_weight_reshape(g): # For PRelu with single dimension weight. Expand it to 1, x, 1, 1 for node in g.node: @@ -91,16 +97,18 @@ def PRelu_weight_reshape(g): new_input = onnx.helper.make_tensor_value_info( node.input[1], input_value.type.tensor_type.elem_type, - (1, slope.dims[1], 1, 1)) + (1, slope.dims[1], 1, 1), + ) g.input.remove(input_value) g.input.append(new_input) value_info = helper.find_value_by_name(g, node.input[1]) if value_info is not None: g.value_info.remove(value_info) + def do_convert(m): graph = m.graph - + # Modify the nodes. remove_BN_spatial(graph) upsample_attribute_to_const(graph) @@ -113,6 +121,7 @@ def do_convert(m): m.opset_import[0].version = 9 return m + if __name__ == "__main__": if len(sys.argv) != 3: print("Usage:{} file_in file_out".format(sys.argv[0])) diff --git a/tools/optimizer_scripts/onnx1_4to1_6.py b/tools/optimizer_scripts/onnx1_4to1_6.py index 825b3cd..caa5540 100644 --- a/tools/optimizer_scripts/onnx1_4to1_6.py +++ b/tools/optimizer_scripts/onnx1_4to1_6.py @@ -3,31 +3,38 @@ import sys import onnx import onnx.utils -import numpy as np -from onnx import numpy_helper from tools import other, helper, replacing """ Change onnx model from version 1.4 to version 1.6. """ + def replace_all_attribute_to_const_node_in_pad_node(g): node_to_remove = [] node_to_extend = [] for node in g.node: - if node.op_type != 'Pad': + if node.op_type != "Pad": continue - pad_loc_node = None # must have - pad_mode = 'constant' - pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [0.0]) # need scalar + pad_loc_node = None # must have + pad_mode = "constant" + pad_value_node = helper.list_to_constant( + node.name + "_pad_value", [], [0.0] + ) # need scalar for att in node.attribute: - if att.name == 'mode': - pad_mode = helper.get_var_attribute_by_name(node, 'mode', 'string') - if att.name == 'pads': - pad_loc_node = helper.list_to_constant(node.name+'_pad_loc', [len(att.ints)], att.ints) - if att.name == 'value': - pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [att.f]) + if att.name == "mode": + pad_mode = helper.get_var_attribute_by_name( + node, "mode", "string" + ) + if att.name == "pads": + pad_loc_node = helper.list_to_constant( + node.name + "_pad_loc", [len(att.ints)], att.ints + ) + if att.name == "value": + pad_value_node = helper.list_to_constant( + node.name + "_pad_value", [], [att.f] + ) new_node = onnx.helper.make_node( "Pad", @@ -40,24 +47,30 @@ def replace_all_attribute_to_const_node_in_pad_node(g): node_to_extend.append(new_node) node_to_extend.append(pad_loc_node) node_to_extend.append(pad_value_node) - - for node in node_to_remove: + + for node in node_to_remove: g.node.remove(node) - for node in node_to_extend: + for node in node_to_extend: g.node.extend([node]) def upsampling_to_resize(g): for node in g.node: - if node.op_type != 'Upsample': + if node.op_type != "Upsample": continue - upsampling_mode = helper.get_var_attribute_by_name(node, 'mode', 'string') + upsampling_mode = helper.get_var_attribute_by_name( + node, "mode", "string" + ) scale_value_node = helper.find_node_by_output_name(g, node.input[1]) if scale_value_node.op_type != "Constant": - raise TypeError('seems there is a dynamic "scales" param in Upsampling node: ' + node.name + ' , you might need to do constant folding first') + raise TypeError( + 'seems there is a dynamic "scales" param in Upsampling node: ' + + node.name + + " , you might need to do constant folding first" + ) - roi_node = helper.list_to_constant(node.name+'_roi_value', [0], []) + roi_node = helper.list_to_constant(node.name + "_roi_value", [0], []) new_node = onnx.helper.make_node( "Resize", @@ -65,7 +78,7 @@ def upsampling_to_resize(g): [node.output[0]], name=node.output[0], mode=upsampling_mode, - coordinate_transformation_mode = 'asymmetric' + coordinate_transformation_mode="asymmetric", ) g.node.remove(node) @@ -75,7 +88,7 @@ def upsampling_to_resize(g): def replace_all_attribute_to_const_node_in_slice_node(g): for node in g.node: - if node.op_type != 'Slice': + if node.op_type != "Slice": continue axes_const_node = None @@ -83,62 +96,75 @@ def replace_all_attribute_to_const_node_in_slice_node(g): starts_const_node = None steps_const_node = None for att in node.attribute: - if att.name == 'axes': - axes_const_node = helper.list_to_constant(node.name+'_axes_value', [len(att.ints)], att.ints) - - if att.name == 'ends': - ends_const_node = helper.list_to_constant(node.name+'_ends_value', [len(att.ints)], att.ints) + if att.name == "axes": + axes_const_node = helper.list_to_constant( + node.name + "_axes_value", [len(att.ints)], att.ints + ) - if att.name == 'starts': - starts_const_node = helper.list_to_constant(node.name+'_starts_value', [len(att.ints)], att.ints) + if att.name == "ends": + ends_const_node = helper.list_to_constant( + node.name + "_ends_value", [len(att.ints)], att.ints + ) - if att.name == 'steps': - steps_const_node = helper.list_to_constant(node.name+'_steps_value',[ len(att.ints)], att.ints) + if att.name == "starts": + starts_const_node = helper.list_to_constant( + node.name + "_starts_value", [len(att.ints)], att.ints + ) - ## pop out from back + if att.name == "steps": + steps_const_node = helper.list_to_constant( + node.name + "_steps_value", [len(att.ints)], att.ints + ) + + # pop out from back attr_len = len(node.attribute) for i in range(attr_len): - node.attribute.remove(node.attribute[ attr_len -1 - i ]) + node.attribute.remove(node.attribute[attr_len - 1 - i]) - ## according the spec, we need to add node in specific order - if starts_const_node != None: + # according the spec, we need to add node in specific order + if starts_const_node is not None: g.node.extend([starts_const_node]) node.input.extend([starts_const_node.name]) - if ends_const_node != None: + if ends_const_node is not None: g.node.extend([ends_const_node]) - node.input.extend([ends_const_node.name]) - if axes_const_node != None: + node.input.extend([ends_const_node.name]) + if axes_const_node is not None: g.node.extend([axes_const_node]) node.input.extend([axes_const_node.name]) - if steps_const_node != None: + if steps_const_node is not None: g.node.extend([steps_const_node]) node.input.extend([steps_const_node.name]) - + def replace_min_max_attribute_to_const_node_in_clip_node(g): for node in g.node: - if node.op_type != 'Clip': + if node.op_type != "Clip": continue max_const_node = None min_const_node = None for att in node.attribute: - if att.name == 'max': - max_const_node = helper.list_to_constant(node.name+'_max_value', [], [att.f]) - - if att.name == 'min': - min_const_node = helper.list_to_constant(node.name+'_min_value', [], [att.f]) + if att.name == "max": + max_const_node = helper.list_to_constant( + node.name + "_max_value", [], [att.f] + ) - ## pop out from back - node.attribute.remove(node.attribute[1]) - node.attribute.remove(node.attribute[0]) - - ## according the spec, we need to add node in specific order + if att.name == "min": + min_const_node = helper.list_to_constant( + node.name + "_min_value", [], [att.f] + ) + + # pop out from back + node.attribute.remove(node.attribute[1]) + node.attribute.remove(node.attribute[0]) + + # according the spec, we need to add node in specific order g.node.extend([min_const_node]) g.node.extend([max_const_node]) node.input.extend([min_const_node.name]) node.input.extend([max_const_node.name]) + def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto: """Update ir_version from 4 to 6 and update opset from 9 to 11. @@ -173,6 +199,7 @@ def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto: model = other.polish_model(model) return model + if __name__ == "__main__": if len(sys.argv) != 3: print("Usage:{} file_in file_out".format(sys.argv[0])) diff --git a/tools/optimizer_scripts/onnx2onnx.py b/tools/optimizer_scripts/onnx2onnx.py index b820378..884dd2b 100644 --- a/tools/optimizer_scripts/onnx2onnx.py +++ b/tools/optimizer_scripts/onnx2onnx.py @@ -1,45 +1,51 @@ import onnx import onnx.utils -try: - from onnx import optimizer -except ImportError: - import onnxoptimizer as optimizer -import sys + import argparse import logging from tools import eliminating -from tools import fusing -from tools import replacing from tools import other from tools import special from tools import combo -from tools.helper import logger + # from tools import temp -def onnx2onnx_flow(m: onnx.ModelProto, - disable_fuse_bn=False, - bn_on_skip=False, - bn_before_add=False, - bgr=False, - norm=False, - rgba2yynn=False, - eliminate_tail=False, - opt_matmul=False, - duplicate_shared_weights=True) -> onnx.ModelProto: + +def onnx2onnx_flow( + m: onnx.ModelProto, + disable_fuse_bn=False, + bn_on_skip=False, + bn_before_add=False, + bgr=False, + norm=False, + rgba2yynn=False, + eliminate_tail=False, + opt_matmul=False, + duplicate_shared_weights=True, +) -> onnx.ModelProto: """Optimize the onnx. Args: m (ModelProto): the input onnx ModelProto - disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False. - bn_on_skip (bool, optional): add BN operator on skip branches. Defaults to False. - bn_before_add (bool, optional): add BN before Add node on every branches. Defaults to False. - bgr (bool, optional): add an Conv layer to convert rgb input to bgr. Defaults to False. - norm (bool, optional): add an Conv layer to add 0.5 tp the input. Defaults to False. - rgba2yynn (bool, optional): add an Conv layer to convert rgb input to yynn . Defaults to False. - eliminate_tail (bool, optional): remove the trailing NPU unsupported nodes. Defaults to False. - opt_matmul(bool, optional): optimize the MatMul layers according to the NPU limit. Defaults to False. - duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True. + disable_fuse_bn (bool, optional): do not fuse BN into Conv. + Defaults to False. + bn_on_skip (bool, optional): add BN operator on skip branches. + Defaults to False. + bn_before_add (bool, optional): add BN before Add node on every branch. + Defaults to False. + bgr (bool, optional): add an Conv layer to convert rgb input to bgr. + Defaults to False. + norm (bool, optional): add an Conv layer to add 0.5 tp the input. + Defaults to False. + rgba2yynn (bool, optional): add an Conv layer to convert rgb to yynn. + Defaults to False. + eliminate_tail (bool, optional): remove trailing NPU unsupported nodes. + Defaults to False. + opt_matmul(bool, optional): optimize MatMul layers due to NPU limit. + Defaults to False. + duplicate_shared_weights(bool, optional): duplicate shared weights. + Defaults to True. Returns: ModelProto: the optimized onnx model object. @@ -79,28 +85,83 @@ def onnx2onnx_flow(m: onnx.ModelProto, return m + # Main process if __name__ == "__main__": # Argument parser - parser = argparse.ArgumentParser(description="Optimize an ONNX model for Kneron compiler") - parser.add_argument('in_file', help='input ONNX FILE') - parser.add_argument('-o', '--output', dest='out_file', type=str, help="ouput ONNX FILE") - parser.add_argument('--log', default='i', type=str, help="set log level") - parser.add_argument('--bgr', action='store_true', default=False, help="set if the model is trained in BGR mode") - parser.add_argument('--norm', action='store_true', default=False, help="set if you have the input -0.5~0.5") - parser.add_argument('--rgba2yynn', action='store_true', default=False, help="set if the model has yynn input but you want to take rgba images") - parser.add_argument('--add-bn-on-skip', dest='bn_on_skip', action='store_true', default=False, - help="set if you only want to add BN on skip branches") - parser.add_argument('--add-bn', dest='bn_before_add', action='store_true', default=False, - help="set if you want to add BN before Add") - parser.add_argument('-t', '--eliminate-tail-unsupported', dest='eliminate_tail', action='store_true', default=False, - help='whether remove the last unsupported node for hardware') - parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False, - help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.") - parser.add_argument('--opt-matmul', dest='opt_matmul', action='store_true', default=False, - help="set if you want to optimize the MatMul operations for the kneron hardware.") - parser.add_argument('--no-duplicate-shared-weights', dest='no_duplicate_shared_weights', action='store_true', default=False, - help='do not duplicate shared weights. Defaults to False.') + parser = argparse.ArgumentParser( + description="Optimize an ONNX model for Kneron compiler" + ) + parser.add_argument("in_file", help="input ONNX FILE") + parser.add_argument( + "-o", "--output", dest="out_file", type=str, help="ouput ONNX FILE" + ) + parser.add_argument("--log", default="i", type=str, help="set log level") + parser.add_argument( + "--bgr", + action="store_true", + default=False, + help="set if the model is trained in BGR mode", + ) + parser.add_argument( + "--norm", + action="store_true", + default=False, + help="set if you have the input -0.5~0.5", + ) + parser.add_argument( + "--rgba2yynn", + action="store_true", + default=False, + help="set if the model has yynn input but you want " + "to take rgba images", + ) + parser.add_argument( + "--add-bn-on-skip", + dest="bn_on_skip", + action="store_true", + default=False, + help="set if you only want to add BN on skip branches", + ) + parser.add_argument( + "--add-bn", + dest="bn_before_add", + action="store_true", + default=False, + help="set if you want to add BN before Add", + ) + parser.add_argument( + "-t", + "--eliminate-tail-unsupported", + dest="eliminate_tail", + action="store_true", + default=False, + help="whether remove the last unsupported node for hardware", + ) + parser.add_argument( + "--no-bn-fusion", + dest="disable_fuse_bn", + action="store_true", + default=False, + help="set if you have met errors which related to inferenced " + "shape mismatch. This option will prevent fusing " + "BatchNormalization into Conv.", + ) + parser.add_argument( + "--opt-matmul", + dest="opt_matmul", + action="store_true", + default=False, + help="set if you want to optimize MatMul operations " + "for kneron hardware.", + ) + parser.add_argument( + "--no-duplicate-shared-weights", + dest="no_duplicate_shared_weights", + action="store_true", + default=False, + help="do not duplicate shared weights. Defaults to False.", + ) args = parser.parse_args() if args.out_file is None: @@ -108,11 +169,11 @@ if __name__ == "__main__": else: outfile = args.out_file - if args.log == 'w': + if args.log == "w": logging.basicConfig(level=logging.WARN) - elif args.log == 'd': + elif args.log == "d": logging.basicConfig(level=logging.DEBUG) - elif args.log == 'e': + elif args.log == "e": logging.basicConfig(level=logging.ERROR) else: logging.basicConfig(level=logging.INFO) @@ -131,6 +192,17 @@ if __name__ == "__main__": # Basic model organize m = onnx.load(args.in_file) - m = onnx2onnx_flow(m, args.disable_fuse_bn, args.bn_on_skip, args.bn_before_add, args.bgr, args.norm, args.rgba2yynn, args.eliminate_tail, args.opt_matmul, not args.no_duplicate_shared_weights) + m = onnx2onnx_flow( + m, + args.disable_fuse_bn, + args.bn_on_skip, + args.bn_before_add, + args.bgr, + args.norm, + args.rgba2yynn, + args.eliminate_tail, + args.opt_matmul, + not args.no_duplicate_shared_weights, + ) onnx.save(m, outfile) diff --git a/tools/optimizer_scripts/onnx_vs_onnx.py b/tools/optimizer_scripts/onnx_vs_onnx.py index c04c65b..d416045 100644 --- a/tools/optimizer_scripts/onnx_vs_onnx.py +++ b/tools/optimizer_scripts/onnx_vs_onnx.py @@ -5,12 +5,30 @@ import numpy as np from tools import helper -onnx2np_dtype = {0: 'float', 1: 'float32', 2: 'uint8', 3: 'int8', 4: 'uint16', 5: 'int16', 6: 'int32', 7: 'int64', 8: 'str', 9: 'bool', 10: 'float16', 11: 'double', 12: 'uint32', 13: 'uint64', 14: 'complex64', 15: 'complex128', 16: 'float'} +onnx2np_dtype = { + 0: "float", + 1: "float32", + 2: "uint8", + 3: "int8", + 4: "uint16", + 5: "int16", + 6: "int32", + 7: "int64", + 8: "str", + 9: "bool", + 10: "float16", + 11: "double", + 12: "uint32", + 13: "uint64", + 14: "complex64", + 15: "complex128", + 16: "float", +} def onnx_model_results(path_a, path_b, total_times=10): - """ using onnxruntime to inference two onnx models' ouputs - + """using onnxruntime to inference two onnx models' ouputs + :onnx model paths: two model paths :total_times: inference times, default to be 10 :returns: inference results of two models @@ -22,13 +40,20 @@ def onnx_model_results(path_a, path_b, total_times=10): outputs_b = session_b.get_outputs() # check outputs - assert len(outputs_a) == len(outputs_b), 'Two models have different output numbers.' + assert len(outputs_a) == len( + outputs_b + ), "Two models have different output numbers." for i in range(len(outputs_a)): out_shape_a, out_shape_b = outputs_a[i].shape, outputs_b[i].shape - out_shape_a = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_a)) - out_shape_b = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_b)) - assert out_shape_a == out_shape_b, 'Output {} has unmatched shapes'.format(i) - + out_shape_a = list( + map(lambda x: x if isinstance(x, int) else 1, out_shape_a) + ) + out_shape_b = list( + map(lambda x: x if isinstance(x, int) else 1, out_shape_b) + ) + assert ( + out_shape_a == out_shape_b + ), "Output {} has unmatched shapes".format(i) # load onnx graph_a and graph_b, to find the initializer and inputs # then compare to remove the items in the inputs which will be initialized @@ -38,9 +63,16 @@ def onnx_model_results(path_a, path_b, total_times=10): init_a, init_b = graph_a.initializer, graph_b.initializer # remove initializer from raw inputs - input_names_a, input_names_b = set([ele.name for ele in inputs_a]), set([ele.name for ele in inputs_b]) - init_names_a, init_names_b = set([ele.name for ele in init_a]), set([ele.name for ele in init_b]) - real_inputs_names_a, real_inputs_names_b = input_names_a - init_names_a, input_names_b - init_names_b + input_names_a, input_names_b = set([ele.name for ele in inputs_a]), set( + [ele.name for ele in inputs_b] + ) + init_names_a, init_names_b = set([ele.name for ele in init_a]), set( + [ele.name for ele in init_b] + ) + real_inputs_names_a, real_inputs_names_b = ( + input_names_a - init_names_a, + input_names_b - init_names_b, + ) # prepare and figure out matching of real inputs a and real inputs b # try to keep original orders of each inputs @@ -61,17 +93,20 @@ def onnx_model_results(path_a, path_b, total_times=10): for item_a in real_inputs_a: size, shape = helper.find_size_shape_from_value(item_a) if size: - assert real_single_input_a is None, 'Multiple inputs of first model, single input expected.' + assert ( + real_single_input_a is None + ), "Multiple inputs of first model, single input expected." real_single_input_a = item_a size_a, shape_a = size, shape for item_b in real_inputs_b: size, shape = helper.find_size_shape_from_value(item_b) if size: - assert real_single_input_b is None, 'Multiple inputs of second model, single input expected.' + assert ( + real_single_input_b is None + ), "Multiple inputs of second model, single input expected." real_single_input_b = item_b size_b, shape_b = size, shape - assert size_a == size_b, 'Sizes of two models do not match.' - + assert size_a == size_b, "Sizes of two models do not match." # construct inputs tensors input_data_type_a = real_single_input_a.type.tensor_type.elem_type @@ -84,7 +119,7 @@ def onnx_model_results(path_a, path_b, total_times=10): results_a = [[] for i in range(len(outputs_a))] results_b = [[] for i in range(len(outputs_b))] while times < total_times: - # initialize inputs by random data, default to be uniform + # initialize inputs by random data, default to be uniform data = np.random.random(size_a) input_a = np.reshape(data, shape_a).astype(input_data_type_a) input_b = np.reshape(data, shape_b).astype(input_data_type_b) @@ -93,12 +128,18 @@ def onnx_model_results(path_a, path_b, total_times=10): input_dict_b = {} for item_a in real_inputs_a: item_type_a = onnx2np_dtype[item_a.type.tensor_type.elem_type] - input_dict_a[item_a.name] = np.array([]).astype(item_type_a) \ - if item_a.name != real_single_input_a.name else input_a + input_dict_a[item_a.name] = ( + np.array([]).astype(item_type_a) + if item_a.name != real_single_input_a.name + else input_a + ) for item_b in real_inputs_b: item_type_b = onnx2np_dtype[item_b.type.tensor_type.elem_type] - input_dict_b[item_b.name] = np.array([]).astype(item_type_b) \ - if item_b.name != real_single_input_b.name else input_b + input_dict_b[item_b.name] = ( + np.array([]).astype(item_type_b) + if item_b.name != real_single_input_b.name + else input_b + ) ra = session_a.run([], input_dict_a) rb = session_b.run([], input_dict_b) @@ -109,26 +150,32 @@ def onnx_model_results(path_a, path_b, total_times=10): return results_a, results_b -if __name__ == '__main__': + +if __name__ == "__main__": # Argument parser. - parser = argparse.ArgumentParser(description="Compare two ONNX models to check if they have the same output.") - parser.add_argument('in_file_a', help='input ONNX file a') - parser.add_argument('in_file_b', help='input ONNX file b') + parser = argparse.ArgumentParser( + description="Compare two ONNX models to check if " + "they have the same output." + ) + parser.add_argument("in_file_a", help="input ONNX file a") + parser.add_argument("in_file_b", help="input ONNX file b") args = parser.parse_args() - results_a, results_b = onnx_model_results(args.in_file_a, args.in_file_b, total_times=10) + results_a, results_b = onnx_model_results( + args.in_file_a, args.in_file_b, total_times=10 + ) ra_flat = helper.flatten_with_depth(results_a, 0) rb_flat = helper.flatten_with_depth(results_b, 0) shape_a = [item[1] for item in ra_flat] shape_b = [item[1] for item in rb_flat] - assert shape_a == shape_b, 'two results data shape doesn\'t match' + assert shape_a == shape_b, "two results data shape doesn't match" ra_raw = [item[0] for item in ra_flat] rb_raw = [item[0] for item in rb_flat] try: np.testing.assert_almost_equal(ra_raw, rb_raw, 4) - print('Two models have the same behaviour.') + print("Two models have the same behaviour.") except Exception as mismatch: print(mismatch) exit(1) diff --git a/tools/optimizer_scripts/onnx_vs_onnx_opt.py b/tools/optimizer_scripts/onnx_vs_onnx_opt.py index b660cf4..5ac4e6b 100644 --- a/tools/optimizer_scripts/onnx_vs_onnx_opt.py +++ b/tools/optimizer_scripts/onnx_vs_onnx_opt.py @@ -1,4 +1,3 @@ -import onnx import argparse import glob import csv @@ -8,214 +7,242 @@ import matplotlib.pyplot as plt from tools import helper import onnx_vs_onnx as onnx_tester + def compare_results(results_a, results_b): - """ compare onnx model inference results - calculate basic statistical values - results: results from inference multiple times - returns: list of basic statistical values - """ - # input results data can be of nonuniform shape - # get flatten data to compare - ra_flat = helper.flatten_with_depth(results_a, 0) - rb_flat = helper.flatten_with_depth(results_b, 0) - shape_a = [item[1] for item in ra_flat] - shape_b = [item[1] for item in rb_flat] - assert shape_a == shape_b, 'two results data shape doesn\'t match' - ra_raw = [item[0] for item in ra_flat] - rb_raw = [item[0] for item in rb_flat] + """compare onnx model inference results + calculate basic statistical values + results: results from inference multiple times + returns: list of basic statistical values + """ + # input results data can be of nonuniform shape + # get flatten data to compare + ra_flat = helper.flatten_with_depth(results_a, 0) + rb_flat = helper.flatten_with_depth(results_b, 0) + shape_a = [item[1] for item in ra_flat] + shape_b = [item[1] for item in rb_flat] + assert shape_a == shape_b, "two results data shape doesn't match" + ra_raw = [item[0] for item in ra_flat] + rb_raw = [item[0] for item in rb_flat] - # the statistical values - max_rel_diff = 0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } ) - max_abs_diff = 0 # defined to be max( { abs(ra-rb) } ) - mean_rel_diff = 0 - mean_abs_diff = 0 - std_rel_diff = 0 - std_abs_diff = 0 - acc_with_diff_precision = [] - rel_diff = [] - abs_diff_percentiles = [] # rel_diff percentiles - rel_diff_percentiles = [] # abs_diff precentiles + # the statistical values + max_rel_diff = ( + 0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } ) + ) + max_abs_diff = 0 # defined to be max( { abs(ra-rb) } ) + mean_rel_diff = 0 + mean_abs_diff = 0 + std_rel_diff = 0 + std_abs_diff = 0 + acc_with_diff_precision = [] + rel_diff = [] + abs_diff_percentiles = [] # rel_diff percentiles + rel_diff_percentiles = [] # abs_diff precentiles - raw_diff = [ra_raw[i]-rb_raw[i] for i in range(len(ra_raw))] - abs_diff = [abs(num) for num in raw_diff] - for i in range(len(ra_raw)): - divider = max([abs(ra_raw[i]), abs(rb_raw[i])]) - val = abs_diff[i]/divider if divider != 0 else 0 - rel_diff.append(val) - - max_rel_diff = max(rel_diff) - max_abs_diff = max(abs_diff) - mean_rel_diff = np.average(rel_diff) - mean_abs_diff = np.average(abs_diff) - std_rel_diff = np.std(rel_diff) - std_abs_diff = np.std(abs_diff) - - # calculate accuracy with different precison - for digit in range(8): - correct = 0 + raw_diff = [ra_raw[i] - rb_raw[i] for i in range(len(ra_raw))] + abs_diff = [abs(num) for num in raw_diff] for i in range(len(ra_raw)): - if format(ra_raw[i], '.'+str(digit)+'f')\ - == format(rb_raw[i], '.'+str(digit)+'f'): - correct += 1 - acc_with_diff_precision.append([digit, float(format(correct/len(ra_raw), '.3f'))]) + divider = max([abs(ra_raw[i]), abs(rb_raw[i])]) + val = abs_diff[i] / divider if divider != 0 else 0 + rel_diff.append(val) - # analyze rel_diff distribution - rel_diff.sort() - abs_diff.sort() - for i in range(20): - rel_diff_percentiles.append(['{}%'.format(i*5), rel_diff[int((i/20)*len(rel_diff))]]) - abs_diff_percentiles.append(['{}%'.format(i*5), abs_diff[int((i/20)*len(abs_diff))]]) + max_rel_diff = max(rel_diff) + max_abs_diff = max(abs_diff) + mean_rel_diff = np.average(rel_diff) + mean_abs_diff = np.average(abs_diff) + std_rel_diff = np.std(rel_diff) + std_abs_diff = np.std(abs_diff) - results = [ - ['max_rel_diff', max_rel_diff], - ['max_abs_diff', max_abs_diff], - ['mean_rel_diff', mean_rel_diff], - ['mean_abs_diff', mean_abs_diff], - ['std_rel_diff', std_rel_diff], - ['std_abs_diff', std_abs_diff], - ['acc_with_diff_precision', acc_with_diff_precision], - ['rel_diff_percentiles', rel_diff_percentiles], - ['abs_diff_percentiles', abs_diff_percentiles] - ] - - return results + # calculate accuracy with different precison + for digit in range(8): + correct = 0 + for i in range(len(ra_raw)): + if format(ra_raw[i], "." + str(digit) + "f") == format( + rb_raw[i], "." + str(digit) + "f" + ): + correct += 1 + acc_with_diff_precision.append( + [digit, float(format(correct / len(ra_raw), ".3f"))] + ) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='test model optimization results') - - parser.add_argument('dir', type=str, help='the directory that stores onnx models') - parser.add_argument('ending1', type=str, help='model file name ending(eg, .onnx)') - parser.add_argument('ending2', type=str, help='opt model file name ending(eg. _opt.onnx)') - parser.add_argument('out_file', type=str, help='output csv file name') - parser.add_argument('-p', '--plot', default='N', help='get plots (Y/N)') - parser.add_argument('-i', '--iter_times', default=10, type=int, help='inference times') + # analyze rel_diff distribution + rel_diff.sort() + abs_diff.sort() + for i in range(20): + rel_diff_percentiles.append( + ["{}%".format(i * 5), rel_diff[int((i / 20) * len(rel_diff))]] + ) + abs_diff_percentiles.append( + ["{}%".format(i * 5), abs_diff[int((i / 20) * len(abs_diff))]] + ) - args = parser.parse_args() + results = [ + ["max_rel_diff", max_rel_diff], + ["max_abs_diff", max_abs_diff], + ["mean_rel_diff", mean_rel_diff], + ["mean_abs_diff", mean_abs_diff], + ["std_rel_diff", std_rel_diff], + ["std_abs_diff", std_abs_diff], + ["acc_with_diff_precision", acc_with_diff_precision], + ["rel_diff_percentiles", rel_diff_percentiles], + ["abs_diff_percentiles", abs_diff_percentiles], + ] - old_models_paths = glob.glob(args.dir+'*'+args.ending1) - new_models_paths = glob.glob(args.dir+'*'+args.ending2) - - stats_table = [[ - 'Model', - 'max_rel_diff', - 'max_abs_diff', - 'mean_rel_diff', - 'mean_abs_diff', - 'std_rel_diff', - 'std_abs_diff', - 'acc_with_diff_precision', - 'rel_diff_percentiles', - 'abs_diff_percentiles' - ]] - - for new_model_path in new_models_paths: - old_model_path = new_model_path[:-len(args.ending2)] + args.ending1 - if old_model_path not in old_models_paths: - continue - - # run inference - results_a, results_b = onnx_tester.onnx_model_results(old_model_path, new_model_path, total_times=args.iter_times) - - # compare inference results - comparision = compare_results(results_a, results_b) - - new_line = [old_model_path.split('/')[-1]] - for item in comparision: - new_line.append(item[1]) - - stats_table.append(new_line) - - # try to read existing file - old_stats_table = [] - try: - old_file = open(args.out_file, 'r') - reader = csv.reader(old_file) - old_header = reader.__next__() - for row in reader: - old_stats_table.append(row) - old_file.close() - except: - pass - - # compare and merge possible old stat data file with new stat data file - header = stats_table[0] - stats_table = stats_table[1:] - new_model_names = set([item[0] for item in stats_table]) - for row in old_stats_table: - if row[0] not in new_model_names: - stats_table.append(row) - stats_table.insert(0, header) - - # write a new stat data file, overwrite old file - new_file = open(args.out_file, 'w', newline='') - writer = csv.writer(new_file) - for row in stats_table: - writer.writerow(row) - new_file.close() - - # make some plots - if args.plot == 'Y': - if len(stats_table) < 2: - exit(0) - - sample_table = stats_table[1:] if len(stats_table) < 6 else stats_table[1:6] - - max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]] - plt.hist(max_rel_diffs, bins=15) - plt.title('Max Relavtive Difference Histogram') - plt.xlabel('Max Relative Difference') - plt.ylabel('Counts') - plt.savefig('max_rel_diff_hist.png') - plt.close() - - max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]] - plt.hist(max_abs_diffs, bins=15) - plt.title('Max Absolute Difference Histogram') - plt.xlabel('Max Absolute Difference') - plt.ylabel('Counts') - plt.savefig('max_abs_diff_hist.png') - plt.close() - - for line in sample_table: - model_name = line[0] - percentiles = line[-2] - x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))] - y = [ele[1] for ele in percentiles] - plt.plot(x, y, label=model_name) - plt.title('Rel_diff Percentiles of Raw and Optimized Models') - plt.xlabel('percentage') - plt.ylabel('relative difference') - plt.legend() - plt.savefig('rel_diff_percentiles.png') - plt.close() - - for line in sample_table: - model_name = line[0] - percentiles = line[-1] - x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))] - y = [ele[1] for ele in percentiles] - plt.plot(x, y, label=model_name) - plt.title('Abs_diff Percentiles of Raw and Optimized Models') - plt.xlabel('percentage') - plt.ylabel('absolute difference') - plt.legend() - plt.savefig('abs_diff_percentiles.png') - plt.close() - - for line in sample_table: - model_name = line[0] - accuracies = line[-3] - x = [acc[0] for acc in accuracies] - y = [acc[1] for acc in accuracies] - plt.plot(x, y, label=model_name) - plt.title('Accuracies with Different Precisions') - plt.xlabel('Decimals') - plt.ylabel('Precision') - plt.legend() - plt.savefig('precisions.png') - plt.close() + return results - +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="test model optimization results" + ) + parser.add_argument( + "dir", type=str, help="the directory that stores onnx models" + ) + parser.add_argument( + "ending1", type=str, help="model file name ending(eg, .onnx)" + ) + parser.add_argument( + "ending2", type=str, help="opt model file name ending(eg. _opt.onnx)" + ) + parser.add_argument("out_file", type=str, help="output csv file name") + parser.add_argument("-p", "--plot", default="N", help="get plots (Y/N)") + parser.add_argument( + "-i", "--iter_times", default=10, type=int, help="inference times" + ) + args = parser.parse_args() + + old_models_paths = glob.glob(args.dir + "*" + args.ending1) + new_models_paths = glob.glob(args.dir + "*" + args.ending2) + + stats_table = [ + [ + "Model", + "max_rel_diff", + "max_abs_diff", + "mean_rel_diff", + "mean_abs_diff", + "std_rel_diff", + "std_abs_diff", + "acc_with_diff_precision", + "rel_diff_percentiles", + "abs_diff_percentiles", + ] + ] + + for new_model_path in new_models_paths: + old_model_path = new_model_path[: -len(args.ending2)] + args.ending1 + if old_model_path not in old_models_paths: + continue + + # run inference + results_a, results_b = onnx_tester.onnx_model_results( + old_model_path, new_model_path, total_times=args.iter_times + ) + + # compare inference results + comparision = compare_results(results_a, results_b) + + new_line = [old_model_path.split("/")[-1]] + for item in comparision: + new_line.append(item[1]) + + stats_table.append(new_line) + + # try to read existing file + old_stats_table = [] + try: + old_file = open(args.out_file, "r") + reader = csv.reader(old_file) + old_header = reader.__next__() + for row in reader: + old_stats_table.append(row) + old_file.close() + except Exception: + pass + + # compare and merge possible old stat data file with new stat data file + header = stats_table[0] + stats_table = stats_table[1:] + new_model_names = set([item[0] for item in stats_table]) + for row in old_stats_table: + if row[0] not in new_model_names: + stats_table.append(row) + stats_table.insert(0, header) + + # write a new stat data file, overwrite old file + new_file = open(args.out_file, "w", newline="") + writer = csv.writer(new_file) + for row in stats_table: + writer.writerow(row) + new_file.close() + + # make some plots + if args.plot == "Y": + if len(stats_table) < 2: + exit(0) + + sample_table = ( + stats_table[1:] if len(stats_table) < 6 else stats_table[1:6] + ) + + max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]] + plt.hist(max_rel_diffs, bins=15) + plt.title("Max Relavtive Difference Histogram") + plt.xlabel("Max Relative Difference") + plt.ylabel("Counts") + plt.savefig("max_rel_diff_hist.png") + plt.close() + + max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]] + plt.hist(max_abs_diffs, bins=15) + plt.title("Max Absolute Difference Histogram") + plt.xlabel("Max Absolute Difference") + plt.ylabel("Counts") + plt.savefig("max_abs_diff_hist.png") + plt.close() + + for line in sample_table: + model_name = line[0] + percentiles = line[-2] + x = [ + round(i * (1 / len(percentiles)), 2) + for i in range(len(percentiles)) + ] + y = [ele[1] for ele in percentiles] + plt.plot(x, y, label=model_name) + plt.title("Rel_diff Percentiles of Raw and Optimized Models") + plt.xlabel("percentage") + plt.ylabel("relative difference") + plt.legend() + plt.savefig("rel_diff_percentiles.png") + plt.close() + + for line in sample_table: + model_name = line[0] + percentiles = line[-1] + x = [ + round(i * (1 / len(percentiles)), 2) + for i in range(len(percentiles)) + ] + y = [ele[1] for ele in percentiles] + plt.plot(x, y, label=model_name) + plt.title("Abs_diff Percentiles of Raw and Optimized Models") + plt.xlabel("percentage") + plt.ylabel("absolute difference") + plt.legend() + plt.savefig("abs_diff_percentiles.png") + plt.close() + + for line in sample_table: + model_name = line[0] + accuracies = line[-3] + x = [acc[0] for acc in accuracies] + y = [acc[1] for acc in accuracies] + plt.plot(x, y, label=model_name) + plt.title("Accuracies with Different Precisions") + plt.xlabel("Decimals") + plt.ylabel("Precision") + plt.legend() + plt.savefig("precisions.png") + plt.close() diff --git a/tools/optimizer_scripts/pytorch2onnx.py b/tools/optimizer_scripts/pytorch2onnx.py index 0f2c559..9dd79ec 100644 --- a/tools/optimizer_scripts/pytorch2onnx.py +++ b/tools/optimizer_scripts/pytorch2onnx.py @@ -1,21 +1,10 @@ import onnx import onnx.utils -try: - from onnx import optimizer -except ImportError: - import onnxoptimizer as optimizer + import sys -import numpy as np -import struct import logging import argparse -from tools import eliminating -from tools import fusing -from tools import replacing -from tools import other -from tools import combo -from tools import special from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow # Debug use @@ -25,13 +14,28 @@ from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow # Generate a prototype onnx # ###################################### -parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler") -parser.add_argument('in_file', help='input ONNX or PTH FILE') -parser.add_argument('out_file', help="ouput ONNX FILE") -parser.add_argument('--input-size', dest='input_size', nargs=3, - help='if you using pth, please use this argument to set up the input size of the model. It should be in \'CH H W\' format, e.g. \'--input-size 3 256 512\'.') -parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False, - help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.") +parser = argparse.ArgumentParser( + description="Optimize a Pytorch generated model for Kneron compiler" +) +parser.add_argument("in_file", help="input ONNX or PTH FILE") +parser.add_argument("out_file", help="ouput ONNX FILE") +parser.add_argument( + "--input-size", + dest="input_size", + nargs=3, + help="if you using pth, please use this argument to set up the input " + "size of the model. It should be in 'CH H W' format, " + "e.g. '--input-size 3 256 512'.", +) +parser.add_argument( + "--no-bn-fusion", + dest="disable_fuse_bn", + action="store_true", + default=False, + help="set if you have met errors which related to inferenced shape " + "mismatch. This option will prevent fusing BatchNormalization " + "into Conv.", +) args = parser.parse_args() @@ -39,7 +43,7 @@ if len(args.in_file) <= 4: # When the filename is too short. logging.error("Invalid input file: {}".format(args.in_file)) exit(1) -elif args.in_file[-4:] == '.pth': +elif args.in_file[-4:] == ".pth": # Pytorch pth case logging.warning("Converting from pth to onnx is not recommended.") onnx_in = args.out_file @@ -47,21 +51,29 @@ elif args.in_file[-4:] == '.pth': from torch.autograd import Variable import torch import torch.onnx + # import torchvision # Standard ImageNet input - 3 channels, 224x224. # Values don't matter as we care about network structure. # But they can also be real inputs. if args.input_size is None: - logging.error("\'--input-size\' is required for the pth input file.") + logging.error("'--input-size' is required for the pth input file.") exit(1) - dummy_input = Variable(torch.randn(1, int(args.input_size[0]), int(args.input_size[1]), int(args.input_size[2]))) + dummy_input = Variable( + torch.randn( + 1, + int(args.input_size[0]), + int(args.input_size[1]), + int(args.input_size[2]), + ) + ) # Obtain your model, it can be also constructed in your script explicitly. - model = torch.load(sys.argv[1], map_location='cpu') + model = torch.load(sys.argv[1], map_location="cpu") # model = torchvision.models.resnet34(pretrained=True) # Invoke export. # torch.save(model, "resnet34.pth") torch.onnx.export(model, dummy_input, args.out_file, opset_version=11) -elif args.in_file[-4:] == 'onnx': +elif args.in_file[-4:] == "onnx": onnx_in = args.in_file else: # When the file is neither an onnx or a pytorch pth. diff --git a/tools/optimizer_scripts/pytorch_exported_onnx_preprocess.py b/tools/optimizer_scripts/pytorch_exported_onnx_preprocess.py index 509db82..356f0e3 100644 --- a/tools/optimizer_scripts/pytorch_exported_onnx_preprocess.py +++ b/tools/optimizer_scripts/pytorch_exported_onnx_preprocess.py @@ -1,29 +1,22 @@ import onnx import onnx.utils -try: - from onnx import optimizer -except ImportError: - import onnxoptimizer as optimizer -import sys -import numpy as np -import struct + import logging import argparse -from .tools import eliminating -from .tools import fusing -from .tools import replacing -from .tools import other from .tools import combo -from .tools import special + # Define general pytorch exported onnx optimize process -def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx.ModelProto: +def torch_exported_onnx_flow( + m: onnx.ModelProto, disable_fuse_bn=False +) -> onnx.ModelProto: """Optimize the Pytorch exported onnx. Args: m (ModelProto): the input onnx model - disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False. + disable_fuse_bn (bool, optional): do not fuse BN into Conv. + Defaults to False. Returns: ModelProto: the optimized onnx model @@ -38,20 +31,29 @@ def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx. # Main Process if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler") - parser.add_argument('in_file', help='input ONNX') - parser.add_argument('out_file', help="ouput ONNX FILE") - parser.add_argument('--log', default='i', type=str, help="set log level") - parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False, - help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.") + parser = argparse.ArgumentParser( + description="Optimize a Pytorch generated model for Kneron compiler" + ) + parser.add_argument("in_file", help="input ONNX") + parser.add_argument("out_file", help="ouput ONNX FILE") + parser.add_argument("--log", default="i", type=str, help="set log level") + parser.add_argument( + "--no-bn-fusion", + dest="disable_fuse_bn", + action="store_true", + default=False, + help="set if you have met errors which related to inferenced shape " + "mismatch. This option will prevent fusing BatchNormalization " + "into Conv.", + ) args = parser.parse_args() - if args.log == 'w': + if args.log == "w": logging.basicConfig(level=logging.WARN) - elif args.log == 'd': + elif args.log == "d": logging.basicConfig(level=logging.DEBUG) - elif args.log == 'e': + elif args.log == "e": logging.basicConfig(level=logging.ERROR) else: logging.basicConfig(level=logging.INFO) @@ -60,7 +62,7 @@ if __name__ == "__main__": # When the filename is too short. logging.error("Invalid input file: {}".format(args.in_file)) exit(1) - elif args.in_file[-4:] == 'onnx': + elif args.in_file[-4:] == "onnx": onnx_in = args.in_file else: # When the file is not an onnx file. diff --git a/tools/optimizer_scripts/tensorflow2onnx.py b/tools/optimizer_scripts/tensorflow2onnx.py index 13c0dab..44b8667 100644 --- a/tools/optimizer_scripts/tensorflow2onnx.py +++ b/tools/optimizer_scripts/tensorflow2onnx.py @@ -8,7 +8,8 @@ import onnx.utils from tensorflow.python.platform import gfile from tools import combo, eliminating, replacing, other -def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto: + +def tf2onnx_flow(pb_path: str, test_mode=False) -> onnx.ModelProto: """Convert frozen graph pb file into onnx Args: @@ -21,34 +22,45 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto: Returns: onnx.ModelProto: converted onnx """ - TF2ONNX_VERSION = int(tf2onnx.version.version.replace('.', '')) + TF2ONNX_VERSION = int(tf2onnx.version.version.replace(".", "")) if 160 <= TF2ONNX_VERSION: from tf2onnx import tf_loader else: from tf2onnx import loader as tf_loader - - if pb_path[-3:] == '.pb': - model_name = pb_path.split('/')[-1][:-3] - # always reset tensorflow session at begin + if pb_path[-3:] == ".pb": + model_name = pb_path.split("/")[-1][:-3] + + # always reset tensorflow session at begin tf.reset_default_graph() - + with tf.Session() as sess: - with gfile.FastGFile(pb_path, 'rb') as f: + with gfile.FastGFile(pb_path, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() - tf.import_graph_def(graph_def, name='') + tf.import_graph_def(graph_def, name="") - if 160 <= int(tf2onnx.version.version.replace('.', '')): - onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions = tf2onnx.tf_utils.tflist_to_onnx( - sess.graph, - {}) + if 160 <= int(tf2onnx.version.version.replace(".", "")): + ( + onnx_nodes, + op_cnt, + attr_cnt, + output_shapes, + dtypes, + functions, + ) = tf2onnx.tf_utils.tflist_to_onnx(sess.graph, {}) else: - onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tf2onnx.tfonnx.tflist_to_onnx( - sess.graph.get_operations(), - {}) + ( + onnx_nodes, + op_cnt, + attr_cnt, + output_shapes, + dtypes, + ) = tf2onnx.tfonnx.tflist_to_onnx( + sess.graph.get_operations(), {} + ) for n in onnx_nodes: if len(n.output) == 0: @@ -59,12 +71,12 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto: nodes_outputs = set() for n in onnx_nodes: - if n.op_type == 'Placeholder': + if n.op_type == "Placeholder": continue for input in n.input: nodes_inputs.add(input) for output in n.output: - nodes_outputs.add(output) + nodes_outputs.add(output) graph_input_names = set() for input_name in nodes_inputs: @@ -76,35 +88,43 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto: if n.input and n.input[0] not in nodes_outputs: continue if len(n.output) == 0: - n.output.append(n.name + ':0') + n.output.append(n.name + ":0") graph_output_names.add(n.output[0]) else: output_name = n.output[0] - if (output_name not in nodes_inputs) and (0 < len(n.input)): + if (output_name not in nodes_inputs) and ( + 0 < len(n.input) + ): graph_output_names.add(output_name) - logging.info('Model Inputs: %s', str(list(graph_input_names))) - logging.info('Model Outputs: %s', str(list(graph_output_names))) + logging.info("Model Inputs: %s", str(list(graph_input_names))) + logging.info("Model Outputs: %s", str(list(graph_output_names))) - graph_def, inputs, outputs = tf_loader.from_graphdef(model_path=pb_path, - input_names=list(graph_input_names), - output_names=list(graph_output_names)) + graph_def, inputs, outputs = tf_loader.from_graphdef( + model_path=pb_path, + input_names=list(graph_input_names), + output_names=list(graph_output_names), + ) with tf.Graph().as_default() as tf_graph: - tf.import_graph_def(graph_def, name='') + tf.import_graph_def(graph_def, name="") if 160 <= TF2ONNX_VERSION: with tf_loader.tf_session(graph=tf_graph): - onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph, - input_names=inputs, - output_names=outputs, - opset=11) + onnx_graph = tf2onnx.tfonnx.process_tf_graph( + tf_graph=tf_graph, + input_names=inputs, + output_names=outputs, + opset=11, + ) else: with tf.Session(graph=tf_graph): - onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph, - input_names=inputs, - output_names=outputs, - opset=11) + onnx_graph = tf2onnx.tfonnx.process_tf_graph( + tf_graph=tf_graph, + input_names=inputs, + output_names=outputs, + opset=11, + ) # Optimize with tf2onnx.optimizer onnx_graph = tf2onnx.optimizer.optimize_graph(onnx_graph) @@ -115,7 +135,9 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto: model_proto = other.polish_model(model_proto) else: - raise Exception('expect .pb file as input, but got "' + str(pb_path) + '"') + raise Exception( + 'expect .pb file as input, but got "' + str(pb_path) + '"' + ) # rename m = model_proto @@ -133,15 +155,26 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto: return m - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Convert tensorflow pb file to onnx file and optimized onnx file. Or just optimize tensorflow onnx file.') - parser.add_argument('in_file', help='input file') - parser.add_argument('out_file', help='output optimized model file') - parser.add_argument('-t', '--test_mode', default=False, help='test mode will not eliminate shape changes after input') +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert tensorflow pb file to onnx file and optimized " + "onnx file. Or just optimize tensorflow onnx file." + ) + parser.add_argument("in_file", help="input file") + parser.add_argument("out_file", help="output optimized model file") + parser.add_argument( + "-t", + "--test_mode", + default=False, + help="test mode will not eliminate shape changes after input", + ) args = parser.parse_args() - logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.INFO) + logging.basicConfig( + stream=sys.stdout, + format="[%(asctime)s] %(levelname)s: %(message)s", + level=logging.INFO, + ) m = tf2onnx_flow(args.in_file, args.test_mode) onnx.save(m, args.out_file) - logging.info('Save Optimized ONNX: %s', args.out_file) + logging.info("Save Optimized ONNX: %s", args.out_file) diff --git a/tools/optimizer_scripts/tflite_vs_onnx.py b/tools/optimizer_scripts/tflite_vs_onnx.py index ffeecea..e8405cf 100644 --- a/tools/optimizer_scripts/tflite_vs_onnx.py +++ b/tools/optimizer_scripts/tflite_vs_onnx.py @@ -6,6 +6,7 @@ import onnxruntime from tools import helper + def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10): # Setup onnx session and get meta data onnx_session = onnxruntime.InferenceSession(onnx_file, None) @@ -21,21 +22,32 @@ def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10): tflite_session.allocate_tensors() tflite_inputs = tflite_session.get_input_details() tflite_outputs = tflite_session.get_output_details() - tflite_input_shape = tflite_inputs[0]['shape'] + tflite_input_shape = tflite_inputs[0]["shape"] # Compare input shape - assert(len(onnx_input_shape) == len(tflite_input_shape)), "TFLite and ONNX shape unmatch." - assert(onnx_input_shape == [tflite_input_shape[0], tflite_input_shape[3], tflite_input_shape[1], tflite_input_shape[2]]), "TFLite and ONNX shape unmatch." + assert len(onnx_input_shape) == len( + tflite_input_shape + ), "TFLite and ONNX shape unmatch." + assert onnx_input_shape == [ + tflite_input_shape[0], + tflite_input_shape[3], + tflite_input_shape[1], + tflite_input_shape[2], + ], "TFLite and ONNX shape unmatch." # Generate random number and run tflite_results = [] onnx_results = [] for _ in range(total_times): # Generate input - tflite_input_data = np.array(np.random.random_sample(tflite_input_shape), dtype=np.float32) + tflite_input_data = np.array( + np.random.random_sample(tflite_input_shape), dtype=np.float32 + ) onnx_input_data = np.transpose(tflite_input_data, [0, 3, 1, 2]) # Run tflite - tflite_session.set_tensor(tflite_inputs[0]['index'], tflite_input_data) + tflite_session.set_tensor(tflite_inputs[0]["index"], tflite_input_data) tflite_session.invoke() - tflite_results.append(tflite_session.get_tensor(tflite_outputs[0]['index'])) + tflite_results.append( + tflite_session.get_tensor(tflite_outputs[0]["index"]) + ) # Run onnx onnx_input_dict = {onnx_inputs[0].name: onnx_input_data} onnx_results.append(onnx_session.run([], onnx_input_dict)[0]) @@ -43,26 +55,31 @@ def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10): return tflite_results, onnx_results -if __name__ == '__main__': +if __name__ == "__main__": # Argument parser. - parser = argparse.ArgumentParser(description="Compare a TFLite model and an ONNX model to check if they have the same output.") - parser.add_argument('tflite_file', help='input tflite file') - parser.add_argument('onnx_file', help='input ONNX file') + parser = argparse.ArgumentParser( + description="Compare a TFLite model and an ONNX model to check " + "if they have the same output." + ) + parser.add_argument("tflite_file", help="input tflite file") + parser.add_argument("onnx_file", help="input ONNX file") args = parser.parse_args() - results_a, results_b = compare_tflite_and_onnx(args.tflite_file, args.onnx_file, total_times=10) + results_a, results_b = compare_tflite_and_onnx( + args.tflite_file, args.onnx_file, total_times=10 + ) ra_flat = helper.flatten_with_depth(results_a, 0) rb_flat = helper.flatten_with_depth(results_b, 0) shape_a = [item[1] for item in ra_flat] shape_b = [item[1] for item in rb_flat] - assert shape_a == shape_b, 'two results data shape doesn\'t match' + assert shape_a == shape_b, "two results data shape doesn't match" ra_raw = [item[0] for item in ra_flat] rb_raw = [item[0] for item in rb_flat] try: np.testing.assert_almost_equal(ra_raw, rb_raw, 8) - print('Two models have the same behaviour.') + print("Two models have the same behaviour.") except Exception as mismatch: print(mismatch) - exit(1) \ No newline at end of file + exit(1) diff --git a/tools/optimizer_scripts/tools/combo.py b/tools/optimizer_scripts/tools/combo.py index adadecb..1a20ebb 100644 --- a/tools/optimizer_scripts/tools/combo.py +++ b/tools/optimizer_scripts/tools/combo.py @@ -2,7 +2,7 @@ """ import logging -import onnx.utils + try: from onnx import optimizer except ImportError: @@ -15,16 +15,19 @@ from . import eliminating from . import fusing from . import constant_folding from . import removing_transpose -from . import modhelper from .common_pattern import torch_pattern_match, tf_pattern_match from .helper import logger -def preprocess(model_proto, disable_fuse_bn=False, duplicate_shared_weights=True): + +def preprocess( + model_proto, disable_fuse_bn=False, duplicate_shared_weights=True +): """The most common used functions before other processing. Args: model_proto: the original model input - duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True. + duplicate_shared_weights(bool, optional): duplicate shared weights. + Defaults to True. Return: the new model after preprocessing @@ -65,22 +68,28 @@ def preprocess(model_proto, disable_fuse_bn=False, duplicate_shared_weights=True replacing.replace_initializer_with_Constant(model_proto.graph) other.topological_sort(model_proto.graph) m = other.polish_model(model_proto) - passes = ['extract_constant_to_initializer', - 'eliminate_nop_dropout', - 'eliminate_deadend', - 'fuse_matmul_add_bias_into_gemm', - 'fuse_pad_into_conv'] + passes = [ + "extract_constant_to_initializer", + "eliminate_nop_dropout", + "eliminate_deadend", + "fuse_matmul_add_bias_into_gemm", + "fuse_pad_into_conv", + ] if not disable_fuse_bn: - passes.append('fuse_bn_into_conv') + passes.append("fuse_bn_into_conv") m = optimizer.optimize(m, passes) g = m.graph - # Add name again since onnx optimizer higher than 1.7 may remove node names. + # Add name again since onnx optimizer higher than 1.7 may remove node names other.add_name_to_node(g) if duplicate_shared_weights: - replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=True) + replacing.replace_initializer_with_Constant( + g, duplicate_shared_weights=True + ) other.duplicate_param_shared_constant(g) else: - replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=False) + replacing.replace_initializer_with_Constant( + g, duplicate_shared_weights=False + ) other.topological_sort(g) m = other.polish_model(m) g = m.graph @@ -161,12 +170,12 @@ def pytorch_constant_folding(m): other.topological_sort(m.graph) while len(m.graph.value_info) != 0: m.graph.value_info.pop() - + m = other.inference_shapes(m) replacing.replace_shape_with_constant(m.graph) other.topological_sort(m.graph) m = torch_pattern_match(m) - m = optimizer.optimize(m, ['eliminate_deadend']) + m = optimizer.optimize(m, ["eliminate_deadend"]) return m @@ -206,7 +215,7 @@ def tensorflow_optimization(m): replacing.replace_shape_with_constant(m.graph) other.topological_sort(m.graph) m = tf_pattern_match(m) - m = optimizer.optimize(m, ['eliminate_deadend']) + m = optimizer.optimize(m, ["eliminate_deadend"]) eliminating.eliminate_consecutive_reshape(m.graph) eliminating.eliminate_Squeeze_before_Reshape(m.graph) @@ -253,6 +262,6 @@ def postprocess(m): m = other.polish_model(m) other.add_output_to_value_info(m.graph) - m = optimizer.optimize(m, ['eliminate_deadend']) - m.producer_name = 'kneron_formatter' + m = optimizer.optimize(m, ["eliminate_deadend"]) + m.producer_name = "kneron_formatter" return m diff --git a/tools/optimizer_scripts/tools/common_pattern.py b/tools/optimizer_scripts/tools/common_pattern.py index b65d5bd..19d4b35 100644 --- a/tools/optimizer_scripts/tools/common_pattern.py +++ b/tools/optimizer_scripts/tools/common_pattern.py @@ -3,19 +3,20 @@ import numpy as np import onnx.helper import onnx.utils -from . import modhelper from . import helper from . import other + def torch_pattern_match(m): # Create a map from optype to the nodes. optype2node = defaultdict(list) for node in m.graph.node: optype2node[node.op_type].append(node) - for matmul_node in optype2node['MatMul']: + for matmul_node in optype2node["MatMul"]: pattern_matmul_mul_add(m.graph, matmul_node) - for resize_node in optype2node['Resize']: - # torch nn.UpsamplingBilinear2d will be given us 4 input: "X, roi, scales, sizes" + for resize_node in optype2node["Resize"]: + # torch nn.UpsamplingBilinear2d will be given us 4 input: + # "X, roi, scales, sizes" if len(resize_node.input) != 4: continue make_UpsamplingBilinear2d_value_info(m.graph, resize_node.name) @@ -24,15 +25,17 @@ def torch_pattern_match(m): m = other.polish_model(m) return m + def tf_pattern_match(m): # Create a map from optype to the nodes. optype2node = defaultdict(list) for node in m.graph.node: optype2node[node.op_type].append(node) - for matmul_node in optype2node['MatMul']: + for matmul_node in optype2node["MatMul"]: pattern_matmul_mul_add(m.graph, matmul_node) - for resize_node in optype2node['Resize']: - # In tensorflow2onnx, ReizeXXX will be given us 4 input: "X, roi, scales, sizes" + for resize_node in optype2node["Resize"]: + # In tensorflow2onnx, ReizeXXX will be given us 4 input: + # "X, roi, scales, sizes" # and node output name will be given the "node name + :0" if len(resize_node.input) != 4: continue @@ -42,24 +45,25 @@ def tf_pattern_match(m): m = other.polish_model(m) return m + def pattern_matmul_mul_add(g, matmul_node): # Check node match - Mul node next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0]) if len(next_nodes) != 1: return - if next_nodes[0].op_type != 'Mul': + if next_nodes[0].op_type != "Mul": return mul_node = next_nodes[0] # Check node match - Add node next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0]) if len(next_nodes) != 1: return - if next_nodes[0].op_type != 'Add': + if next_nodes[0].op_type != "Add": return add_node = next_nodes[0] # Check Mul weight mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1]) - if mul_weight_node.op_type != 'Constant': + if mul_weight_node.op_type != "Constant": return weight_size, mul_weight = helper.constant_to_list(mul_weight_node) for i in mul_weight: @@ -68,15 +72,19 @@ def pattern_matmul_mul_add(g, matmul_node): channel = weight_size[0] # Check Add weight add_weight_node = helper.find_node_by_output_name(g, add_node.input[1]) - if add_weight_node.op_type != 'Constant': + if add_weight_node.op_type != "Constant": return # Check MatMul weight to see if it need weight broadcast - matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1]) + matmul_weight_node = helper.find_node_by_output_name( + g, matmul_node.input[1] + ) matmul_weight = helper.constant_to_numpy(matmul_weight_node) if matmul_weight.shape[1] == 1: # Weight broadcast new_matmul_weight = np.tile(matmul_weight, channel) - new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight) + new_matmul_weight_node = helper.numpy_to_constant( + matmul_weight_node.name, new_matmul_weight + ) g.node.remove(matmul_weight_node) g.node.extend([new_matmul_weight_node]) value = helper.find_value_by_name(g, matmul_weight_node.output[0]) @@ -93,14 +101,14 @@ def pattern_matmul_mul_add(g, matmul_node): g.value_info.remove(value) # Fuse Matmul and Add gemm_node = onnx.helper.make_node( - 'Gemm', + "Gemm", [matmul_node.input[0], matmul_node.input[1], add_node.input[1]], [add_node.output[0]], - name = matmul_node.name, - alpha = 1.0, - beta = 1.0, - transA = 0, - transB = 0 + name=matmul_node.name, + alpha=1.0, + beta=1.0, + transA=0, + transB=0, ) g.node.extend([gemm_node]) # Clean up @@ -111,6 +119,7 @@ def pattern_matmul_mul_add(g, matmul_node): g.value_info.remove(value) other.topological_sort(g) + def make_UpsamplingBilinear2d_value_info(g, resize_node_name): resize_node = helper.find_node_by_node_name(g, resize_node_name) @@ -124,34 +133,45 @@ def make_UpsamplingBilinear2d_value_info(g, resize_node_name): new_output_value_info = onnx.helper.make_tensor_value_info( resize_node.output[0], onnx.helper.TensorProto.FLOAT, - shape_data.tolist() + shape_data.tolist(), ) g.value_info.extend([new_output_value_info]) + def polish_RESIZE_input_param_node(g, resize_node_name): resize_node = helper.find_node_by_node_name(g, resize_node_name) shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3]) shape_data = helper.constant_to_numpy(shape_data_node).astype(int) - - # handle 0 batch size which is invalid + + # handle 0 batch size which is invalid if shape_data[0] == 0: shape_data[0] = 1 - pre_node_output_value_info = helper.find_value_by_name(g, resize_node.input[0]) - ori_shape = np.array([pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value, - pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value, - pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value, - pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value]) - - resize_node.input.remove(resize_node.input[3]) - + pre_node_output_value_info = helper.find_value_by_name( + g, resize_node.input[0] + ) + ori_shape = np.array( + [ + pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value, + pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value, + pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value, + pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value, + ] + ) - resize_scales = np.array(shape_data/ori_shape).astype(float) - resize_scale_node = helper.list_to_constant('resize_scales_node_' + resize_node.name, resize_scales.shape, resize_scales, data_type=onnx.helper.TensorProto.FLOAT) + resize_node.input.remove(resize_node.input[3]) + + resize_scales = np.array(shape_data / ori_shape).astype(float) + resize_scale_node = helper.list_to_constant( + "resize_scales_node_" + resize_node.name, + resize_scales.shape, + resize_scales, + data_type=onnx.helper.TensorProto.FLOAT, + ) resize_node.input[2] = resize_scale_node.name g.node.extend([resize_scale_node]) - + other.topological_sort(g) diff --git a/tools/optimizer_scripts/tools/constant_folding.py b/tools/optimizer_scripts/tools/constant_folding.py index 8149628..45ef674 100644 --- a/tools/optimizer_scripts/tools/constant_folding.py +++ b/tools/optimizer_scripts/tools/constant_folding.py @@ -5,15 +5,14 @@ import logging import traceback from . import helper -from .general_graph import Graph, Node from .other import topological_sort -from .replacing import replace_shape_with_constant from .helper import logger + def are_all_inputs_Constant_with_one_child(g, node): for input_name in node.input: input_node = helper.find_node_by_output_name(g, input_name) - if input_node is None or input_node.op_type != 'Constant': + if input_node is None or input_node.op_type != "Constant": return False relative_outputs = helper.find_nodes_by_input_name(g, input_name) if len(relative_outputs) > 1: @@ -28,7 +27,7 @@ def constant_folding(g): :return: If any node is folded, return True. Otherwise, return False. """ keep_folding = True # Keep the while loop - folded = False # Return value + folded = False # Return value try: # Before constant folding, duplicate the constant nodes. duplicate_constant_node(g) @@ -38,37 +37,47 @@ def constant_folding(g): # Check if the node is foldable if node.op_type not in constant_folding_nodes.keys(): continue - # Check if the parents of the node are all single follower constant node. + # Check if parents of the node are all + # single follower constant node. if not are_all_inputs_Constant_with_one_child(g, node): continue # Constant folding for the specific node if constant_folding_nodes[node.op_type](g, node): - logging.debug("Constant nodes and %s %s are folded.", - node.op_type, node.name) + logging.debug( + "Constant nodes and %s %s are folded.", + node.op_type, + node.name, + ) folded = True keep_folding = True else: logging.debug( - "Constant nodes and %s %s are skipped.", node.op_type, node.name) - except Exception as e: + "Constant nodes and %s %s are skipped.", + node.op_type, + node.name, + ) + except Exception: logger.error("An exception is raised while constant folding.") logger.error(traceback.format_exc()) return folded - def duplicate_constant_node(g): - """ Duplicate the constant node if its following nodes contain constant folding - nodes. Create and link the new constant nodes to the constant folding nodes. + """ + Duplicate the constant node if its following nodes contain + constant folding nodes. Create and link the new constant nodes + to the constant folding nodes. """ for node in g.node: # Find a valid constant node - if node.op_type != 'Constant': + if node.op_type != "Constant": continue output_val_info = helper.find_value_by_name(g, node.output[0]) if output_val_info is None: - print("Cannot inference the shape of Const node output: " + - node.output[0]) + print( + "Cannot inference the shape of Const node output: " + + node.output[0] + ) exit(1) data_shape = helper.get_shape_from_value_info(output_val_info) output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) @@ -78,30 +87,37 @@ def duplicate_constant_node(g): continue # Check if its following nodes are foldable - foldable_output_nodes = list(filter(lambda n: n.op_type in - constant_folding_nodes.keys(), output_nodes)) + foldable_output_nodes = list( + filter( + lambda n: n.op_type in constant_folding_nodes.keys(), + output_nodes, + ) + ) if not foldable_output_nodes: continue # Duplicate the node needed by foldable nodes for i in range(len(foldable_output_nodes)): - logging.debug("Found constant %s and %s %s are availble for folding. Duplicate constant.", - node.name, foldable_output_nodes[i].op_type, foldable_output_nodes[i].name) - output_name = node.output[0] + '_dup_' + str(i) + logging.debug( + f"Found constant {node.name} and " + f"{foldable_output_nodes[i].op_type} " + f"{foldable_output_nodes[i].name} are availble for folding. " + "Duplicate constant.", + ) + output_name = node.output[0] + "_dup_" + str(i) new_constant_node = onnx.helper.make_node( - 'Constant', + "Constant", [], [output_name], name=output_name, - value=node.attribute[0].t + value=node.attribute[0].t, ) new_val_info = onnx.helper.make_tensor_value_info( - output_name, - node.attribute[0].t.data_type, - data_shape + output_name, node.attribute[0].t.data_type, data_shape ) input_ind = list(foldable_output_nodes[i].input).index( - node.output[0]) + node.output[0] + ) foldable_output_nodes[i].input[input_ind] = output_name g.node.extend([new_constant_node]) @@ -116,6 +132,7 @@ def duplicate_constant_node(g): return + def slice_constant_folding(g, node): op_version = helper.get_current_opset_version() # only support opset 9 & 11 @@ -124,9 +141,9 @@ def slice_constant_folding(g, node): elif op_version == 9: return slice_constant_folding_Opset_9(g, node) + def slice_constant_folding_Opset_11(g, node): - """ Fold constant and slice nodes to a single constant node. - """ + """Fold constant and slice nodes to a single constant node.""" pre_node = helper.find_node_by_output_name(g, node.input[0]) pre_shape, data_list = helper.constant_to_list(pre_node) @@ -136,20 +153,26 @@ def slice_constant_folding_Opset_11(g, node): ends_node = helper.find_node_by_output_name(g, node.input[2]) _, ends = helper.constant_to_list(ends_node) - - axes_node = None if len(node.input) <= 3 else helper.find_node_by_output_name(g, node.input[3]) + axes_node = ( + None + if len(node.input) <= 3 + else helper.find_node_by_output_name(g, node.input[3]) + ) if not axes_node: axes = list(range(len(helper.get_shape(data_list)))) else: _, axes = helper.constant_to_list(axes_node) - steps_node = None if len(node.input) <= 4 else helper.find_node_by_output_name(g, node.input[4]) + steps_node = ( + None + if len(node.input) <= 4 + else helper.find_node_by_output_name(g, node.input[4]) + ) if not steps_node: - steps = [1]*len(helper.get_shape(data_list)) + steps = [1] * len(helper.get_shape(data_list)) else: _, steps = helper.constant_to_list(steps_node) - data_list = list(map(int, data_list)) starts = list(map(int, starts)) ends = list(map(int, ends)) @@ -160,10 +183,15 @@ def slice_constant_folding_Opset_11(g, node): new_data = None for idx, _ in enumerate(axes): - new_data = np.apply_along_axis( lambda x: x[starts[idx] : ends[idx] : steps[idx]], idx, data_list ) + new_data = np.apply_along_axis( + lambda x: x[starts[idx]:ends[idx]:steps[idx]], idx, data_list + ) - new_node = helper.list_to_constant(node.output[0], helper.get_shape( - new_data), helper.flatten_to_list(new_data)) + new_node = helper.list_to_constant( + node.output[0], + helper.get_shape(new_data), + helper.flatten_to_list(new_data), + ) g.node.extend([new_node]) value_info = helper.find_value_by_name(g, pre_node.output[0]) if value_info is not None: @@ -173,16 +201,16 @@ def slice_constant_folding_Opset_11(g, node): return True + def slice_constant_folding_Opset_9(g, node): - """ Fold constant and slice nodes to a single constant node. - """ + """Fold constant and slice nodes to a single constant node.""" pre_node = helper.find_node_by_output_name(g, node.input[0]) pre_shape, data_list = helper.constant_to_list(pre_node) data_list = np.reshape(data_list, pre_shape) - axes = helper.get_attribute_by_name(node, 'axes') - ends = list(helper.get_attribute_by_name(node, 'ends').ints) - starts = list(helper.get_attribute_by_name(node, 'starts').ints) + axes = helper.get_attribute_by_name(node, "axes") + ends = list(helper.get_attribute_by_name(node, "ends").ints) + starts = list(helper.get_attribute_by_name(node, "starts").ints) if not axes: axes = list(range(len(helper.get_shape(data_list)))) @@ -190,8 +218,11 @@ def slice_constant_folding_Opset_9(g, node): axes = list(axes.ints) new_data = helper.slice_data(data_list, starts, ends, axes) - new_node = helper.list_to_constant(node.output[0], helper.get_shape( - new_data), helper.flatten_to_list(new_data)) + new_node = helper.list_to_constant( + node.output[0], + helper.get_shape(new_data), + helper.flatten_to_list(new_data), + ) g.node.extend([new_node]) value_info = helper.find_value_by_name(g, pre_node.output[0]) if value_info is not None: @@ -201,9 +232,9 @@ def slice_constant_folding_Opset_9(g, node): return True + def cast_constant_folding(g, node): - """ Fold constant and cast node to a single constant node. - """ + """Fold constant and cast node to a single constant node.""" pre_node = helper.find_node_by_output_name(g, node.input[0]) shape, data = helper.constant_to_list(pre_node) data_type = node.attribute[0].i @@ -212,28 +243,24 @@ def cast_constant_folding(g, node): elif data_type == onnx.helper.TensorProto.FLOAT: data = list(map(float, data)) else: - raise RuntimeError('data type not supported') + raise RuntimeError("data type not supported") if shape == 1: tensor = onnx.helper.make_tensor( name=pre_node.attribute[0].name, data_type=data_type, dims=[], - vals=data + vals=data, ) else: tensor = onnx.helper.make_tensor( name=pre_node.attribute[0].name, data_type=data_type, dims=shape, - vals=helper.flatten_to_list(data) + vals=helper.flatten_to_list(data), ) new_node = onnx.helper.make_node( - 'Constant', - [], - [node.output[0]], - name=node.output[0], - value=tensor + "Constant", [], [node.output[0]], name=node.output[0], value=tensor ) g.node.extend([new_node]) @@ -250,15 +277,14 @@ def cast_constant_folding(g, node): def reduceprod_constant_folding(g, node): - """ Fold constant and reduceprod nodes to a single constant node. - """ + """Fold constant and reduceprod nodes to a single constant node.""" pre_node = helper.find_node_by_output_name(g, node.input[0]) shape, data_set = helper.constant_to_list(pre_node) tensor = pre_node.attribute[0].t data_set = np.reshape(data_set, shape) for att in node.attribute: - if att.name == 'axes': + if att.name == "axes": axes = list(att.ints) else: keepdims = int(att.i) @@ -270,14 +296,10 @@ def reduceprod_constant_folding(g, node): name=node.output[0], data_type=tensor.data_type, dims=new_shape, - vals=new_flat_data + vals=new_flat_data, ) new_node = onnx.helper.make_node( - 'Constant', - [], - [node.output[0]], - name=node.output[0], - value=new_tensor + "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor ) g.node.extend([new_node]) @@ -294,8 +316,7 @@ def reduceprod_constant_folding(g, node): def reshape_constant_input_folding(g, node): - """ Fold constant and reshape nodes to a single constant node. - """ + """Fold constant and reshape nodes to a single constant node.""" pre_data_node = helper.find_node_by_output_name(g, node.input[0]) pre_shape_node = helper.find_node_by_output_name(g, node.input[1]) @@ -307,14 +328,10 @@ def reshape_constant_input_folding(g, node): name=node.output[0], data_type=pre_data_node.attribute[0].t.data_type, dims=new_data.shape, - vals=helper.flatten_to_list(new_data) + vals=helper.flatten_to_list(new_data), ) new_node = onnx.helper.make_node( - 'Constant', - [], - [node.output[0]], - name=node.output[0], - value=new_tensor + "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor ) g.node.extend([new_node]) @@ -332,8 +349,7 @@ def reshape_constant_input_folding(g, node): def concat_constant_folding(g, node): - """ Fold constant and concat nodes to a single constant node. - """ + """Fold constant and concat nodes to a single constant node.""" node_to_del = [] valid_inputs = True for input_name in node.input: @@ -342,7 +358,7 @@ def concat_constant_folding(g, node): if len(input_node_output) > 1: valid_inputs = False break - if input_node.op_type != 'Constant': + if input_node.op_type != "Constant": valid_inputs = False break @@ -370,7 +386,7 @@ def concat_constant_folding(g, node): node.output[0], helper.get_shape(concat_data), helper.flatten_to_list(concat_data), - data_type=node_data_type + data_type=node_data_type, ) g.node.extend([new_node]) node_to_del.append(node) @@ -388,8 +404,7 @@ def concat_constant_folding(g, node): def transpose_constant_folding(g, node): - """Fold constant and transpose nodes to a single constant node. - """ + """Fold constant and transpose nodes to a single constant node.""" node_to_del = [] pre_node = helper.find_node_by_output_name(g, node.input[0]) shape, data = helper.constant_to_list(pre_node) @@ -402,7 +417,7 @@ def transpose_constant_folding(g, node): node.output[0], new_shape, new_data.flatten().tolist(), - data_type=pre_node.attribute[0].t.data_type + data_type=pre_node.attribute[0].t.data_type, ) g.node.extend([new_node]) @@ -415,9 +430,7 @@ def transpose_constant_folding(g, node): g.value_info.remove(next_val_info) new_val_info = onnx.helper.make_tensor_value_info( - node.output[0], - pre_node.attribute[0].t.data_type, - new_shape + node.output[0], pre_node.attribute[0].t.data_type, new_shape ) g.value_info.extend([new_val_info]) @@ -430,8 +443,7 @@ def transpose_constant_folding(g, node): def unsqueeze_constant_folding(g, node): - """Fold constant and unsqueeze nodes to a single constant node. - """ + """Fold constant and unsqueeze nodes to a single constant node.""" node_to_del = [] pre_node = helper.find_node_by_output_name(g, node.input[0]) shape, data = helper.constant_to_list(pre_node) @@ -449,7 +461,7 @@ def unsqueeze_constant_folding(g, node): node.output[0], new_shape, np_data.flatten().tolist(), - data_type=pre_node.attribute[0].t.data_type + data_type=pre_node.attribute[0].t.data_type, ) g.node.extend([new_node]) node_to_del.extend([node, pre_node]) @@ -464,9 +476,7 @@ def unsqueeze_constant_folding(g, node): g.value_info.remove(next_val_info) new_val_info = onnx.helper.make_tensor_value_info( - node.output[0], - pre_node.attribute[0].t.data_type, - new_shape + node.output[0], pre_node.attribute[0].t.data_type, new_shape ) g.value_info.extend([new_val_info]) @@ -478,8 +488,7 @@ def unsqueeze_constant_folding(g, node): def gather_constant_folding(g, node): - """Fold constant and gather nodes to a single constant node. - """ + """Fold constant and gather nodes to a single constant node.""" node_to_del = [] pre_data_node = helper.find_node_by_output_name(g, node.input[0]) @@ -502,7 +511,7 @@ def gather_constant_folding(g, node): node.output[0], new_shape, new_data.flatten().tolist(), - data_type=pre_data_node.attribute[0].t.data_type + data_type=pre_data_node.attribute[0].t.data_type, ) node_to_del.extend([node, pre_data_node, pre_indices_node]) @@ -512,9 +521,7 @@ def gather_constant_folding(g, node): val_info_2 = helper.find_value_by_name(g, node.input[1]) val_info_3 = helper.find_value_by_name(g, node.output[0]) new_val_info = onnx.helper.make_tensor_value_info( - new_node.output[0], - pre_data_node.attribute[0].t.data_type, - new_shape + new_node.output[0], pre_data_node.attribute[0].t.data_type, new_shape ) if val_info_1 is not None: @@ -533,8 +540,7 @@ def gather_constant_folding(g, node): def add_constant_folding(g, node): - """Fold constant and add nodes to a single constant node. - """ + """Fold constant and add nodes to a single constant node.""" node_to_del = [] pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) @@ -547,14 +553,14 @@ def add_constant_folding(g, node): np_data2 = np.reshape(data2, shape2) try: new_data = np.add(np_data1, np_data2) - except: - raise RuntimeError('can\'t broadcast and add two data sets') + except Exception: + raise RuntimeError("can't broadcast and add two data sets") new_node = helper.list_to_constant( node.output[0], new_data.shape, new_data.flatten().tolist(), - data_type=pre_node_1.attribute[0].t.data_type + data_type=pre_node_1.attribute[0].t.data_type, ) g.node.extend([new_node]) @@ -571,8 +577,7 @@ def add_constant_folding(g, node): def sqrt_constant_folding(g, node): - """ Fold constant and sqrt nodes to a single node. - """ + """Fold constant and sqrt nodes to a single node.""" node_to_del = [] pre_node = helper.find_node_by_output_name(g, node.input[0]) shape, data = helper.constant_to_list(pre_node) @@ -582,17 +587,13 @@ def sqrt_constant_folding(g, node): data_type = output_val_info.type.tensor_type.elem_type new_tensor = onnx.helper.make_tensor( - name=node.output[0]+'_data', + name=node.output[0] + "_data", data_type=data_type, dims=shape, - vals=np_data.flatten().tolist() + vals=np_data.flatten().tolist(), ) new_node = onnx.helper.make_node( - 'Constant', - [], - [node.output[0]], - name=node.output[0], - value=new_tensor + "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor ) g.value_info.remove(input_val_info) @@ -607,13 +608,12 @@ def sqrt_constant_folding(g, node): def reciprocal_constant_folding(g, node): - """ Fold constant and reciprocal nodes to a single constant node. - """ + """Fold constant and reciprocal nodes to a single constant node.""" node_to_del = [] pre_node = helper.find_node_by_output_name(g, node.input[0]) shape, data = helper.constant_to_list(pre_node) - data = list(map(lambda x: x if abs(x) > 1.e-8 else 1.e-8, data)) + data = list(map(lambda x: x if abs(x) > 1.0e-8 else 1.0e-8, data)) np_data = np.reshape(data, shape) np_data = np.reciprocal(np_data) @@ -622,17 +622,13 @@ def reciprocal_constant_folding(g, node): data_type = output_val_info.type.tensor_type.elem_type new_tensor = onnx.helper.make_tensor( - name=node.output[0]+'_data', + name=node.output[0] + "_data", data_type=data_type, dims=shape, - vals=np_data.flatten().tolist() + vals=np_data.flatten().tolist(), ) new_node = onnx.helper.make_node( - 'Constant', - [], - [node.output[0]], - name=node.output[0], - value=new_tensor + "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor ) node_to_del.extend([node, pre_node]) @@ -648,8 +644,7 @@ def reciprocal_constant_folding(g, node): def mul_constant_folding(g, node): - """ Fold constant and mul nodes to a single constant node. - """ + """Fold constant and mul nodes to a single constant node.""" node_to_del = [] pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) @@ -666,8 +661,8 @@ def mul_constant_folding(g, node): try: new_data = np.multiply(np_data1, np_data2) - except: - raise RuntimeError('can not broadcast and multiply two data sets') + except Exception: + raise RuntimeError("can not broadcast and multiply two data sets") # Special shape for single element. if shape1 == 1 and shape2 == 1: @@ -676,17 +671,13 @@ def mul_constant_folding(g, node): new_shape = new_data.shape new_tensor = onnx.helper.make_tensor( - name=node.output[0]+'_data', + name=node.output[0] + "_data", data_type=pre_node_1.attribute[0].t.data_type, dims=new_shape, - vals=new_data.flatten().tolist() + vals=new_data.flatten().tolist(), ) new_node = onnx.helper.make_node( - 'Constant', - [], - [node.output[0]], - name=node.output[0], - value=new_tensor + "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor ) node_to_del.extend([node, pre_node_1, pre_node_2]) @@ -703,8 +694,7 @@ def mul_constant_folding(g, node): def div_constant_folding(g, node): - """ Fold constant and mul nodes to a single constant node. - """ + """Fold constant and mul nodes to a single constant node.""" node_to_del = [] pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) @@ -721,8 +711,8 @@ def div_constant_folding(g, node): try: new_data = np.divide(np_data1, np_data2) - except: - raise RuntimeError('can not broadcast and multiply two data sets') + except Exception: + raise RuntimeError("can not broadcast and multiply two data sets") # Special shape for single element. if shape1 == 1 and shape2 == 1: @@ -732,20 +722,16 @@ def div_constant_folding(g, node): # Check data type if it is int if pre_node_1.attribute[0].t.data_type == 7: - new_data = new_data.astype('int64') + new_data = new_data.astype("int64") new_tensor = onnx.helper.make_tensor( - name=node.output[0]+'_data', + name=node.output[0] + "_data", data_type=pre_node_1.attribute[0].t.data_type, dims=new_shape, - vals=new_data.flatten().tolist() + vals=new_data.flatten().tolist(), ) new_node = onnx.helper.make_node( - 'Constant', - [], - [node.output[0]], - name=node.output[0], - value=new_tensor + "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor ) node_to_del.extend([node, pre_node_1, pre_node_2]) @@ -762,8 +748,7 @@ def div_constant_folding(g, node): def sub_constant_folding(g, node): - """ Fold constant and sub nodes to a single node. - """ + """Fold constant and sub nodes to a single node.""" node_to_del = [] pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) @@ -781,17 +766,13 @@ def sub_constant_folding(g, node): new_shape = new_data.shape new_tensor = onnx.helper.make_tensor( - name=node.output[0]+'_data', + name=node.output[0] + "_data", data_type=pre_node_1.attribute[0].t.data_type, dims=new_shape, - vals=helper.flatten_to_list(new_data) + vals=helper.flatten_to_list(new_data), ) new_node = onnx.helper.make_node( - 'Constant', - [], - [node.output[0]], - name=node.output[0], - value=new_tensor + "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor ) g.node.extend([new_node]) @@ -815,17 +796,13 @@ def neg_constant_folding(g, node): new_data_list = [-num for num in data_list] new_tensor = onnx.helper.make_tensor( - name=pre_node.name+'_neg_tensor', + name=pre_node.name + "_neg_tensor", data_type=pre_node.attribute[0].t.data_type, dims=shape, - vals=new_data_list + vals=new_data_list, ) new_node = onnx.helper.make_node( - 'Constant', - [], - [node.output[0]], - name=node.output[0], - value=new_tensor + "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor ) g.node.extend([new_node]) @@ -851,17 +828,13 @@ def floor_constant_folding(g, node): new_shape = shape new_tensor = onnx.helper.make_tensor( - name=node.output[0]+'_data', + name=node.output[0] + "_data", data_type=pre_node.attribute[0].t.data_type, dims=new_shape, - vals=helper.flatten_to_list(new_data) + vals=helper.flatten_to_list(new_data), ) new_node = onnx.helper.make_node( - 'Constant', - [], - [node.output[0]], - name=node.output[0], - value=new_tensor + "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor ) g.node.extend([new_node]) @@ -877,8 +850,7 @@ def floor_constant_folding(g, node): def bn_constant_folding(g, node): - """ Fold constant and mul nodes to a single constant node. - """ + """Fold constant and mul nodes to a single constant node.""" # Prepare data node_to_del = [] input_node = helper.find_node_by_output_name(g, node.input[0]) @@ -900,17 +872,22 @@ def bn_constant_folding(g, node): mean_data = helper.constant_to_numpy(mean_node) var_data = helper.constant_to_numpy(var_node) - epsilon = helper.get_var_attribute_by_name(node, 'epsilon', 'float') + epsilon = helper.get_var_attribute_by_name(node, "epsilon", "float") if epsilon is None: epsilon = 0.00001 # Calculate new node - new_data = scale_data * (input_data - mean_data) / np.sqrt(var_data + epsilon) + bias_data + new_data = ( + scale_data * (input_data - mean_data) / np.sqrt(var_data + epsilon) + + bias_data + ) new_node = helper.numpy_to_constant(node.output[0], new_data) # Reconnect the graph - node_to_del.extend([node, input_node, scale_node, bias_node, mean_node, var_node]) + node_to_del.extend( + [node, input_node, scale_node, bias_node, mean_node, var_node] + ) g.node.extend([new_node]) for value in input_value_info: @@ -925,8 +902,7 @@ def bn_constant_folding(g, node): def DequantizeLinear_constant_folding(g, node): - """ Fold constant and mul nodes to a single constant node. - """ + """Fold constant and mul nodes to a single constant node.""" # Prepare data node_to_del = [] x_node = helper.find_node_by_output_name(g, node.input[0]) @@ -951,7 +927,9 @@ def DequantizeLinear_constant_folding(g, node): x_zero_point_data = np.array([0.0]) # Calculate new node - new_data = (x_data.astype(np.float32) - x_zero_point_data.astype(np.float32)) * x_scale_data + new_data = ( + x_data.astype(np.float32) - x_zero_point_data.astype(np.float32) + ) * x_scale_data new_node = helper.numpy_to_constant(node.output[0], new_data) @@ -974,22 +952,22 @@ def DequantizeLinear_constant_folding(g, node): # Available constant folding names to function map. constant_folding_nodes = { - 'Add': add_constant_folding, - 'BatchNormalization': bn_constant_folding, - 'Cast': cast_constant_folding, - 'Concat': concat_constant_folding, - 'DequantizeLinear': DequantizeLinear_constant_folding, - 'Div': div_constant_folding, - 'Floor': floor_constant_folding, - 'Gather': gather_constant_folding, - 'Mul': mul_constant_folding, - 'Reciprocal': reciprocal_constant_folding, - 'ReduceProd': reduceprod_constant_folding, - 'Reshape': reshape_constant_input_folding, - 'Slice': slice_constant_folding, - 'Sqrt': sqrt_constant_folding, - 'Transpose': transpose_constant_folding, - 'Unsqueeze': unsqueeze_constant_folding, - 'Sub': sub_constant_folding, - 'Neg': neg_constant_folding + "Add": add_constant_folding, + "BatchNormalization": bn_constant_folding, + "Cast": cast_constant_folding, + "Concat": concat_constant_folding, + "DequantizeLinear": DequantizeLinear_constant_folding, + "Div": div_constant_folding, + "Floor": floor_constant_folding, + "Gather": gather_constant_folding, + "Mul": mul_constant_folding, + "Reciprocal": reciprocal_constant_folding, + "ReduceProd": reduceprod_constant_folding, + "Reshape": reshape_constant_input_folding, + "Slice": slice_constant_folding, + "Sqrt": sqrt_constant_folding, + "Transpose": transpose_constant_folding, + "Unsqueeze": unsqueeze_constant_folding, + "Sub": sub_constant_folding, + "Neg": neg_constant_folding, } diff --git a/tools/optimizer_scripts/tools/eliminating.py b/tools/optimizer_scripts/tools/eliminating.py index bc22b2e..7871665 100644 --- a/tools/optimizer_scripts/tools/eliminating.py +++ b/tools/optimizer_scripts/tools/eliminating.py @@ -7,6 +7,7 @@ from . import helper from . import modhelper from .general_graph import Graph + def eliminate_Identify_and_Dropout(g): """ Eliminate Identify layers @@ -15,31 +16,46 @@ def eliminate_Identify_and_Dropout(g): """ node_to_remove = [] for node in g.node: - if node.op_type != 'Identity' and node.op_type != 'Dropout': + if node.op_type != "Identity" and node.op_type != "Dropout": continue - # If this node is the last node, leave it to `eliminate_useless_last node` + # If this node is the last, leave it to `eliminate_useless_last node` if helper.find_output_by_name(g, node.output[0]) is not None: continue # Replace the parents in all the following nodes - following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, node.output[0] + ) for following_node in following_nodes: - modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + modhelper.replace_node_input( + following_node, node.output[0], node.input[0] + ) # Delete value info value_between = helper.find_value_by_name(g, node.output[0]) try: g.value_info.remove(value_between) - except: + except Exception: print("No value info to delete while eliminating identity layers.") # Node is waiting for elimination node_to_remove.append(node) for node in node_to_remove: g.node.remove(node) + # Remove last useless nodes def remove_useless_last_nodes(g): - """Remove useless nodes from the tail of the graph - """ - USELESS = ["Reshape", "Identity", "Transpose", "Flatten", "Dropout", "Mystery", "Constant", "Squeeze", "Unsqueeze", 'Softmax'] + """Remove useless nodes from the tail of the graph""" + USELESS = [ + "Reshape", + "Identity", + "Transpose", + "Flatten", + "Dropout", + "Mystery", + "Constant", + "Squeeze", + "Unsqueeze", + "Softmax", + ] graph = Graph(g) todo = collections.deque() for node in graph.output_nodes: @@ -54,19 +70,30 @@ def remove_useless_last_nodes(g): if cur_node.proto.op_type not in USELESS: continue # Find the output - cur_node_output = helper.find_output_by_name(g, cur_node.proto.output[0]) + cur_node_output = helper.find_output_by_name( + g, cur_node.proto.output[0] + ) for cur_input in cur_node.parents: cur_input.children.remove(cur_node) if len(cur_input.children) == 0: todo.append(cur_input) if cur_node_output is not None: - cur_input_output = helper.find_value_by_name(g, cur_input.proto.output[0]) - cur_input_output_in_output = helper.find_output_by_name(g, cur_input.proto.output[0]) - if cur_input_output is not None and cur_input_output_in_output is None: + cur_input_output = helper.find_value_by_name( + g, cur_input.proto.output[0] + ) + cur_input_output_in_output = helper.find_output_by_name( + g, cur_input.proto.output[0] + ) + if ( + cur_input_output is not None + and cur_input_output_in_output is None + ): g.output.extend([cur_input_output]) node_to_remove.append(cur_node.proto) try: - g.value_info.remove(helper.find_value_by_name(g, cur_node.proto.output[0])) + g.value_info.remove( + helper.find_value_by_name(g, cur_node.proto.output[0]) + ) except ValueError: pass if cur_node_output is not None: @@ -76,10 +103,12 @@ def remove_useless_last_nodes(g): for node in node_to_remove: g.node.remove(node) + ###################################### # TF only optimization passes # ###################################### + def eliminate_shape_changing_after_input(g): """ Eliminate the Reshape node after input and reshape the input @@ -87,7 +116,14 @@ def eliminate_shape_changing_after_input(g): :param g: the onnx graph """ node_to_remove = [] - REMOVE_LIST = ["Reshape", "Transpose", "Flatten", "Dropout", "Squeeze", "Unsqueeze"] + REMOVE_LIST = [ + "Reshape", + "Transpose", + "Flatten", + "Dropout", + "Squeeze", + "Unsqueeze", + ] for node in g.node: # Find an input and the shape node if node.op_type not in REMOVE_LIST: @@ -105,9 +141,9 @@ def eliminate_shape_changing_after_input(g): # Remove Weight if any. output_val_info = helper.find_value_by_name(g, node.output[0]) - if node.op_type == 'Reshape': + if node.op_type == "Reshape": shape_node = helper.find_node_by_output_name(g, node.input[1]) - if shape_node.op_type != 'Constant': + if shape_node.op_type != "Constant": continue # manuelly set the input shape @@ -117,25 +153,29 @@ def eliminate_shape_changing_after_input(g): _, new_shape = helper.constant_to_list(shape_node) for i in range(len(new_shape)): if new_shape[i] == -1: - dim = int(old_size//np.prod(new_shape)*(-1)) + dim = int(old_size // np.prod(new_shape) * (-1)) new_shape[i] = dim new_input = onnx.helper.make_tensor_value_info( output_val_info.name, output_val_info.type.tensor_type.elem_type, - new_shape + new_shape, ) node_to_remove.append(node) - shape_outputs = helper.find_nodes_by_input_name(g, shape_node.output[0]) + shape_outputs = helper.find_nodes_by_input_name( + g, shape_node.output[0] + ) if len(shape_outputs) == 1: node_to_remove.append(shape_node) - g.value_info.remove(helper.find_value_by_name(g, shape_node.output[0])) - + g.value_info.remove( + helper.find_value_by_name(g, shape_node.output[0]) + ) + g.input.remove(old_input) g.input.extend([new_input]) g.value_info.remove(output_val_info) - elif node.op_type == 'Transpose': + elif node.op_type == "Transpose": permutation = list(node.attribute[0].ints) pre_shape = helper.get_shape_from_value_info(old_input) new_shape = [pre_shape[i] for i in permutation] @@ -143,7 +183,7 @@ def eliminate_shape_changing_after_input(g): new_input = onnx.helper.make_tensor_value_info( output_val_info.name, output_val_info.type.tensor_type.elem_type, - new_shape + new_shape, ) node_to_remove.append(node) @@ -151,7 +191,7 @@ def eliminate_shape_changing_after_input(g): g.input.remove(old_input) g.input.extend([new_input]) g.value_info.remove(output_val_info) - elif node.op_type == 'Flatten': + elif node.op_type == "Flatten": axis = node.attribute[0].int pre_shape = helper.get_shape_from_value_info(old_input) dim_1, dim_2 = 1, 1 @@ -166,7 +206,7 @@ def eliminate_shape_changing_after_input(g): new_input = onnx.helper.make_tensor_value_info( output_val_info.name, output_val_info.type.tensor_type.elem_type, - new_shape + new_shape, ) node_to_remove.append(node) @@ -174,18 +214,18 @@ def eliminate_shape_changing_after_input(g): g.input.remove(old_input) g.input.extend([new_input]) g.value_info.remove(output_val_info) - elif node.op_type == 'Dropout': + elif node.op_type == "Dropout": g.input.remove(old_input) g.input.extend([output_val_info]) g.value_info.remove(output_val_info) - + node_to_remove.append(node) - elif node.op_type == 'Squeeze': + elif node.op_type == "Squeeze": axis = list(node.attribute[0].ints) pre_shape = helper.get_shape_from_value_info(old_input) for pos in sorted(axis)[::-1]: if pre_shape[pos] != 1: - raise RuntimeError('invalid axis for squeeze') + raise RuntimeError("invalid axis for squeeze") else: pre_shape.pop(pos) new_shape = pre_shape @@ -193,7 +233,7 @@ def eliminate_shape_changing_after_input(g): new_input = onnx.helper.make_tensor_value_info( output_val_info.name, output_val_info.type.tensor_type.elem_type, - new_shape + new_shape, ) node_to_remove.append(node) @@ -201,7 +241,7 @@ def eliminate_shape_changing_after_input(g): g.input.remove(old_input) g.input.extend([new_input]) g.value_info.remove(output_val_info) - elif node.op_type == 'Unsqueeze': + elif node.op_type == "Unsqueeze": axis = list(node.attribute[0].ints) pre_shape = helper.get_shape_from_value_info(old_input) new_shape = pre_shape @@ -210,7 +250,7 @@ def eliminate_shape_changing_after_input(g): new_input = onnx.helper.make_tensor_value_info( output_val_info.name, output_val_info.type.tensor_type.elem_type, - new_shape + new_shape, ) node_to_remove.append(node) @@ -222,7 +262,7 @@ def eliminate_shape_changing_after_input(g): for node in node_to_remove: g.node.remove(node) - + other.topological_sort(g) @@ -231,15 +271,13 @@ def eliminate_Reshape_Cast(g): :param g: the onnx graph """ - #Find all reshape layers - node_to_remove = [] + # Find all reshape layers for node in g.node: - if node.op_type != 'Reshape': + if node.op_type != "Reshape": continue prev_node = helper.find_node_by_output_name(g, node.input[1]) - if prev_node.op_type != 'Cast': + if prev_node.op_type != "Cast": continue - # Now we find the cast weight pattern. Cast the weight, delete the cast. reshape_node = node cast_node = prev_node weight_node = helper.find_node_by_output_name(g, cast_node.input[0]) @@ -248,10 +286,12 @@ def eliminate_Reshape_Cast(g): weight_node.attribute[0].t.data_type = 7 if weight_node.attribute[0].t.raw_data: raw_data = weight_node.attribute[0].t.raw_data - int_data = [i[0] for i in struct.iter_unpack('i', raw_data)] - raw_data = struct.pack('q' * len(int_data), *int_data) - elif len(weight_node.attribute[0].t.int64_data) > 0\ - or len(weight_node.attribute[0].t.int32_data) > 0: + int_data = [i[0] for i in struct.iter_unpack("i", raw_data)] + raw_data = struct.pack("q" * len(int_data), *int_data) + elif ( + len(weight_node.attribute[0].t.int64_data) > 0 + or len(weight_node.attribute[0].t.int32_data) > 0 + ): # It's already int. Do nothing pass else: @@ -264,6 +304,7 @@ def eliminate_Reshape_Cast(g): g.value_info.remove(origin_weight_out) g.node.remove(cast_node) + def eliminate_Cast_after_input(g): """Eliminate the cast layer right after the input @@ -271,7 +312,7 @@ def eliminate_Cast_after_input(g): """ node_to_remove = [] for node in g.node: - if node.op_type != 'Cast': + if node.op_type != "Cast": continue old_input = helper.find_input_by_name(g, node.input[0]) if old_input is None: @@ -279,9 +320,7 @@ def eliminate_Cast_after_input(g): next_val_info = helper.find_value_by_name(g, node.output[0]) shape = helper.get_shape_from_value_info(next_val_info) new_val_info = onnx.helper.make_tensor_value_info( - next_val_info.name, - node.attribute[0].i, - shape + next_val_info.name, node.attribute[0].i, shape ) # Delete old value_info g.input.remove(old_input) @@ -293,6 +332,7 @@ def eliminate_Cast_after_input(g): for node in node_to_remove: g.node.remove(node) + def eliminate_consecutive_Cast(g): """If two cast is next to each other, remove the first cast @@ -300,10 +340,10 @@ def eliminate_consecutive_Cast(g): """ node_to_remove = [] for node in g.node: - if node.op_type != 'Cast': + if node.op_type != "Cast": continue first_node = helper.find_node_by_output_name(g, node.input[0]) - if first_node is None or first_node.op_type != 'Cast': + if first_node is None or first_node.op_type != "Cast": continue # Here we have two consecutive Cast Node # Reset the input of the later node @@ -315,6 +355,7 @@ def eliminate_consecutive_Cast(g): for node in node_to_remove: g.node.remove(node) + def eliminate_Squeeze_before_Reshape(g): """If Squeeze and Reshape is next to each other, remove the first node @@ -322,12 +363,12 @@ def eliminate_Squeeze_before_Reshape(g): """ node_to_remove = [] for node in g.node: - if node.op_type != 'Reshape': + if node.op_type != "Reshape": continue first_node = helper.find_node_by_output_name(g, node.input[0]) if not first_node: continue - if first_node.op_type != 'Squeeze': + if first_node.op_type != "Squeeze": continue # Here we have two consecutive Cast Node # Reset the input of the later node @@ -339,9 +380,9 @@ def eliminate_Squeeze_before_Reshape(g): for node in node_to_remove: g.node.remove(node) + def eliminate_no_children_input(g): - """Eliminate inputs with no children at all. - """ + """Eliminate inputs with no children at all.""" # Create a set of input names input_names = set([i.name for i in g.input]) # If a name is used in any node, remove this name from the set. @@ -353,31 +394,33 @@ def eliminate_no_children_input(g): info = helper.find_input_by_name(g, i) g.input.remove(info) + def eliminate_consecutive_reshape(g): - """Replace consecutive reshape nodes by a single node. - """ + """Replace consecutive reshape nodes by a single node.""" node_to_del = [] for node in g.node: - if node.op_type != 'Reshape': + if node.op_type != "Reshape": continue pre_data_node = helper.find_node_by_output_name(g, node.input[0]) pre_shape_node = helper.find_node_by_output_name(g, node.input[1]) if not pre_data_node or not pre_shape_node: continue - if pre_shape_node.op_type != 'Constant': + if pre_shape_node.op_type != "Constant": continue - if pre_data_node.op_type != 'Reshape': + if pre_data_node.op_type != "Reshape": continue - - pre_pre_shape_node = helper.find_node_by_output_name(g, pre_data_node.input[1]) - if pre_pre_shape_node.op_type != 'Constant': + + pre_pre_shape_node = helper.find_node_by_output_name( + g, pre_data_node.input[1] + ) + if pre_pre_shape_node.op_type != "Constant": continue new_reshape_node = onnx.helper.make_node( - 'Reshape', + "Reshape", [pre_data_node.input[0], node.input[1]], [node.output[0]], - name = node.output[0] + name=node.output[0], ) g.node.extend([new_reshape_node]) @@ -394,6 +437,7 @@ def eliminate_consecutive_reshape(g): node = node_to_del.pop() g.node.remove(node) + def eliminate_single_input_Concat(g): """ Eliminate single input Concat layers @@ -402,12 +446,12 @@ def eliminate_single_input_Concat(g): """ node_to_remove = [] for node in g.node: - if node.op_type != 'Concat': + if node.op_type != "Concat": continue # If this node has more than 1 input, continue. if len(node.input) > 1: continue - # If this node is the output node, set its previous node as output nodes. + # If this node is output node, set its previous node as output nodes. if helper.find_output_by_name(g, node.output[0]) is not None: todel_output = helper.find_output_by_name(g, node.output[0]) the_input_value = helper.find_value_by_name(g, node.input[0]) @@ -416,20 +460,25 @@ def eliminate_single_input_Concat(g): node_to_remove.append(node) continue # Replace the parents in all the following nodes - following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, node.output[0] + ) for following_node in following_nodes: - modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + modhelper.replace_node_input( + following_node, node.output[0], node.input[0] + ) # Delete value info value_between = helper.find_value_by_name(g, node.output[0]) try: g.value_info.remove(value_between) - except: + except Exception: print("No value info to delete while eliminating identity layers.") # Node is waiting for elimination node_to_remove.append(node) for node in node_to_remove: g.node.remove(node) + def eliminate_nop_Maxpool_and_AveragePool(g): """ Eliminate do nothing MaxPool and AveragePool layers. @@ -439,7 +488,7 @@ def eliminate_nop_Maxpool_and_AveragePool(g): """ node_to_remove = [] for node in g.node: - if node.op_type != 'MaxPool' and node.op_type != 'AveragePool': + if node.op_type != "MaxPool" and node.op_type != "AveragePool": continue # If this node is actually working, continue. kernel = helper.get_list_attribute_by_name(node, "kernel_shape", "int") @@ -447,7 +496,7 @@ def eliminate_nop_Maxpool_and_AveragePool(g): strides = helper.get_list_attribute_by_name(node, "strides", "int") if kernel != [1, 1] or pads != [0, 0, 0, 0] or strides != [1, 1]: continue - # If this node is the output node, set its previous node as output nodes. + # If this node is the output, set its previous node as output nodes. if helper.find_output_by_name(g, node.output[0]) is not None: todel_output = helper.find_output_by_name(g, node.output[0]) the_input_value = helper.find_value_by_name(g, node.input[0]) @@ -456,14 +505,18 @@ def eliminate_nop_Maxpool_and_AveragePool(g): node_to_remove.append(node) continue # Replace the parents in all the following nodes - following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, node.output[0] + ) for following_node in following_nodes: - modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + modhelper.replace_node_input( + following_node, node.output[0], node.input[0] + ) # Delete value info value_between = helper.find_value_by_name(g, node.output[0]) try: g.value_info.remove(value_between) - except: + except Exception: print("No value info to delete while eliminating identity layers.") # Node is waiting for elimination node_to_remove.append(node) @@ -474,20 +527,20 @@ def eliminate_nop_Maxpool_and_AveragePool(g): def eliminate_trivial_maxpool(g): node_to_del = [] for node in g.node: - if node.op_type != 'MaxPool': + if node.op_type != "MaxPool": continue pads = None strides = None dilation = None kernel_shape = None for att in node.attribute: - if att.name == 'pads': + if att.name == "pads": pads = list(att.ints) - elif att.name == 'strides': + elif att.name == "strides": strides = list(att.ints) - elif att.name == 'kernel_shape': + elif att.name == "kernel_shape": kernel_shape = list(att.ints) - elif att.name == 'dilation': + elif att.name == "dilation": dilation = list(att.ints) else: pass @@ -504,7 +557,7 @@ def eliminate_trivial_maxpool(g): next_nodes = helper.find_nodes_by_input_name(g, node.output[0]) - if next_nodes[0] == None: + if next_nodes[0] is None: output_value = helper.find_output_by_name(g, node.output[0]) if not output_value: continue @@ -512,18 +565,21 @@ def eliminate_trivial_maxpool(g): pre_val_info = helper.find_value_by_name(g, node.input[0]) g.output.extend([pre_val_info]) g.output.remove(output_value) - + for next_node in next_nodes: - modhelper.replace_node_input(next_node, node.output[0], node.input[0]) - + modhelper.replace_node_input( + next_node, node.output[0], node.input[0] + ) + next_val_info = helper.find_value_by_name(g, node.output[0]) g.value_info.remove(next_val_info) while node_to_del: g.node.remove(node_to_del.pop()) - + other.topological_sort(g) + def eliminate_empty_value_infos(g): to_remove = [] for value_info in g.value_info: @@ -532,10 +588,11 @@ def eliminate_empty_value_infos(g): for value_info in to_remove: g.value_info.remove(value_info) + def eliminate_nop_pads(g): node_to_remove = [] for node in g.node: - if node.op_type != 'Pad': + if node.op_type != "Pad": continue # Check if the Pad is empty or not pads_node = helper.find_node_by_output_name(g, node.input[1]) @@ -546,11 +603,7 @@ def eliminate_nop_pads(g): all_zero = False if not all_zero: continue - # Check if it has the constant_value_node - constant_value_node = None - if len(node.input) > 2: - constant_value_node = helper.find_node_by_output_name(g, node.input[2]) - # If this node is the output node, set its previous node as output nodes. + # If this node is the output, set its previous node as output nodes. if helper.find_output_by_name(g, node.output[0]) is not None: todel_output = helper.find_output_by_name(g, node.output[0]) g.output.remove(todel_output) @@ -559,38 +612,44 @@ def eliminate_nop_pads(g): if the_input_value is not None: g.output.extend([the_input_value]) # Replace the parents in all the following nodes - following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, node.output[0] + ) for following_node in following_nodes: - modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + modhelper.replace_node_input( + following_node, node.output[0], node.input[0] + ) # Delete value info value_between = helper.find_value_by_name(g, node.output[0]) try: g.value_info.remove(value_between) - except: - helper.logger.info("No value info to delete while eliminating identity layers.") + except Exception: + helper.logger.info( + "No value info to delete while eliminating identity layers." + ) # Node is waiting for elimination node_to_remove.append(node) for node in node_to_remove: g.node.remove(node) + def eliminate_trivial_elementwise_calculation(g): - """Eliminate Add, Sub, Mul, Sub nodes which do nothing. - """ + """Eliminate Add, Sub, Mul, Sub nodes which do nothing.""" node_to_remove = [] for node in g.node: weight_node = None - if node.op_type == 'Add' or node.op_type == 'Sub': + if node.op_type == "Add" or node.op_type == "Sub": # For add and sub, check if the weights are 0s. weight_node = helper.find_node_by_output_name(g, node.input[1]) - if weight_node is None or weight_node.op_type != 'Constant': + if weight_node is None or weight_node.op_type != "Constant": continue weight_np = helper.constant_to_numpy(weight_node) if np.any(weight_np): continue - elif node.op_type == 'Mul' or node.op_type == 'Div': + elif node.op_type == "Mul" or node.op_type == "Div": # For Mul and Div, check if the weights are 1s. weight_node = helper.find_node_by_output_name(g, node.input[1]) - if weight_node is None or weight_node.op_type != 'Constant': + if weight_node is None or weight_node.op_type != "Constant": continue weight_np = helper.constant_to_numpy(weight_node) weight_np = weight_np - 1 @@ -605,9 +664,13 @@ def eliminate_trivial_elementwise_calculation(g): if output_value_info is not None: g.value_info.remove(output_value_info) # Replace next node input if any. - following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, node.output[0] + ) for following_node in following_nodes: - modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + modhelper.replace_node_input( + following_node, node.output[0], node.input[0] + ) todel_output = helper.find_output_by_name(g, node.output[0]) if todel_output is not None: g.output.remove(todel_output) @@ -616,38 +679,53 @@ def eliminate_trivial_elementwise_calculation(g): the_input_value = helper.find_value_by_name(g, node.input[0]) g.output.extend([the_input_value]) # Delete the constant node if it is not used by other nodes - constant_following_nodes = helper.find_following_nodes_by_input_value_name(g, weight_node.output[0]) + constant_following_nodes = ( + helper.find_following_nodes_by_input_value_name( + g, weight_node.output[0] + ) + ) if len(constant_following_nodes) == 1: node_to_remove.append(weight_node) - output_value_info = helper.find_value_by_name(g, weight_node.output[0]) + output_value_info = helper.find_value_by_name( + g, weight_node.output[0] + ) if output_value_info is not None: g.value_info.remove(output_value_info) for node in node_to_remove: g.node.remove(node) + def eliminate_nop_cast(g): - """Eliminate do nothing Cast nodes. - """ + """Eliminate do nothing Cast nodes.""" node_to_remove = [] for node in g.node: - if node.op_type != 'Cast': + if node.op_type != "Cast": continue # Get input value_info input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: - helper.logger.debug(f"Cannot find the input value_info for Cast node {node.name}. Skip elimination check.") + helper.logger.debug( + f"Cannot find the input value_info for Cast node {node.name}. " + "Skip elimination check." + ) continue # Get output value_info output_value = helper.find_value_by_name(g, node.output[0]) if output_value is None: output_value = helper.find_output_by_name(g, node.output[0]) if output_value is None: - helper.logger.debug(f"Cannot find the output value_info for Cast node {node.name}. Skip elimination check.") + helper.logger.debug( + f"Cannot find the output value_info for Cast node {node.name}." + " Skip elimination check." + ) continue # Compare the type. - if input_value.type.tensor_type.elem_type != output_value.type.tensor_type.elem_type: + if ( + input_value.type.tensor_type.elem_type + != output_value.type.tensor_type.elem_type + ): continue - # If this node is the output node, set its previous node as output nodes. + # If this node is the output, set its previous node as output nodes. if helper.find_output_by_name(g, node.output[0]) is not None: todel_output = helper.find_output_by_name(g, node.output[0]) g.output.remove(todel_output) @@ -656,9 +734,13 @@ def eliminate_nop_cast(g): if the_input_value is not None: g.output.extend([the_input_value]) # Replace the parents in all the following nodes - following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, node.output[0] + ) for following_node in following_nodes: - modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + modhelper.replace_node_input( + following_node, node.output[0], node.input[0] + ) # Delete value info value_between = helper.find_value_by_name(g, node.output[0]) if value_between is not None: diff --git a/tools/optimizer_scripts/tools/fusing.py b/tools/optimizer_scripts/tools/fusing.py index 202a4c2..e19ca94 100644 --- a/tools/optimizer_scripts/tools/fusing.py +++ b/tools/optimizer_scripts/tools/fusing.py @@ -4,6 +4,7 @@ from . import helper from .other import topological_sort from .modhelper import delete_value_with_name_if_exists, replace_node_input + def fuse_Transpose_into_Constant(g): """ Fuse Transpose layers into the Constant layers before @@ -12,10 +13,10 @@ def fuse_Transpose_into_Constant(g): """ node_to_remove = [] for node in g.node: - if node.op_type != 'Transpose': + if node.op_type != "Transpose": continue prev_node = helper.find_node_by_output_name(g, node.input[0]) - if prev_node is None or prev_node.op_type != 'Constant': + if prev_node is None or prev_node.op_type != "Constant": continue pre_shape, data_list = helper.constant_to_list(prev_node) @@ -25,17 +26,17 @@ def fuse_Transpose_into_Constant(g): w = w.flatten() new_tensor = onnx.helper.make_tensor( - name=prev_node.name+'_data', + name=prev_node.name + "_data", data_type=prev_node.attribute[0].t.data_type, dims=new_shape, - vals=w.tolist() + vals=w.tolist(), ) new_node = onnx.helper.make_node( - 'Constant', + "Constant", [], [node.output[0]], name=node.output[0], - value=new_tensor + value=new_tensor, ) value_between = helper.find_value_by_name(g, prev_node.output[0]) @@ -48,13 +49,13 @@ def fuse_Transpose_into_Constant(g): if new_node.output[0] not in [i.name for i in g.value_info]: new_value = onnx.helper.make_tensor_value_info( - name=new_node.output[0], - elem_type=value_type, - shape=new_shape - ) + name=new_node.output[0], elem_type=value_type, shape=new_shape + ) g.value_info.extend([new_value]) if new_node.output[0]: - val_info_to_del = helper.find_value_by_name(g, new_node.output[0]) + val_info_to_del = helper.find_value_by_name( + g, new_node.output[0] + ) g.value_info.remove(val_info_to_del) for node in node_to_remove: @@ -62,6 +63,7 @@ def fuse_Transpose_into_Constant(g): topological_sort(g) + def fuse_Add_into_Conv(g): """ Fuse Transpose layers into the Constant layers before @@ -70,17 +72,17 @@ def fuse_Add_into_Conv(g): """ node_to_remove = [] for node in g.node: - if node.op_type != 'Add': + if node.op_type != "Add": continue conv_node = helper.find_node_by_output_name(g, node.input[0]) cons_node = helper.find_node_by_output_name(g, node.input[1]) if conv_node is None or cons_node is None: continue - if conv_node.op_type != 'Conv' or cons_node.op_type != 'Constant': + if conv_node.op_type != "Conv" or cons_node.op_type != "Constant": continue if len(conv_node.input) > 2: continue - # This layer should be fused. Connect constant node into convolution node. + # This layer should be fused. Connect constant node into convolution. add_node = node conv_node.input.extend([cons_node.output[0]]) old_value = helper.find_value_by_name(g, conv_node.output[0]) @@ -93,6 +95,7 @@ def fuse_Add_into_Conv(g): for node in node_to_remove: g.node.remove(node) + def fuse_BN_into_Gemm(g): """Fuse the following BN into the previous Gemm. @@ -101,14 +104,21 @@ def fuse_BN_into_Gemm(g): node_to_remove = [] for node in g.node: # Check for BN and Gemm - if node.op_type != 'BatchNormalization': + if node.op_type != "BatchNormalization": continue gemm_node = helper.find_node_by_output_name(g, node.input[0]) if gemm_node is None: continue - if gemm_node.op_type != 'Gemm': + if gemm_node.op_type != "Gemm": continue - if len(helper.find_following_nodes_by_input_value_name(g, gemm_node.output[0])) > 1: + if ( + len( + helper.find_following_nodes_by_input_value_name( + g, gemm_node.output[0] + ) + ) + > 1 + ): continue bn_node = node # Get original weights @@ -126,59 +136,67 @@ def fuse_BN_into_Gemm(g): bn_var = helper.constant_to_numpy(bn_var_node) # Apply attributes # epsilon - epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') + epsilon = helper.get_attribute_by_name(bn_node, "epsilon") if epsilon is None: epsilon = 0.00001 else: epsilon = epsilon.f bn_var = bn_var + epsilon # alpha - alpha = helper.get_attribute_by_name(gemm_node, 'alpha') + alpha = helper.get_attribute_by_name(gemm_node, "alpha") if alpha is None: alpha = 1 else: alpha = alpha.f gemm_b = gemm_b * alpha # beta - beta = helper.get_attribute_by_name(gemm_node, 'beta') + beta = helper.get_attribute_by_name(gemm_node, "beta") if beta is None: beta = 1 else: beta = beta.f gemm_c = gemm_c * beta # transA - transA = helper.get_attribute_by_name(gemm_node, 'transA') + transA = helper.get_attribute_by_name(gemm_node, "transA") if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") # transB - transB = helper.get_attribute_by_name(gemm_node, 'transB') + transB = helper.get_attribute_by_name(gemm_node, "transB") if transB is not None and transB.i == 1: gemm_b = gemm_b.transpose() # Calculate new weights new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias # Replace original weights - new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) - new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) + new_gemm_b_node = helper.numpy_to_constant( + gemm_b_node.name + "_fused", new_gemm_b + ) + new_gemm_c_node = helper.numpy_to_constant( + gemm_c_node.name + "_fused", new_gemm_c + ) g.node.extend([new_gemm_b_node, new_gemm_c_node]) - node_to_remove.extend([gemm_b_node, - gemm_c_node, - bn_node, - bn_scale_node, - bn_bias_node, - bn_mean_node, - bn_var_node]) + node_to_remove.extend( + [ + gemm_b_node, + gemm_c_node, + bn_node, + bn_scale_node, + bn_bias_node, + bn_mean_node, + bn_var_node, + ] + ) # Modify attributes # alpha - alpha = helper.get_attribute_by_name(gemm_node, 'alpha') + alpha = helper.get_attribute_by_name(gemm_node, "alpha") if alpha is not None: alpha.f = 1.0 # beta - beta = helper.get_attribute_by_name(gemm_node, 'beta') + beta = helper.get_attribute_by_name(gemm_node, "beta") if beta is not None: beta.f = 1.0 # transB - transB = helper.get_attribute_by_name(gemm_node, 'transB') + transB = helper.get_attribute_by_name(gemm_node, "transB") if transB is not None: transB.i = 0 # Connect the new graph @@ -199,6 +217,7 @@ def fuse_BN_into_Gemm(g): g.node.remove(node) topological_sort(g) + def fuse_BN_with_Reshape_into_Gemm(g): """Fuse the following BN into the previous Gemm, even with Reshape or \\ Squeeze and Unsqueeze surrounding. @@ -209,7 +228,7 @@ def fuse_BN_with_Reshape_into_Gemm(g): for node in g.node: # Check for BN and Gemm pattern: Gemm A BN B # Find BatchNorm Node - if node.op_type != 'BatchNormalization': + if node.op_type != "BatchNormalization": continue bn_node = node # Find A Node @@ -218,10 +237,12 @@ def fuse_BN_with_Reshape_into_Gemm(g): continue # Find Gemm Node gemm_node = helper.find_node_by_output_name(g, a_node.input[0]) - if gemm_node is None or gemm_node.op_type != 'Gemm': + if gemm_node is None or gemm_node.op_type != "Gemm": continue # Find B Node - b_node_list = helper.find_following_nodes_by_input_value_name(g, bn_node.output[0]) + b_node_list = helper.find_following_nodes_by_input_value_name( + g, bn_node.output[0] + ) if len(b_node_list) == 0: the_output = helper.find_output_by_name(g, bn_node.output[0]) if the_output is None: @@ -232,17 +253,33 @@ def fuse_BN_with_Reshape_into_Gemm(g): else: b_node = b_node_list[0] # Check for branches - if len(helper.find_following_nodes_by_input_value_name(g, gemm_node.output[0])) > 1: + if ( + len( + helper.find_following_nodes_by_input_value_name( + g, gemm_node.output[0] + ) + ) + > 1 + ): continue - if len(helper.find_following_nodes_by_input_value_name(g, a_node.output[0])) > 1: + if ( + len( + helper.find_following_nodes_by_input_value_name( + g, a_node.output[0] + ) + ) + > 1 + ): continue # Check type of A - if a_node.op_type == 'Unsqueeze': - axes = helper.get_attribute_by_name(a_node, 'axes') + if a_node.op_type == "Unsqueeze": + axes = helper.get_attribute_by_name(a_node, "axes") if axes.ints != [2]: continue - elif a_node.op_type == 'Reshape': - a = helper.constant_to_list(helper.find_node_by_output_name(g, a_node.input[1]))[1] + elif a_node.op_type == "Reshape": + a = helper.constant_to_list( + helper.find_node_by_output_name(g, a_node.input[1]) + )[1] if len(a) != 3 or a[2] != 1: continue else: @@ -250,14 +287,16 @@ def fuse_BN_with_Reshape_into_Gemm(g): # Check type of B if b_node is None: pass - elif b_node.op_type == 'Flatten': + elif b_node.op_type == "Flatten": pass - elif b_node.op_type == 'Squeeze': - axes = helper.get_attribute_by_name(a_node, 'axes') + elif b_node.op_type == "Squeeze": + axes = helper.get_attribute_by_name(a_node, "axes") if axes.ints != [2]: continue - elif b_node.op_type == 'Reshape': - a = helper.constant_to_list(helper.find_node_by_output_name(g, b_node.input[1]))[1] + elif b_node.op_type == "Reshape": + a = helper.constant_to_list( + helper.find_node_by_output_name(g, b_node.input[1]) + )[1] if len(a) != 2: continue else: @@ -278,73 +317,85 @@ def fuse_BN_with_Reshape_into_Gemm(g): bn_var = helper.constant_to_numpy(bn_var_node) # Apply attributes # epsilon - epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') + epsilon = helper.get_attribute_by_name(bn_node, "epsilon") if epsilon is None: epsilon = 0.00001 else: epsilon = epsilon.f bn_var = bn_var + epsilon # alpha - alpha = helper.get_attribute_by_name(gemm_node, 'alpha') + alpha = helper.get_attribute_by_name(gemm_node, "alpha") if alpha is None: alpha = 1 else: alpha = alpha.f gemm_b = gemm_b * alpha # beta - beta = helper.get_attribute_by_name(gemm_node, 'beta') + beta = helper.get_attribute_by_name(gemm_node, "beta") if beta is None: beta = 1 else: beta = beta.f gemm_c = gemm_c * beta # transA - transA = helper.get_attribute_by_name(gemm_node, 'transA') + transA = helper.get_attribute_by_name(gemm_node, "transA") if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") # transB - transB = helper.get_attribute_by_name(gemm_node, 'transB') + transB = helper.get_attribute_by_name(gemm_node, "transB") if transB is not None and transB.i == 1: gemm_b = gemm_b.transpose() # Calculate new weights new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias # Replace original weights - new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) - new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) + new_gemm_b_node = helper.numpy_to_constant( + gemm_b_node.name + "_fused", new_gemm_b + ) + new_gemm_c_node = helper.numpy_to_constant( + gemm_c_node.name + "_fused", new_gemm_c + ) g.node.extend([new_gemm_b_node, new_gemm_c_node]) # Modify attributes # alpha - alpha = helper.get_attribute_by_name(gemm_node, 'alpha') + alpha = helper.get_attribute_by_name(gemm_node, "alpha") if alpha is not None: alpha.f = 1.0 # beta - beta = helper.get_attribute_by_name(gemm_node, 'beta') + beta = helper.get_attribute_by_name(gemm_node, "beta") if beta is not None: beta.f = 1.0 # transB - transB = helper.get_attribute_by_name(gemm_node, 'transB') + transB = helper.get_attribute_by_name(gemm_node, "transB") if transB is not None: transB.i = 0 # Remove useless nodes - node_to_remove.extend([gemm_b_node, - gemm_c_node, - bn_node, - bn_scale_node, - bn_bias_node, - bn_mean_node, - bn_var_node, - a_node]) - if a_node.op_type == 'Reshape': - node_to_remove.append(helper.find_node_by_output_name(g, a_node.input[1])) + node_to_remove.extend( + [ + gemm_b_node, + gemm_c_node, + bn_node, + bn_scale_node, + bn_bias_node, + bn_mean_node, + bn_var_node, + a_node, + ] + ) + if a_node.op_type == "Reshape": + node_to_remove.append( + helper.find_node_by_output_name(g, a_node.input[1]) + ) if b_node is not None: node_to_remove.append(b_node) - if b_node.op_type == 'Reshape': - node_to_remove.append(helper.find_node_by_output_name(g, b_node.input[1])) + if b_node.op_type == "Reshape": + node_to_remove.append( + helper.find_node_by_output_name(g, b_node.input[1]) + ) # Delete useless value infos value = helper.find_value_by_name(g, a_node.output[0]) g.value_info.remove(value) - if a_node.op_type == 'Reshape': + if a_node.op_type == "Reshape": value = helper.find_value_by_name(g, a_node.input[1]) g.value_info.remove(value) for i in range(1, 5): @@ -356,7 +407,7 @@ def fuse_BN_with_Reshape_into_Gemm(g): if b_node is not None: value = helper.find_value_by_name(g, gemm_node.output[0]) g.value_info.remove(value) - if b_node.op_type == 'Reshape': + if b_node.op_type == "Reshape": value = helper.find_value_by_name(g, b_node.input[1]) g.value_info.remove(value) # Connect the new graph @@ -366,14 +417,20 @@ def fuse_BN_with_Reshape_into_Gemm(g): gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0]) gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0]) gemm_b_value.name = new_gemm_b_node.output[0] - gemm_b_value.type.tensor_type.shape.dim[0].dim_value = new_gemm_b.shape[0] - gemm_b_value.type.tensor_type.shape.dim[1].dim_value = new_gemm_b.shape[1] + gemm_b_value.type.tensor_type.shape.dim[ + 0 + ].dim_value = new_gemm_b.shape[0] + gemm_b_value.type.tensor_type.shape.dim[ + 1 + ].dim_value = new_gemm_b.shape[1] gemm_c_value.name = new_gemm_c_node.output[0] if b_node is None: # If b node is None, set the Gemm output as the graph output output_value = helper.find_output_by_name(g, bn_node.output[0]) g.output.remove(output_value) - g.output.extend([helper.find_value_by_name(g, gemm_node.output[0])]) + g.output.extend( + [helper.find_value_by_name(g, gemm_node.output[0])] + ) else: # Else, set node B output as gemm output gemm_node.output[0] = b_node.output[0] @@ -391,12 +448,12 @@ def fuse_Gemm_into_Gemm(g): node_to_remove = [] for node in g.node: # Check for Gemm and Gemm - if node.op_type != 'Gemm': + if node.op_type != "Gemm": continue prev_node = helper.find_node_by_output_name(g, node.input[0]) if prev_node is None: continue - if prev_node.op_type != 'Gemm': + if prev_node.op_type != "Gemm": continue # Get original weights prev_b_node = helper.find_node_by_output_name(g, prev_node.input[1]) @@ -409,68 +466,66 @@ def fuse_Gemm_into_Gemm(g): c = helper.constant_to_numpy(c_node) # Apply attributes # alpha - alpha = helper.get_attribute_by_name(node, 'alpha') + alpha = helper.get_attribute_by_name(node, "alpha") if alpha is None: alpha = 1 else: alpha = alpha.f b = b * alpha - alpha = helper.get_attribute_by_name(prev_node, 'alpha') + alpha = helper.get_attribute_by_name(prev_node, "alpha") if alpha is None: alpha = 1 else: alpha = alpha.f prev_b = prev_b * alpha # beta - beta = helper.get_attribute_by_name(node, 'beta') + beta = helper.get_attribute_by_name(node, "beta") if beta is None: beta = 1 else: beta = beta.f c = c * beta - beta = helper.get_attribute_by_name(prev_node, 'beta') + beta = helper.get_attribute_by_name(prev_node, "beta") if beta is None: beta = 1 else: beta = beta.f prev_c = prev_c * beta # transA - transA = helper.get_attribute_by_name(node, 'transA') + transA = helper.get_attribute_by_name(node, "transA") if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") - transA = helper.get_attribute_by_name(prev_node, 'transA') + transA = helper.get_attribute_by_name(prev_node, "transA") if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") # transB - transB = helper.get_attribute_by_name(node, 'transB') + transB = helper.get_attribute_by_name(node, "transB") if transB is not None and transB.i == 1: b = b.transpose() - transB = helper.get_attribute_by_name(prev_node, 'transB') + transB = helper.get_attribute_by_name(prev_node, "transB") if transB is not None and transB.i == 1: prev_b = prev_b.transpose() # Calculate new weights new_b = prev_b.dot(b) new_c = prev_c.dot(b) + c # Replace original weights - new_b_node = helper.numpy_to_constant(b_node.name + '_fused', new_b) - new_c_node = helper.numpy_to_constant(c_node.name + '_fused', new_c) + new_b_node = helper.numpy_to_constant(b_node.name + "_fused", new_b) + new_c_node = helper.numpy_to_constant(c_node.name + "_fused", new_c) g.node.extend([new_b_node, new_c_node]) - node_to_remove.extend([b_node, - c_node, - prev_b_node, - prev_c_node, - prev_node]) + node_to_remove.extend( + [b_node, c_node, prev_b_node, prev_c_node, prev_node] + ) # Modify attributes # alpha - alpha = helper.get_attribute_by_name(node, 'alpha') + alpha = helper.get_attribute_by_name(node, "alpha") if alpha is not None: alpha.f = 1.0 # beta - beta = helper.get_attribute_by_name(node, 'beta') + beta = helper.get_attribute_by_name(node, "beta") if beta is not None: beta.f = 1.0 # transB - transB = helper.get_attribute_by_name(node, 'transB') + transB = helper.get_attribute_by_name(node, "transB") if transB is not None: transB.i = 0 # Connect the new graph @@ -486,6 +541,7 @@ def fuse_Gemm_into_Gemm(g): g.node.remove(node) topological_sort(g) + def fuse_MatMul_and_Add_into_Gemm(g): """ Fuse MatMul and Add layers into a new Gemm layers. @@ -496,7 +552,7 @@ def fuse_MatMul_and_Add_into_Gemm(g): node_to_remove = [] node_to_add = [] for node in g.node: - if node.op_type != 'MatMul': + if node.op_type != "MatMul": continue add_node = None for i in g.node: @@ -506,7 +562,11 @@ def fuse_MatMul_and_Add_into_Gemm(g): add_node = i break value_to_remove = helper.find_value_by_name(g, node.output[0]) - if add_node is None or value_to_remove is None or add_node.op_type != 'Add': + if ( + add_node is None + or value_to_remove is None + or add_node.op_type != "Add" + ): continue input_list = node.input input_list.append(add_node.input[1]), @@ -518,7 +578,7 @@ def fuse_MatMul_and_Add_into_Gemm(g): alpha=1.0, beta=1.0, transA=0, - transB=0 + transB=0, ) node_to_add.append(new_node) node_to_remove.append(node) @@ -528,13 +588,14 @@ def fuse_MatMul_and_Add_into_Gemm(g): g.node.remove(node) g.node.extend(node_to_add) + def fuse_consecutive_transposes(g): node_to_del = [] for node in g.node: - if node.op_type != 'Transpose': + if node.op_type != "Transpose": continue pre_node = helper.find_node_by_output_name(g, node.input[0]) - if pre_node.op_type != 'Transpose': + if pre_node.op_type != "Transpose": continue pre_permutation = list(pre_node.attribute[0].ints) @@ -547,11 +608,11 @@ def fuse_consecutive_transposes(g): new_permutation.append(pre_permutation[ind]) new_trans_node = onnx.helper.make_node( - 'Transpose', + "Transpose", [pre_node.input[0]], [node.output[0]], name=node.name, - perm=new_permutation + perm=new_permutation, ) g.node.extend([new_trans_node]) @@ -567,20 +628,24 @@ def fuse_consecutive_transposes(g): topological_sort(g) + def fuse_mul_and_add_into_bn(g): node_to_del = [] for node in g.node: - if node.op_type != 'Add': + if node.op_type != "Add": continue add_node = node - input_nodes_add = [helper.find_node_by_output_name(g, input_name) for input_name in add_node.input] - if any([n == None for n in input_nodes_add]): + input_nodes_add = [ + helper.find_node_by_output_name(g, input_name) + for input_name in add_node.input + ] + if any([n is None for n in input_nodes_add]): continue mul_node, const_add = None, None for input_node_add in input_nodes_add: - if input_node_add.op_type == 'Mul': + if input_node_add.op_type == "Mul": mul_node = input_node_add - elif input_node_add.op_type == 'Constant': + elif input_node_add.op_type == "Constant": const_add = input_node_add else: pass @@ -591,7 +656,7 @@ def fuse_mul_and_add_into_bn(g): input_node = helper.find_node_by_output_name(g, input_name) if not input_node: data_input_name = input_name - elif input_node.op_type == 'Constant': + elif input_node.op_type == "Constant": if not const_mul: const_mul = input_node else: @@ -611,9 +676,14 @@ def fuse_mul_and_add_into_bn(g): data_input_value = helper.find_value_by_name(g, data_input_name) if data_input_value is None: data_input_value = helper.find_input_by_name(g, data_input_name) - _ , previous_node_output_shape = helper.find_size_shape_from_value(data_input_value) + _, previous_node_output_shape = helper.find_size_shape_from_value( + data_input_value + ) # only allow 4 dim data input due to the hardware limitation - if previous_node_output_shape is None or len(previous_node_output_shape) != 4: + if ( + previous_node_output_shape is None + or len(previous_node_output_shape) != 4 + ): continue # check if mul's dim and input channel dimension are matched @@ -637,16 +707,25 @@ def fuse_mul_and_add_into_bn(g): continue bn_name = add_node.output[0] - const_mean = helper.list_to_constant(bn_name+'_mean', [c_dim], [0.0 for _ in range(c_dim)]) - const_var = helper.list_to_constant(bn_name+'_var', [c_dim], [1.0 for _ in range(c_dim)]) + const_mean = helper.list_to_constant( + bn_name + "_mean", [c_dim], [0.0 for _ in range(c_dim)] + ) + const_var = helper.list_to_constant( + bn_name + "_var", [c_dim], [1.0 for _ in range(c_dim)] + ) bn_node = onnx.helper.make_node( - 'BatchNormalization', - [data_input_name, const_mul.output[0], const_add.output[0],\ - const_mean.output[0], const_var.output[0]], + "BatchNormalization", + [ + data_input_name, + const_mul.output[0], + const_add.output[0], + const_mean.output[0], + const_var.output[0], + ], [add_node.output[0]], name=bn_name, - epsilon=0.00000001 + epsilon=0.00000001, ) mid_val_info = helper.find_value_by_name(g, mul_node.output[0]) @@ -657,24 +736,16 @@ def fuse_mul_and_add_into_bn(g): g.value_info.remove(bais_val_info) new_scale_val_info = onnx.helper.make_tensor_value_info( - const_mul.output[0], - const_mul.attribute[0].t.data_type, - [c_dim] - ) + const_mul.output[0], const_mul.attribute[0].t.data_type, [c_dim] + ) new_bais_val_info = onnx.helper.make_tensor_value_info( - const_add.output[0], - const_add.attribute[0].t.data_type, - [c_dim] + const_add.output[0], const_add.attribute[0].t.data_type, [c_dim] ) mean_val_info = onnx.helper.make_tensor_value_info( - const_mean.output[0], - const_mean.attribute[0].t.data_type, - [c_dim] + const_mean.output[0], const_mean.attribute[0].t.data_type, [c_dim] ) var_val_info = onnx.helper.make_tensor_value_info( - const_var.output[0], - const_var.attribute[0].t.data_type, - [c_dim] + const_var.output[0], const_var.attribute[0].t.data_type, [c_dim] ) g.value_info.extend([new_scale_val_info]) @@ -695,17 +766,17 @@ def fuse_mul_and_add_into_bn(g): def fuse_mul_and_add_into_gemm(g): node_to_del = [] for node in g.node: - if node.op_type != 'Add': + if node.op_type != "Add": continue add_node = node mul_node = helper.find_node_by_output_name(g, add_node.input[0]) - if not mul_node or mul_node.op_type != 'Mul': + if not mul_node or mul_node.op_type != "Mul": continue mul_const = helper.find_node_by_output_name(g, mul_node.input[1]) - if not mul_const or mul_const.op_type != 'Constant': + if not mul_const or mul_const.op_type != "Constant": continue add_const = helper.find_node_by_output_name(g, add_node.input[1]) - if not add_const or add_const.op_type != 'Constant': + if not add_const or add_const.op_type != "Constant": continue input_val = helper.find_value_by_name(g, mul_node.input[0]) @@ -735,26 +806,26 @@ def fuse_mul_and_add_into_gemm(g): b_data[i][i] = mul_const_data[i] b_data = b_data.flatten().tolist() b_tensor = onnx.helper.make_tensor( - name=mul_const.name+'_tensor', + name=mul_const.name + "_tensor", data_type=mul_const.attribute[0].t.data_type, dims=[dim, dim], - vals=b_data + vals=b_data, ) b_const_node = onnx.helper.make_node( - 'Constant', + "Constant", [], [mul_const.output[0]], value=b_tensor, - name=mul_const.output[0] + name=mul_const.output[0], ) add_const.attribute[0].t.dims.insert(0, 1) gemm_node = onnx.helper.make_node( - 'Gemm', + "Gemm", [mul_node.input[0], b_const_node.output[0], add_const.output[0]], [add_node.output[0]], - name=add_node.output[0] + name=add_node.output[0], ) g.node.extend([gemm_node, b_const_node]) @@ -775,22 +846,23 @@ def fuse_mul_and_add_into_gemm(g): topological_sort(g) + def fuse_conv_and_add_into_conv(g): node_to_del = [] for node in g.node: # Check if two nodes can be fused - if node.op_type != 'Add': + if node.op_type != "Add": continue add_node = node add_const = helper.find_node_by_output_name(g, add_node.input[1]) - if not add_const or add_const.op_type != 'Constant': + if not add_const or add_const.op_type != "Constant": continue conv_node = helper.find_node_by_output_name(g, add_node.input[0]) - if not conv_node or conv_node.op_type != 'Conv': + if not conv_node or conv_node.op_type != "Conv": continue weight_node = helper.find_node_by_output_name(g, conv_node.input[1]) - if not weight_node or weight_node.op_type != 'Constant': + if not weight_node or weight_node.op_type != "Constant": continue m_dim = weight_node.attribute[0].t.dims[0] @@ -807,20 +879,28 @@ def fuse_conv_and_add_into_conv(g): output_value_info = helper.find_value_by_name(g, add_node.output[0]) if output_value_info is not None: g.value_info.remove(output_value_info) - add_weight_value_info = helper.find_value_by_name(g, add_const.output[0]) + add_weight_value_info = helper.find_value_by_name( + g, add_const.output[0] + ) if add_weight_value_info is not None: g.value_info.remove(add_weight_value_info) # Replace next node input if any. - following_nodes = helper.find_following_nodes_by_input_value_name(g, add_node.output[0]) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, add_node.output[0] + ) for following_node in following_nodes: - replace_node_input(following_node, add_node.output[0], add_node.input[0]) + replace_node_input( + following_node, add_node.output[0], add_node.input[0] + ) # Replace output if any todel_output = helper.find_output_by_name(g, add_node.output[0]) if todel_output is not None: g.output.remove(todel_output) previous_output = helper.find_output_by_name(g, add_node.input[0]) if previous_output is None: - the_input_value = helper.find_value_by_name(g, add_node.input[0]) + the_input_value = helper.find_value_by_name( + g, add_node.input[0] + ) g.output.extend([the_input_value]) while node_to_del: @@ -833,16 +913,20 @@ def fuse_consecutive_reducemean(g): node_to_del = [] for node in g.node: # Find consecutive ReduceMean - if node.op_type != 'ReduceMean': + if node.op_type != "ReduceMean": continue pre_node = helper.find_node_by_output_name(g, node.input[0]) - if pre_node is None or pre_node.op_type != 'ReduceMean': + if pre_node is None or pre_node.op_type != "ReduceMean": continue # Check attributes - pre_keepdims = helper.get_var_attribute_by_name(pre_node, 'keepdims', 'int') - pre_axes = helper.get_list_attribute_by_name(pre_node, 'axes', 'int') - cur_keepdims = helper.get_var_attribute_by_name(node, 'keepdims', 'int') - cur_axes = helper.get_list_attribute_by_name(node, 'axes', 'int') + pre_keepdims = helper.get_var_attribute_by_name( + pre_node, "keepdims", "int" + ) + pre_axes = helper.get_list_attribute_by_name(pre_node, "axes", "int") + cur_keepdims = helper.get_var_attribute_by_name( + node, "keepdims", "int" + ) + cur_axes = helper.get_list_attribute_by_name(node, "axes", "int") if pre_keepdims != 0 or cur_keepdims != 0: continue axes = sorted(pre_axes + cur_axes) @@ -850,17 +934,17 @@ def fuse_consecutive_reducemean(g): continue # Merge two ReduceMean into GlobalAveragePool. new_gap_node = onnx.helper.make_node( - 'GlobalAveragePool', + "GlobalAveragePool", [pre_node.input[0]], - [node.output[0] + '_intermedia'], - name = node.name + '_gap' + [node.output[0] + "_intermedia"], + name=node.name + "_gap", ) new_flatten_node = onnx.helper.make_node( - 'Flatten', - [node.output[0] + '_intermedia'], + "Flatten", + [node.output[0] + "_intermedia"], [node.output[0]], - name = node.name + '_flatten', - axis = 1 + name=node.name + "_flatten", + axis=1, ) # Clean up @@ -876,14 +960,17 @@ def fuse_consecutive_reducemean(g): topological_sort(g) + def fuse_slice_nodes_into_conv(g): # define pattern checker def check_is_slice(node): - if node.op_type == 'Concat': + if node.op_type == "Concat": return True - if node.op_type != 'Slice': + if node.op_type != "Slice": return False - following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, node.output[0] + ) if len(following_nodes) != 1: return False # also check attributes @@ -891,7 +978,7 @@ def fuse_slice_nodes_into_conv(g): return False # starts should be 0 or 1 starts_node = helper.find_node_by_output_name(g, node.input[1]) - if starts_node.op_type != 'Constant': + if starts_node.op_type != "Constant": return False _, starts_list = helper.constant_to_list(starts_node) for num in starts_list: @@ -899,11 +986,11 @@ def fuse_slice_nodes_into_conv(g): return False # ends ends_node = helper.find_node_by_output_name(g, node.input[2]) - if ends_node.op_type != 'Constant': + if ends_node.op_type != "Constant": return False # axes should be 2 or 3 axes_node = helper.find_node_by_output_name(g, node.input[3]) - if axes_node.op_type != 'Constant': + if axes_node.op_type != "Constant": return False _, axes_list = helper.constant_to_list(axes_node) for num in axes_list: @@ -911,7 +998,7 @@ def fuse_slice_nodes_into_conv(g): return False # Steps can only be 2 steps_node = helper.find_node_by_output_name(g, node.input[4]) - if steps_node.op_type != 'Constant': + if steps_node.op_type != "Constant": return False _, steps_list = helper.constant_to_list(steps_node) for num in steps_list: @@ -919,16 +1006,25 @@ def fuse_slice_nodes_into_conv(g): return False # Recursion return check_is_slice(following_nodes[0]) + # defind concat finder def find_concat_node(node): - while node.op_type != 'Concat': - node = helper.find_following_nodes_by_input_value_name(g, node.output[0])[0] + while node.op_type != "Concat": + node = helper.find_following_nodes_by_input_value_name( + g, node.output[0] + )[0] return node + # define remove node function. def remove_nodes(input_name): - following_nodes = helper.find_following_nodes_by_input_value_name(g, input_name) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, input_name + ) # Remove concat directly - if len(following_nodes) == 1 and following_nodes[0].op_type == 'Concat': + if ( + len(following_nodes) == 1 + and following_nodes[0].op_type == "Concat" + ): g.node.remove(following_nodes[0]) return for following_node in following_nodes: @@ -936,37 +1032,58 @@ def fuse_slice_nodes_into_conv(g): remove_nodes(following_node.output[0]) # Remove weights for i in range(1, len(following_node.input)): - if len(helper.find_following_nodes_by_input_value_name(g, following_node.input[i])) > 1: + if ( + len( + helper.find_following_nodes_by_input_value_name( + g, following_node.input[i] + ) + ) + > 1 + ): # More than one following nodes. Skip. continue - input_weight = helper.find_node_by_output_name(g, following_node.input[i]) + input_weight = helper.find_node_by_output_name( + g, following_node.input[i] + ) g.node.remove(input_weight) # Remove Slice nodes g.node.remove(following_node) + # define remove value_info function def remove_value_infos(input_name): - following_nodes = helper.find_following_nodes_by_input_value_name(g, input_name) - if following_nodes[0].op_type == 'Concat': + following_nodes = helper.find_following_nodes_by_input_value_name( + g, input_name + ) + if following_nodes[0].op_type == "Concat": return for following_node in following_nodes: - output_value = helper.find_value_by_name(g, following_node.output[0]) + output_value = helper.find_value_by_name( + g, following_node.output[0] + ) # Remove output values if output_value is not None: g.value_info.remove(output_value) # Remove weight values for i in range(1, len(following_node.input)): - input_value = helper.find_value_by_name(g, following_node.input[i]) + input_value = helper.find_value_by_name( + g, following_node.input[i] + ) if input_value is not None: g.value_info.remove(input_value) # Recursion remove_value_infos(following_node.output[0]) + # define get slice position def get_slice_position(final_slice_output): slice_position = [0, 0] prev_node = helper.find_node_by_output_name(g, final_slice_output) while prev_node is not None: - starts_np = helper.constant_to_numpy(helper.find_node_by_output_name(g, prev_node.input[1])) - axes_np = helper.constant_to_numpy(helper.find_node_by_output_name(g, prev_node.input[3])) + starts_np = helper.constant_to_numpy( + helper.find_node_by_output_name(g, prev_node.input[1]) + ) + axes_np = helper.constant_to_numpy( + helper.find_node_by_output_name(g, prev_node.input[3]) + ) for i in range(len(axes_np)): if axes_np[i] == 2: slice_position[0] = starts_np[i] @@ -974,12 +1091,15 @@ def fuse_slice_nodes_into_conv(g): slice_position[1] = starts_np[i] prev_node = helper.find_node_by_output_name(g, prev_node.input[0]) return slice_position + # Check pattern from each input for input_value in g.input: - nodes_after_input = helper.find_following_nodes_by_input_value_name(g, input_value.name) + nodes_after_input = helper.find_following_nodes_by_input_value_name( + g, input_value.name + ) pattern_matched = True for following_node in nodes_after_input: - if following_node.op_type != 'Slice': + if following_node.op_type != "Slice": pattern_matched = False break else: @@ -996,24 +1116,33 @@ def fuse_slice_nodes_into_conv(g): input_shape = helper.get_shape_from_value_info(input_value) channel_num = input_shape[1] # Construct weight - weight_np = np.zeros((input_shape[1] * 4, input_shape[1], 3, 3), dtype=np.float32) + weight_np = np.zeros( + (input_shape[1] * 4, input_shape[1], 3, 3), dtype=np.float32 + ) for i in range(4): # Check each branch slice_position = get_slice_position(concat_node.input[i]) for j in range(channel_num): - weight_np[i * channel_num + j, j, slice_position[0], slice_position[1]] = 1 - weight_node = helper.numpy_to_constant(concat_node.name + '_weight', weight_np) + weight_np[ + i * channel_num + j, + j, + slice_position[0], + slice_position[1], + ] = 1 + weight_node = helper.numpy_to_constant( + concat_node.name + "_weight", weight_np + ) # Construct Conv node new_conv = onnx.helper.make_node( - 'Conv', - [input_value.name, concat_node.name + '_weight'], + "Conv", + [input_value.name, concat_node.name + "_weight"], [concat_node.output[0]], - name = concat_node.name + '_fused', - dilations = [1, 1], - group = 1, - kernel_shape = [3, 3], - strides = [2, 2], - pads = [0, 0, 2, 2] + name=concat_node.name + "_fused", + dilations=[1, 1], + group=1, + kernel_shape=[3, 3], + strides=[2, 2], + pads=[0, 0, 2, 2], ) # Delete old nodes, weights and value_infos remove_value_infos(input_value.name) @@ -1027,33 +1156,41 @@ def fuse_relu_min_into_clip(g): node_to_del = [] for node in g.node: # Check Min node - if node.op_type != 'Min': + if node.op_type != "Min": continue min_node = node # Check Constant node min_const = helper.find_node_by_output_name(g, min_node.input[1]) - if not min_const or min_const.op_type != 'Constant': + if not min_const or min_const.op_type != "Constant": continue min_shape, min_value = helper.constant_to_list(min_const) if min_shape != 1: continue # Check Relu node relu_node = helper.find_node_by_output_name(g, min_node.input[0]) - if not relu_node or relu_node.op_type != 'Relu': + if not relu_node or relu_node.op_type != "Relu": continue # Create Clip node - relu_min_const_node = helper.list_to_constant(relu_node.name+'_min_value', [], [0.0]) + relu_min_const_node = helper.list_to_constant( + relu_node.name + "_min_value", [], [0.0] + ) clip_node = onnx.helper.make_node( "Clip", - [relu_node.input[0], relu_min_const_node.output[0], min_const.output[0]], + [ + relu_node.input[0], + relu_min_const_node.output[0], + min_const.output[0], + ], [min_node.output[0]], - name=min_node.name + name=min_node.name, ) node_to_del.extend([relu_node, min_node]) - old_relu_const_val_info = helper.find_value_by_name(g, min_node.input[0]) + old_relu_const_val_info = helper.find_value_by_name( + g, min_node.input[0] + ) if old_relu_const_val_info: g.value_info.remove(old_relu_const_val_info) g.node.extend([relu_min_const_node, clip_node]) @@ -1061,4 +1198,4 @@ def fuse_relu_min_into_clip(g): while node_to_del: g.node.remove(node_to_del.pop()) - topological_sort(g) \ No newline at end of file + topological_sort(g) diff --git a/tools/optimizer_scripts/tools/general_graph.py b/tools/optimizer_scripts/tools/general_graph.py index 352445b..f9904f2 100644 --- a/tools/optimizer_scripts/tools/general_graph.py +++ b/tools/optimizer_scripts/tools/general_graph.py @@ -1,9 +1,11 @@ from collections import deque + class Node: """A Node which maps a node proto. It has pointers to its parents and children. """ + def __init__(self, onnx_node): """Initialize a node. This initialization only set up the mapping to node proto. The pointers should be set up by outside. @@ -17,12 +19,12 @@ class Node: self.name = onnx_node.name self.proto = onnx_node + class Graph: - """A graph which is constructed from the onnx proto. - """ + """A graph which is constructed from the onnx proto.""" + def __init__(self, onnx_graph): - """Construct the graph from onnx. - """ + """Construct the graph from onnx.""" self.input_nodes = [] self.output_nodes = [] self.name2node = {} @@ -51,9 +53,9 @@ class Graph: for value in onnx_graph.value_info: node = self.output2node[value.name] node.output_value = value + def get_sorted_node_list(self): - """Return a node list in topological order. - """ + """Return a node list in topological order.""" visited = set() todo = deque() result = [] diff --git a/tools/optimizer_scripts/tools/helper.py b/tools/optimizer_scripts/tools/helper.py index 18bc1e3..02da09d 100644 --- a/tools/optimizer_scripts/tools/helper.py +++ b/tools/optimizer_scripts/tools/helper.py @@ -6,21 +6,26 @@ import struct import numpy as np import logging -__ONNX_VERSION__ = -1 +__ONNX_VERSION__ = -1 logger = logging.getLogger("optimizer_scripts") + def setup_current_opset_version(m): global __ONNX_VERSION__ __ONNX_VERSION__ = m.opset_import[0].version if __ONNX_VERSION__ not in [11]: - raise RuntimeError('Only support opset 11, but got ' + str(__ONNX_VERSION__)) + raise RuntimeError( + "Only support opset 11, but got " + str(__ONNX_VERSION__) + ) + def get_current_opset_version(): if __ONNX_VERSION__ == -1: - raise RuntimeError('do setup_current_opset_version first please') + raise RuntimeError("do setup_current_opset_version first please") return __ONNX_VERSION__ + def find_nodes_by_input_name(g, name): nodes = [] for node in g.node: @@ -28,6 +33,7 @@ def find_nodes_by_input_name(g, name): nodes.append(node) return nodes + def find_node_by_output_name(g, name): """ Find a node in the graph by its output name @@ -41,6 +47,7 @@ def find_node_by_output_name(g, name): return i return None + def find_node_by_node_name(g, name): """ Find a node in the graph by its output name @@ -54,6 +61,7 @@ def find_node_by_node_name(g, name): return i return None + def find_following_nodes_by_input_value_name(g, name): """ Find the following nodes of a specific value. @@ -63,6 +71,7 @@ def find_following_nodes_by_input_value_name(g, name): """ return find_nodes_by_input_name(g, name) + def find_value_by_name(g, name): """ Find a value_info in the graph by name @@ -76,6 +85,7 @@ def find_value_by_name(g, name): return i return None + def find_output_by_name(g, name): """ Find a value_info in the graph by name @@ -89,6 +99,7 @@ def find_output_by_name(g, name): return i return None + def find_input_by_name(g, name): """ Find a input in the graph by name @@ -102,6 +113,7 @@ def find_input_by_name(g, name): return i return None + def list_to_constant(name, shape, data, data_type=None): """Generate a constant node using the given infomation. @@ -119,18 +131,9 @@ def list_to_constant(name, shape, data, data_type=None): data_type = onnx.helper.TensorProto.INT64 else: data_type = onnx.helper.TensorProto.FLOAT - tensor = onnx.helper.make_tensor( - name, - data_type, - shape, - data - ) + tensor = onnx.helper.make_tensor(name, data_type, shape, data) new_w_node = onnx.helper.make_node( - "Constant", - [], - [name], - name = name, - value = tensor + "Constant", [], [name], name=name, value=tensor ) return new_w_node @@ -151,18 +154,9 @@ def scaler_to_constant(name, data, data_type=None): else: logger.error("Cannot create scaler constant with a list.") exit(1) - tensor = onnx.helper.make_tensor( - name, - data_type, - None, - [data] - ) + tensor = onnx.helper.make_tensor(name, data_type, None, [data]) new_w_node = onnx.helper.make_node( - "Constant", - [], - [name], - name = name, - value = tensor + "Constant", [], [name], name=name, value=tensor ) return new_w_node @@ -170,6 +164,7 @@ def scaler_to_constant(name, data, data_type=None): def numpy_to_constant(name, np_array): return list_to_constant(name, np_array.shape, np_array.flatten().tolist()) + def constant_to_list(node): """Generate a list from the constant node @@ -184,27 +179,27 @@ def constant_to_list(node): if len(tensor.int32_data) != 0: data = list(tensor.int32_data) else: - data = [i[0] for i in struct.iter_unpack('i', tensor.raw_data)] + data = [i[0] for i in struct.iter_unpack("i", tensor.raw_data)] elif tensor.data_type == onnx.helper.TensorProto.INT64: if len(tensor.int64_data) != 0: data = list(tensor.int64_data) else: - data = [i[0] for i in struct.iter_unpack('q', tensor.raw_data)] + data = [i[0] for i in struct.iter_unpack("q", tensor.raw_data)] elif tensor.data_type == onnx.helper.TensorProto.INT8: if len(tensor.int32_data) != 0: data = list(tensor.int32_data) else: - data = [i[0] for i in struct.iter_unpack('b', tensor.raw_data)] + data = [i[0] for i in struct.iter_unpack("b", tensor.raw_data)] elif tensor.data_type == onnx.helper.TensorProto.FLOAT: if len(tensor.float_data) != 0: data = list(tensor.float_data) else: - data = [i[0] for i in struct.iter_unpack('f', tensor.raw_data)] + data = [i[0] for i in struct.iter_unpack("f", tensor.raw_data)] elif tensor.data_type == onnx.helper.TensorProto.DOUBLE: if len(tensor.double_data) != 0: data = list(tensor.double_data) else: - data = [i[0] for i in struct.iter_unpack('d', tensor.raw_data)] + data = [i[0] for i in struct.iter_unpack("d", tensor.raw_data)] else: print("Not supported data type {}".format(tensor.data_type)) raise RuntimeError @@ -214,6 +209,7 @@ def constant_to_list(node): shape = list(tensor.dims) return shape, data + def constant_to_numpy(node): """Generate a numpy array from the constant node @@ -223,6 +219,7 @@ def constant_to_numpy(node): shape, data = constant_to_list(node) return np.array(data).reshape(shape) + def all_constant_input(node): """Find the inputs of the given node. If the inputs of this node are all\\ constant nodes, return True. Otherwise, return False. @@ -234,24 +231,26 @@ def all_constant_input(node): return False isConstant = True for parent in node.parents: - if parent.proto is None or parent.proto.op_type != 'Constant': + if parent.proto is None or parent.proto.op_type != "Constant": isConstant = False break return isConstant + def get_padding(size, kernel_size, strides): - """ Calculate the padding array for same padding in the Tensorflow fashion.\\ + """ Calculate the padding array for same padding in the Tensorflow fashion.\\ See https://www.tensorflow.org/api_guides/python/nn#Convolution for more. """ - if size[0] % strides[0] == 0: - pad_h = max(kernel_size[0] - strides[0], 0) - else: - pad_h = max(kernel_size[0] - (size[0] % strides[0]), 0) - if size[1] % strides[1] == 0: - pad_w = max(kernel_size[1] - strides[1], 0) - else: - pad_w = max(kernel_size[1] - (size[1] % strides[1]), 0) - return [pad_h//2, pad_w//2, pad_h-pad_h//2, pad_w-pad_w//2] + if size[0] % strides[0] == 0: + pad_h = max(kernel_size[0] - strides[0], 0) + else: + pad_h = max(kernel_size[0] - (size[0] % strides[0]), 0) + if size[1] % strides[1] == 0: + pad_w = max(kernel_size[1] - strides[1], 0) + else: + pad_w = max(kernel_size[1] - (size[1] % strides[1]), 0) + return [pad_h // 2, pad_w // 2, pad_h - pad_h // 2, pad_w - pad_w // 2] + def get_shape_from_value_info(value): """Get shape from a value info. @@ -261,12 +260,13 @@ def get_shape_from_value_info(value): """ return [d.dim_value for d in value.type.tensor_type.shape.dim] + def find_size_shape_from_value(value): - ''' + """ Find the size of data within the value_info object. :param value: value_info :return: int size and list shape of the data in the value_info - ''' + """ if not value: return None, None if not value.type.tensor_type.shape.dim: @@ -292,6 +292,7 @@ def get_attribute_by_name(node, attr_name): return attr return None + def get_list_attribute_by_name(node, attr_name: str, attr_type: str): """Get list attribute with specific name in the given node proto. @@ -317,12 +318,13 @@ def get_list_attribute_by_name(node, attr_name: str, attr_type: str): print("Warning: undefined type for list attribute extraction") return None + def get_var_attribute_by_name(node, attr_name: str, attr_type: str): """Get variable attribute with specific name in the given node proto. - :param node: the node proto.\\ - :param attr_name: str for the name of the target.\\ - :param attr_type: str which should be "float", "int", "string" or "tensor".\\ + :param node: the node proto. + :param attr_name: str for the name of the target. + :param attr_type: str which should be "float", "int", "string" or "tensor". :return: if found, return the variable. Else, return None. """ attr_proto = get_attribute_by_name(node, attr_name) @@ -333,7 +335,7 @@ def get_var_attribute_by_name(node, attr_name: str, attr_type: str): elif attr_type == "float": return attr_proto.f elif attr_type == "string": - if type(attr_proto.s) == type(b'abc'): + if isinstance(attr_proto.s, bytes): return attr_proto.s.decode("utf-8") else: return attr_proto.s @@ -343,22 +345,25 @@ def get_var_attribute_by_name(node, attr_name: str, attr_type: str): print("Warning: undefined type for variable attribute extraction") return None + def flatten_with_depth(data, depth): output = [] if type(data) not in [type(np.array([1])), type([1])]: return [[data, 0]] for item in data: if type(item) not in [type(np.array([1])), type([1])]: - output.append([item, depth+1]) + output.append([item, depth + 1]) else: - output += flatten_with_depth(item, depth+1) + output += flatten_with_depth(item, depth + 1) return output + def flatten_to_list(data): flatten_depth = flatten_with_depth(data, 0) flat_data = [item[0] for item in flatten_depth] return flat_data + def get_shape(data): shape = [] if type(data) not in [type(np.array([1])), type([1])]: @@ -378,7 +383,7 @@ def slice_data(data, starts, ends, axes): starts_updated = [] ends_updated = [] for i in range(len(starts)): - start_updated = min(starts[i], shape[i]-1) % shape[i] + start_updated = min(starts[i], shape[i] - 1) % shape[i] starts_updated.append(start_updated) for j in range(len(starts)): if ends[j] >= shape[j]: @@ -393,19 +398,21 @@ def slice_data(data, starts, ends, axes): index_slices.append(list(range(shape[i]))) else: axe_ind = axes.index(i) - index_slices.append(list(range(starts_updated[axe_ind], ends_updated[axe_ind]))) + index_slices.append( + list(range(starts_updated[axe_ind], ends_updated[axe_ind])) + ) indices = [1] - for i in range(len(shape)-1, -1, -1): - step = np.prod(shape[i+1:]) + for i in range(len(shape) - 1, -1, -1): + step = np.prod(shape[i + 1:]) temp_pos = indices new_indices = [] for n in index_slices[i]: for pos in temp_pos: - new_indices.append(int(n*step+pos)) + new_indices.append(int(n * step + pos)) indices = new_indices - sliced_data = [flat_data[k-1] for k in indices] + sliced_data = [flat_data[k - 1] for k in indices] # reshape to correct shape. new_shape = [] @@ -414,48 +421,51 @@ def slice_data(data, starts, ends, axes): new_shape.append(shape[i]) else: axe_ind = axes.index(i) - new_shape.append(ends_updated[axe_ind]-starts_updated[axe_ind]) + new_shape.append(ends_updated[axe_ind] - starts_updated[axe_ind]) if any([dim < 1 for dim in new_shape]): - raise RuntimeError('Invalid starts ends.') - + raise RuntimeError("Invalid starts ends.") + sliced_data = np.reshape(sliced_data, new_shape) return sliced_data + def concatenate(data_sets, axis): # check shapes shapes = [] shapes_ = [] for data_set in data_sets: - shape = get_shape(data_set) - shapes.append(list(shape)) - shape.pop(axis) - shapes_.append(shape) + shape = get_shape(data_set) + shapes.append(list(shape)) + shape.pop(axis) + shapes_.append(shape) if not all([s == shapes_[0] for s in shapes_]): - raise RuntimeError('data sets shapes do not match') - + raise RuntimeError("data sets shapes do not match") + new_dim = sum([s[axis] for s in shapes]) new_shape = list(shapes[0]) new_shape[axis] = new_dim flat_data_sets = [] for data_set in data_sets: - flat_data_sets.append(flatten_to_list(data_set)) - + flat_data_sets.append(flatten_to_list(data_set)) + sub_block_size = 1 - for i in range(axis+1, len(shapes[0])): - sub_block_size *= shapes[0][i] - + for i in range(axis + 1, len(shapes[0])): + sub_block_size *= shapes[0][i] + split_num = 1 for i in range(axis): - split_num *= shapes[0][i] + split_num *= shapes[0][i] total_flat_data = [] for i in range(split_num): - for j in range(len(shapes)): - block_size = sub_block_size*shapes[j][axis] - total_flat_data.extend(flat_data_sets[j][i*block_size:(i+1)*block_size]) - + for j in range(len(shapes)): + block_size = sub_block_size * shapes[j][axis] + total_flat_data.extend( + flat_data_sets[j][i * block_size:(i + 1) * block_size] + ) + new_data = np.reshape(total_flat_data, new_shape) return new_data @@ -464,158 +474,169 @@ def concatenate(data_sets, axis): def broadcast_data_sets(data_set_1, data_set_2): shape1 = get_shape(data_set_1) shape2 = get_shape(data_set_2) - + # compare shapes and get broadcasted shape - list_a, list_b = (shape1, shape2) if len(shape1) > len(shape2) else (shape2, shape1) + list_a, list_b = ( + (shape1, shape2) if len(shape1) > len(shape2) else (shape2, shape1) + ) while len(list_a) > len(list_b): list_b.insert(0, 0) broadcasted_shape = [] for i in range(len(list_a)): - if list_b[i] == 0: - broadcasted_shape.append(list_a[i]) - elif list_b[i] == 1: - broadcasted_shape.append(list_a[i]) - elif list_a[i] == 1: - broadcasted_shape.append(list_b[i]) - elif list_a[i] == list_b[i]: - broadcasted_shape.append(list_a[i]) - else: - raise RuntimeError('Can not broadcast two data sets') + if list_b[i] == 0: + broadcasted_shape.append(list_a[i]) + elif list_b[i] == 1: + broadcasted_shape.append(list_a[i]) + elif list_a[i] == 1: + broadcasted_shape.append(list_b[i]) + elif list_a[i] == list_b[i]: + broadcasted_shape.append(list_a[i]) + else: + raise RuntimeError("Can not broadcast two data sets") # prepare data for broadcasting. - shape1 = list(map(lambda x:x if x != 0 else 1, shape1)) - shape2 = list(map(lambda x:x if x != 0 else 1, shape2)) + shape1 = list(map(lambda x: x if x != 0 else 1, shape1)) + shape2 = list(map(lambda x: x if x != 0 else 1, shape2)) data_1 = np.reshape(data_set_1, shape1) data_2 = np.reshape(data_set_2, shape2) for i in range(len(shape1)): - if shape1[i] != broadcasted_shape[i]: - new_data_total = [list(data_1) for _ in range(broadcasted_shape[i])] - data_1 = concatenate(new_data_total, axis=i) + if shape1[i] != broadcasted_shape[i]: + new_data_total = [ + list(data_1) for _ in range(broadcasted_shape[i]) + ] + data_1 = concatenate(new_data_total, axis=i) for i in range(len(shape2)): - if shape2[i] != broadcasted_shape[i]: - new_data_total = [list(data_2) for _ in range(broadcasted_shape[i])] - data_2 = concatenate(new_data_total, axis=i) + if shape2[i] != broadcasted_shape[i]: + new_data_total = [ + list(data_2) for _ in range(broadcasted_shape[i]) + ] + data_2 = concatenate(new_data_total, axis=i) return data_1, data_2 def add(data_set_1, data_set_2): - broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2) + broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets( + data_set_1, data_set_2 + ) - flat_data_1 = flatten_to_list(broadcasted_data_1) - flat_data_2 = flatten_to_list(broadcasted_data_2) - shape = get_shape(broadcasted_data_1) - res = [] - for i in range(len(flat_data_1)): - res.append(flat_data_1[i]+flat_data_2[i]) - - res = np.reshape(res, shape) + flat_data_1 = flatten_to_list(broadcasted_data_1) + flat_data_2 = flatten_to_list(broadcasted_data_2) + shape = get_shape(broadcasted_data_1) + res = [] + for i in range(len(flat_data_1)): + res.append(flat_data_1[i] + flat_data_2[i]) - return res + res = np.reshape(res, shape) + + return res def reduceprod(data_set, axis, keepdims=1): - flat_data = flatten_to_list(data_set) - old_shape = get_shape(data_set) + flat_data = flatten_to_list(data_set) + old_shape = get_shape(data_set) - temp_shape = old_shape - temp_flat_data = flat_data - for ax in axis: - split_num = 1 - step = 1 - for i in range(ax): - split_num *= temp_shape[i] - for i in range(ax+1, len(temp_shape)): - step *= temp_shape[i] - - block_size = len(temp_flat_data)//split_num - new_flat_data = [] - for j in range(split_num): - block_data = temp_flat_data[j*block_size:(j+1)*block_size] - reduced_block_data = [] - for k in range(step): - val = block_data[k] - for l in range(1, block_size//step): - val *= block_data[k+l*step] - reduced_block_data.append(val) - new_flat_data.extend(reduced_block_data) - temp_flat_data = new_flat_data - temp_shape[ax] = 1 - - new_flat_data = temp_flat_data - new_shape = temp_shape - if not keepdims: - axis = sorted(list(axis)) - for pos in axis[::-1]: - new_shape.pop(pos) - - return np.reshape(new_flat_data, new_shape) + temp_shape = old_shape + temp_flat_data = flat_data + for ax in axis: + split_num = 1 + step = 1 + for i in range(ax): + split_num *= temp_shape[i] + for i in range(ax + 1, len(temp_shape)): + step *= temp_shape[i] + + block_size = len(temp_flat_data) // split_num + new_flat_data = [] + for j in range(split_num): + block_data = temp_flat_data[j * block_size:(j + 1) * block_size] + reduced_block_data = [] + for k in range(step): + val = block_data[k] + for li in range(1, block_size // step): + val *= block_data[k + li * step] + reduced_block_data.append(val) + new_flat_data.extend(reduced_block_data) + temp_flat_data = new_flat_data + temp_shape[ax] = 1 + + new_flat_data = temp_flat_data + new_shape = temp_shape + if not keepdims: + axis = sorted(list(axis)) + for pos in axis[::-1]: + new_shape.pop(pos) + + return np.reshape(new_flat_data, new_shape) def transpose(data_set, permutation): - # find series of local swaps - data_set = list(data_set) - perm = list(permutation) - shape = get_shape(data_set) - flat_data = flatten_to_list(data_set) - assert set(perm) == set(range(len(shape))), 'invalid permutation' + # find series of local swaps + data_set = list(data_set) + perm = list(permutation) + shape = get_shape(data_set) + flat_data = flatten_to_list(data_set) + assert set(perm) == set(range(len(shape))), "invalid permutation" - new_shape = [shape[i] for i in perm] - swaps = [] - bubbled = True - while bubbled: - bubbled = False - for i in range(len(new_shape)-1): - if perm[i] > perm[i+1]: - swaps.append([i, i+1]) - p_1, p_2 = perm[i], perm[i+1] - perm[i], perm[i+1] = p_2, p_1 - bubbled = True - - # apply local swaps - current_shape = list(shape) - temp_flat_data = flat_data + new_shape = [shape[i] for i in perm] + swaps = [] + bubbled = True + while bubbled: + bubbled = False + for i in range(len(new_shape) - 1): + if perm[i] > perm[i + 1]: + swaps.append([i, i + 1]) + p_1, p_2 = perm[i], perm[i + 1] + perm[i], perm[i + 1] = p_2, p_1 + bubbled = True - for swap in swaps[::-1]: - ind_1, ind_2 = swap[0], swap[1] - dim_1 = current_shape[ind_1] - dim_2 = current_shape[ind_2] - split_num = 1 - block_size = 1 + # apply local swaps + current_shape = list(shape) + temp_flat_data = flat_data - for i in range(ind_1): - split_num *= current_shape[i] - for i in range(ind_2+1, len(current_shape)): - block_size *= current_shape[i] + for swap in swaps[::-1]: + ind_1, ind_2 = swap[0], swap[1] + dim_1 = current_shape[ind_1] + dim_2 = current_shape[ind_2] + split_num = 1 + block_size = 1 - data_blocks = np.reshape(temp_flat_data, [-1, block_size]) - flat_data_1 = [] - for k in range(split_num): - block = [] - for m in range(dim_2): - for n in range(dim_1): - block_pos = k*dim_1*dim_2 + n*dim_2+m - block.extend(data_blocks[block_pos]) - flat_data_1.extend(block) + for i in range(ind_1): + split_num *= current_shape[i] + for i in range(ind_2 + 1, len(current_shape)): + block_size *= current_shape[i] - temp_flat_data = flat_data_1 - current_shape[ind_1] = dim_2 - current_shape[ind_2] = dim_1 + data_blocks = np.reshape(temp_flat_data, [-1, block_size]) + flat_data_1 = [] + for k in range(split_num): + block = [] + for m in range(dim_2): + for n in range(dim_1): + block_pos = k * dim_1 * dim_2 + n * dim_2 + m + block.extend(data_blocks[block_pos]) + flat_data_1.extend(block) + + temp_flat_data = flat_data_1 + current_shape[ind_1] = dim_2 + current_shape[ind_2] = dim_1 + + return np.reshape(temp_flat_data, current_shape) - return np.reshape(temp_flat_data, current_shape) def subtract(data_set_1, data_set_2): - broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2) + broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets( + data_set_1, data_set_2 + ) shape = get_shape(broadcasted_data_1) flat_data_1 = flatten_to_list(broadcasted_data_1) flat_data_2 = flatten_to_list(broadcasted_data_2) - substracted_data = [flat_data_1[i] - flat_data_2[i] for i in range(len(flat_data_1))] + substracted_data = [ + flat_data_1[i] - flat_data_2[i] for i in range(len(flat_data_1)) + ] new_data = np.reshape(substracted_data, shape) return new_data - - \ No newline at end of file diff --git a/tools/optimizer_scripts/tools/modhelper.py b/tools/optimizer_scripts/tools/modhelper.py index 5e8302f..ca5e040 100644 --- a/tools/optimizer_scripts/tools/modhelper.py +++ b/tools/optimizer_scripts/tools/modhelper.py @@ -1,7 +1,7 @@ -"""This module contains helper functions that do graph modifications. +""" +This module contains helper functions that do graph modifications. """ -import onnx from . import helper @@ -10,9 +10,10 @@ def replace_node_input(node, old_input, new_input): if input_name == old_input: node.input[i] = new_input + def delete_nodes(g, node_list): node_to_delete = [] - #Find target nodes + # Find target nodes for node in g.node: if node.name not in node_list: continue @@ -23,16 +24,28 @@ def delete_nodes(g, node_list): for node in node_to_delete: # Check the node whether if it is valid to delete if len(node.input) == 0: - print("Deleting an Constant node. Please make sure you also delete all its following nodes") + print( + "Deleting an Constant node. " + "Please make sure you also delete all its following nodes" + ) elif len(node.input) > 1: - print("Warning: Node {} has more than one input. This script cannot delete merge nodes.".format(node.name)) + print( + f"Warning: Node {node.name} has more than one input. " + "This script cannot delete merge nodes." + ) # Connect the nodes around the target node. # Set the following node input as the previous node output. - following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, node.output[0] + ) if len(node.input) == 0: for following_node in following_nodes: following_node.input.remove(node.output[0]) - elif len(following_nodes) > 0 and len(node.input) == 1 and helper.find_input_by_name(g, node.input[0]) is not None: + elif ( + len(following_nodes) > 0 + and len(node.input) == 1 + and helper.find_input_by_name(g, node.input[0]) is not None + ): # The node input is an input new_input = helper.find_value_by_name(g, node.output[0]) g.input.append(new_input) @@ -40,9 +53,11 @@ def delete_nodes(g, node_list): g.value_info.remove(new_input) elif len(following_nodes) > 0: for following_node in following_nodes: - replace_node_input(following_node, node.output[0], node.input[0]) + replace_node_input( + following_node, node.output[0], node.input[0] + ) else: - # If the node is the output, replace the output with the previous input. + # If the node is the output, replace it with previous input. value = helper.find_value_by_name(g, node.input[0]) output_values = [] while len(g.output): @@ -56,6 +71,7 @@ def delete_nodes(g, node_list): # Remove the node and value info. g.node.remove(node) + def delete_input(g, target_list): for name in target_list: input_value = helper.find_input_by_name(g, name) @@ -64,6 +80,7 @@ def delete_input(g, target_list): continue g.input.remove(input_value) + def delete_output(g, target_list): for name in target_list: output_value = helper.find_output_by_name(g, name) @@ -72,6 +89,7 @@ def delete_output(g, target_list): continue g.output.remove(output_value) + def delete_value_with_name_if_exists(g, name): value = helper.find_value_by_name(g, name) if value is not None: diff --git a/tools/optimizer_scripts/tools/other.py b/tools/optimizer_scripts/tools/other.py index 171179e..b003fbb 100644 --- a/tools/optimizer_scripts/tools/other.py +++ b/tools/optimizer_scripts/tools/other.py @@ -1,5 +1,6 @@ -"""Optimization functions that are not fusing, eliminating or replacing. In most -cases, these are the modifications on the original nodes. +""" +Optimization functions that are not fusing, eliminating or replacing. +In most cases, these are the modifications on the original nodes. """ import struct import collections @@ -15,9 +16,9 @@ from .helper import logger def polish_model(model): - ''' + """ This function combines several useful utility functions together. - ''' + """ onnx.checker.check_model(model) onnx.helper.strip_doc_string(model) model = onnx.shape_inference.infer_shapes(model) @@ -33,21 +34,31 @@ def format_value_info_shape(g): :param g: the onnx graph """ for value in g.input: - if len(value.type.tensor_type.shape.dim) > 0 and\ - (value.type.tensor_type.shape.dim[0].dim_value <= 0 or\ - not isinstance(value.type.tensor_type.shape.dim[0].dim_value, int)): + if len(value.type.tensor_type.shape.dim) > 0 and ( + value.type.tensor_type.shape.dim[0].dim_value <= 0 + or not isinstance( + value.type.tensor_type.shape.dim[0].dim_value, int + ) + ): value.type.tensor_type.shape.dim[0].dim_value = 1 for value in g.output: - if len(value.type.tensor_type.shape.dim) > 0 and\ - (value.type.tensor_type.shape.dim[0].dim_value <= 0 or\ - not isinstance(value.type.tensor_type.shape.dim[0].dim_value, int)): + if len(value.type.tensor_type.shape.dim) > 0 and ( + value.type.tensor_type.shape.dim[0].dim_value <= 0 + or not isinstance( + value.type.tensor_type.shape.dim[0].dim_value, int + ) + ): value.type.tensor_type.shape.dim[0].dim_value = 1 for value in g.value_info: - if len(value.type.tensor_type.shape.dim) > 0 and\ - (value.type.tensor_type.shape.dim[0].dim_value < 0 or\ - not isinstance(value.type.tensor_type.shape.dim[0].dim_value, int)): + if len(value.type.tensor_type.shape.dim) > 0 and ( + value.type.tensor_type.shape.dim[0].dim_value < 0 + or not isinstance( + value.type.tensor_type.shape.dim[0].dim_value, int + ) + ): value.type.tensor_type.shape.dim[0].dim_value = 1 + def add_name_to_node(g): """ If no name presents, give a name based on output name. @@ -58,6 +69,7 @@ def add_name_to_node(g): if len(node.name) == 0: node.name = node.output[0] + def rename_all_node_name(g): """ rename all nodes if the node name is a number: @@ -76,23 +88,28 @@ def rename_all_node_name(g): # in order to keep same output node name, skip if it is output node. output_value_info = helper.find_output_by_name(g, node.output[0]) - if output_value_info != None: + if output_value_info is not None: continue # rename the input of all the following nodes - following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, node.output[0] + ) for following_node in following_nodes: - replace_node_input(following_node, node.output[0], new_node_output0_name ) + replace_node_input( + following_node, node.output[0], new_node_output0_name + ) # rename value info value_info = helper.find_value_by_name(g, node.output[0]) - if value_info != None: + if value_info is not None: value_info.name = new_node_output0_name # rename node node.output[0] = new_node_output0_name node.name = new_node_name + def add_output_to_value_info(g): """ If output does not present in value_info, copy one @@ -103,6 +120,7 @@ def add_output_to_value_info(g): if helper.find_value_by_name(g, output.name) is None: g.value_info.extend([output]) + def find_first_sequential_output(g, node): for value_name in node.output: value = helper.find_output_by_name(g, value_name) @@ -114,15 +132,17 @@ def find_first_sequential_output(g, node): return None return find_first_sequential_output(g, next_nodes[0]) + def remove_nodes(g, cut_nodes=[], cut_types=[]): node_to_delete = [] - #Find target nodes + # Find target nodes for node in g.node: if node.name not in cut_nodes and node.op_type not in cut_types: continue else: node_to_delete.append(node) - # Mapping originnal outputs to new outputs. This mapping is to keep the output order. + # Mapping originnal outputs to new outputs. + # This mapping is to keep the output order. output_mapping = {} new_output = set() for node in node_to_delete: @@ -131,7 +151,11 @@ def remove_nodes(g, cut_nodes=[], cut_types=[]): output_mapping[original_output.name] = [] for input_name in node.input: value = helper.find_value_by_name(g, input_name) - if value is not None and helper.find_output_by_name(g, input_name) is None and value.name not in new_output: + if ( + value is not None + and helper.find_output_by_name(g, input_name) is None + and value.name not in new_output + ): output_mapping[original_output.name].append(value) new_output.add(value.name) # Remove them @@ -143,7 +167,7 @@ def remove_nodes(g, cut_nodes=[], cut_types=[]): for input_value in g.input: visited_values.add(input_value.name) for node in g.node: - if node.op_type == 'Constant': + if node.op_type == "Constant": visited_values.add(node.output[0]) unused_constant_map[node.output[0]] = node continue @@ -166,20 +190,24 @@ def remove_nodes(g, cut_nodes=[], cut_types=[]): output_mapping[original_output.name] = [] for input_name in node.input: value = helper.find_value_by_name(g, input_name) - if value is not None and helper.find_output_by_name(g, input_name) is None and value.name not in new_output: + if ( + value is not None + and helper.find_output_by_name(g, input_name) is None + and value.name not in new_output + ): output_mapping[original_output.name].append(value) new_output.add(value.name) # Remove them while node_to_delete: g.node.remove(node_to_delete.pop()) - #Remove unused constants + # Remove unused constants for node in g.node: for input_name in node.input: if input_name in unused_constant_map: del unused_constant_map[input_name] for node in unused_constant_map.values(): g.node.remove(node) - #Remove unreachable value infos + # Remove unreachable value infos reachable_values = set() for input_value in g.input: reachable_values.add(input_value.name) @@ -205,13 +233,22 @@ def remove_nodes(g, cut_nodes=[], cut_types=[]): logger.info("Keep output {}".format(output_value.name)) g.output.extend([output_value]) elif output_value.name in output_mapping: - real_outputs = [i for i in output_mapping[output_value.name] if i.name in reachable_values] - logger.info("Replace output {} with {}".format(output_value.name, [i.name for i in real_outputs])) + real_outputs = [ + i + for i in output_mapping[output_value.name] + if i.name in reachable_values + ] + logger.info( + "Replace output {} with {}".format( + output_value.name, [i.name for i in real_outputs] + ) + ) g.output.extend(real_outputs) else: logger.info("Abandon output {}".format(output_value.name)) continue + def transpose_B_in_Gemm(g): """ If transB is set in Gemm, transpose it @@ -219,7 +256,7 @@ def transpose_B_in_Gemm(g): :param g: the onnx graph """ for node in g.node: - if node.op_type != 'Gemm': + if node.op_type != "Gemm": continue do_it = False for attr in node.attribute: @@ -241,18 +278,19 @@ def transpose_B_in_Gemm(g): w_node.attribute[0].t.dims[1] = dim_0 if w_node.attribute[0].t.raw_data: raw_data = w_node.attribute[0].t.raw_data - fl_data = [i[0] for i in struct.iter_unpack('f', raw_data)] + fl_data = [i[0] for i in struct.iter_unpack("f", raw_data)] else: fl_data = w_node.attribute[0].t.float_data w = np.reshape(fl_data, (dim_0, dim_1)) w = w.transpose((1, 0)).flatten() if w_node.attribute[0].t.raw_data: - buf = struct.pack('%sf' % len(w), *w) + buf = struct.pack("%sf" % len(w), *w) w_node.attribute[0].t.raw_data = buf else: for i in range(len(fl_data)): w_node.attribute[0].t.float_data[i] = w[i] + def topological_sort(g): """ Topological sort all the layers. @@ -273,12 +311,12 @@ def topological_sort(g): for _ in range(length): node = g.node.pop() node_map[node.name] = node - if len([i for i in node.input if i != '']) == 0: + if len([i for i in node.input if i != ""]) == 0: to_add.append(node.name) else: - in_degree[node.name] = len([i for i in node.input if i != '']) + in_degree[node.name] = len([i for i in node.input if i != ""]) for input_name in node.input: - if input_name == '': + if input_name == "": continue output_nodes[input_name].append(node.name) # sort @@ -308,10 +346,13 @@ def topological_sort(g): del in_degree[next_node_name] g.node.extend(sorted_nodes) if in_degree: - raise RuntimeError("Unreachable nodes exist: {}".format(in_degree.keys())) + raise RuntimeError( + "Unreachable nodes exist: {}".format(in_degree.keys()) + ) if node_map: raise RuntimeError("Unused nodes exist: {}".format(node_map.keys())) + def remove_zero_value_info(g): value_info_list = list(g.value_info) for vi in value_info_list: @@ -323,6 +364,7 @@ def remove_zero_value_info(g): g.value_info.remove(vi) break + def inference_shapes(m): while len(m.graph.value_info) > 0: m.graph.value_info.pop() @@ -346,26 +388,32 @@ def inference_shapes(m): m = polish_model(m) return m + def inference_resize_shape(g): for node in g.node: - if node.op_type != 'Resize': + if node.op_type != "Resize": continue output_value = helper.find_value_by_name(g, node.output[0]) - output_value = helper.find_output_by_name(g, node.output[0]) if output_value is None else output_value + output_value = ( + helper.find_output_by_name(g, node.output[0]) + if output_value is None + else output_value + ) if output_value is not None: continue - if len(node.input) == 4: # input: X, roi, scales, sizes + if len(node.input) == 4: # input: X, roi, scales, sizes shape_node = helper.find_node_by_output_name(g, node.input[3]) - if shape_node.op_type != 'Constant': + if shape_node.op_type != "Constant": continue _, shape_value = helper.constant_to_list(shape_node) output_value = onnx.helper.make_tensor_value_info( - node.output[0], - onnx.TensorProto.FLOAT, - [int(v) for v in shape_value]) + node.output[0], + onnx.TensorProto.FLOAT, + [int(v) for v in shape_value], + ) g.value_info.extend([output_value]) return True else: @@ -376,19 +424,21 @@ def inference_resize_shape(g): continue shape_value = helper.get_shape_from_value_info(input_value) scales_node = helper.find_node_by_output_name(g, node.input[2]) - if scales_node.op_type != 'Constant': + if scales_node.op_type != "Constant": continue _, scales_value = helper.constant_to_list(scales_node) for i in range(len(shape_value)): shape_value[i] *= scales_value[i] output_value = onnx.helper.make_tensor_value_info( - node.output[0], - onnx.TensorProto.FLOAT, - [int(v) for v in shape_value]) + node.output[0], + onnx.TensorProto.FLOAT, + [int(v) for v in shape_value], + ) g.value_info.extend([output_value]) return True return False + def inference_upsample_shape(g): """For onnx v1.4.1+, onnx cannot inference upsample output shape. Let's\\ do it ourselves. This function only inference the next upsample without\\ @@ -398,7 +448,7 @@ def inference_upsample_shape(g): :return: True if any Upsample shape is generated. Otherwise, False. """ for node in g.node: - if node.op_type != 'Upsample': + if node.op_type != "Upsample": continue output_value = helper.find_value_by_name(g, node.output[0]) if output_value is None: @@ -409,33 +459,37 @@ def inference_upsample_shape(g): input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: continue - #raise RuntimeError("Shape for {} has not been generated.".format(node.input[0])) if not helper.get_shape_from_value_info(input_value): continue - #raise RuntimeError("Shape for {} is empty.".format(node.input[0])) input_shape = helper.get_shape_from_value_info(input_value) # Get upsample weight weight_node = helper.find_node_by_output_name(g, node.input[1]) weight_shape, weight = helper.constant_to_list(weight_node) if len(input_shape) != weight_shape[0]: - raise RuntimeError("Unmatch input shape and weight shape: {} vs {}".format(input_shape, weight_shape)) + raise RuntimeError( + "Unmatch input shape and weight shape: {} vs {}".format( + input_shape, weight_shape + ) + ) # Calculate shape output_shape = list(input_shape) for i in range(len(output_shape)): output_shape[i] = int(input_shape[i] * weight[i]) output_value = onnx.helper.make_tensor_value_info( - node.output[0], - input_value.type.tensor_type.elem_type, - output_shape) + node.output[0], + input_value.type.tensor_type.elem_type, + output_shape, + ) g.value_info.extend([output_value]) return True return False + def inference_cov_shape(g): processed = False for node in g.node: # Check for Conv output shape need to be inferrenced. - if node.op_type != 'Conv': + if node.op_type != "Conv": continue # Input shape is not ready yet. Skip. input_value_info = helper.find_value_by_name(g, node.input[0]) @@ -450,8 +504,9 @@ def inference_cov_shape(g): output_value_info = helper.find_value_by_name(g, node.output[0]) if not output_value_info: output_value_info = helper.find_output_by_name(g, node.output[0]) - if output_value_info and \ - helper.get_shape_from_value_info(output_value_info): + if output_value_info and helper.get_shape_from_value_info( + output_value_info + ): continue # Now start the inference. @@ -461,30 +516,35 @@ def inference_cov_shape(g): if not kernel_shape: continue # If auto_pad is set, use the auto_pad. - auto_pad = helper.get_var_attribute_by_name(node, 'auto_pad', 'string') + auto_pad = helper.get_var_attribute_by_name(node, "auto_pad", "string") pads = None - if auto_pad is not None and auto_pad != 'NOTSET': - if auto_pad == 'SAME_LOWER' or auto_pad == 'SAME_UPPER': + if auto_pad is not None and auto_pad != "NOTSET": + if auto_pad == "SAME_LOWER" or auto_pad == "SAME_UPPER": new_output_value_info = onnx.helper.make_tensor_value_info( node.output[0], input_value_info.type.tensor_type.elem_type, - [input_shape[0], kernel_shape[0], input_shape[2], input_shape[3]] + [ + input_shape[0], + kernel_shape[0], + input_shape[2], + input_shape[3], + ], ) if output_value_info: g.value_info.remove(output_value_info) g.value_info.extend([new_output_value_info]) processed = True continue - elif auto_pad == 'VALID': + elif auto_pad == "VALID": pads = [0, 0, 0, 0] else: logger.error("Unrecognized auto_pad value: " + str(auto_pad)) exit(1) - strides = helper.get_attribute_by_name(node, 'strides').ints + strides = helper.get_attribute_by_name(node, "strides").ints if not pads: - pads = helper.get_attribute_by_name(node, 'pads').ints - dilation = helper.get_attribute_by_name(node, 'dilations').ints + pads = helper.get_attribute_by_name(node, "pads").ints + dilation = helper.get_attribute_by_name(node, "dilations").ints # Pytorch model has the case where strides only have one number if len(strides) == 1: @@ -492,16 +552,34 @@ def inference_cov_shape(g): if len(dilation) == 1: dilation.append(dilation[0]) - H = math.floor((input_shape[2]+pads[0]+pads[2]-\ - dilation[0]*(kernel_shape[2]-1)-1)/strides[0]+1) - W = math.floor((input_shape[3]+pads[1]+pads[3]-\ - dilation[1]*(kernel_shape[3]-1)-1)/strides[1]+1) + H = math.floor( + ( + input_shape[2] + + pads[0] + + pads[2] + - dilation[0] * (kernel_shape[2] - 1) + - 1 + ) + / strides[0] + + 1 + ) + W = math.floor( + ( + input_shape[3] + + pads[1] + + pads[3] + - dilation[1] * (kernel_shape[3] - 1) + - 1 + ) + / strides[1] + + 1 + ) output_shape = [input_shape[0], kernel_shape[0], H, W] new_output_value_info = onnx.helper.make_tensor_value_info( node.output[0], input_value_info.type.tensor_type.elem_type, - output_shape + output_shape, ) processed = True @@ -516,9 +594,9 @@ def inference_cov_shape(g): def inference_split_shape(g): processed = False for node in g.node: - if node.op_type != 'Split': + if node.op_type != "Split": continue - + input_val_info = helper.find_value_by_name(g, node.input[0]) if not input_val_info: input_val_info = helper.find_input_by_name(g, node.input[0]) @@ -530,18 +608,24 @@ def inference_split_shape(g): continue output_val_names = list(node.output) - output_vals = [helper.find_value_by_name(g, val_name) for val_name in output_val_names] + output_vals = [ + helper.find_value_by_name(g, val_name) + for val_name in output_val_names + ] - output_shapes = [helper.find_size_shape_from_value(output_val)[1] for output_val in output_vals] + output_shapes = [ + helper.find_size_shape_from_value(output_val)[1] + for output_val in output_vals + ] if not any([len(s) == 0 for s in output_shapes]): continue for att in node.attribute: - if att.name == 'axis': + if att.name == "axis": axis = att.i else: split = list(att.ints) - + new_output_vals = [] for i in range(len(output_val_names)): new_shape = list(input_shape) @@ -549,24 +633,23 @@ def inference_split_shape(g): new_output_val = onnx.helper.make_tensor_value_info( output_val_names[i], input_val_info.type.tensor_type.elem_type, - new_shape + new_shape, ) new_output_vals.append(new_output_val) - + for val in output_vals: if val is not None: g.value_info.remove(val) g.value_info.extend(new_output_vals) processed = True - + return processed def parse_shape_change_input(s: str): - """The input should be like 'input 1 1 224 224'. - """ - s_list = s.split(' ') + """The input should be like 'input 1 1 224 224'.""" + s_list = s.split(" ") if len(s_list) < 2: print("Cannot parse the shape change input: {}".format(s)) return None @@ -575,6 +658,7 @@ def parse_shape_change_input(s: str): shape.append(int(s_list[i])) return s_list[0], shape + def change_input_shape(g, target_list): for target in target_list: try: @@ -596,6 +680,7 @@ def change_input_shape(g, target_list): print("Cannot parse {} into name and int".format(target)) continue + def change_output_shape(g, target_list): for target in target_list: try: @@ -617,6 +702,7 @@ def change_output_shape(g, target_list): print("Cannot parse {} into name and int".format(target)) continue + def add_nop_conv_after(g, value_names): """Add do-nothing depthwise Conv nodes after the given value info. It will\\ take the given names as the inputs of the new node and replace the inputs\\ @@ -641,31 +727,31 @@ def add_nop_conv_after(g, value_names): # Construct 4 weights node_name = value_name + "_nop_conv" ones = [1.0] * channel - weight_node = helper.list_to_constant(node_name + "_weight", [channel, 1, 1, 1], ones) + weight_node = helper.list_to_constant( + node_name + "_weight", [channel, 1, 1, 1], ones + ) # Construct BN node conv_node = onnx.helper.make_node( "Conv", - [value_name, - weight_node.output[0]], + [value_name, weight_node.output[0]], [node_name], - name = node_name, - dilations = [1, 1], - group = channel, - kernel_shape = [1, 1], - pads = [0, 0, 0, 0], - strides = [1, 1] + name=node_name, + dilations=[1, 1], + group=channel, + kernel_shape=[1, 1], + pads=[0, 0, 0, 0], + strides=[1, 1], ) # Reconnect the graph - following_nodes = helper.find_following_nodes_by_input_value_name(g, value_name) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, value_name + ) if len(following_nodes) > 0: for following_node in following_nodes: replace_node_input(following_node, value_name, node_name) else: - # If the node is the output, replace the output with the previous input. new_value = onnx.helper.make_tensor_value_info( - node_name, - value.type.tensor_type.elem_type, - shape + node_name, value.type.tensor_type.elem_type, shape ) output_values = [] while len(g.output): @@ -680,12 +766,13 @@ def add_nop_conv_after(g, value_names): g.node.extend([conv_node, weight_node]) topological_sort(g) -def add_nop_bn_after(g, value_names): - """Add do-nothing BatchNormalization nodes after the given value info. It will\\ - take the given names as the inputs of the new node and replace the inputs\\ - of the following nodes. - :param g: the graph\\ +def add_nop_bn_after(g, value_names): + """Add do-nothing BatchNormalization nodes after the given value info. + It will take the given names as the inputs of the new node and replace + the inputs of the following nodes. + + :param g: the graph :param value_names: a list of string which are the names of value_info. """ for value_name in value_names: @@ -705,32 +792,39 @@ def add_nop_bn_after(g, value_names): node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel - scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) - bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) - mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) + scale_node = helper.list_to_constant( + node_name + "_scale", [channel], ones + ) + bias_node = helper.list_to_constant( + node_name + "_bias", [channel], zeros + ) + mean_node = helper.list_to_constant( + node_name + "_mean", [channel], zeros + ) var_node = helper.list_to_constant(node_name + "_var", [channel], ones) # Construct BN node bn_node = onnx.helper.make_node( "BatchNormalization", - [value_name, - scale_node.output[0], - bias_node.output[0], - mean_node.output[0], - var_node.output[0]], + [ + value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0], + ], [node_name], - name = node_name + name=node_name, ) # Reconnect the graph - following_nodes = helper.find_following_nodes_by_input_value_name(g, value_name) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, value_name + ) if len(following_nodes) > 0: for following_node in following_nodes: replace_node_input(following_node, value_name, node_name) else: - # If the node is the output, replace the output with the previous input. new_value = onnx.helper.make_tensor_value_info( - node_name, - value.type.tensor_type.elem_type, - shape + node_name, value.type.tensor_type.elem_type, shape ) output_values = [] while len(g.output): @@ -745,12 +839,14 @@ def add_nop_bn_after(g, value_names): g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) topological_sort(g) -def add_bias_scale_bn_after(g, value_name, channel_bias, channel_scale): - """Add do-nothing BatchNormalization nodes after the given value info. It will\\ - take the given names as the inputs of the new node and replace the inputs\\ - of the following nodes. - :param g: the graph\\ +def add_bias_scale_bn_after(g, value_name, channel_bias, channel_scale): + """ + Add do-nothing BatchNormalization nodes after the given value info. + It will take the given names as the inputs of the new node and replace + the inputs of the following nodes. + + :param g: the graph :param value_name: a list of string which are the name of value_info. """ # Find the value first @@ -769,32 +865,37 @@ def add_bias_scale_bn_after(g, value_name, channel_bias, channel_scale): node_name = value_name + "_scale_shift_bn" ones = [1.0] * channel zeros = [0.0] * channel - scale_node = helper.list_to_constant(node_name + "_scale", [len(channel_scale)], channel_scale) - bias_node = helper.list_to_constant(node_name + "_bias", [len(channel_bias)], channel_bias) + scale_node = helper.list_to_constant( + node_name + "_scale", [len(channel_scale)], channel_scale + ) + bias_node = helper.list_to_constant( + node_name + "_bias", [len(channel_bias)], channel_bias + ) mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) var_node = helper.list_to_constant(node_name + "_var", [channel], ones) # Construct BN node bn_node = onnx.helper.make_node( "BatchNormalization", - [value_name, - scale_node.output[0], - bias_node.output[0], - mean_node.output[0], - var_node.output[0]], + [ + value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0], + ], [node_name], - name = node_name + name=node_name, ) # Reconnect the graph - following_nodes = helper.find_following_nodes_by_input_value_name(g, value_name) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, value_name + ) if len(following_nodes) > 0: for following_node in following_nodes: replace_node_input(following_node, value_name, node_name) else: - # If the node is the output, replace the output with the previous input. new_value = onnx.helper.make_tensor_value_info( - node_name, - value.type.tensor_type.elem_type, - shape + node_name, value.type.tensor_type.elem_type, shape ) output_values = [] while len(g.output): @@ -809,6 +910,7 @@ def add_bias_scale_bn_after(g, value_name, channel_bias, channel_scale): g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) topological_sort(g) + def duplicate_shared_Flatten(g): """To feed our compiler, bind Flatten with Gemm. If the output of one\\ Flatten goes to two Gemm nodes, duplicate the Flatten. @@ -817,15 +919,17 @@ def duplicate_shared_Flatten(g): """ for node in g.node: # Find a Flatten node - if node.op_type != 'Flatten': + if node.op_type != "Flatten": continue # Check Flatten outputs. Get following Gemm - output_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + output_nodes = helper.find_following_nodes_by_input_value_name( + g, node.output[0] + ) if len(output_nodes) < 2: continue gemm_nodes = [] for output_node in output_nodes: - if output_node.op_type == 'Gemm': + if output_node.op_type == "Gemm": gemm_nodes.append(output_node) if len(gemm_nodes) < 2: continue @@ -838,13 +942,14 @@ def duplicate_shared_Flatten(g): node.input, [new_flatten_name], name=new_flatten_name, - axis=1 + axis=1, ) # Connect new graph replace_node_input(gemm_nodes[i], node.output[0], new_flatten_name) g.node.extend([new_flatten_node]) topological_sort(g) + def deconv_to_conv_info_extraction(input_size, node_proto): """Extract the information needed for deconv split. @@ -854,29 +959,62 @@ def deconv_to_conv_info_extraction(input_size, node_proto): """ attr = dict() # Get attributes from Deconv node - attr["auto_pad"] = helper.get_var_attribute_by_name(node_proto, "auto_pad", "string") - attr["dilations"] = helper.get_list_attribute_by_name(node_proto, "dilations", "int") - attr["group"] = helper.get_var_attribute_by_name(node_proto, "group", "int") - attr["kernel_shape"] = helper.get_list_attribute_by_name(node_proto, "kernel_shape", "int") - attr["output_padding"] = helper.get_list_attribute_by_name(node_proto, "output_padding", "int") + attr["auto_pad"] = helper.get_var_attribute_by_name( + node_proto, "auto_pad", "string" + ) + attr["dilations"] = helper.get_list_attribute_by_name( + node_proto, "dilations", "int" + ) + attr["group"] = helper.get_var_attribute_by_name( + node_proto, "group", "int" + ) + attr["kernel_shape"] = helper.get_list_attribute_by_name( + node_proto, "kernel_shape", "int" + ) + attr["output_padding"] = helper.get_list_attribute_by_name( + node_proto, "output_padding", "int" + ) attr["pads"] = helper.get_list_attribute_by_name(node_proto, "pads", "int") - attr["strides"] = helper.get_list_attribute_by_name(node_proto, "strides", "int") + attr["strides"] = helper.get_list_attribute_by_name( + node_proto, "strides", "int" + ) # Get output_padding if attr["output_padding"] is None: - if attr["auto_pad"] == "SAME_LOWER" or attr["auto_pad"] == "SAME_UPPER": - attr["output_padding"] = [attr["strides"][0] - 1, attr["strides"][1]] + if ( + attr["auto_pad"] == "SAME_LOWER" + or attr["auto_pad"] == "SAME_UPPER" + ): + attr["output_padding"] = [ + attr["strides"][0] - 1, + attr["strides"][1], + ] else: - attr["output_padding"] = [max(attr["strides"][0] - attr["kernel_shape"][0], 0), - max(attr["strides"][1] - attr["kernel_shape"][1], 0)] + attr["output_padding"] = [ + max(attr["strides"][0] - attr["kernel_shape"][0], 0), + max(attr["strides"][1] - attr["kernel_shape"][1], 0), + ] # Calculate conv_padding if attr["auto_pad"] == "SAME_LOWER" or attr["auto_pad"] == "SAME_UPPER": - pad1_h = attr["kernel_shape"][0] - (attr["kernel_shape"][0] - 1) // 2 - 1 - pad1_w = attr["kernel_shape"][1] - (attr["kernel_shape"][1] - 1) // 2 - 1 - head_h = min(attr["kernel_shape"][0] // 2, (attr["output_padding"][0] + 1) // 2) - head_w = min(attr["kernel_shape"][1] // 2, (attr["output_padding"][1] + 1) // 2) + pad1_h = ( + attr["kernel_shape"][0] - (attr["kernel_shape"][0] - 1) // 2 - 1 + ) + pad1_w = ( + attr["kernel_shape"][1] - (attr["kernel_shape"][1] - 1) // 2 - 1 + ) + head_h = min( + attr["kernel_shape"][0] // 2, (attr["output_padding"][0] + 1) // 2 + ) + head_w = min( + attr["kernel_shape"][1] // 2, (attr["output_padding"][1] + 1) // 2 + ) tail_h = attr["output_padding"][0] - head_h tail_w = attr["output_padding"][1] - head_w - attr["conv_pads"] = [pad1_h + head_h, pad1_w + head_w, pad1_h + tail_h, pad1_w + tail_w] + attr["conv_pads"] = [ + pad1_h + head_h, + pad1_w + head_w, + pad1_h + tail_h, + pad1_w + tail_w, + ] elif attr["pads"] is not None: sum_of_pads = sum(attr["pads"]) if sum_of_pads == 0: @@ -887,22 +1025,51 @@ def deconv_to_conv_info_extraction(input_size, node_proto): head_w = 0 tail_h = attr["output_padding"][0] - head_h tail_w = attr["output_padding"][1] - head_w - attr["conv_pads"] = [pad1_h + head_h, pad1_w + head_w, pad1_h + tail_h, pad1_w + tail_w] + attr["conv_pads"] = [ + pad1_h + head_h, + pad1_w + head_w, + pad1_h + tail_h, + pad1_w + tail_w, + ] else: # Calculate output shape tmp_output_shape = [0, 0] - tmp_output_shape[0] = attr["strides"][0] * (input_size[2] - 1) + attr["output_padding"][0] + attr["kernel_shape"][0] - attr["pads"][0] - attr["pads"][2] - tmp_output_shape[1] = attr["strides"][1] * (input_size[3] - 1) + attr["output_padding"][1] + attr["kernel_shape"][1] - attr["pads"][1] - attr["pads"][3] + tmp_output_shape[0] = ( + attr["strides"][0] * (input_size[2] - 1) + + attr["output_padding"][0] + + attr["kernel_shape"][0] + - attr["pads"][0] + - attr["pads"][2] + ) + tmp_output_shape[1] = ( + attr["strides"][1] * (input_size[3] - 1) + + attr["output_padding"][1] + + attr["kernel_shape"][1] + - attr["pads"][1] + - attr["pads"][3] + ) # Calculate real conv output shape tmp_center_shape = [0, 0] tmp_center_shape[0] = (input_size[2] - 1) * attr["strides"][0] + 1 tmp_center_shape[1] = (input_size[3] - 1) * attr["strides"][1] + 1 # Calculate padding total_padding = [0, 0] - total_padding[0] = tmp_output_shape[0] - tmp_center_shape[0] + attr["kernel_shape"][0] - 1 - total_padding[1] = tmp_output_shape[1] - tmp_center_shape[1] + attr["kernel_shape"][1] - 1 + total_padding[0] = ( + tmp_output_shape[0] + - tmp_center_shape[0] + + attr["kernel_shape"][0] + - 1 + ) + total_padding[1] = ( + tmp_output_shape[1] + - tmp_center_shape[1] + + attr["kernel_shape"][1] + - 1 + ) if total_padding[0] < 0 or total_padding[1] < 0: - raise RuntimeError(node_proto.name + " cannot infer conv padding.") + raise RuntimeError( + node_proto.name + " cannot infer conv padding." + ) conv_pads_ = [0] * 4 conv_pads_[0] = total_padding[0] // 2 conv_pads_[1] = total_padding[1] // 2 @@ -916,9 +1083,15 @@ def deconv_to_conv_info_extraction(input_size, node_proto): head_w = 0 tail_h = attr["output_padding"][0] - head_h tail_w = attr["output_padding"][1] - head_w - attr["conv_pads"] = [pad1_h + head_h, pad1_w + head_w, pad1_h + tail_h, pad1_w + tail_w] + attr["conv_pads"] = [ + pad1_h + head_h, + pad1_w + head_w, + pad1_h + tail_h, + pad1_w + tail_w, + ] return attr + def split_ConvTranspose(model): """To feed our compiler, split ConvTranspose into Upsample and Conv. @@ -934,7 +1107,7 @@ def split_ConvTranspose(model): # Get a Convtranspose layer for node in g.node: # Find a Flatten node - if node.op_type != 'ConvTranspose': + if node.op_type != "ConvTranspose": continue # Check auto_pad auto_pad_proto = helper.get_attribute_by_name(node, "auto_pad") @@ -958,11 +1131,15 @@ def split_ConvTranspose(model): attr = deconv_to_conv_info_extraction(input_shape, node) # Generate Upsample scales upsample_output_shape = list(input_shape) - upsample_output_shape[2] = (input_shape[2] - 1) * attr["strides"][0] + 1 - upsample_output_shape[3] = (input_shape[3] - 1) * attr["strides"][1] + 1 + upsample_output_shape[2] = (input_shape[2] - 1) * attr["strides"][ + 0 + ] + 1 + upsample_output_shape[3] = (input_shape[3] - 1) * attr["strides"][ + 1 + ] + 1 upsample_node_name = node.name + "_inner_upsample" upsample_scale_name = upsample_node_name + "_scales" - scales_np = np.ones([4]).astype('float32') + scales_np = np.ones([4]).astype("float32") scales_np[2] = float(upsample_output_shape[2]) / input_shape[2] scales_np[3] = float(upsample_output_shape[3]) / input_shape[3] scales_node = helper.numpy_to_constant(upsample_scale_name, scales_np) @@ -972,19 +1149,21 @@ def split_ConvTranspose(model): [node.input[0], upsample_scale_name], [upsample_node_name], name=upsample_node_name, - mode="zeros" + mode="zeros", ) upsample_value_info = onnx.helper.make_tensor_value_info( upsample_node_name, input_value.type.tensor_type.elem_type, - upsample_output_shape + upsample_output_shape, ) # Check the weight layer, it may need a transpose if attr["group"] != input_shape[1]: weight_node = helper.find_node_by_output_name(g, node.input[1]) weight_np = helper.constant_to_numpy(weight_node) new_weight_np = np.transpose(weight_np, [1, 0, 2, 3]) - new_weight_node = helper.numpy_to_constant(node.input[1], new_weight_np) + new_weight_node = helper.numpy_to_constant( + node.input[1], new_weight_np + ) node_to_delete.append(weight_node) g.node.extend([new_weight_node]) value = helper.find_value_by_name(g, node.input[1]) @@ -1002,7 +1181,7 @@ def split_ConvTranspose(model): dilations=[int(i) for i in attr["dilations"]], group=int(attr["group"]), kernel_shape=[int(i) for i in attr["kernel_shape"]], - strides=[int(1), int(1)] + strides=[int(1), int(1)], ) # Reconnect the graph g.node.extend([scales_node, upsample_node, conv_node]) @@ -1013,20 +1192,28 @@ def split_ConvTranspose(model): g.node.remove(node) topological_sort(g) + def add_bn_on_skip_branch(g): for n in g.node: # Find merge node (Add) - if n.op_type != 'Add': + if n.op_type != "Add": continue if len(n.input) != 2: continue # TODO: Still need to consider more cases # Check if skip branch exist input_node_a = helper.find_node_by_output_name(g, n.input[0]) - output_of_input_node_a = helper.find_nodes_by_input_name(g, input_node_a.output[0]) + output_of_input_node_a = helper.find_nodes_by_input_name( + g, input_node_a.output[0] + ) input_node_b = helper.find_node_by_output_name(g, n.input[1]) - output_of_input_node_b = helper.find_nodes_by_input_name(g, input_node_b.output[0]) - if len(output_of_input_node_a) == 1 and len(output_of_input_node_b) == 1: + output_of_input_node_b = helper.find_nodes_by_input_name( + g, input_node_b.output[0] + ) + if ( + len(output_of_input_node_a) == 1 + and len(output_of_input_node_b) == 1 + ): continue if len(output_of_input_node_a) == 2: split_node = input_node_a @@ -1043,20 +1230,28 @@ def add_bn_on_skip_branch(g): node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel - scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) - bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) - mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) + scale_node = helper.list_to_constant( + node_name + "_scale", [channel], ones + ) + bias_node = helper.list_to_constant( + node_name + "_bias", [channel], zeros + ) + mean_node = helper.list_to_constant( + node_name + "_mean", [channel], zeros + ) var_node = helper.list_to_constant(node_name + "_var", [channel], ones) # Construct BN node bn_node = onnx.helper.make_node( "BatchNormalization", - [value_name, - scale_node.output[0], - bias_node.output[0], - mean_node.output[0], - var_node.output[0]], + [ + value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0], + ], [node_name], - name = node_name + name=node_name, ) # Reconnect the graph replace_node_input(n, value_name, node_name) @@ -1064,10 +1259,11 @@ def add_bn_on_skip_branch(g): g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) topological_sort(g) + def add_bn_before_add(g): for n in g.node: # Find merge node (Add) - if n.op_type != 'Add': + if n.op_type != "Add": continue if len(n.input) != 2: continue @@ -1075,10 +1271,11 @@ def add_bn_before_add(g): input_node_a = helper.find_node_by_output_name(g, n.input[0]) input_node_b = helper.find_node_by_output_name(g, n.input[1]) # Skip constant input add - if input_node_a is None or input_node_a.op_type == 'Constant': + if input_node_a is None or input_node_a.op_type == "Constant": continue - if input_node_b is None or input_node_b.op_type == 'Constant': + if input_node_b is None or input_node_b.op_type == "Constant": continue + def add_bn_after(prev_node): # Get the channel number from value info value_name = prev_node.output[0] @@ -1089,35 +1286,65 @@ def add_bn_before_add(g): node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel - scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) - bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) - mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) - var_node = helper.list_to_constant(node_name + "_var", [channel], ones) + scale_node = helper.list_to_constant( + node_name + "_scale", [channel], ones + ) + bias_node = helper.list_to_constant( + node_name + "_bias", [channel], zeros + ) + mean_node = helper.list_to_constant( + node_name + "_mean", [channel], zeros + ) + var_node = helper.list_to_constant( + node_name + "_var", [channel], ones + ) # Construct BN node bn_node = onnx.helper.make_node( "BatchNormalization", - [value_name, - scale_node.output[0], - bias_node.output[0], - mean_node.output[0], - var_node.output[0]], + [ + value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0], + ], [node_name], - name = node_name, - epsilon=0.00000001 + name=node_name, + epsilon=0.00000001, ) # Reconnect the graph replace_node_input(n, value_name, node_name) # Add node to the graph - g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) - if not input_node_a.op_type == 'BatchNormalization' or len(helper.find_following_nodes_by_input_value_name(g, input_node_a.output[0])) > 1: + g.node.extend( + [bn_node, scale_node, bias_node, mean_node, var_node] + ) + + if ( + not input_node_a.op_type == "BatchNormalization" + or len( + helper.find_following_nodes_by_input_value_name( + g, input_node_a.output[0] + ) + ) + > 1 + ): add_bn_after(input_node_a) - if not input_node_b.op_type == 'BatchNormalization' or len(helper.find_following_nodes_by_input_value_name(g, input_node_b.output[0])) > 1: + if ( + not input_node_b.op_type == "BatchNormalization" + or len( + helper.find_following_nodes_by_input_value_name( + g, input_node_b.output[0] + ) + ) + > 1 + ): add_bn_after(input_node_b) topological_sort(g) + def add_bn_before_activation(g): - activation_nodes = set(['Relu', 'Clip', 'PRelu', 'LeakyRelu']) - previous_nodes = set(['Conv', 'BatchNormalization']) + activation_nodes = set(["Relu", "Clip", "PRelu", "LeakyRelu"]) + previous_nodes = set(["Conv", "BatchNormalization"]) for n in g.node: # Find activation node if n.op_type not in activation_nodes: @@ -1126,6 +1353,7 @@ def add_bn_before_activation(g): input_node = helper.find_node_by_output_name(g, n.input[0]) if input_node is None or input_node.op_type in previous_nodes: continue + def add_bn_after(prev_node): # Get the channel number from value info value_name = prev_node.output[0] @@ -1136,29 +1364,43 @@ def add_bn_before_activation(g): node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel - scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) - bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) - mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) - var_node = helper.list_to_constant(node_name + "_var", [channel], ones) + scale_node = helper.list_to_constant( + node_name + "_scale", [channel], ones + ) + bias_node = helper.list_to_constant( + node_name + "_bias", [channel], zeros + ) + mean_node = helper.list_to_constant( + node_name + "_mean", [channel], zeros + ) + var_node = helper.list_to_constant( + node_name + "_var", [channel], ones + ) # Construct BN node bn_node = onnx.helper.make_node( "BatchNormalization", - [value_name, - scale_node.output[0], - bias_node.output[0], - mean_node.output[0], - var_node.output[0]], + [ + value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0], + ], [node_name], - name = node_name, - epsilon=0.00000001 + name=node_name, + epsilon=0.00000001, ) # Reconnect the graph replace_node_input(n, value_name, node_name) # Add node to the graph - g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) + g.node.extend( + [bn_node, scale_node, bias_node, mean_node, var_node] + ) + add_bn_after(input_node) topological_sort(g) + def rename_output_name(g, original_name, new_name): # Output output_value = helper.find_output_by_name(g, original_name) @@ -1178,19 +1420,28 @@ def rename_output_name(g, original_name, new_name): for node in nodes: replace_node_input(node, original_name, new_name) + def duplicate_param_shared_constant(g): for node in g.node: input_names = set() for n, input_node_name in enumerate(node.input): - param_data_node = helper.find_node_by_output_name(g, input_node_name) - if param_data_node is None or param_data_node.op_type != 'Constant': + param_data_node = helper.find_node_by_output_name( + g, input_node_name + ) + if ( + param_data_node is None + or param_data_node.op_type != "Constant" + ): continue if param_data_node.name not in input_names: input_names.add(input_node_name) continue - new_node_name = param_data_node.name + '_' + str(n) - helper.logger.debug(f"Duplicating weight: {param_data_node.name} -> {new_node_name}") + new_node_name = param_data_node.name + "_" + str(n) + helper.logger.debug( + f"Duplicating weight: {param_data_node.name} -> " + f"{new_node_name}" + ) duplicated_node = copy.deepcopy(param_data_node) duplicated_node.name = new_node_name diff --git a/tools/optimizer_scripts/tools/removing_transpose.py b/tools/optimizer_scripts/tools/removing_transpose.py index d0b7882..89f772b 100644 --- a/tools/optimizer_scripts/tools/removing_transpose.py +++ b/tools/optimizer_scripts/tools/removing_transpose.py @@ -1,317 +1,368 @@ from . import helper from . import other from . import modhelper -from . import fusing import numpy as np import onnx import onnx.utils -def eliminate_transposes(m): - g = m.graph - keep_eliminating = True - while keep_eliminating: - while swap_transpose_with_single_next_node(g): - pass - splitted = split_transpose_for_multiple_next_nodes(g) - annihilated = annihilate_transposes(g) - multiple_trans_swapped = swap_multiple_transposes_with_node(g) - keep_eliminating = splitted or annihilated or multiple_trans_swapped - if keep_eliminating: - m = other.polish_model(m) - g = m.graph - - return m +def eliminate_transposes(m): + g = m.graph + keep_eliminating = True + while keep_eliminating: + while swap_transpose_with_single_next_node(g): + pass + splitted = split_transpose_for_multiple_next_nodes(g) + annihilated = annihilate_transposes(g) + multiple_trans_swapped = swap_multiple_transposes_with_node(g) + keep_eliminating = splitted or annihilated or multiple_trans_swapped + + if keep_eliminating: + m = other.polish_model(m) + g = m.graph + + return m def swap_transpose_with_single_next_node(g): - swapped = False - passable_nodes = set(['Relu', 'Neg', 'LeakyRelu', 'Sqrt', 'Reciprocal', 'Add', 'Mul', 'Tanh']) - for node in g.node: - trans_node = node - # Check for transpose node - if trans_node.op_type != 'Transpose': - continue - next_nodes = helper.find_nodes_by_input_name(g, trans_node.output[0]) - if len(next_nodes) != 1: - continue - next_node = next_nodes[0] - # Check if the next node is the type can be swapped - if next_node.op_type not in passable_nodes: - continue + swapped = False + passable_nodes = set( + [ + "Relu", + "Neg", + "LeakyRelu", + "Sqrt", + "Reciprocal", + "Add", + "Mul", + "Tanh", + ] + ) + for node in g.node: + trans_node = node + # Check for transpose node + if trans_node.op_type != "Transpose": + continue + next_nodes = helper.find_nodes_by_input_name(g, trans_node.output[0]) + if len(next_nodes) != 1: + continue + next_node = next_nodes[0] + # Check if the next node is the type can be swapped + if next_node.op_type not in passable_nodes: + continue - input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in next_node.input] + input_nodes = [ + helper.find_node_by_output_name(g, input_name) + for input_name in next_node.input + ] - # Check if the node has nonconstant input other than the Transpose node itself - nonconstant_input = False - for input_node in input_nodes: - if input_node == None: - nonconstant_input = True - break - if input_node.name == trans_node.name: - continue - elif input_node.op_type == 'Constant': - continue - else: - nonconstant_input = True - break - if nonconstant_input: - continue + # Check if the node has nonconstant input + # other than the Transpose node itself + nonconstant_input = False + for input_node in input_nodes: + if input_node is None: + nonconstant_input = True + break + if input_node.name == trans_node.name: + continue + elif input_node.op_type == "Constant": + continue + else: + nonconstant_input = True + break + if nonconstant_input: + continue - for input_node in input_nodes: - if input_node.name == trans_node.name: - # if the input is just the transpose node - next_value_info = helper.find_value_by_name(g, next_node.output[0]) - mid_value_info = helper.find_value_by_name(g, trans_node.output[0]) + for input_node in input_nodes: + if input_node.name == trans_node.name: + # if the input is just the transpose node + next_value_info = helper.find_value_by_name( + g, next_node.output[0] + ) + mid_value_info = helper.find_value_by_name( + g, trans_node.output[0] + ) - output_nodes = helper.find_nodes_by_input_name(g, next_node.output[0]) - for out_node in output_nodes: - modhelper.replace_node_input(out_node, next_node.output[0], trans_node.name) + output_nodes = helper.find_nodes_by_input_name( + g, next_node.output[0] + ) + for out_node in output_nodes: + modhelper.replace_node_input( + out_node, next_node.output[0], trans_node.name + ) - next_node.input[0] = trans_node.input[0] - next_node.output[0] = next_node.name - trans_node.input[0] = next_node.name - trans_node.output[0] = trans_node.name + next_node.input[0] = trans_node.input[0] + next_node.output[0] = next_node.name + trans_node.input[0] = next_node.name + trans_node.output[0] = trans_node.name - if next_value_info: - next_value_info.name = trans_node.name - if mid_value_info: - g.value_info.remove(mid_value_info) - else: - # if the input is a constant node - old_tensor = input_node.attribute[0].t - old_shape, data = helper.constant_to_list(input_node) - # If the constant node is a scaler, no action is needed - if type(old_shape) == int: - old_shape = [old_shape] - permutation = list(trans_node.attribute[0].ints) - while len(old_shape) < len(permutation): - old_shape.insert(0, 1) - np_data = np.reshape(data, old_shape) - reverse_perm = [] - for i in range(len(permutation)): - reverse_perm.append(permutation.index(i)) - np_data = np.transpose(np_data, reverse_perm) - new_shape = np_data.shape - new_tensor = onnx.helper.make_tensor( - name=old_tensor.name, - data_type=old_tensor.data_type, - dims=new_shape, - vals=np_data.flatten().tolist() - ) - new_node = onnx.helper.make_node( - 'Constant', - [], - [input_node.output[0]], - name=input_node.name, - value=new_tensor - ) - g.node.extend([new_node]) + if next_value_info: + next_value_info.name = trans_node.name + if mid_value_info: + g.value_info.remove(mid_value_info) + else: + # if the input is a constant node + old_tensor = input_node.attribute[0].t + old_shape, data = helper.constant_to_list(input_node) + # If the constant node is a scaler, no action is needed + if type(old_shape) == int: + old_shape = [old_shape] + permutation = list(trans_node.attribute[0].ints) + while len(old_shape) < len(permutation): + old_shape.insert(0, 1) + np_data = np.reshape(data, old_shape) + reverse_perm = [] + for i in range(len(permutation)): + reverse_perm.append(permutation.index(i)) + np_data = np.transpose(np_data, reverse_perm) + new_shape = np_data.shape + new_tensor = onnx.helper.make_tensor( + name=old_tensor.name, + data_type=old_tensor.data_type, + dims=new_shape, + vals=np_data.flatten().tolist(), + ) + new_node = onnx.helper.make_node( + "Constant", + [], + [input_node.output[0]], + name=input_node.name, + value=new_tensor, + ) + g.node.extend([new_node]) - g.value_info.remove(helper.find_value_by_name(g, input_node.output[0])) - g.node.remove(input_node) + g.value_info.remove( + helper.find_value_by_name(g, input_node.output[0]) + ) + g.node.remove(input_node) - swapped = True + swapped = True - other.topological_sort(g) - return swapped + other.topological_sort(g) + return swapped def swap_multiple_transposes_with_node(g): - # here only consider same input transposes - swapped = False - passable_nodes = set(['Add', 'Mul']) - node_to_del = [] - for node in g.node: - if node.op_type not in passable_nodes: - continue - input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in node.input] - if any([input_node == None for input_node in input_nodes]): - continue - if any([input_node.op_type != 'Transpose' for input_node in input_nodes]): - continue + # here only consider same input transposes + swapped = False + passable_nodes = set(["Add", "Mul"]) + node_to_del = [] + for node in g.node: + if node.op_type not in passable_nodes: + continue + input_nodes = [ + helper.find_node_by_output_name(g, input_name) + for input_name in node.input + ] + if any([input_node is None for input_node in input_nodes]): + continue + if any( + [input_node.op_type != "Transpose" for input_node in input_nodes] + ): + continue - permutation = list(input_nodes[0].attribute[0].ints) - if any([list(input_node.attribute[0].ints) != permutation for input_node in input_nodes]): - continue - - for input_name in node.input: - input_node = helper.find_node_by_output_name(g, input_name) - modhelper.replace_node_input(node, input_name, input_node.input[0]) + permutation = list(input_nodes[0].attribute[0].ints) + if any( + [ + list(input_node.attribute[0].ints) != permutation + for input_node in input_nodes + ] + ): + continue - node_to_del.extend(input_nodes) - for input_node in input_nodes: - input_val_info = helper.find_value_by_name(g, input_node.output[0]) - if input_val_info is not None: - g.value_info.remove(input_val_info) - output_val_info = helper.find_value_by_name(g, node.output[0]) - if output_val_info is not None: - g.value_info.remove(output_val_info) + for input_name in node.input: + input_node = helper.find_node_by_output_name(g, input_name) + modhelper.replace_node_input(node, input_name, input_node.input[0]) - output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) - for i in range(len(output_nodes)): - new_trans_node_name = node.name+'_trans_'+str(i) - new_trans_node = onnx.helper.make_node( - 'Transpose', - [node.output[0]], - [new_trans_node_name], - name=new_trans_node_name, - perm=permutation - ) - modhelper.replace_node_input(output_nodes[i], node.output[0], new_trans_node_name) - - g.node.extend([new_trans_node]) - - swapped = True - - while node_to_del: - node = node_to_del.pop() - g.node.remove(node) - - other.topological_sort(g) - return swapped + node_to_del.extend(input_nodes) + for input_node in input_nodes: + input_val_info = helper.find_value_by_name(g, input_node.output[0]) + if input_val_info is not None: + g.value_info.remove(input_val_info) + output_val_info = helper.find_value_by_name(g, node.output[0]) + if output_val_info is not None: + g.value_info.remove(output_val_info) + + output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + for i in range(len(output_nodes)): + new_trans_node_name = node.name + "_trans_" + str(i) + new_trans_node = onnx.helper.make_node( + "Transpose", + [node.output[0]], + [new_trans_node_name], + name=new_trans_node_name, + perm=permutation, + ) + modhelper.replace_node_input( + output_nodes[i], node.output[0], new_trans_node_name + ) + + g.node.extend([new_trans_node]) + + swapped = True + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + other.topological_sort(g) + return swapped def annihilate_transposes(g): - node_to_del = [] - annihilated = False - for node in g.node: - if node.op_type != 'Transpose': - continue - pre_node = helper.find_node_by_output_name(g, node.input[0]) - if not pre_node or pre_node.op_type != 'Transpose': - continue - nodes_from_top_transpose = helper.find_nodes_by_input_name(g, pre_node.output[0]) - if len(nodes_from_top_transpose) > 1: - continue - - perm_1 = list(pre_node.attribute[0].ints) - perm_2 = list(node.attribute[0].ints) - if perm_1 != perm_2: - continue + node_to_del = [] + annihilated = False + for node in g.node: + if node.op_type != "Transpose": + continue + pre_node = helper.find_node_by_output_name(g, node.input[0]) + if not pre_node or pre_node.op_type != "Transpose": + continue + nodes_from_top_transpose = helper.find_nodes_by_input_name( + g, pre_node.output[0] + ) + if len(nodes_from_top_transpose) > 1: + continue - out_nodes = helper.find_nodes_by_input_name(g, node.output[0]) - for out_node in out_nodes: - modhelper.replace_node_input(out_node, node.output[0], pre_node.input[0]) - - node_to_del.extend([node, pre_node]) - mid_value_info = helper.find_value_by_name(g, pre_node.output[0]) - out_value_info = helper.find_value_by_name(g, node.output[0]) - g.value_info.remove(mid_value_info) - g.value_info.remove(out_value_info) + perm_1 = list(pre_node.attribute[0].ints) + perm_2 = list(node.attribute[0].ints) + if perm_1 != perm_2: + continue - annihilated = True - while node_to_del: - node = node_to_del.pop() - g.node.remove(node) - - return annihilated + out_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + for out_node in out_nodes: + modhelper.replace_node_input( + out_node, node.output[0], pre_node.input[0] + ) + + node_to_del.extend([node, pre_node]) + mid_value_info = helper.find_value_by_name(g, pre_node.output[0]) + out_value_info = helper.find_value_by_name(g, node.output[0]) + g.value_info.remove(mid_value_info) + g.value_info.remove(out_value_info) + + annihilated = True + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return annihilated def split_transpose_for_multiple_next_nodes(g): - splitted = False - node_to_del = [] - for node in g.node: - if node.op_type != 'Transpose': - continue - output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) - if len(output_nodes) < 2: - continue - for i in range(len(output_nodes)): - output_node = output_nodes[i] - new_trans_node_name = node.name + '_' + str(i) - new_trans_node = onnx.helper.make_node( - 'Transpose', - [node.input[0]], - [new_trans_node_name], - name=new_trans_node_name, - perm=list(node.attribute[0].ints) - ) - modhelper.replace_node_input(output_node, node.output[0], new_trans_node.output[0]) - g.node.extend([new_trans_node]) - - node_to_del.append(node) - val_info = helper.find_value_by_name(g, node.output[0]) - g.value_info.remove(val_info) + splitted = False + node_to_del = [] + for node in g.node: + if node.op_type != "Transpose": + continue + output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + if len(output_nodes) < 2: + continue + for i in range(len(output_nodes)): + output_node = output_nodes[i] + new_trans_node_name = node.name + "_" + str(i) + new_trans_node = onnx.helper.make_node( + "Transpose", + [node.input[0]], + [new_trans_node_name], + name=new_trans_node_name, + perm=list(node.attribute[0].ints), + ) + modhelper.replace_node_input( + output_node, node.output[0], new_trans_node.output[0] + ) + g.node.extend([new_trans_node]) + + node_to_del.append(node) + val_info = helper.find_value_by_name(g, node.output[0]) + g.value_info.remove(val_info) + + splitted = True + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + other.topological_sort(g) + return splitted - splitted = True - - while node_to_del: - node = node_to_del.pop() - g.node.remove(node) - - other.topological_sort(g) - return splitted def remove_trivial_transpose(g): - node_to_del = [] - for node in g.node: - if node.op_type != 'Transpose': - continue - permutation = list(node.attribute[0].ints) - if permutation != list(range(len(permutation))): - continue - - next_nodes = helper.find_nodes_by_input_name(g, node.output[0]) - if not next_nodes: - input_val_info = helper.find_value_by_name(g, node.input[0]) - out_val_info = helper.find_output_by_name(g, node.output[0]) - if not input_val_info: - input_val_info = helper.find_input_by_name(g, node.input[0]) - g.output.remove(out_val_info) - g.output.extend([input_val_info]) - else: - out_val_info = helper.find_value_by_name(g, node.output[0]) - for next_node in next_nodes: - modhelper.replace_node_input(next_node, node.output[0], node.input[0]) - g.value_info.remove(out_val_info) - - node_to_del.append(node) - - while node_to_del: - node = node_to_del.pop() - g.node.remove(node) - - other.topological_sort(g) + node_to_del = [] + for node in g.node: + if node.op_type != "Transpose": + continue + permutation = list(node.attribute[0].ints) + if permutation != list(range(len(permutation))): + continue + + next_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + if not next_nodes: + input_val_info = helper.find_value_by_name(g, node.input[0]) + out_val_info = helper.find_output_by_name(g, node.output[0]) + if not input_val_info: + input_val_info = helper.find_input_by_name(g, node.input[0]) + g.output.remove(out_val_info) + g.output.extend([input_val_info]) + else: + out_val_info = helper.find_value_by_name(g, node.output[0]) + for next_node in next_nodes: + modhelper.replace_node_input( + next_node, node.output[0], node.input[0] + ) + g.value_info.remove(out_val_info) + + node_to_del.append(node) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + other.topological_sort(g) + def fuse_Transpose_into_Gemm_weight(g): - node_to_del = [] - for node in g.node: - # Check pattern - if node.op_type != 'Gemm': - continue - prev_node = helper.find_node_by_output_name(g, node.input[0]) - if prev_node is None or prev_node.op_type != 'Flatten': - continue - transpose_node = helper.find_node_by_output_name(g, prev_node.input[0]) - if transpose_node.op_type != 'Transpose': - continue - # Check attribute - perm = helper.get_list_attribute_by_name(transpose_node, 'perm', 'int') - if perm != [0, 2, 3, 1]: - continue - transB = helper.get_var_attribute_by_name(node, 'transB', 'int') - if transB is not None and transB == 1: - continue - # Get the original weight - origin_weight = helper.find_node_by_output_name(g, node.input[1]) - origin_np = helper.constant_to_numpy(origin_weight) - # Calculate a new weight - shape = helper.get_shape_from_value_info(helper.find_value_by_name(g, prev_node.input[0])) - shape.append(-1) - new_np = np.reshape(origin_np, shape) - new_np = np.transpose(new_np, [0, 3, 1, 2, 4]) - new_np = np.reshape(new_np, [-1, new_np.shape[-1]]) - new_weight = helper.numpy_to_constant(origin_weight.output[0], new_np) - # Replace and eliminate - prev_node.input[0] = transpose_node.input[0] - node_to_del.append(transpose_node) - node_to_del.append(origin_weight) - g.value_info.remove(helper.find_value_by_name(g, transpose_node.output[0])) - g.node.extend([new_weight]) + node_to_del = [] + for node in g.node: + # Check pattern + if node.op_type != "Gemm": + continue + prev_node = helper.find_node_by_output_name(g, node.input[0]) + if prev_node is None or prev_node.op_type != "Flatten": + continue + transpose_node = helper.find_node_by_output_name(g, prev_node.input[0]) + if transpose_node.op_type != "Transpose": + continue + # Check attribute + perm = helper.get_list_attribute_by_name(transpose_node, "perm", "int") + if perm != [0, 2, 3, 1]: + continue + transB = helper.get_var_attribute_by_name(node, "transB", "int") + if transB is not None and transB == 1: + continue + # Get the original weight + origin_weight = helper.find_node_by_output_name(g, node.input[1]) + origin_np = helper.constant_to_numpy(origin_weight) + # Calculate a new weight + shape = helper.get_shape_from_value_info( + helper.find_value_by_name(g, prev_node.input[0]) + ) + shape.append(-1) + new_np = np.reshape(origin_np, shape) + new_np = np.transpose(new_np, [0, 3, 1, 2, 4]) + new_np = np.reshape(new_np, [-1, new_np.shape[-1]]) + new_weight = helper.numpy_to_constant(origin_weight.output[0], new_np) + # Replace and eliminate + prev_node.input[0] = transpose_node.input[0] + node_to_del.append(transpose_node) + node_to_del.append(origin_weight) + g.value_info.remove( + helper.find_value_by_name(g, transpose_node.output[0]) + ) + g.node.extend([new_weight]) - while node_to_del: - node = node_to_del.pop() - g.node.remove(node) + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) - other.topological_sort(g) + other.topological_sort(g) diff --git a/tools/optimizer_scripts/tools/replacing.py b/tools/optimizer_scripts/tools/replacing.py index 091e571..fdbaa62 100644 --- a/tools/optimizer_scripts/tools/replacing.py +++ b/tools/optimizer_scripts/tools/replacing.py @@ -1,6 +1,6 @@ -"""Optimizations that replace one node with another. """ -from os import dup +Optimizations that replace one node with another. +""" import struct import copy import logging @@ -10,6 +10,7 @@ from . import helper from . import modhelper from .other import topological_sort + def replace_initializer_with_Constant(g, duplicate_shared_weights=True): """ Replace initializers with Constant and a corresponding value_info @@ -27,26 +28,24 @@ def replace_initializer_with_Constant(g, duplicate_shared_weights=True): following_nodes = helper.find_nodes_by_input_name(g, tensor.name) if duplicate_shared_weights and len(following_nodes) >= 2: for i, node in enumerate(following_nodes): - new_name = tensor.name + "_duplicated_No" + str(i) if i > 0 else tensor.name - helper.logger.debug(f"Duplicating weight: {tensor.name} -> {new_name}") + new_name = ( + tensor.name + "_duplicated_No" + str(i) + if i > 0 + else tensor.name + ) + helper.logger.debug( + f"Duplicating weight: {tensor.name} -> {new_name}" + ) modhelper.replace_node_input(node, tensor.name, new_name) new_node = onnx.helper.make_node( - "Constant", - [], - [new_name], - name=new_name, - value=tensor + "Constant", [], [new_name], name=new_name, value=tensor ) # Add node to lists g.node.extend([new_node]) else: new_name = tensor.name new_node = onnx.helper.make_node( - "Constant", - [], - [new_name], - name=new_name, - value=tensor + "Constant", [], [new_name], name=new_name, value=tensor ) # Add node to lists g.node.extend([new_node]) @@ -62,6 +61,7 @@ def replace_initializer_with_Constant(g, duplicate_shared_weights=True): topological_sort(g) + def replace_Reshape_with_Flatten(g): """ Replace Reshape node into Flatten node if applicable. @@ -70,19 +70,18 @@ def replace_Reshape_with_Flatten(g): """ node_to_remove = [] for node in g.node: - if node.op_type != 'Reshape': + if node.op_type != "Reshape": continue found_Gemm = False # Flatten could be followed by Gemm for i in g.node: if len(i.input) == 0 or i.input[0] != node.output[0]: continue - if i.op_type == 'Gemm': - found = True + if i.op_type == "Gemm": break # Check weight shape_node = helper.find_node_by_output_name(g, node.input[1]) - if shape_node.op_type != 'Constant': + if shape_node.op_type != "Constant": continue shape_value = helper.constant_to_numpy(shape_node) if (shape_value.size != 2 or shape_value[0] != 1) and not found_Gemm: @@ -95,12 +94,13 @@ def replace_Reshape_with_Flatten(g): node.input.pop() node_to_remove.append(shape_node) # If found shape value_info, remove it - if shape_value != None: + if shape_value is not None: g.value_info.remove(shape_value) for node in node_to_remove: g.node.remove(node) + def replace_Squeeze_with_Reshape(g): """ Replace Squeeze nodes with Reshape node. @@ -110,7 +110,7 @@ def replace_Squeeze_with_Reshape(g): node_to_remove = [] for node in g.node: # Find Squeeze node - if node.op_type != 'Squeeze': + if node.op_type != "Squeeze": continue # Get the shape and Construct the shape output_value = helper.find_value_by_name(g, node.output[0]) @@ -118,14 +118,18 @@ def replace_Squeeze_with_Reshape(g): output_value = helper.find_output_by_name(g, node.output[0]) if output_value is None: raise RuntimeError("Cannot get shape for Squeeze") - shape = [dim.dim_value for dim in output_value.type.tensor_type.shape.dim] - const_node = helper.list_to_constant(node.name + "_shape", [len(shape)], shape) + shape = [ + dim.dim_value for dim in output_value.type.tensor_type.shape.dim + ] + const_node = helper.list_to_constant( + node.name + "_shape", [len(shape)], shape + ) # Construct the Reshape layer with same input, output and name. new_node = onnx.helper.make_node( "Reshape", [node.input[0], node.name + "_shape"], node.output, - name=node.name + name=node.name, ) # Append constructed nodes and append old node to remove_list g.node.extend([const_node, new_node]) @@ -136,6 +140,7 @@ def replace_Squeeze_with_Reshape(g): # Topological sort topological_sort(g) + def replace_Unsqueeze_with_Reshape(g): """ Replace Unsqueeze nodes with Reshape node. @@ -145,7 +150,7 @@ def replace_Unsqueeze_with_Reshape(g): node_to_remove = [] for node in g.node: # Find Squeeze node - if node.op_type != 'Unsqueeze': + if node.op_type != "Unsqueeze": continue # Get the shape and Construct the shape output_value = helper.find_value_by_name(g, node.output[0]) @@ -153,15 +158,19 @@ def replace_Unsqueeze_with_Reshape(g): output_value = helper.find_output_by_name(g, node.output[0]) if output_value is None: raise RuntimeError("Cannot get shape for Unsqueeze") - shape = [dim.dim_value for dim in output_value.type.tensor_type.shape.dim] + shape = [ + dim.dim_value for dim in output_value.type.tensor_type.shape.dim + ] - const_node = helper.list_to_constant(node.name + "_shape", [len(shape)], shape) + const_node = helper.list_to_constant( + node.name + "_shape", [len(shape)], shape + ) # Construct the Reshape layer with same input, output and name. new_node = onnx.helper.make_node( "Reshape", [node.input[0], node.name + "_shape"], node.output, - name=node.name + name=node.name, ) # Append constructed nodes and append old node to remove_list g.node.extend([const_node, new_node]) @@ -172,6 +181,7 @@ def replace_Unsqueeze_with_Reshape(g): # Topological sort topological_sort(g) + def replace_average_pool_with_GAP(g): """ Replace AveragePool nodes with GlobalAveragePool node when available. @@ -181,16 +191,16 @@ def replace_average_pool_with_GAP(g): node_to_remove = [] for node in g.node: # Find a average pool layer - if node.op_type != 'AveragePool': + if node.op_type != "AveragePool": continue # Check attributes not_replace = False for attr in node.attribute: - if attr.name == 'pads': + if attr.name == "pads": if list(attr.ints) != [0, 0, 0, 0]: not_replace = True break - if attr.name == 'kernel_shape': + if attr.name == "kernel_shape": kernel_shape = list(attr.ints) value_info = helper.find_value_by_name(g, node.input[0]) if value_info is None: @@ -206,10 +216,7 @@ def replace_average_pool_with_GAP(g): continue # Replace it with GlobalAveragePool new_node = onnx.helper.make_node( - "GlobalAveragePool", - node.input, - node.output, - name=node.name + "GlobalAveragePool", node.input, node.output, name=node.name ) g.node.extend([new_node]) node_to_remove.append(node) @@ -217,6 +224,7 @@ def replace_average_pool_with_GAP(g): g.node.remove(node) topological_sort(g) + def replace_dilated_conv(g): """ If the dilation of a convolution is not (1, 1), replace it with a regular @@ -227,7 +235,7 @@ def replace_dilated_conv(g): node_to_remove = [] for node in g.node: # Check if this is a conv layer - if node.op_type != 'Conv': + if node.op_type != "Conv": continue # Check if this has dilation has_dilations = False @@ -255,9 +263,9 @@ def replace_dilated_conv(g): if len(weight) == 0: # Unpack from raw data raw_data = w_node.attribute[0].t.raw_data - weight = [i[0] for i in struct.iter_unpack('f', raw_data)] + weight = [i[0] for i in struct.iter_unpack("f", raw_data)] weight = np.array(weight) - weight = np.reshape(weight ,shape) + weight = np.reshape(weight, shape) new_shape = copy.copy(shape) new_shape[2] = 1 + (shape[2] - 1) * dilations[0] new_shape[3] = 1 + (shape[3] - 1) * dilations[1] @@ -273,14 +281,10 @@ def replace_dilated_conv(g): w_node.attribute[0].t.name, w_node.attribute[0].t.data_type, new_shape, - new_weight.ravel() + new_weight.ravel(), ) new_w_node = onnx.helper.make_node( - "Constant", - [], - list(w_node.output), - name=w_node.name, - value=tensor + "Constant", [], list(w_node.output), name=w_node.name, value=tensor ) g.node.extend([new_w_node]) node_to_remove.append(w_node) @@ -298,6 +302,7 @@ def replace_dilated_conv(g): for node in node_to_remove: g.node.remove(node) + def replace_depthwise_1x1_with_bn(g): """Replace 1x1 DepthwiseConv node into BN node if applicable. @@ -306,13 +311,16 @@ def replace_depthwise_1x1_with_bn(g): node_to_remove = [] for node in g.node: # Check op_type - if node.op_type != 'Conv': + if node.op_type != "Conv": continue # Check attributes attr_map = {attr.name: attr for attr in node.attribute} if "group" not in attr_map or attr_map["group"].i == 1: continue - if attr_map["kernel_shape"].ints[0] != 1 or attr_map["kernel_shape"].ints[1] != 1: + if ( + attr_map["kernel_shape"].ints[0] != 1 + or attr_map["kernel_shape"].ints[1] != 1 + ): continue if "pads" in attr_map and sum(attr_map["pads"].ints) != 0: continue @@ -333,29 +341,42 @@ def replace_depthwise_1x1_with_bn(g): bias_name = node.input[2] else: bias_name = node.name + "_bias" - bias_node = helper.list_to_constant(bias_name, [attr_map["group"].i], [0.0] * attr_map["group"].i) + bias_node = helper.list_to_constant( + bias_name, [attr_map["group"].i], [0.0] * attr_map["group"].i + ) g.node.extend([bias_node]) # Construct mean and vars mean_name = node.name + "_mean" - mean_node = helper.list_to_constant(mean_name, [attr_map["group"].i], [0.0] * attr_map["group"].i) + mean_node = helper.list_to_constant( + mean_name, [attr_map["group"].i], [0.0] * attr_map["group"].i + ) var_name = node.name + "_var" - var_node = helper.list_to_constant(var_name, [attr_map["group"].i], [1.0] * attr_map["group"].i) + var_node = helper.list_to_constant( + var_name, [attr_map["group"].i], [1.0] * attr_map["group"].i + ) g.node.extend([mean_node, var_node]) # Convert bn_node = onnx.helper.make_node( - op_type='BatchNormalization', - inputs=[node.input[0], node.input[1], bias_name, mean_name, var_name], + op_type="BatchNormalization", + inputs=[ + node.input[0], + node.input[1], + bias_name, + mean_name, + var_name, + ], outputs=node.output, name=node.name, epsilon=0.00001, - momentum=0.9 - ) + momentum=0.9, + ) g.node.extend([bn_node]) node_to_remove.append(node) for node in node_to_remove: g.node.remove(node) topological_sort(g) + def replace_shape_with_constant(g): """Replace Shape with Constant.\\ This is the first step of reshape constant folding. @@ -366,13 +387,16 @@ def replace_shape_with_constant(g): node_to_remove = [] for node in g.node: # Find a Shape - if node.op_type != 'Shape': + if node.op_type != "Shape": continue # Check its input input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: input_value = helper.find_input_by_name(g, node.input[0]) - if input_value is None or len(input_value.type.tensor_type.shape.dim) == 0: + if ( + input_value is None + or len(input_value.type.tensor_type.shape.dim) == 0 + ): continue # Check for case where dimension could be 0 or -1 tmp = True @@ -382,16 +406,20 @@ def replace_shape_with_constant(g): continue # Repalce it input_shape = [ - d.dim_value for d in input_value.type.tensor_type.shape.dim] + d.dim_value for d in input_value.type.tensor_type.shape.dim + ] node_name = node.output[0] new_node = helper.list_to_constant( - node_name, [len(input_shape)], input_shape) + node_name, [len(input_shape)], input_shape + ) g.node.extend([new_node]) node_to_remove.append(node) # if the input value_info is not used by other node # delete this input value_info - val_info_used = sum([input_value.name in node.input for node in g.node]) + val_info_used = sum( + [input_value.name in node.input for node in g.node] + ) if val_info_used == 1: g.value_info.remove(input_value) @@ -404,6 +432,7 @@ def replace_shape_with_constant(g): return replaced + def replace_ConstantOfShape_with_constant(g): """Replace Shape with Constant.\\ This is the first step of reshape constant folding. @@ -414,24 +443,28 @@ def replace_ConstantOfShape_with_constant(g): node_to_remove = [] for node in g.node: # Find a Shape - if node.op_type != 'ConstantOfShape': + if node.op_type != "ConstantOfShape": continue # Check input input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: input_value = helper.find_input_by_name(g, node.input[0]) - if input_value is None or len(input_value.type.tensor_type.shape.dim) == 0: + if ( + input_value is None + or len(input_value.type.tensor_type.shape.dim) == 0 + ): continue # Replace to constant node pre_node = helper.find_node_by_output_name(g, node.input[0]) _, target_shape = helper.constant_to_list(pre_node) - value = helper.get_attribute_by_name(node, 'value').i + value = helper.get_attribute_by_name(node, "value").i node_name = node.output[0] new_node = helper.list_to_constant( - node_name, [target_shape[0]], [value] * target_shape[0]) + node_name, [target_shape[0]], [value] * target_shape[0] + ) g.node.extend([new_node]) @@ -439,7 +472,9 @@ def replace_ConstantOfShape_with_constant(g): node_to_remove.append(node) # delete value_info - val_info_used = sum([input_value.name in node.input for node in g.node]) + val_info_used = sum( + [input_value.name in node.input for node in g.node] + ) if val_info_used == 1: g.value_info.remove(input_value) @@ -452,6 +487,7 @@ def replace_ConstantOfShape_with_constant(g): return replaced + def replace_split_with_slices(g): """Replace split node with slice nodes. :param g: input graph. @@ -460,7 +496,7 @@ def replace_split_with_slices(g): node_to_remove = [] for node in g.node: # Find a Split - if node.op_type != 'Split': + if node.op_type != "Split": continue input_value = helper.find_value_by_name(g, node.input[0]) @@ -475,9 +511,9 @@ def replace_split_with_slices(g): axis = 0 split = [] for item in node.attribute: - if item.name == 'axis': + if item.name == "axis": axis = item.i - if item.name == 'split': + if item.name == "split": split = item.ints # For opset 11, axis could be negative. @@ -492,39 +528,51 @@ def replace_split_with_slices(g): pos += split[i] new_node_name = output_val_names[i] # Construct starts, ends, axes - starts_name = new_node_name + '_starts_' + str(i) - ends_name = new_node_name + '_ends_' + str(i) - axes_name = new_node_name + '_axes_' + str(i) - starts_node = helper.list_to_constant(starts_name, (1, ), [int(pos-split[i])]) - ends_node = helper.list_to_constant(ends_name, (1, ), [int(pos)]) - axes_node = helper.list_to_constant(axes_name, (1, ), [int(axis)]) + starts_name = new_node_name + "_starts_" + str(i) + ends_name = new_node_name + "_ends_" + str(i) + axes_name = new_node_name + "_axes_" + str(i) + starts_node = helper.list_to_constant( + starts_name, (1,), [int(pos - split[i])] + ) + ends_node = helper.list_to_constant( + ends_name, (1,), [int(pos)] + ) + axes_node = helper.list_to_constant( + axes_name, (1,), [int(axis)] + ) # Construtc node new_node = onnx.helper.make_node( - op_type='Slice', + op_type="Slice", inputs=[node.input[0], starts_name, ends_name, axes_name], outputs=[node.output[i]], - name=new_node_name + name=new_node_name, ) g.node.extend([starts_node, ends_node, axes_node, new_node]) node_to_remove.append(node) else: n_out = len(output_val_names) - width = length//n_out + width = length // n_out for i in range(n_out): new_node_name = output_val_names[i] # Construct starts, ends, axes - starts_name = new_node_name + '_starts_' + str(i) - ends_name = new_node_name + '_ends_' + str(i) - axes_name = new_node_name + '_axes_' + str(i) - starts_node = helper.list_to_constant(starts_name, (1, ), [int(i*width)]) - ends_node = helper.list_to_constant(ends_name, (1, ), [int((1+i)*width)]) - axes_node = helper.list_to_constant(axes_name, (1, ), [int(axis)]) + starts_name = new_node_name + "_starts_" + str(i) + ends_name = new_node_name + "_ends_" + str(i) + axes_name = new_node_name + "_axes_" + str(i) + starts_node = helper.list_to_constant( + starts_name, (1,), [int(i * width)] + ) + ends_node = helper.list_to_constant( + ends_name, (1,), [int((1 + i) * width)] + ) + axes_node = helper.list_to_constant( + axes_name, (1,), [int(axis)] + ) # Construtc node new_node = onnx.helper.make_node( - op_type='Slice', + op_type="Slice", inputs=[node.input[0], starts_name, ends_name, axes_name], outputs=[node.output[i]], - name=new_node_name + name=new_node_name, ) g.node.extend([starts_node, ends_node, axes_node, new_node]) node_to_remove.append(node) @@ -546,19 +594,19 @@ def replace_ReduceMean_with_GlobalAveragePool(g): node_to_remove = [] for node in g.node: # Find a ReduceMean layer - if node.op_type != 'ReduceMean': + if node.op_type != "ReduceMean": continue # Find if it have previous Transpose and its attribute meet the need. prev_node = helper.find_node_by_output_name(g, node.input[0]) - if prev_node is not None and prev_node.op_type != 'Transpose': + if prev_node is not None and prev_node.op_type != "Transpose": prev_node = None if prev_node is not None: - perm = helper.get_list_attribute_by_name(prev_node, 'perm', 'int') + perm = helper.get_list_attribute_by_name(prev_node, "perm", "int") if perm != [0, 2, 3, 1]: prev_node = None # Check attributes - axes = helper.get_list_attribute_by_name(node, 'axes', 'int') - keepdims = helper.get_var_attribute_by_name(node, 'keepdims', 'int') + axes = helper.get_list_attribute_by_name(node, "axes", "int") + keepdims = helper.get_var_attribute_by_name(node, "keepdims", "int") if axes is None: continue if prev_node is None and axes != [2, 3]: @@ -575,20 +623,17 @@ def replace_ReduceMean_with_GlobalAveragePool(g): if keepdims == 1: output_list = node.output else: - output_list = [node.output[0] + '_before_flatten'] + output_list = [node.output[0] + "_before_flatten"] flatten_node = onnx.helper.make_node( "Flatten", output_list, node.output, - name = node.name + "_flatten", - axis = 1 + name=node.name + "_flatten", + axis=1, ) g.node.extend([flatten_node]) new_node = onnx.helper.make_node( - "GlobalAveragePool", - input_list, - output_list, - name=node.name + "GlobalAveragePool", input_list, output_list, name=node.name ) g.node.extend([new_node]) node_to_remove.append(node) @@ -601,6 +646,7 @@ def replace_ReduceMean_with_GlobalAveragePool(g): g.node.remove(node) topological_sort(g) + def replace_mul_to_bn(g): """Replace single Mul node with Batchnorm node. :param g: input graph. @@ -608,30 +654,44 @@ def replace_mul_to_bn(g): """ node_to_del = [] for node in g.node: - if node.op_type != 'Mul': + if node.op_type != "Mul": continue mul_op_node = node # only support one input node - if len(mul_op_node.input) != 2: # OP node and value node + if len(mul_op_node.input) != 2: # OP node and value node continue input_op_node_name = mul_op_node.input[0] - mul_value_node = helper.find_node_by_output_name(g, mul_op_node.input[1]) - if not mul_value_node or mul_value_node.op_type != 'Constant': + mul_value_node = helper.find_node_by_output_name( + g, mul_op_node.input[1] + ) + if not mul_value_node or mul_value_node.op_type != "Constant": continue - prev_shape_value_info = helper.find_value_by_name(g, input_op_node_name) - prev_shape_value_info = helper.find_input_by_name(g, input_op_node_name) if prev_shape_value_info is None else prev_shape_value_info + prev_shape_value_info = helper.find_value_by_name( + g, input_op_node_name + ) + prev_shape_value_info = ( + helper.find_input_by_name(g, input_op_node_name) + if prev_shape_value_info is None + else prev_shape_value_info + ) if prev_shape_value_info is None: continue - _ , previous_node_output_shape = helper.find_size_shape_from_value(prev_shape_value_info) + _, previous_node_output_shape = helper.find_size_shape_from_value( + prev_shape_value_info + ) scale_shape, scale_data = helper.constant_to_list(mul_value_node) # channel dimension - c_dim = previous_node_output_shape[1] if len(previous_node_output_shape) > 1 else 1 + c_dim = ( + previous_node_output_shape[1] + if len(previous_node_output_shape) > 1 + else 1 + ) # only allow channelwise mul or const mul if scale_shape == [1, c_dim, 1, 1]: @@ -646,21 +706,31 @@ def replace_mul_to_bn(g): ones = [1.0] * c_dim zeros = [0.0] * c_dim bn_name = mul_op_node.output[0] - mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) - variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) - bias_value_node = helper.list_to_constant(bn_name+'_add', np.array(zeros).shape, zeros) - new_mul_value_node = helper.list_to_constant(bn_name+'_mul', np.array(muls).shape, muls) + mean_value_node = helper.list_to_constant( + bn_name + "_mean", np.array(zeros).shape, zeros + ) + variance_value_node = helper.list_to_constant( + bn_name + "_var", np.array(ones).shape, ones + ) + bias_value_node = helper.list_to_constant( + bn_name + "_add", np.array(zeros).shape, zeros + ) + new_mul_value_node = helper.list_to_constant( + bn_name + "_mul", np.array(muls).shape, muls + ) bn_node = onnx.helper.make_node( - 'BatchNormalization', - [input_op_node_name, - new_mul_value_node.output[0], - bias_value_node.output[0], - mean_value_node.output[0], - variance_value_node.output[0]], + "BatchNormalization", + [ + input_op_node_name, + new_mul_value_node.output[0], + bias_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0], + ], [mul_op_node.output[0]], name=bn_name, - epsilon=0.00000001 + epsilon=0.00000001, ) scale_val_info = helper.find_value_by_name(g, mul_value_node.output[0]) @@ -680,6 +750,7 @@ def replace_mul_to_bn(g): topological_sort(g) + def replace_div_to_bn(g): """Replace single Div node with Batchnorm node. :param g: input graph. @@ -687,30 +758,44 @@ def replace_div_to_bn(g): """ node_to_del = [] for node in g.node: - if node.op_type != 'Div': + if node.op_type != "Div": continue div_op_node = node # only support one input node - if len(div_op_node.input) != 2: # OP node and value node + if len(div_op_node.input) != 2: # OP node and value node continue input_op_node_name = div_op_node.input[0] - div_value_node = helper.find_node_by_output_name(g, div_op_node.input[1]) - if not div_value_node or div_value_node.op_type != 'Constant': + div_value_node = helper.find_node_by_output_name( + g, div_op_node.input[1] + ) + if not div_value_node or div_value_node.op_type != "Constant": continue - prev_shape_value_info = helper.find_value_by_name(g, input_op_node_name) - prev_shape_value_info = helper.find_input_by_name(g, input_op_node_name) if prev_shape_value_info is None else prev_shape_value_info + prev_shape_value_info = helper.find_value_by_name( + g, input_op_node_name + ) + prev_shape_value_info = ( + helper.find_input_by_name(g, input_op_node_name) + if prev_shape_value_info is None + else prev_shape_value_info + ) if prev_shape_value_info is None: continue - _ , previous_node_output_shape = helper.find_size_shape_from_value(prev_shape_value_info) + _, previous_node_output_shape = helper.find_size_shape_from_value( + prev_shape_value_info + ) scale_shape, scale_data = helper.constant_to_list(div_value_node) # channel dimension - c_dim = previous_node_output_shape[1] if len(previous_node_output_shape) > 1 else 1 + c_dim = ( + previous_node_output_shape[1] + if len(previous_node_output_shape) > 1 + else 1 + ) # only allow channelwise div or const div if scale_shape == [1, c_dim, 1, 1]: @@ -726,21 +811,31 @@ def replace_div_to_bn(g): zeros = [0.0] * c_dim muls = (1 / np.array(muls)).tolist() bn_name = div_op_node.output[0] - mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) - variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) - bias_value_node = helper.list_to_constant(bn_name+'_add', np.array(zeros).shape, zeros) - new_mul_value_node = helper.list_to_constant(bn_name+'_mul', np.array(muls).shape, muls) + mean_value_node = helper.list_to_constant( + bn_name + "_mean", np.array(zeros).shape, zeros + ) + variance_value_node = helper.list_to_constant( + bn_name + "_var", np.array(ones).shape, ones + ) + bias_value_node = helper.list_to_constant( + bn_name + "_add", np.array(zeros).shape, zeros + ) + new_mul_value_node = helper.list_to_constant( + bn_name + "_mul", np.array(muls).shape, muls + ) bn_node = onnx.helper.make_node( - 'BatchNormalization', - [input_op_node_name, - new_mul_value_node.output[0], - bias_value_node.output[0], - mean_value_node.output[0], - variance_value_node.output[0]], + "BatchNormalization", + [ + input_op_node_name, + new_mul_value_node.output[0], + bias_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0], + ], [div_op_node.output[0]], name=bn_name, - epsilon=0.00000001 + epsilon=0.00000001, ) scale_val_info = helper.find_value_by_name(g, div_value_node.output[0]) @@ -768,30 +863,44 @@ def replace_add_to_bn(g): """ node_to_del = [] for node in g.node: - if node.op_type != 'Add': + if node.op_type != "Add": continue add_op_node = node # only support one input node - if len(add_op_node.input) != 2: # OP node and value node + if len(add_op_node.input) != 2: # OP node and value node continue input_op_node_name = add_op_node.input[0] - add_value_node = helper.find_node_by_output_name(g, add_op_node.input[1]) - if not add_value_node or add_value_node.op_type != 'Constant': + add_value_node = helper.find_node_by_output_name( + g, add_op_node.input[1] + ) + if not add_value_node or add_value_node.op_type != "Constant": continue - prev_shape_value_info = helper.find_value_by_name(g, input_op_node_name) - prev_shape_value_info = helper.find_input_by_name(g, input_op_node_name) if prev_shape_value_info is None else prev_shape_value_info + prev_shape_value_info = helper.find_value_by_name( + g, input_op_node_name + ) + prev_shape_value_info = ( + helper.find_input_by_name(g, input_op_node_name) + if prev_shape_value_info is None + else prev_shape_value_info + ) if prev_shape_value_info is None: continue - _ , previous_node_output_shape = helper.find_size_shape_from_value(prev_shape_value_info) + _, previous_node_output_shape = helper.find_size_shape_from_value( + prev_shape_value_info + ) bias_shape, bias_data = helper.constant_to_list(add_value_node) # channel dimension - c_dim = previous_node_output_shape[1] if len(previous_node_output_shape) > 1 else 1 + c_dim = ( + previous_node_output_shape[1] + if len(previous_node_output_shape) > 1 + else 1 + ) # only allow channelwise add or const add if bias_shape == [1, c_dim, 1, 1]: @@ -806,21 +915,31 @@ def replace_add_to_bn(g): ones = [1.0] * c_dim zeros = [0.0] * c_dim bn_name = add_op_node.output[0] - mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) - variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) - scale_value_node = helper.list_to_constant(bn_name+'_mul', np.array(ones).shape, ones) - new_add_value_node = helper.list_to_constant(bn_name+'_add', np.array(bias).shape, bias) + mean_value_node = helper.list_to_constant( + bn_name + "_mean", np.array(zeros).shape, zeros + ) + variance_value_node = helper.list_to_constant( + bn_name + "_var", np.array(ones).shape, ones + ) + scale_value_node = helper.list_to_constant( + bn_name + "_mul", np.array(ones).shape, ones + ) + new_add_value_node = helper.list_to_constant( + bn_name + "_add", np.array(bias).shape, bias + ) bn_node = onnx.helper.make_node( - 'BatchNormalization', - [input_op_node_name, - scale_value_node.output[0], - new_add_value_node.output[0], - mean_value_node.output[0], - variance_value_node.output[0]], + "BatchNormalization", + [ + input_op_node_name, + scale_value_node.output[0], + new_add_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0], + ], [add_op_node.output[0]], name=bn_name, - epsilon=0.00000001 + epsilon=0.00000001, ) add_val_info = helper.find_value_by_name(g, add_value_node.output[0]) @@ -840,6 +959,7 @@ def replace_add_to_bn(g): topological_sort(g) + def replace_sub_to_bn(g): """Replace single Sub node with BatchNorm node. :param g: input graph. @@ -847,13 +967,13 @@ def replace_sub_to_bn(g): """ node_to_del = [] for node in g.node: - if node.op_type != 'Sub': + if node.op_type != "Sub": continue sub_op_node = node # only support one input node - if len(sub_op_node.input) != 2: # OP node and value node + if len(sub_op_node.input) != 2: # OP node and value node continue # Check the input type @@ -861,11 +981,13 @@ def replace_sub_to_bn(g): input_2nd_name = sub_op_node.input[1] input_1st_node = helper.find_node_by_output_name(g, input_1st_name) input_2nd_node = helper.find_node_by_output_name(g, input_2nd_name) - if input_1st_node is not None and input_1st_node.op_type == 'Constant': + if input_1st_node is not None and input_1st_node.op_type == "Constant": real_input_name = input_2nd_name reverse = True constant_node = input_1st_node - elif input_2nd_node is not None and input_2nd_node.op_type == 'Constant': + elif ( + input_2nd_node is not None and input_2nd_node.op_type == "Constant" + ): real_input_name = input_1st_name reverse = False constant_node = input_2nd_node @@ -874,15 +996,25 @@ def replace_sub_to_bn(g): # Get shapes prev_shape_value_info = helper.find_value_by_name(g, real_input_name) - prev_shape_value_info = helper.find_input_by_name(g, real_input_name) if prev_shape_value_info is None else prev_shape_value_info + prev_shape_value_info = ( + helper.find_input_by_name(g, real_input_name) + if prev_shape_value_info is None + else prev_shape_value_info + ) if prev_shape_value_info is None: continue - _ , previous_node_output_shape = helper.find_size_shape_from_value(prev_shape_value_info) + _, previous_node_output_shape = helper.find_size_shape_from_value( + prev_shape_value_info + ) bias_shape, bias_data = helper.constant_to_list(constant_node) # channel dimension - c_dim = previous_node_output_shape[1] if len(previous_node_output_shape) > 1 else 1 + c_dim = ( + previous_node_output_shape[1] + if len(previous_node_output_shape) > 1 + else 1 + ) # only allow channelwise sub or const sub if bias_shape == [1, c_dim, 1, 1]: @@ -903,21 +1035,31 @@ def replace_sub_to_bn(g): scale = ones bias *= -1 bn_name = sub_op_node.output[0] - mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) - variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) - scale_value_node = helper.list_to_constant(bn_name+'_mul', np.array(scale).shape, scale) - new_add_value_node = helper.list_to_constant(bn_name+'_add', np.array(bias).shape, bias) + mean_value_node = helper.list_to_constant( + bn_name + "_mean", np.array(zeros).shape, zeros + ) + variance_value_node = helper.list_to_constant( + bn_name + "_var", np.array(ones).shape, ones + ) + scale_value_node = helper.list_to_constant( + bn_name + "_mul", np.array(scale).shape, scale + ) + new_add_value_node = helper.list_to_constant( + bn_name + "_add", np.array(bias).shape, bias + ) bn_node = onnx.helper.make_node( - 'BatchNormalization', - [real_input_name, - scale_value_node.output[0], - new_add_value_node.output[0], - mean_value_node.output[0], - variance_value_node.output[0]], + "BatchNormalization", + [ + real_input_name, + scale_value_node.output[0], + new_add_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0], + ], [sub_op_node.output[0]], name=bn_name, - epsilon=0.00000001 + epsilon=0.00000001, ) add_val_info = helper.find_value_by_name(g, constant_node.output[0]) @@ -937,19 +1079,20 @@ def replace_sub_to_bn(g): topological_sort(g) + def replace_sub_with_bn_and_add(g): """Replace two input Sub node with BN and Add: A - B = A + (-1) * B :param g: input graph. :return: """ for node in g.node: - if node.op_type != 'Sub': + if node.op_type != "Sub": continue sub_op_node = node # only support one input node - if len(sub_op_node.input) != 2: # OP node and value node + if len(sub_op_node.input) != 2: # OP node and value node continue # Check the input type @@ -957,9 +1100,11 @@ def replace_sub_with_bn_and_add(g): input_2nd_name = sub_op_node.input[1] input_1st_node = helper.find_node_by_output_name(g, input_1st_name) input_2nd_node = helper.find_node_by_output_name(g, input_2nd_name) - if input_1st_node is not None and input_1st_node.op_type == 'Constant': + if input_1st_node is not None and input_1st_node.op_type == "Constant": continue - elif input_2nd_node is not None and input_2nd_node.op_type == 'Constant': + elif ( + input_2nd_node is not None and input_2nd_node.op_type == "Constant" + ): continue # Get shapes @@ -970,30 +1115,45 @@ def replace_sub_with_bn_and_add(g): continue # Get channel dimension - _ , input_2nd_shape = helper.find_size_shape_from_value(input_2nd_value_info) + _, input_2nd_shape = helper.find_size_shape_from_value( + input_2nd_value_info + ) if len(input_2nd_shape) < 2: - helper.logger.debug(f"{sub_op_node.name} cannot be replaced due to the input shape.") + helper.logger.debug( + f"{sub_op_node.name} cannot be replaced " + "due to the input shape." + ) c_dim = input_2nd_shape[1] # Create * -1 bn node. ones = [1.0] * c_dim zeros = [0.0] * c_dim scale = [-1.0] * c_dim - bn_name = input_2nd_name + '_neg_for_' + node.name - mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) - variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) - scale_value_node = helper.list_to_constant(bn_name+'_mul', np.array(scale).shape, scale) - bias_value_node = helper.list_to_constant(bn_name+'_add', np.array(zeros).shape, zeros) + bn_name = input_2nd_name + "_neg_for_" + node.name + mean_value_node = helper.list_to_constant( + bn_name + "_mean", np.array(zeros).shape, zeros + ) + variance_value_node = helper.list_to_constant( + bn_name + "_var", np.array(ones).shape, ones + ) + scale_value_node = helper.list_to_constant( + bn_name + "_mul", np.array(scale).shape, scale + ) + bias_value_node = helper.list_to_constant( + bn_name + "_add", np.array(zeros).shape, zeros + ) bn_node = onnx.helper.make_node( - 'BatchNormalization', - [input_2nd_name, - scale_value_node.output[0], - bias_value_node.output[0], - mean_value_node.output[0], - variance_value_node.output[0]], + "BatchNormalization", + [ + input_2nd_name, + scale_value_node.output[0], + bias_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0], + ], [bn_name], name=bn_name, - epsilon=0.00000001 + epsilon=0.00000001, ) # Change sub to add @@ -1001,29 +1161,42 @@ def replace_sub_with_bn_and_add(g): # Replace add input modhelper.replace_node_input(sub_op_node, input_2nd_name, bn_name) - g.node.extend([scale_value_node, bias_value_node, mean_value_node, variance_value_node, bn_node]) + g.node.extend( + [ + scale_value_node, + bias_value_node, + mean_value_node, + variance_value_node, + bn_node, + ] + ) topological_sort(g) + def replace_Sum_with_Adds(g): node_to_del = [] for node in g.node: # Check for sum - if node.op_type != 'Sum': + if node.op_type != "Sum": continue # Check for input number if len(node.input) == 1: # If input number is 1, delete the sum node. - following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + following_nodes = helper.find_following_nodes_by_input_value_name( + g, node.output[0] + ) for following_node in following_nodes: - modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + modhelper.replace_node_input( + following_node, node.output[0], node.input[0] + ) node_to_del.append(node) if helper.find_value_by_name(node.output[0]) is not None: g.value_info.remove(helper.find_value_by_name(node.output[0])) elif len(node.input) == 2: # If input number is 2, replace it with add. - node.op_type = 'Add' + node.op_type = "Add" continue elif len(node.input) > 2: # If input number is larger than 2, replace it with n-1 add. @@ -1032,23 +1205,29 @@ def replace_Sum_with_Adds(g): first_node = onnx.helper.make_node( "Add", [node.input[0], node.input[1]], - [node.output[0] + '_replacement_1'], - name=node.name + '_replacement_1' + [node.output[0] + "_replacement_1"], + name=node.name + "_replacement_1", ) # Last node has the same output as the original sum node last_node = onnx.helper.make_node( "Add", - [node.output[0] + '_replacement_' + str(input_count - 2), node.input[input_count - 1]], + [ + node.output[0] + "_replacement_" + str(input_count - 2), + node.input[input_count - 1], + ], [node.output[0]], - name=node.name + name=node.name, ) g.node.extend([first_node, last_node]) for i in range(2, input_count - 1): new_node = onnx.helper.make_node( "Add", - [node.output[0] + '_replacement_' + str(i - 1), node.input[i]], - [node.output[0] + '_replacement_' + str(i)], - name=node.name + '_replacement_' + str(i) + [ + node.output[0] + "_replacement_" + str(i - 1), + node.input[i], + ], + [node.output[0] + "_replacement_" + str(i)], + name=node.name + "_replacement_" + str(i), ) g.node.extend([new_node]) node_to_del.append(node) @@ -1063,14 +1242,16 @@ def replace_Sum_with_Adds(g): def replace_constant_input_concat_with_pad(g): - """If single input is concating with constant node of same number. Replace it with pad. Currently only support 2-3 inputs. + """ + If single input is concating with constant node of same number. + Replace it with pad. Currently only support 2-3 inputs. :param g: input graph. :return: """ node_to_del = [] for node in g.node: # Check for Concat node - if node.op_type != 'Concat': + if node.op_type != "Concat": continue # Check concat node input @@ -1080,16 +1261,22 @@ def replace_constant_input_concat_with_pad(g): if len(node.input) == 2: input_1st_node = helper.find_node_by_output_name(g, node.input[0]) input_2nd_node = helper.find_node_by_output_name(g, node.input[1]) - if input_1st_node is not None and input_1st_node.op_type == 'Constant': - mode = 'left' + if ( + input_1st_node is not None + and input_1st_node.op_type == "Constant" + ): + mode = "left" constant_value = helper.constant_to_numpy(input_1st_node) real_input_name = node.input[1] value = constant_value.flatten()[0] # Check if the values are all the same. if np.any(constant_value - value): continue - elif input_2nd_node is not None and input_2nd_node.op_type == 'Constant': - mode = 'right' + elif ( + input_2nd_node is not None + and input_2nd_node.op_type == "Constant" + ): + mode = "right" constant_value = helper.constant_to_numpy(input_2nd_node) real_input_name = node.input[0] value = constant_value.flatten()[0] @@ -1100,14 +1287,19 @@ def replace_constant_input_concat_with_pad(g): # No constant input case continue elif len(node.input) == 3: - # For 3 inputs concat node, the 1st and the 3rd input should be constant with the same value. + # For 3 inputs concat node, the 1st and the 3rd input should be + # constant with the same value. input_1st_node = helper.find_node_by_output_name(g, node.input[0]) input_2nd_node = helper.find_node_by_output_name(g, node.input[1]) input_3rd_node = helper.find_node_by_output_name(g, node.input[2]) - if input_1st_node is None or input_1st_node.op_type != 'Constant' or \ - input_3rd_node is None or input_3rd_node.op_type != 'Constant': + if ( + input_1st_node is None + or input_1st_node.op_type != "Constant" + or input_3rd_node is None + or input_3rd_node.op_type != "Constant" + ): continue - mode = 'both' + mode = "both" real_input_name = node.input[1] input_1st_value = helper.constant_to_numpy(input_1st_node) input_3rd_value = helper.constant_to_numpy(input_3rd_node) @@ -1124,41 +1316,46 @@ def replace_constant_input_concat_with_pad(g): input_value_info = helper.find_value_by_name(g, real_input_name) input_shape = helper.get_shape_from_value_info(input_value_info) pads = [0] * (len(input_shape) * 2) - axis = helper.get_var_attribute_by_name(node, 'axis', 'int') + axis = helper.get_var_attribute_by_name(node, "axis", "int") if axis < 0: axis = len(input_shape) - axis - if mode == 'left': + if mode == "left": left_value_info = helper.find_value_by_name(g, node.input[0]) - left_input_shape = helper.get_shape_from_value_info(left_value_info) + left_input_shape = helper.get_shape_from_value_info( + left_value_info + ) pads[axis] = left_input_shape[axis] - elif mode == 'right': + elif mode == "right": right_value_info = helper.find_value_by_name(g, node.input[1]) - right_input_shape = helper.get_shape_from_value_info(right_value_info) + right_input_shape = helper.get_shape_from_value_info( + right_value_info + ) pads[axis + len(input_shape)] = right_input_shape[axis] else: # mode shoule be both left_value_info = helper.find_value_by_name(g, node.input[0]) - left_input_shape = helper.get_shape_from_value_info(left_value_info) + left_input_shape = helper.get_shape_from_value_info( + left_value_info + ) pads[axis] = left_input_shape[axis] right_value_info = helper.find_value_by_name(g, node.input[2]) - right_input_shape = helper.get_shape_from_value_info(right_value_info) + right_input_shape = helper.get_shape_from_value_info( + right_value_info + ) pads[axis + len(input_shape)] = right_input_shape[axis] pads_node = helper.list_to_constant( - node.name + '_pads', - (len(pads), ), - pads + node.name + "_pads", (len(pads),), pads ) constant_value_node = helper.scaler_to_constant( - node.name + '_constant_value', - value + node.name + "_constant_value", value ) # Create new Pad node new_pad_node = onnx.helper.make_node( "Pad", [real_input_name, pads_node.name, constant_value_node.name], [node.output[0]], - name = node.name, - mode = "constant" + name=node.name, + mode="constant", ) # Replace node_to_del.append(node) @@ -1168,4 +1365,3 @@ def replace_constant_input_concat_with_pad(g): g.node.remove(node_to_del.pop()) topological_sort(g) - diff --git a/tools/optimizer_scripts/tools/special.py b/tools/optimizer_scripts/tools/special.py index 38de4f5..275f8c5 100644 --- a/tools/optimizer_scripts/tools/special.py +++ b/tools/optimizer_scripts/tools/special.py @@ -1,11 +1,10 @@ """Special operations on model. """ -import logging import onnx.helper import numpy as np from . import helper from . import other -from . import modhelper + def change_first_conv_from_bgr_to_rgb(m): """For input channel format BGR model, use this function to change the first @@ -16,12 +15,14 @@ def change_first_conv_from_bgr_to_rgb(m): # Check for first node. g = m.graph input_name = g.input[0].name - first_nodes = helper.find_following_nodes_by_input_value_name(g, input_name) + first_nodes = helper.find_following_nodes_by_input_value_name( + g, input_name + ) if len(first_nodes) > 1: return False first_node = first_nodes[0] # Now we have the first node. Check this first node. - if first_node.op_type != 'Conv': + if first_node.op_type != "Conv": return False weight_value = helper.find_value_by_name(g, first_node.input[1]) weight_shape = helper.get_shape_from_value_info(weight_value) @@ -41,10 +42,12 @@ def change_first_conv_from_bgr_to_rgb(m): other.topological_sort(g) return True + def change_input_from_bgr_to_rgb(m): - """For input channel format BGR model, use this function to modify the model - to accepct RGB image.If the first node is a non-group Conv. Modify weight to - adapt the input into RGB. Otherwise create a new node. + """ + For input channel format BGR model, use this function to modify the model + to accepct RGB image.If the first node is a non-group Conv. + Modify weight to adapt the input into RGB. Otherwise create a new node. :param m: the model proto """ @@ -61,34 +64,33 @@ def change_input_from_bgr_to_rgb(m): return # Otherwise, create a special conv node and replace the input # Construct weight - weight_np = np.zeros((3, 3, 3, 3)).astype('float32') + weight_np = np.zeros((3, 3, 3, 3)).astype("float32") weight_np[0, 2, 1, 1] = 1.0 weight_np[1, 1, 1, 1] = 1.0 weight_np[2, 0, 1, 1] = 1.0 new_weight = helper.numpy_to_constant("bgr_shuffle_weight", weight_np) # Construct Conv new_conv = onnx.helper.make_node( - 'Conv', - ['rgb_input', "bgr_shuffle_weight"], + "Conv", + ["rgb_input", "bgr_shuffle_weight"], [g.input[0].name], - name='bgr_shuffle', + name="bgr_shuffle", dilations=[1, 1], kernel_shape=[3, 3], pads=[1, 1, 1, 1], - strides=[1, 1] + strides=[1, 1], ) # Connect the graph old_input_value = g.input.pop() new_input_value = onnx.helper.make_tensor_value_info( - 'rgb_input', - old_input_value.type.tensor_type.elem_type, - input_shape + "rgb_input", old_input_value.type.tensor_type.elem_type, input_shape ) g.input.extend([new_input_value]) g.node.extend([new_weight, new_conv]) # topological sort other.topological_sort(g) + def add_0_5_to_normalized_input(m): """For normalized input between -0.5 ~ 0.5, add 0.5 to the input to keep it between 0 ~ 1. @@ -105,41 +107,37 @@ def add_0_5_to_normalized_input(m): return # Construct weight ch = input_shape[1] - weight_np = np.zeros((ch, ch, 3, 3)).astype('float32') + weight_np = np.zeros((ch, ch, 3, 3)).astype("float32") for i in range(ch): weight_np[i, i, 1, 1] = 1.0 new_weight = helper.numpy_to_constant("input_norm_weight", weight_np) # Construct bias - bias_np = np.array([0.5] * ch).astype('float32') + bias_np = np.array([0.5] * ch).astype("float32") new_bias = helper.numpy_to_constant("input_norm_bias", bias_np) # Construct Conv new_conv = onnx.helper.make_node( - 'Conv', - ['origin_input', "input_norm_weight", "input_norm_bias"], + "Conv", + ["origin_input", "input_norm_weight", "input_norm_bias"], [g.input[0].name], - name='input_norm', + name="input_norm", dilations=[1, 1], kernel_shape=[3, 3], pads=[1, 1, 1, 1], - strides=[1, 1] + strides=[1, 1], ) # Construct value_infos old_input_value = g.input.pop() weight_value = onnx.helper.make_tensor_value_info( - 'input_norm_weight', + "input_norm_weight", old_input_value.type.tensor_type.elem_type, - [3, 3, 3, 3] + [3, 3, 3, 3], ) bias_value = onnx.helper.make_tensor_value_info( - 'input_norm_bias', - old_input_value.type.tensor_type.elem_type, - [3] + "input_norm_bias", old_input_value.type.tensor_type.elem_type, [3] ) # Connect the graph new_input_value = onnx.helper.make_tensor_value_info( - 'origin_input', - old_input_value.type.tensor_type.elem_type, - input_shape + "origin_input", old_input_value.type.tensor_type.elem_type, input_shape ) g.input.extend([new_input_value]) g.node.extend([new_weight, new_bias, new_conv]) @@ -147,9 +145,9 @@ def add_0_5_to_normalized_input(m): # topological sort other.topological_sort(g) + def add_rgb2yynn_node(m): - """Add a conv layer which can convert rgb to yynn input. - """ + """Add a conv layer which can convert rgb to yynn input.""" g = m.graph if len(g.input) > 1: print("This model has multiple inputs. Cannot change to rgb input.") @@ -159,37 +157,32 @@ def add_rgb2yynn_node(m): print("The input shape is not BCHW. Cannot normalize input.") return # Construct weight - ch = input_shape[1] - weight_np = np.zeros((3, 3, 4, 4)).astype('float32') - weight_np[1, 1, :3, :2] = np.array([[[[0.299], - [0.587], - [0.114]]]]) - weight_np[1, 1, 3, 2:] = 1. + weight_np = np.zeros((3, 3, 4, 4)).astype("float32") + weight_np[1, 1, :3, :2] = np.array([[[[0.299], [0.587], [0.114]]]]) + weight_np[1, 1, 3, 2:] = 1.0 weight_np = np.transpose(weight_np, (3, 2, 0, 1)) new_weight = helper.numpy_to_constant("input_rgb2yynn_weight", weight_np) # Construct conv node new_conv = onnx.helper.make_node( - 'Conv', - ['new_input', "input_rgb2yynn_weight"], + "Conv", + ["new_input", "input_rgb2yynn_weight"], [g.input[0].name], - name='input_rgba2yynn', + name="input_rgba2yynn", dilations=[1, 1], kernel_shape=[3, 3], pads=[1, 1, 1, 1], - strides=[1, 1] + strides=[1, 1], ) # Construct value_infos old_input_value = g.input.pop() weight_value = onnx.helper.make_tensor_value_info( - 'input_rgb2yynn_weight', + "input_rgb2yynn_weight", old_input_value.type.tensor_type.elem_type, - [4, 4, 3, 3] + [4, 4, 3, 3], ) # Connect the graph new_input_value = onnx.helper.make_tensor_value_info( - 'new_input', - old_input_value.type.tensor_type.elem_type, - input_shape + "new_input", old_input_value.type.tensor_type.elem_type, input_shape ) g.input.extend([new_input_value]) g.node.extend([new_weight, new_conv]) @@ -197,6 +190,7 @@ def add_rgb2yynn_node(m): # topological sort other.topological_sort(g) + def swap_MatMul_inputs(g, original_matmul_node): # Create Transpose nodes input_a_value = helper.find_value_by_name(g, original_matmul_node.input[0]) @@ -206,11 +200,12 @@ def swap_MatMul_inputs(g, original_matmul_node): else: perm = [0, 2, 1] new_input_b_node = onnx.helper.make_node( - 'Transpose', - inputs = [input_a_value.name], - outputs = [input_a_value.name + '_transposed'], - name = f"{input_a_value.name}_transposed_for_{original_matmul_node.name}", - perm = perm + "Transpose", + inputs=[input_a_value.name], + outputs=[input_a_value.name + "_transposed"], + name=f"{input_a_value.name}_transposed_for_" + f"{original_matmul_node.name}", + perm=perm, ) input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1]) input_b_shape = helper.get_shape_from_value_info(input_b_value) @@ -219,18 +214,19 @@ def swap_MatMul_inputs(g, original_matmul_node): else: perm = [0, 1, 3, 2] new_input_a_node = onnx.helper.make_node( - 'Transpose', - inputs = [input_b_value.name], - outputs = [input_b_value.name + '_transposed'], - name = f'{input_b_value.name}_transposed_for_{original_matmul_node.name}', - perm = perm + "Transpose", + inputs=[input_b_value.name], + outputs=[input_b_value.name + "_transposed"], + name=f"{input_b_value.name}_transposed_for_" + f"{original_matmul_node.name}", + perm=perm, ) # Create new MatMul node new_matmul_node = onnx.helper.make_node( - 'MatMul', - inputs = [new_input_a_node.output[0], new_input_b_node.output[0]], - outputs = [original_matmul_node.output[0] + '_transposed'], - name = original_matmul_node.name + '_transposed' + "MatMul", + inputs=[new_input_a_node.output[0], new_input_b_node.output[0]], + outputs=[original_matmul_node.output[0] + "_transposed"], + name=original_matmul_node.name + "_transposed", ) # Create final Transpose node output_value = helper.find_value_by_name(g, original_matmul_node.output[0]) @@ -240,17 +236,25 @@ def swap_MatMul_inputs(g, original_matmul_node): else: perm = [0, 1, 3, 2] new_final_transpose_node = onnx.helper.make_node( - 'Transpose', - inputs = [new_matmul_node.output[0]], - outputs = [original_matmul_node.output[0]], - name = original_matmul_node.name + '_final_transpose', - perm = perm + "Transpose", + inputs=[new_matmul_node.output[0]], + outputs=[original_matmul_node.output[0]], + name=original_matmul_node.name + "_final_transpose", + perm=perm, ) # Add new nodes - g.node.extend([new_input_a_node, new_input_b_node, new_matmul_node, new_final_transpose_node]) + g.node.extend( + [ + new_input_a_node, + new_input_b_node, + new_matmul_node, + new_final_transpose_node, + ] + ) # Delete original nodes g.node.remove(original_matmul_node) + def split_MatMul_batch_then_concat(g, original_matmul_node): new_nodes = [] final_concat_inputs = [] @@ -265,49 +269,85 @@ def split_MatMul_batch_then_concat(g, original_matmul_node): batch_count = input_a_shape[1] for i in range(batch_count): # Create Split nodes for input A - starts_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_starts", (1, ), [i]) - ends_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_ends", (1, ), [i+1]) - axes_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_axes", (1, ), [len(input_a_shape) - 3]) + starts_node = helper.list_to_constant( + f"{input_a_value.name}_sliced_{i}_starts", (1,), [i] + ) + ends_node = helper.list_to_constant( + f"{input_a_value.name}_sliced_{i}_ends", (1,), [i + 1] + ) + axes_node = helper.list_to_constant( + f"{input_a_value.name}_sliced_{i}_axes", + (1,), + [len(input_a_shape) - 3], + ) new_sliced_a_node = onnx.helper.make_node( - 'Slice', - inputs = [input_a_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]], - outputs = [f"{input_a_value.name}_sliced_{i}"], - name = f"{input_a_value.name}_sliced_{i}_for_{original_matmul_node.name}" + "Slice", + inputs=[ + input_a_value.name, + starts_node.output[0], + ends_node.output[0], + axes_node.output[0], + ], + outputs=[f"{input_a_value.name}_sliced_{i}"], + name=f"{input_a_value.name}_sliced_{i}_for_" + f"{original_matmul_node.name}", + ) + new_nodes.extend( + [starts_node, ends_node, axes_node, new_sliced_a_node] ) - new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_a_node]) # Create Split nodes for input B - starts_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_starts", (1, ), [i]) - ends_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_ends", (1, ), [i+1]) - axes_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_axes", (1, ), [len(input_b_shape) - 3]) - new_sliced_b_node = onnx.helper.make_node( - 'Slice', - inputs = [input_b_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]], - outputs = [f"{input_b_value.name}_sliced_{i}"], - name = f"{input_b_value.name}_sliced_{i}_for_{original_matmul_node.name}" + starts_node = helper.list_to_constant( + f"{input_b_value.name}_sliced_{i}_starts", (1,), [i] + ) + ends_node = helper.list_to_constant( + f"{input_b_value.name}_sliced_{i}_ends", (1,), [i + 1] + ) + axes_node = helper.list_to_constant( + f"{input_b_value.name}_sliced_{i}_axes", + (1,), + [len(input_b_shape) - 3], + ) + new_sliced_b_node = onnx.helper.make_node( + "Slice", + inputs=[ + input_b_value.name, + starts_node.output[0], + ends_node.output[0], + axes_node.output[0], + ], + outputs=[f"{input_b_value.name}_sliced_{i}"], + name=f"{input_b_value.name}_sliced_{i}_for_" + f"{original_matmul_node.name}", + ) + new_nodes.extend( + [starts_node, ends_node, axes_node, new_sliced_b_node] ) - new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_b_node]) # Create MatMul nodes new_matmul_node = onnx.helper.make_node( - 'MatMul', - inputs = [new_sliced_a_node.output[0], new_sliced_b_node.output[0]], - outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"], - name = f"{original_matmul_node.name}_sliced_{i}" + "MatMul", + inputs=[new_sliced_a_node.output[0], new_sliced_b_node.output[0]], + outputs=[f"{original_matmul_node.output[0]}_sliced_{i}"], + name=f"{original_matmul_node.name}_sliced_{i}", ) new_nodes.append(new_matmul_node) final_concat_inputs.append(new_matmul_node.output[0]) # Create Concat nodes output_value = helper.find_value_by_name(g, original_matmul_node.output[0]) if output_value is None: - output_value = helper.find_output_by_name(g, original_matmul_node.output[0]) + output_value = helper.find_output_by_name( + g, original_matmul_node.output[0] + ) if output_value is None: - helper.logger.error(f"Cannot find value_info for {original_matmul_node.output[0]}") + helper.logger.error( + f"Cannot find value_info for {original_matmul_node.output[0]}" + ) output_shape = helper.get_shape_from_value_info(output_value) new_concat_node = onnx.helper.make_node( "Concat", - inputs = final_concat_inputs, - outputs = [original_matmul_node.output[0]], - name = f"{original_matmul_node.name}_final_concat", - axis = len(output_shape) - 3 + inputs=final_concat_inputs, + outputs=[original_matmul_node.output[0]], + name=f"{original_matmul_node.name}_final_concat", + axis=len(output_shape) - 3, ) new_nodes.append(new_concat_node) # Add new nodes @@ -320,7 +360,9 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node): new_nodes = [] final_concat_inputs = [] # Get the batch count - input_b_node = helper.find_node_by_output_name(g, original_matmul_node.input[1]) + input_b_node = helper.find_node_by_output_name( + g, original_matmul_node.input[1] + ) input_b_np = helper.constant_to_numpy(input_b_node) if len(input_b_np.shape) == 3: batch_count = input_b_np.shape[0] @@ -329,17 +371,19 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node): for i in range(batch_count): # Create new constant node if len(input_b_np.shape) == 3: - new_np = input_b_np[i:i+1, ...] + new_np = input_b_np[i:i + 1, ...] else: - new_np = input_b_np[:, i:i+1, ...] - new_weight = helper.numpy_to_constant(f"{input_b_node.name}_sliced_{i}", new_np) + new_np = input_b_np[:, i:i + 1, ...] + new_weight = helper.numpy_to_constant( + f"{input_b_node.name}_sliced_{i}", new_np + ) new_nodes.append(new_weight) # Create MatMul nodes new_matmul_node = onnx.helper.make_node( - 'MatMul', - inputs = [original_matmul_node.input[0], new_weight.output[0]], - outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"], - name = f"{original_matmul_node.name}_sliced_{i}" + "MatMul", + inputs=[original_matmul_node.input[0], new_weight.output[0]], + outputs=[f"{original_matmul_node.output[0]}_sliced_{i}"], + name=f"{original_matmul_node.name}_sliced_{i}", ) new_nodes.append(new_matmul_node) final_concat_inputs.append(new_matmul_node.output[0]) @@ -348,10 +392,10 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node): output_shape = helper.get_shape_from_value_info(output_value) new_concat_node = onnx.helper.make_node( "Concat", - inputs = final_concat_inputs, - outputs = [original_matmul_node.output[0]], - name = f"{original_matmul_node.name}_final_concat", - axis = len(output_shape) - 3 + inputs=final_concat_inputs, + outputs=[original_matmul_node.output[0]], + name=f"{original_matmul_node.name}_final_concat", + axis=len(output_shape) - 3, ) new_nodes.append(new_concat_node) # Add new nodes @@ -367,7 +411,7 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node): def special_MatMul_process(g): for node in g.node: - if node.op_type != 'MatMul': + if node.op_type != "MatMul": continue input_a_name = node.input[0] input_a_value = helper.find_value_by_name(g, input_a_name) @@ -383,19 +427,30 @@ def special_MatMul_process(g): continue # Too many dimensions or too few dimensions. Not supported. Skip if len(input_a_shape) > 4 or len(input_b_shape) > 4: - helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have too many dimensions.") + helper.logger.warning( + f"Cannot optimize MatMul {node.name}: " + "inputs have too many dimensions." + ) continue if len(input_a_shape) < 2 or len(input_b_shape) < 2: - helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have two few dimensions.") + helper.logger.warning( + f"Cannot optimize MatMul {node.name}: " + "inputs have two few dimensions." + ) continue - # For 4 dimension, check the first dimension (should be 1) and treated as 3 dimensions. + # For 4 dimension, check the first dimension (should be 1) + # and treated as 3 dimensions. extra_dim = None if len(input_a_shape) == 4: extra_dim = input_a_shape[0] input_a_shape = input_a_shape[1:] if len(input_b_shape) == 4: if input_b_shape[0] != extra_dim: - helper.logger.warning(f"Cannot optimize MatMul {node.name}: input dimension batch sizes does not match ({extra_dim} vs {input_b_shape[0]}).") + helper.logger.warning( + f"Cannot optimize MatMul {node.name}: " + "input dimension batch sizes does not match " + f"({extra_dim} vs {input_b_shape[0]})." + ) continue input_b_shape = input_b_shape[1:] # Check input B dimension @@ -404,20 +459,31 @@ def special_MatMul_process(g): continue # If B is B x W x V, but B is a constant. input_b_node = helper.find_node_by_output_name(g, input_b_name) - if input_b_node is not None and input_b_node.op_type == 'Constant': + if input_b_node is not None and input_b_node.op_type == "Constant": # Constant input - helper.logger.debug(f"Optimizing MatMul node {node.name}: split constant input.") + helper.logger.debug( + f"Optimizing MatMul node {node.name}: split constant input." + ) split_MatMul_Constant_input_then_concat(g, node) # If B is B x W x V and A is 1 x H x W, do the swap. - elif len(input_a_shape) == 2 or (input_a_shape[0] == 1 and (extra_dim is None or extra_dim == 1)): - helper.logger.debug(f"Optimizing MatMul node {node.name}: swap input.") + elif len(input_a_shape) == 2 or ( + input_a_shape[0] == 1 and (extra_dim is None or extra_dim == 1) + ): + helper.logger.debug( + f"Optimizing MatMul node {node.name}: swap input." + ) swap_MatMul_inputs(g, node) # If B is B x W x V and A is B x H x W, do the split. elif input_b_shape[0] == input_a_shape[0]: - helper.logger.debug(f"Optimizing MatMul node {node.name}: split input batch.") + helper.logger.debug( + f"Optimizing MatMul node {node.name}: split input batch." + ) split_MatMul_batch_then_concat(g, node) # Other cases are not supported: If B is B x W x V but A is X x H x W. else: - helper.logger.warning(f"Cannot optimize MatMul {node.name}: unknown reason. Might be shape mismatch.") + helper.logger.warning( + f"Cannot optimize MatMul {node.name}: " + "unknown reason. Might be shape mismatch." + ) continue - other.topological_sort(g) \ No newline at end of file + other.topological_sort(g)