diff --git a/tools/optimizer_scripts/.clang-format b/tools/optimizer_scripts/.clang-format new file mode 100644 index 0000000..2593ef5 --- /dev/null +++ b/tools/optimizer_scripts/.clang-format @@ -0,0 +1 @@ +BasedOnStyle: Google \ No newline at end of file diff --git a/tools/optimizer_scripts/.gitignore b/tools/optimizer_scripts/.gitignore new file mode 100644 index 0000000..991fd07 --- /dev/null +++ b/tools/optimizer_scripts/.gitignore @@ -0,0 +1,7 @@ +__pycache__ +.vscode +*.pyc +models.py +temp.py +.ssh/ +docker/test_models/ \ No newline at end of file diff --git a/tools/optimizer_scripts/README.md b/tools/optimizer_scripts/README.md new file mode 100644 index 0000000..cac99c5 --- /dev/null +++ b/tools/optimizer_scripts/README.md @@ -0,0 +1,189 @@ +# Converter Scripts + +[![pipeline status](http://192.168.200.1:8088/jiyuan/converter_scripts/badges/master/pipeline.svg)](http://192.168.200.1:8088/jiyuan/converter_scripts/commits/master) + +This project collects various optimization scripts and converter scritps for +Kneron toolchain. This collection does not include the Keras to ONNX converter +and the Caffe to ONNX converter. They are in seperate projects. + +**The scripts not listed below are used as libraries and cannot be used +directly.** + +## onnx2onnx.py + +### 1.1. Description + +General optimizations on ONNX model for Kneron toolchain. Though Kneron +toolchains are designed to take ONNX models as input, they have some +restrictions on the models (e.g. inferenced shapes for all value_info). Thus, we +have this tool to do some general optimization and conversion on ONNX models. +**Notice that this script should take an valid ONNX model as input.** It cannot +turn an invalid ONNX model into a valid one. + +### 1.2. Basic Usage + +```bash +python onnx2onnx.py input.onnx -o output.onnx +``` + +### 1.3. Optimizations Included + +* Fusing BN into Conv. +* Fusing BN into Gemm. +* Fusing consecutive Gemm. +* Eliminating Identify layers and Dropout layers. +* Eliminating last shape changing nodes. +* Replacing initializers into Constant nodes. +* Replacing global AveragePool with GAP. +* Replacing Squeeze and Unsqueeze with Reshape. +* Replacing 1x1 depthwise with BN. +* Inferencing Upsample shapes. +* Transposing B in Gemm. + +## pytorch2onnx.py + +### 2.1. Description + +Convert Pytorch models or Pytorch generated ONNX models into Kneron toolchain +compatible ONNX files. This script include most of the optimizations in +`onnx2onnx.py`. It also includes some optimizations for Pytorch model only. + +### 2.2. Basic Usage + +```bash +# Take Pytorch model name, input channel number, input height, input width +python pytorch2onnx.py input.pth output.onnx --input-size 3 224 224 +# Or take Pytorch exported ONNX. +python pytorch2onnx.py input.onnx output.onnx +``` + +### 2.3. Optimizations Included + +* Adding name to nodes. +* Unsqueeze nodes constant folding. +* Reshape nodes constant folding. +* Optimizations in `onnx2onnx.py`. + +## editor.py + +### 3.1. Description + +This is an simple ONNX editor which achieves the following functions: + +* Add nop BN or Conv nodes. +* Delete specific nodes or inputs. +* Cut the graph from certain node (Delete all the nodes following the node). +* Reshape inputs and outputs + +### 3.2 Usage + +``` +usage: editor.py [-h] [-c CUT_NODE [CUT_NODE ...]] + [--cut-type CUT_TYPE [CUT_TYPE ...]] + [-d DELETE_NODE [DELETE_NODE ...]] + [--delete-input DELETE_INPUT [DELETE_INPUT ...]] + [-i INPUT_CHANGE [INPUT_CHANGE ...]] + [-o OUTPUT_CHANGE [OUTPUT_CHANGE ...]] + [--add-conv ADD_CONV [ADD_CONV ...]] + [--add-bn ADD_BN [ADD_BN ...]] + in_file out_file + +Edit an ONNX model. The processing sequense is 'delete nodes/values' -> 'add +nodes' -> 'change shapes'. Cutting cannot be done with other operations +together + +positional arguments: + in_file input ONNX FILE + out_file ouput ONNX FILE + +optional arguments: + -h, --help show this help message and exit + -c CUT_NODE [CUT_NODE ...], --cut CUT_NODE [CUT_NODE ...] + remove nodes from the given nodes(inclusive) + --cut-type CUT_TYPE [CUT_TYPE ...] + remove nodes by type from the given nodes(inclusive) + -d DELETE_NODE [DELETE_NODE ...], --delete DELETE_NODE [DELETE_NODE ...] + delete nodes by names and only those nodes + --delete-input DELETE_INPUT [DELETE_INPUT ...] + delete inputs by names + -i INPUT_CHANGE [INPUT_CHANGE ...], --input INPUT_CHANGE [INPUT_CHANGE ...] + change input shape (e.g. -i 'input_0 1 3 224 224') + -o OUTPUT_CHANGE [OUTPUT_CHANGE ...], --output OUTPUT_CHANGE [OUTPUT_CHANGE ...] + change output shape (e.g. -o 'input_0 1 3 224 224') + --add-conv ADD_CONV [ADD_CONV ...] + add nop conv using specific input + --add-bn ADD_BN [ADD_BN ...] + add nop bn using specific input +``` + +### 3.3. Example + +Here is an example of when and how to use the editor.py. + +```bash +# In the `res` folder, there is a vdsr model from tensorflow. +# We need to convert this model firstly. +./tf2onnx.sh res/vdsr_41_20layer_1.pb res/tmp.onnx images:0 output:0 +# This onnx file seems valid. But, it's channel last for the input and output. +# It is using Traspose to convert to channel first, affacting the performance. +# Thus, here we use the editor to delete these Transpose and reset the shapes. +python editor.py debug.onnx new.onnx -d Conv2D__6 Conv2D_19__84 -i 'images:0 1 3 41 41' -o 'output:0 1 3 41 41' +# Now, it has no Transpose and take channel first inputs directly. +``` + +## test_models_opt.py + +### 4.1. Description +Compare all original and optimized onnx models under a specified directory. +Using different endings to locate original and optimized model paths. Apply +onnxruntime inference to the models, and compare the results from original +and optimized models. Calculate basic statistics and store to a csv file. + +### 4.2. Usage + +```bash +python DIR ending1 ending2 csv_out_file -p=Y/N + +# csv_out_file is file path for the stats data. +# -p --plot is the plot option, if Y, stats plots will be generated. +``` + +### 4.3. Statistics +* max_rel_diff +* max_abs_diff +* mean_rel_diff +* mean_abs_diff +* std_rel_diff +* std_abs_diff +* acc_with_diff_precision +* percentile + +### 4.4. Plots +* Max Relative Difference Histogram +* Max Absolute Difference Histogram +* Rel_diff Percentiles of Raw and Optimized Models +* Abs_diff Percentiles of Raw and Optimized Models +* Accuracies with Different Precisions + +## tensorflow2onnx.py + +### 5.1. Description +Convert and optimize tensorflow models. If input file is frozen tensorflow .pb model, +convert to onnx model and do the custmized optimization afterwards. If input model is already +onnx model, apply optimization and save optimized model. + +### 5.2 Dependency + +This scripts depends on the tensorflow-onnx project. Please [check and install it](https://github.com/onnx/tensorflow-onnx/tree/r1.5) before using this script. We currently support up to version 1.5.5. For other versions, you may need to try it our yourself. + +### 5.3. Basic Usage +```bash +python tensorflow2onnx.py in_file out_file -t=True/False + +# -t --test, is the option for test mode, if True, shape change after input will not be eliminated. +``` + +### 5.4. Model Save Paths +`in_file` is the input model path, `out_file` specifies output optimized model path. +If input file is `.pb` model, an unoptimized onnx model will be saved to the output directory as well. + diff --git a/tools/optimizer_scripts/consecutive_conv_opt.py b/tools/optimizer_scripts/consecutive_conv_opt.py new file mode 100644 index 0000000..c7d4068 --- /dev/null +++ b/tools/optimizer_scripts/consecutive_conv_opt.py @@ -0,0 +1,59 @@ +import numpy as np +import onnx +import sys + +from tools.other import topological_sort +from tools import helper + +def fuse_bias_in_consecutive_1x1_conv(g): + for second in g.node: + # Find two conv + if second.op_type != 'Conv': + continue + first = helper.find_node_by_output_name(g, second.input[0]) + if first is None or first.op_type != 'Conv': + continue + # Check if the first one has only one folloing node + if len(helper.find_following_nodes_by_input_value_name(g, first.output[0])) != 1: + continue + # If first node has no bias, continue + if len(first.input) == 2: + continue + # Check their kernel size + first_kernel_shape = helper.get_list_attribute_by_name(first, 'kernel_shape', 'int') + second_kernel_shape = helper.get_list_attribute_by_name(second, 'kernel_shape', 'int') + prod = first_kernel_shape[0] * first_kernel_shape[1] * second_kernel_shape[0] * second_kernel_shape[1] + if prod != 1: + continue + print('Found: ', first.name, ' ', second.name) + # Get bias of the nodes + first_bias_node = helper.find_node_by_output_name(g, first.input[2]) + second_weight_node = helper.find_node_by_output_name(g, second.input[1]) + second_bias_node = helper.find_node_by_output_name(g, second.input[2]) + first_bias = helper.constant_to_numpy(first_bias_node) + second_weight = helper.constant_to_numpy(second_weight_node) + second_bias = helper.constant_to_numpy(second_bias_node) + # Calculate the weight for second node + first_bias = np.reshape(first_bias, (1, first_bias.size)) + second_weight = np.reshape(second_weight, (second_weight.shape[0], second_weight.shape[1])) + second_weight = np.transpose(second_weight) + new_second_bias = second_bias + np.matmul(first_bias, second_weight) + new_second_bias = np.reshape(new_second_bias, (new_second_bias.size,)) + # Generate new weight + new_first_bias = np.reshape(first_bias, (first_bias.size, )) + for i in range(new_first_bias.shape[0]): + new_first_bias[i] = 0.0 + new_first_bias_node = helper.numpy_to_constant(first_bias_node.output[0], new_first_bias) + new_second_bias_node = helper.numpy_to_constant(second_bias_node.output[0], new_second_bias) + # Delete old weight and add new weights + g.node.remove(first_bias_node) + g.node.remove(second_bias_node) + g.node.extend([new_first_bias_node, new_second_bias_node]) + topological_sort(g) + +if __name__ == "__main__": + if len(sys.argv) != 3: + exit(1) + m = onnx.load(sys.argv[1]) + fuse_bias_in_consecutive_1x1_conv(m.graph) + onnx.save(m, sys.argv[2]) \ No newline at end of file diff --git a/tools/optimizer_scripts/docker/Dockerfile b/tools/optimizer_scripts/docker/Dockerfile new file mode 100644 index 0000000..bb62f7f --- /dev/null +++ b/tools/optimizer_scripts/docker/Dockerfile @@ -0,0 +1,24 @@ +FROM continuumio/miniconda3:latest +LABEL maintainer="jiyuan@kneron.us" + +# Install python packages +RUN conda update -y conda && \ +conda install -y python=3.6 && \ +conda install -y -c intel caffe && \ +conda install -y -c pytorch pytorch=1.3.1 torchvision=0.4.2 cpuonly && \ +conda install -y -c conda-forge tensorflow=1.5.1 keras=2.2.4 && \ +pip install onnx==1.4.1 onnxruntime==1.1.0 tf2onnx==1.5.4 && \ +ln -s /opt/conda/lib/libgflags.so.2.2.2 /opt/conda/lib/libgflags.so.2 + +# Install git lfs packages +RUN apt-get update && apt-get install -y curl apt-utils && \ +curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \ +apt-get install -y git-lfs + +RUN conda clean -a -y && rm -rf /var/lib/apt/lists/* + +# copy the test data +COPY ./test_models /test_models + +# Clean the environment and finalize the process +WORKDIR /root \ No newline at end of file diff --git a/tools/optimizer_scripts/editor.py b/tools/optimizer_scripts/editor.py new file mode 100644 index 0000000..8ccc6ca --- /dev/null +++ b/tools/optimizer_scripts/editor.py @@ -0,0 +1,118 @@ +import onnx +import onnx.utils +try: + from onnx import optimizer +except ImportError: + import onnxoptimizer as optimizer +import argparse + +import tools.modhelper as helper +import tools.other as other +import tools.replacing as replacing +# Main process +# Argument parser +parser = argparse.ArgumentParser(description="Edit an ONNX model.\nThe processing sequense is 'delete nodes/values' -> 'add nodes' -> 'change shapes'.\nCutting cannot be done with other operations together") +parser.add_argument('in_file', type=str, help='input ONNX FILE') +parser.add_argument('out_file', type=str, help="ouput ONNX FILE") +parser.add_argument('-c', '--cut', dest='cut_node', type=str, nargs='+', help="remove nodes from the given nodes(inclusive)") +parser.add_argument('--cut-type', dest='cut_type', type=str, nargs='+', help="remove nodes by type from the given nodes(inclusive)") +parser.add_argument('-d', '--delete', dest='delete_node', type=str, nargs='+', help="delete nodes by names and only those nodes") +parser.add_argument('--delete-input', dest='delete_input', type=str, nargs='+', help="delete inputs by names") +parser.add_argument('--delete-output', dest='delete_output', type=str, nargs='+', help="delete outputs by names") +parser.add_argument('-i', '--input', dest='input_change', type=str, nargs='+', help="change input shape (e.g. -i 'input_0 1 3 224 224')") +parser.add_argument('-o', '--output', dest='output_change', type=str, nargs='+', help="change output shape (e.g. -o 'input_0 1 3 224 224')") +parser.add_argument('--add-conv', dest='add_conv', type=str, nargs='+', help='add nop conv using specific input') +parser.add_argument('--add-bn', dest='add_bn', type=str, nargs='+', help='add nop bn using specific input') +parser.add_argument('--rename-output', dest='rename_output', type=str, nargs='+', help='Rename the specific output(e.g. --rename-output old_name new_name)') +parser.add_argument('--pixel-bias-value', dest='pixel_bias_value', type=str, nargs='+', help='(per channel) set pixel value bias bn layer at model front for normalization( e.g. --pixel_bias_value "[104.0, 117.0, 123.0]" )') +parser.add_argument('--pixel-scale-value', dest='pixel_scale_value', type=str, nargs='+', help='(per channel) set pixel value scale bn layer at model front for normalization( e.g. --pixel_scale_value "[0.0078125, 0.0078125, 0.0078125]" )') + +args = parser.parse_args() + +# Load model and polish +m = onnx.load(args.in_file) +m = other.polish_model(m) +g = m.graph +replacing.replace_initializer_with_Constant(g) +other.topological_sort(g) + +# Remove nodes according to the given arguments. +if args.delete_node is not None: + helper.delete_nodes(g, args.delete_node) + +if args.delete_input is not None: + helper.delete_input(g, args.delete_input) + +if args.delete_output is not None: + helper.delete_output(g, args.delete_output) + +# Add do-nothing Conv node +if args.add_conv is not None: + other.add_nop_conv_after(g, args.add_conv) + other.topological_sort(g) + +# Add do-nothing BN node +if args.add_bn is not None: + other.add_nop_bn_after(g, args.add_bn) + other.topological_sort(g) + +# Add bias scale BN node +if args.pixel_bias_value is not None or args.pixel_scale_value is not None: + + if len(g.input) > 1: + raise ValueError(" '--pixel-bias-value' and '--pixel-scale-value' only support one input node model currently") + + i_n = g.input[0] + + pixel_bias_value = [0] * i_n.type.tensor_type.shape.dim[1].dim_value + pixel_scale_value = [1] * i_n.type.tensor_type.shape.dim[1].dim_value + + if args.pixel_bias_value is not None and len(args.pixel_bias_value) == 1: + pixel_bias_value = [float(n) for n in args.pixel_bias_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')] + + if args.pixel_scale_value is not None and len(args.pixel_scale_value) == 1: + pixel_scale_value = [float(n) for n in args.pixel_scale_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')] + + + if i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_bias_value) or i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_scale_value): + raise ValueError("--pixel-bias-value (" + str(pixel_bias_value) + ") and --pixel-scale-value (" + str(pixel_scale_value) + ") should be same as input dimension:" + str(i_n.type.tensor_type.shape.dim[1].dim_value) ) + other.add_bias_scale_bn_after(g, i_n.name, pixel_bias_value, pixel_scale_value) + +# Change input and output shapes as requested +if args.input_change is not None: + other.change_input_shape(g, args.input_change) +if args.output_change is not None: + other.change_output_shape(g, args.output_change) + +# Cutting nodes according to the given arguments. +if args.cut_node is not None or args.cut_type is not None: + if args.cut_node is None: + other.remove_nodes(g, cut_types=args.cut_type) + elif args.cut_type is None: + other.remove_nodes(g, cut_nodes=args.cut_node) + else: + other.remove_nodes(g, cut_nodes=args.cut_node, cut_types=args.cut_type) + other.topological_sort(g) + +# Rename nodes +if args.rename_output: + if len(args.rename_output) % 2 != 0: + print("Rename output should be paires of names.") + else: + for i in range(0, len(args.rename_output), 2): + other.rename_output_name(g, args.rename_output[i], args.rename_output[i + 1]) + +# Remove useless nodes +if args.delete_node or args.delete_input or args.input_change or args.output_change: + # If shape changed during the modification, redo shape inference. + while(len(g.value_info) > 0): + g.value_info.pop() +passes = ['extract_constant_to_initializer'] +m = optimizer.optimize(m, passes) +g = m.graph +replacing.replace_initializer_with_Constant(g) +other.topological_sort(g) +# Polish and output +m = other.polish_model(m) +other.add_output_to_value_info(m.graph) +onnx.save(m, args.out_file) \ No newline at end of file diff --git a/tools/optimizer_scripts/norm_on_scaled_onnx.py b/tools/optimizer_scripts/norm_on_scaled_onnx.py new file mode 100644 index 0000000..f99a866 --- /dev/null +++ b/tools/optimizer_scripts/norm_on_scaled_onnx.py @@ -0,0 +1,52 @@ +import onnx +import sys +import json + +from tools import special + +if len(sys.argv) != 3: + print("python norm_on_scaled_onnx.py input.onnx input.json") + exit(1) + +# Modify onnx +m = onnx.load(sys.argv[1]) +special.add_0_5_to_normalized_input(m) +onnx.save(m, sys.argv[1][:-4] + 'norm.onnx') + +# Change input node +origin_file = open(sys.argv[2], 'r') +origin_json = json.load(origin_file) +origin_json["input_node"]["output_datapath_radix"] = [8] +new_json_str = json.dumps(origin_json) + +# Modify json +file = open(sys.argv[1][:-4] + 'norm.onnx' + '.json', 'w') +s = """{{ + \"{0}\" : + {{ + \"bias_bitwidth\" : 16, + \"{0}_bias\" : [15], + \"{0}_weight\" : [3,3,3], + \"conv_coarse_shift\" : [-4,-4,-4], + \"conv_fine_shift\" : [0,0,0], + \"conv_total_shift\" : [-4,-4,-4], + \"cpu_mode\" : false, + \"delta_input_bitwidth\" : [0], + \"delta_output_bitwidth\" : 8, + \"flag_radix_bias_eq_output\" : true, + \"input_scale\" : [[1.0,1.0,1.0]], + \"output_scale\" : [1.0, 1.0, 1.0], + \"psum_bitwidth\" : 16, + \"weight_bitwidth\" : 8, + \"input_datapath_bitwidth\" : [8], + \"input_datapath_radix\" : [8], + \"working_input_bitwidth\" : 8, + \"working_input_radix\" : [8], + \"working_output_bitwidth\" : 16, + \"working_output_radix\" : 15, + \"output_datapath_bitwidth\" : 8, + \"output_datapath_radix\" : 7 + }},\n""".format('input_norm') +file.write(s + new_json_str[1:]) +file.close() +origin_file.close() diff --git a/tools/optimizer_scripts/onnx1_3to1_4.py b/tools/optimizer_scripts/onnx1_3to1_4.py new file mode 100644 index 0000000..64b72b5 --- /dev/null +++ b/tools/optimizer_scripts/onnx1_3to1_4.py @@ -0,0 +1,135 @@ +# ref http://192.168.200.1:8088/jiyuan/converter_scripts.git + +import sys +import onnx +import numpy as np +from onnx import numpy_helper +from tools import other, helper + +""" +Change onnx model from version 1.3 to version 1.4. +Modify the BN node by removing the spatial attribute +Modify the Upsample node by removing the 'scales' attribute, and adding a constant node instead. +Model's ir_version and opset_import are updated. +""" + +def remove_BN_spatial(g): + for node in g.node: + if node.op_type != 'BatchNormalization': + continue + for att in node.attribute: + if att.name == 'spatial': + node.attribute.remove(att) + + +def upsample_attribute_to_const(g): + for node in g.node: + if node.op_type != 'Upsample': + continue + scales_exist = False + for att in node.attribute: + if att.name == 'scales': + scales_exist = True + break + if not scales_exist: + continue + + shape = [len(att.floats)] + node.attribute.remove(att) + new_node = helper.list_to_constant(node.name+'_input', shape, att.floats) + + g.node.extend([new_node]) + value_info = onnx.helper.make_tensor_value_info(node.name+'_input', onnx.TensorProto.FLOAT, shape) + node.input.extend([node.name+'_input']) + g.value_info.extend([value_info]) + +def relu6_to_clip(g): + for node in g.node: + if node.op_type != 'Relu': + continue + max_val = helper.get_var_attribute_by_name(node, 'max', 'float') + if max_val is None: + continue + new_node = onnx.helper.make_node( + "Clip", + node.input, + node.output, + name=node.name, + max=max_val, + min=0.0 + ) + g.node.remove(node) + g.node.extend([new_node]) + +def PRelu_weight_reshape(g): + # For PRelu with single dimension weight. Expand it to 1, x, 1, 1 + for node in g.node: + if node.op_type != "PRelu": + continue + slope = helper.find_node_by_output_name(g, node.input[1]) + if slope is not None: + # Constant node + if len(slope.attribute[0].t.dims) != 1: + continue + slope.attribute[0].t.dims.append(slope.attribute[0].t.dims[0]) + slope.attribute[0].t.dims[0] = 1 + slope.attribute[0].t.dims.append(1) + slope.attribute[0].t.dims.append(1) + else: + # Initializer + for i in g.initializer: + if i.name == node.input[1]: + slope = i + break + if len(slope.dims) != 1: + continue + slope.dims.append(slope.dims[0]) + slope.dims[0] = 1 + slope.dims.append(1) + slope.dims.append(1) + input_value = helper.find_input_by_name(g, node.input[1]) + new_input = onnx.helper.make_tensor_value_info( + node.input[1], + input_value.type.tensor_type.elem_type, + (1, slope.dims[1], 1, 1)) + g.input.remove(input_value) + g.input.append(new_input) + value_info = helper.find_value_by_name(g, node.input[1]) + if value_info is not None: + g.value_info.remove(value_info) + +def do_convert(m): + graph = m.graph + + # Modify the nodes. + remove_BN_spatial(graph) + upsample_attribute_to_const(graph) + relu6_to_clip(graph) + PRelu_weight_reshape(graph) + other.topological_sort(graph) + + # Change model properties. + m.ir_version = 4 + m.opset_import[0].version = 9 + return m + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage:{} file_in file_out".format(sys.argv[0])) + exit(1) + + model = onnx.load(sys.argv[1]) + graph = model.graph + + # Modify the nodes. + remove_BN_spatial(graph) + upsample_attribute_to_const(graph) + relu6_to_clip(graph) + PRelu_weight_reshape(graph) + other.topological_sort(graph) + + # Change model properties. + model.ir_version = 4 + model.opset_import[0].version = 9 + + onnx.save(model, sys.argv[2]) diff --git a/tools/optimizer_scripts/onnx1_4to1_6.py b/tools/optimizer_scripts/onnx1_4to1_6.py new file mode 100644 index 0000000..825b3cd --- /dev/null +++ b/tools/optimizer_scripts/onnx1_4to1_6.py @@ -0,0 +1,184 @@ +# ref http://192.168.200.1:8088/jiyuan/converter_scripts.git + +import sys +import onnx +import onnx.utils +import numpy as np +from onnx import numpy_helper +from tools import other, helper, replacing + +""" +Change onnx model from version 1.4 to version 1.6. +""" + +def replace_all_attribute_to_const_node_in_pad_node(g): + node_to_remove = [] + node_to_extend = [] + for node in g.node: + if node.op_type != 'Pad': + continue + + pad_loc_node = None # must have + pad_mode = 'constant' + pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [0.0]) # need scalar + for att in node.attribute: + if att.name == 'mode': + pad_mode = helper.get_var_attribute_by_name(node, 'mode', 'string') + if att.name == 'pads': + pad_loc_node = helper.list_to_constant(node.name+'_pad_loc', [len(att.ints)], att.ints) + if att.name == 'value': + pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [att.f]) + + new_node = onnx.helper.make_node( + "Pad", + [node.input[0], pad_loc_node.name, pad_value_node.name], + [node.output[0]], + name=node.output[0], + mode=pad_mode, + ) + node_to_remove.append(node) + node_to_extend.append(new_node) + node_to_extend.append(pad_loc_node) + node_to_extend.append(pad_value_node) + + for node in node_to_remove: + g.node.remove(node) + for node in node_to_extend: + g.node.extend([node]) + + +def upsampling_to_resize(g): + for node in g.node: + if node.op_type != 'Upsample': + continue + upsampling_mode = helper.get_var_attribute_by_name(node, 'mode', 'string') + + scale_value_node = helper.find_node_by_output_name(g, node.input[1]) + if scale_value_node.op_type != "Constant": + raise TypeError('seems there is a dynamic "scales" param in Upsampling node: ' + node.name + ' , you might need to do constant folding first') + + roi_node = helper.list_to_constant(node.name+'_roi_value', [0], []) + + new_node = onnx.helper.make_node( + "Resize", + [node.input[0], roi_node.name, scale_value_node.name], + [node.output[0]], + name=node.output[0], + mode=upsampling_mode, + coordinate_transformation_mode = 'asymmetric' + ) + + g.node.remove(node) + g.node.extend([new_node]) + g.node.extend([roi_node]) + + +def replace_all_attribute_to_const_node_in_slice_node(g): + for node in g.node: + if node.op_type != 'Slice': + continue + + axes_const_node = None + ends_const_node = None + starts_const_node = None + steps_const_node = None + for att in node.attribute: + if att.name == 'axes': + axes_const_node = helper.list_to_constant(node.name+'_axes_value', [len(att.ints)], att.ints) + + if att.name == 'ends': + ends_const_node = helper.list_to_constant(node.name+'_ends_value', [len(att.ints)], att.ints) + + if att.name == 'starts': + starts_const_node = helper.list_to_constant(node.name+'_starts_value', [len(att.ints)], att.ints) + + if att.name == 'steps': + steps_const_node = helper.list_to_constant(node.name+'_steps_value',[ len(att.ints)], att.ints) + + ## pop out from back + attr_len = len(node.attribute) + for i in range(attr_len): + node.attribute.remove(node.attribute[ attr_len -1 - i ]) + + ## according the spec, we need to add node in specific order + if starts_const_node != None: + g.node.extend([starts_const_node]) + node.input.extend([starts_const_node.name]) + if ends_const_node != None: + g.node.extend([ends_const_node]) + node.input.extend([ends_const_node.name]) + if axes_const_node != None: + g.node.extend([axes_const_node]) + node.input.extend([axes_const_node.name]) + if steps_const_node != None: + g.node.extend([steps_const_node]) + node.input.extend([steps_const_node.name]) + + +def replace_min_max_attribute_to_const_node_in_clip_node(g): + for node in g.node: + if node.op_type != 'Clip': + continue + + max_const_node = None + min_const_node = None + for att in node.attribute: + if att.name == 'max': + max_const_node = helper.list_to_constant(node.name+'_max_value', [], [att.f]) + + if att.name == 'min': + min_const_node = helper.list_to_constant(node.name+'_min_value', [], [att.f]) + + ## pop out from back + node.attribute.remove(node.attribute[1]) + node.attribute.remove(node.attribute[0]) + + ## according the spec, we need to add node in specific order + g.node.extend([min_const_node]) + g.node.extend([max_const_node]) + node.input.extend([min_const_node.name]) + node.input.extend([max_const_node.name]) + +def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto: + """Update ir_version from 4 to 6 and update opset from 9 to 11. + + Args: + model (onnx.ModelProto): input onnx model. + + Returns: + onnx.ModelProto: updated onnx model. + """ + graph = model.graph + + if model.opset_import[0].version == 11: + print("(Stop) the input model is already opset 11, no need to upgrade") + exit(1) + + # deal with empty node name issue + other.add_name_to_node(graph) + # simplify the node param type from initializer to constant + replacing.replace_initializer_with_Constant(graph) + + # Modify the nodes. + replace_min_max_attribute_to_const_node_in_clip_node(graph) + replace_all_attribute_to_const_node_in_slice_node(graph) + replace_all_attribute_to_const_node_in_pad_node(graph) + upsampling_to_resize(graph) + other.topological_sort(graph) + + # Change model properties. + model.ir_version = 6 + model.opset_import[0].version = 11 + + model = other.polish_model(model) + return model + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage:{} file_in file_out".format(sys.argv[0])) + exit(1) + + model = onnx.load(sys.argv[1]) + model = onnx1_4to1_6(model) + + onnx.save(model, sys.argv[2]) diff --git a/tools/optimizer_scripts/onnx2onnx.py b/tools/optimizer_scripts/onnx2onnx.py new file mode 100644 index 0000000..b820378 --- /dev/null +++ b/tools/optimizer_scripts/onnx2onnx.py @@ -0,0 +1,136 @@ +import onnx +import onnx.utils +try: + from onnx import optimizer +except ImportError: + import onnxoptimizer as optimizer +import sys +import argparse +import logging + +from tools import eliminating +from tools import fusing +from tools import replacing +from tools import other +from tools import special +from tools import combo +from tools.helper import logger +# from tools import temp + +def onnx2onnx_flow(m: onnx.ModelProto, + disable_fuse_bn=False, + bn_on_skip=False, + bn_before_add=False, + bgr=False, + norm=False, + rgba2yynn=False, + eliminate_tail=False, + opt_matmul=False, + duplicate_shared_weights=True) -> onnx.ModelProto: + """Optimize the onnx. + + Args: + m (ModelProto): the input onnx ModelProto + disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False. + bn_on_skip (bool, optional): add BN operator on skip branches. Defaults to False. + bn_before_add (bool, optional): add BN before Add node on every branches. Defaults to False. + bgr (bool, optional): add an Conv layer to convert rgb input to bgr. Defaults to False. + norm (bool, optional): add an Conv layer to add 0.5 tp the input. Defaults to False. + rgba2yynn (bool, optional): add an Conv layer to convert rgb input to yynn . Defaults to False. + eliminate_tail (bool, optional): remove the trailing NPU unsupported nodes. Defaults to False. + opt_matmul(bool, optional): optimize the MatMul layers according to the NPU limit. Defaults to False. + duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True. + + Returns: + ModelProto: the optimized onnx model object. + """ + # temp.weight_broadcast(m.graph) + m = combo.preprocess(m, disable_fuse_bn, duplicate_shared_weights) + # temp.fuse_bias_in_consecutive_1x1_conv(m.graph) + + # Add BN on skip branch + if bn_on_skip: + other.add_bn_on_skip_branch(m.graph) + elif bn_before_add: + other.add_bn_before_add(m.graph) + other.add_bn_before_activation(m.graph) + + # My optimization + m = combo.common_optimization(m) + # Special options + if bgr: + special.change_input_from_bgr_to_rgb(m) + if norm: + special.add_0_5_to_normalized_input(m) + if rgba2yynn: + special.add_rgb2yynn_node(m) + + # Remove useless last node + if eliminate_tail: + eliminating.remove_useless_last_nodes(m.graph) + + # Postprocessing + m = combo.postprocess(m) + + # Put matmul after postprocess to avoid transpose moving downwards + if opt_matmul: + special.special_MatMul_process(m.graph) + m = other.polish_model(m) + + return m + +# Main process +if __name__ == "__main__": + # Argument parser + parser = argparse.ArgumentParser(description="Optimize an ONNX model for Kneron compiler") + parser.add_argument('in_file', help='input ONNX FILE') + parser.add_argument('-o', '--output', dest='out_file', type=str, help="ouput ONNX FILE") + parser.add_argument('--log', default='i', type=str, help="set log level") + parser.add_argument('--bgr', action='store_true', default=False, help="set if the model is trained in BGR mode") + parser.add_argument('--norm', action='store_true', default=False, help="set if you have the input -0.5~0.5") + parser.add_argument('--rgba2yynn', action='store_true', default=False, help="set if the model has yynn input but you want to take rgba images") + parser.add_argument('--add-bn-on-skip', dest='bn_on_skip', action='store_true', default=False, + help="set if you only want to add BN on skip branches") + parser.add_argument('--add-bn', dest='bn_before_add', action='store_true', default=False, + help="set if you want to add BN before Add") + parser.add_argument('-t', '--eliminate-tail-unsupported', dest='eliminate_tail', action='store_true', default=False, + help='whether remove the last unsupported node for hardware') + parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False, + help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.") + parser.add_argument('--opt-matmul', dest='opt_matmul', action='store_true', default=False, + help="set if you want to optimize the MatMul operations for the kneron hardware.") + parser.add_argument('--no-duplicate-shared-weights', dest='no_duplicate_shared_weights', action='store_true', default=False, + help='do not duplicate shared weights. Defaults to False.') + args = parser.parse_args() + + if args.out_file is None: + outfile = args.in_file[:-5] + "_polished.onnx" + else: + outfile = args.out_file + + if args.log == 'w': + logging.basicConfig(level=logging.WARN) + elif args.log == 'd': + logging.basicConfig(level=logging.DEBUG) + elif args.log == 'e': + logging.basicConfig(level=logging.ERROR) + else: + logging.basicConfig(level=logging.INFO) + + # onnx Polish model includes: + # -- nop + # -- eliminate_identity + # -- eliminate_nop_transpose + # -- eliminate_nop_pad + # -- eliminate_unused_initializer + # -- fuse_consecutive_squeezes + # -- fuse_consecutive_transposes + # -- fuse_add_bias_into_conv + # -- fuse_transpose_into_gemm + + # Basic model organize + m = onnx.load(args.in_file) + + m = onnx2onnx_flow(m, args.disable_fuse_bn, args.bn_on_skip, args.bn_before_add, args.bgr, args.norm, args.rgba2yynn, args.eliminate_tail, args.opt_matmul, not args.no_duplicate_shared_weights) + + onnx.save(m, outfile) diff --git a/tools/optimizer_scripts/onnx_vs_onnx.py b/tools/optimizer_scripts/onnx_vs_onnx.py new file mode 100644 index 0000000..c04c65b --- /dev/null +++ b/tools/optimizer_scripts/onnx_vs_onnx.py @@ -0,0 +1,134 @@ +import onnxruntime +import onnx +import argparse +import numpy as np +from tools import helper + + +onnx2np_dtype = {0: 'float', 1: 'float32', 2: 'uint8', 3: 'int8', 4: 'uint16', 5: 'int16', 6: 'int32', 7: 'int64', 8: 'str', 9: 'bool', 10: 'float16', 11: 'double', 12: 'uint32', 13: 'uint64', 14: 'complex64', 15: 'complex128', 16: 'float'} + + +def onnx_model_results(path_a, path_b, total_times=10): + """ using onnxruntime to inference two onnx models' ouputs + + :onnx model paths: two model paths + :total_times: inference times, default to be 10 + :returns: inference results of two models + """ + # load model a and model b to runtime + session_a = onnxruntime.InferenceSession(path_a, None) + session_b = onnxruntime.InferenceSession(path_b, None) + outputs_a = session_a.get_outputs() + outputs_b = session_b.get_outputs() + + # check outputs + assert len(outputs_a) == len(outputs_b), 'Two models have different output numbers.' + for i in range(len(outputs_a)): + out_shape_a, out_shape_b = outputs_a[i].shape, outputs_b[i].shape + out_shape_a = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_a)) + out_shape_b = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_b)) + assert out_shape_a == out_shape_b, 'Output {} has unmatched shapes'.format(i) + + + # load onnx graph_a and graph_b, to find the initializer and inputs + # then compare to remove the items in the inputs which will be initialized + model_a, model_b = onnx.load(path_a), onnx.load(path_b) + graph_a, graph_b = model_a.graph, model_b.graph + inputs_a, inputs_b = graph_a.input, graph_b.input + init_a, init_b = graph_a.initializer, graph_b.initializer + + # remove initializer from raw inputs + input_names_a, input_names_b = set([ele.name for ele in inputs_a]), set([ele.name for ele in inputs_b]) + init_names_a, init_names_b = set([ele.name for ele in init_a]), set([ele.name for ele in init_b]) + real_inputs_names_a, real_inputs_names_b = input_names_a - init_names_a, input_names_b - init_names_b + + # prepare and figure out matching of real inputs a and real inputs b + # try to keep original orders of each inputs + real_inputs_a, real_inputs_b = [], [] + for item in inputs_a: + if item.name in real_inputs_names_a: + real_inputs_a.append(item) + for item in inputs_b: + if item.name in real_inputs_names_b: + real_inputs_b.append(item) + + # suppose there's only one real single input tensor for each model + # find the real single inputs for model_a and model_b + real_single_input_a = None + real_single_input_b = None + size_a, size_b = 0, 0 + shape_a, shape_b = [], [] + for item_a in real_inputs_a: + size, shape = helper.find_size_shape_from_value(item_a) + if size: + assert real_single_input_a is None, 'Multiple inputs of first model, single input expected.' + real_single_input_a = item_a + size_a, shape_a = size, shape + for item_b in real_inputs_b: + size, shape = helper.find_size_shape_from_value(item_b) + if size: + assert real_single_input_b is None, 'Multiple inputs of second model, single input expected.' + real_single_input_b = item_b + size_b, shape_b = size, shape + assert size_a == size_b, 'Sizes of two models do not match.' + + + # construct inputs tensors + input_data_type_a = real_single_input_a.type.tensor_type.elem_type + input_data_type_b = real_single_input_b.type.tensor_type.elem_type + input_data_type_a = onnx2np_dtype[input_data_type_a] + input_data_type_b = onnx2np_dtype[input_data_type_b] + + # run inference + times = 0 + results_a = [[] for i in range(len(outputs_a))] + results_b = [[] for i in range(len(outputs_b))] + while times < total_times: + # initialize inputs by random data, default to be uniform + data = np.random.random(size_a) + input_a = np.reshape(data, shape_a).astype(input_data_type_a) + input_b = np.reshape(data, shape_b).astype(input_data_type_b) + + input_dict_a = {} + input_dict_b = {} + for item_a in real_inputs_a: + item_type_a = onnx2np_dtype[item_a.type.tensor_type.elem_type] + input_dict_a[item_a.name] = np.array([]).astype(item_type_a) \ + if item_a.name != real_single_input_a.name else input_a + for item_b in real_inputs_b: + item_type_b = onnx2np_dtype[item_b.type.tensor_type.elem_type] + input_dict_b[item_b.name] = np.array([]).astype(item_type_b) \ + if item_b.name != real_single_input_b.name else input_b + + ra = session_a.run([], input_dict_a) + rb = session_b.run([], input_dict_b) + for i in range(len(outputs_a)): + results_a[i].append(ra[i]) + results_b[i].append(rb[i]) + times += 1 + + return results_a, results_b + +if __name__ == '__main__': + # Argument parser. + parser = argparse.ArgumentParser(description="Compare two ONNX models to check if they have the same output.") + parser.add_argument('in_file_a', help='input ONNX file a') + parser.add_argument('in_file_b', help='input ONNX file b') + + args = parser.parse_args() + + results_a, results_b = onnx_model_results(args.in_file_a, args.in_file_b, total_times=10) + ra_flat = helper.flatten_with_depth(results_a, 0) + rb_flat = helper.flatten_with_depth(results_b, 0) + shape_a = [item[1] for item in ra_flat] + shape_b = [item[1] for item in rb_flat] + assert shape_a == shape_b, 'two results data shape doesn\'t match' + ra_raw = [item[0] for item in ra_flat] + rb_raw = [item[0] for item in rb_flat] + + try: + np.testing.assert_almost_equal(ra_raw, rb_raw, 4) + print('Two models have the same behaviour.') + except Exception as mismatch: + print(mismatch) + exit(1) diff --git a/tools/optimizer_scripts/onnx_vs_onnx_opt.py b/tools/optimizer_scripts/onnx_vs_onnx_opt.py new file mode 100644 index 0000000..b660cf4 --- /dev/null +++ b/tools/optimizer_scripts/onnx_vs_onnx_opt.py @@ -0,0 +1,221 @@ +import onnx +import argparse +import glob +import csv +import numpy as np +import matplotlib.pyplot as plt + +from tools import helper +import onnx_vs_onnx as onnx_tester + +def compare_results(results_a, results_b): + """ compare onnx model inference results + calculate basic statistical values + results: results from inference multiple times + returns: list of basic statistical values + """ + # input results data can be of nonuniform shape + # get flatten data to compare + ra_flat = helper.flatten_with_depth(results_a, 0) + rb_flat = helper.flatten_with_depth(results_b, 0) + shape_a = [item[1] for item in ra_flat] + shape_b = [item[1] for item in rb_flat] + assert shape_a == shape_b, 'two results data shape doesn\'t match' + ra_raw = [item[0] for item in ra_flat] + rb_raw = [item[0] for item in rb_flat] + + # the statistical values + max_rel_diff = 0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } ) + max_abs_diff = 0 # defined to be max( { abs(ra-rb) } ) + mean_rel_diff = 0 + mean_abs_diff = 0 + std_rel_diff = 0 + std_abs_diff = 0 + acc_with_diff_precision = [] + rel_diff = [] + abs_diff_percentiles = [] # rel_diff percentiles + rel_diff_percentiles = [] # abs_diff precentiles + + raw_diff = [ra_raw[i]-rb_raw[i] for i in range(len(ra_raw))] + abs_diff = [abs(num) for num in raw_diff] + for i in range(len(ra_raw)): + divider = max([abs(ra_raw[i]), abs(rb_raw[i])]) + val = abs_diff[i]/divider if divider != 0 else 0 + rel_diff.append(val) + + max_rel_diff = max(rel_diff) + max_abs_diff = max(abs_diff) + mean_rel_diff = np.average(rel_diff) + mean_abs_diff = np.average(abs_diff) + std_rel_diff = np.std(rel_diff) + std_abs_diff = np.std(abs_diff) + + # calculate accuracy with different precison + for digit in range(8): + correct = 0 + for i in range(len(ra_raw)): + if format(ra_raw[i], '.'+str(digit)+'f')\ + == format(rb_raw[i], '.'+str(digit)+'f'): + correct += 1 + acc_with_diff_precision.append([digit, float(format(correct/len(ra_raw), '.3f'))]) + + # analyze rel_diff distribution + rel_diff.sort() + abs_diff.sort() + for i in range(20): + rel_diff_percentiles.append(['{}%'.format(i*5), rel_diff[int((i/20)*len(rel_diff))]]) + abs_diff_percentiles.append(['{}%'.format(i*5), abs_diff[int((i/20)*len(abs_diff))]]) + + results = [ + ['max_rel_diff', max_rel_diff], + ['max_abs_diff', max_abs_diff], + ['mean_rel_diff', mean_rel_diff], + ['mean_abs_diff', mean_abs_diff], + ['std_rel_diff', std_rel_diff], + ['std_abs_diff', std_abs_diff], + ['acc_with_diff_precision', acc_with_diff_precision], + ['rel_diff_percentiles', rel_diff_percentiles], + ['abs_diff_percentiles', abs_diff_percentiles] + ] + + return results + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='test model optimization results') + + parser.add_argument('dir', type=str, help='the directory that stores onnx models') + parser.add_argument('ending1', type=str, help='model file name ending(eg, .onnx)') + parser.add_argument('ending2', type=str, help='opt model file name ending(eg. _opt.onnx)') + parser.add_argument('out_file', type=str, help='output csv file name') + parser.add_argument('-p', '--plot', default='N', help='get plots (Y/N)') + parser.add_argument('-i', '--iter_times', default=10, type=int, help='inference times') + + args = parser.parse_args() + + old_models_paths = glob.glob(args.dir+'*'+args.ending1) + new_models_paths = glob.glob(args.dir+'*'+args.ending2) + + stats_table = [[ + 'Model', + 'max_rel_diff', + 'max_abs_diff', + 'mean_rel_diff', + 'mean_abs_diff', + 'std_rel_diff', + 'std_abs_diff', + 'acc_with_diff_precision', + 'rel_diff_percentiles', + 'abs_diff_percentiles' + ]] + + for new_model_path in new_models_paths: + old_model_path = new_model_path[:-len(args.ending2)] + args.ending1 + if old_model_path not in old_models_paths: + continue + + # run inference + results_a, results_b = onnx_tester.onnx_model_results(old_model_path, new_model_path, total_times=args.iter_times) + + # compare inference results + comparision = compare_results(results_a, results_b) + + new_line = [old_model_path.split('/')[-1]] + for item in comparision: + new_line.append(item[1]) + + stats_table.append(new_line) + + # try to read existing file + old_stats_table = [] + try: + old_file = open(args.out_file, 'r') + reader = csv.reader(old_file) + old_header = reader.__next__() + for row in reader: + old_stats_table.append(row) + old_file.close() + except: + pass + + # compare and merge possible old stat data file with new stat data file + header = stats_table[0] + stats_table = stats_table[1:] + new_model_names = set([item[0] for item in stats_table]) + for row in old_stats_table: + if row[0] not in new_model_names: + stats_table.append(row) + stats_table.insert(0, header) + + # write a new stat data file, overwrite old file + new_file = open(args.out_file, 'w', newline='') + writer = csv.writer(new_file) + for row in stats_table: + writer.writerow(row) + new_file.close() + + # make some plots + if args.plot == 'Y': + if len(stats_table) < 2: + exit(0) + + sample_table = stats_table[1:] if len(stats_table) < 6 else stats_table[1:6] + + max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]] + plt.hist(max_rel_diffs, bins=15) + plt.title('Max Relavtive Difference Histogram') + plt.xlabel('Max Relative Difference') + plt.ylabel('Counts') + plt.savefig('max_rel_diff_hist.png') + plt.close() + + max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]] + plt.hist(max_abs_diffs, bins=15) + plt.title('Max Absolute Difference Histogram') + plt.xlabel('Max Absolute Difference') + plt.ylabel('Counts') + plt.savefig('max_abs_diff_hist.png') + plt.close() + + for line in sample_table: + model_name = line[0] + percentiles = line[-2] + x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))] + y = [ele[1] for ele in percentiles] + plt.plot(x, y, label=model_name) + plt.title('Rel_diff Percentiles of Raw and Optimized Models') + plt.xlabel('percentage') + plt.ylabel('relative difference') + plt.legend() + plt.savefig('rel_diff_percentiles.png') + plt.close() + + for line in sample_table: + model_name = line[0] + percentiles = line[-1] + x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))] + y = [ele[1] for ele in percentiles] + plt.plot(x, y, label=model_name) + plt.title('Abs_diff Percentiles of Raw and Optimized Models') + plt.xlabel('percentage') + plt.ylabel('absolute difference') + plt.legend() + plt.savefig('abs_diff_percentiles.png') + plt.close() + + for line in sample_table: + model_name = line[0] + accuracies = line[-3] + x = [acc[0] for acc in accuracies] + y = [acc[1] for acc in accuracies] + plt.plot(x, y, label=model_name) + plt.title('Accuracies with Different Precisions') + plt.xlabel('Decimals') + plt.ylabel('Precision') + plt.legend() + plt.savefig('precisions.png') + plt.close() + + + + + diff --git a/tools/optimizer_scripts/pytorch2onnx.py b/tools/optimizer_scripts/pytorch2onnx.py new file mode 100644 index 0000000..0f2c559 --- /dev/null +++ b/tools/optimizer_scripts/pytorch2onnx.py @@ -0,0 +1,81 @@ +import onnx +import onnx.utils +try: + from onnx import optimizer +except ImportError: + import onnxoptimizer as optimizer +import sys +import numpy as np +import struct +import logging +import argparse + +from tools import eliminating +from tools import fusing +from tools import replacing +from tools import other +from tools import combo +from tools import special +from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow + +# Debug use +# logging.basicConfig(level=logging.DEBUG) + +###################################### +# Generate a prototype onnx # +###################################### + +parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler") +parser.add_argument('in_file', help='input ONNX or PTH FILE') +parser.add_argument('out_file', help="ouput ONNX FILE") +parser.add_argument('--input-size', dest='input_size', nargs=3, + help='if you using pth, please use this argument to set up the input size of the model. It should be in \'CH H W\' format, e.g. \'--input-size 3 256 512\'.') +parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False, + help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.") + +args = parser.parse_args() + +if len(args.in_file) <= 4: + # When the filename is too short. + logging.error("Invalid input file: {}".format(args.in_file)) + exit(1) +elif args.in_file[-4:] == '.pth': + # Pytorch pth case + logging.warning("Converting from pth to onnx is not recommended.") + onnx_in = args.out_file + # Import pytorch libraries + from torch.autograd import Variable + import torch + import torch.onnx + # import torchvision + # Standard ImageNet input - 3 channels, 224x224. + # Values don't matter as we care about network structure. + # But they can also be real inputs. + if args.input_size is None: + logging.error("\'--input-size\' is required for the pth input file.") + exit(1) + dummy_input = Variable(torch.randn(1, int(args.input_size[0]), int(args.input_size[1]), int(args.input_size[2]))) + # Obtain your model, it can be also constructed in your script explicitly. + model = torch.load(sys.argv[1], map_location='cpu') + # model = torchvision.models.resnet34(pretrained=True) + # Invoke export. + # torch.save(model, "resnet34.pth") + torch.onnx.export(model, dummy_input, args.out_file, opset_version=11) +elif args.in_file[-4:] == 'onnx': + onnx_in = args.in_file +else: + # When the file is neither an onnx or a pytorch pth. + logging.error("Invalid input file: {}".format(args.in_file)) + exit(1) + +onnx_out = args.out_file + +###################################### +# Optimize onnx # +###################################### + +m = onnx.load(onnx_in) + +m = torch_exported_onnx_flow(m, args.disable_fuse_bn) + +onnx.save(m, onnx_out) diff --git a/tools/optimizer_scripts/pytorch_exported_onnx_preprocess.py b/tools/optimizer_scripts/pytorch_exported_onnx_preprocess.py new file mode 100644 index 0000000..509db82 --- /dev/null +++ b/tools/optimizer_scripts/pytorch_exported_onnx_preprocess.py @@ -0,0 +1,80 @@ +import onnx +import onnx.utils +try: + from onnx import optimizer +except ImportError: + import onnxoptimizer as optimizer +import sys +import numpy as np +import struct +import logging +import argparse + +from .tools import eliminating +from .tools import fusing +from .tools import replacing +from .tools import other +from .tools import combo +from .tools import special + +# Define general pytorch exported onnx optimize process +def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx.ModelProto: + """Optimize the Pytorch exported onnx. + + Args: + m (ModelProto): the input onnx model + disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False. + + Returns: + ModelProto: the optimized onnx model + """ + m = combo.preprocess(m, disable_fuse_bn) + m = combo.pytorch_constant_folding(m) + m = combo.common_optimization(m) + m = combo.postprocess(m) + + return m + + +# Main Process +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler") + parser.add_argument('in_file', help='input ONNX') + parser.add_argument('out_file', help="ouput ONNX FILE") + parser.add_argument('--log', default='i', type=str, help="set log level") + parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False, + help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.") + + args = parser.parse_args() + + if args.log == 'w': + logging.basicConfig(level=logging.WARN) + elif args.log == 'd': + logging.basicConfig(level=logging.DEBUG) + elif args.log == 'e': + logging.basicConfig(level=logging.ERROR) + else: + logging.basicConfig(level=logging.INFO) + + if len(args.in_file) <= 4: + # When the filename is too short. + logging.error("Invalid input file: {}".format(args.in_file)) + exit(1) + elif args.in_file[-4:] == 'onnx': + onnx_in = args.in_file + else: + # When the file is not an onnx file. + logging.error("Invalid input file: {}".format(args.in_file)) + exit(1) + + onnx_out = args.out_file + + ###################################### + # Optimize onnx # + ###################################### + + m = onnx.load(onnx_in) + + m = torch_exported_onnx_flow(m, args.disable_fuse_bn) + + onnx.save(m, onnx_out) diff --git a/tools/optimizer_scripts/res/first_insert_layer.json b/tools/optimizer_scripts/res/first_insert_layer.json new file mode 100644 index 0000000..4fe3f59 --- /dev/null +++ b/tools/optimizer_scripts/res/first_insert_layer.json @@ -0,0 +1,27 @@ +{ + "LAYERNAME" : + { + "bias_bitwidth" : 16, + "LAYERNAME_bias" : [15], + "LAYERNAME_weight" : [3,3,3], + "conv_coarse_shift" : [-4,-4,-4], + "conv_fine_shift" : [0,0,0], + "conv_total_shift" : [-4,-4,-4], + "cpu_mode" : false, + "delta_input_bitwidth" : [0], + "delta_output_bitwidth" : 8, + "flag_radix_bias_eq_output" : true, + "input_scale" : [[1.0,1.0,1.0]], + "output_scale" : [1.0, 1.0, 1.0], + "psum_bitwidth" : 16, + "weight_bitwidth" : 8, + "input_datapath_bitwidth" : [8], + "input_datapath_radix" : [7], + "working_input_bitwidth" : 8, + "working_input_radix" : [7], + "working_output_bitwidth" : 16, + "working_output_radix" : 15, + "output_datapath_bitwidth" : 8, + "output_datapath_radix" : 7 + } +} diff --git a/tools/optimizer_scripts/res/test_onnx_tester_on_difference.sh b/tools/optimizer_scripts/res/test_onnx_tester_on_difference.sh new file mode 100644 index 0000000..342b198 --- /dev/null +++ b/tools/optimizer_scripts/res/test_onnx_tester_on_difference.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +python onnx_tester.py /test_models/mobilenet_v2_224.onnx /test_models/mobilenet_v2_224.cut.onnx +if [ $? -eq 0 ]; then + echo "Those two model results should be different!" + exit 1 +fi + +exit 0 diff --git a/tools/optimizer_scripts/res/vdsr_41_20layer_1.pb b/tools/optimizer_scripts/res/vdsr_41_20layer_1.pb new file mode 100644 index 0000000..81096de Binary files /dev/null and b/tools/optimizer_scripts/res/vdsr_41_20layer_1.pb differ diff --git a/tools/optimizer_scripts/tensorflow2onnx.py b/tools/optimizer_scripts/tensorflow2onnx.py new file mode 100644 index 0000000..13c0dab --- /dev/null +++ b/tools/optimizer_scripts/tensorflow2onnx.py @@ -0,0 +1,147 @@ +import tensorflow as tf +import tf2onnx +import argparse +import logging +import sys +import onnx +import onnx.utils +from tensorflow.python.platform import gfile +from tools import combo, eliminating, replacing, other + +def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto: + """Convert frozen graph pb file into onnx + + Args: + pb_path (str): input pb file path + test_mode (bool, optional): test mode. Defaults to False. + + Raises: + Exception: invalid input file + + Returns: + onnx.ModelProto: converted onnx + """ + TF2ONNX_VERSION = int(tf2onnx.version.version.replace('.', '')) + + if 160 <= TF2ONNX_VERSION: + from tf2onnx import tf_loader + else: + from tf2onnx import loader as tf_loader + + if pb_path[-3:] == '.pb': + model_name = pb_path.split('/')[-1][:-3] + + # always reset tensorflow session at begin + tf.reset_default_graph() + + with tf.Session() as sess: + with gfile.FastGFile(pb_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + sess.graph.as_default() + tf.import_graph_def(graph_def, name='') + + if 160 <= int(tf2onnx.version.version.replace('.', '')): + onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions = tf2onnx.tf_utils.tflist_to_onnx( + sess.graph, + {}) + else: + onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tf2onnx.tfonnx.tflist_to_onnx( + sess.graph.get_operations(), + {}) + + for n in onnx_nodes: + if len(n.output) == 0: + onnx_nodes.remove(n) + + # find inputs and outputs of graph + nodes_inputs = set() + nodes_outputs = set() + + for n in onnx_nodes: + if n.op_type == 'Placeholder': + continue + for input in n.input: + nodes_inputs.add(input) + for output in n.output: + nodes_outputs.add(output) + + graph_input_names = set() + for input_name in nodes_inputs: + if input_name not in nodes_outputs: + graph_input_names.add(input_name) + + graph_output_names = set() + for n in onnx_nodes: + if n.input and n.input[0] not in nodes_outputs: + continue + if len(n.output) == 0: + n.output.append(n.name + ':0') + graph_output_names.add(n.output[0]) + else: + output_name = n.output[0] + if (output_name not in nodes_inputs) and (0 < len(n.input)): + graph_output_names.add(output_name) + + logging.info('Model Inputs: %s', str(list(graph_input_names))) + logging.info('Model Outputs: %s', str(list(graph_output_names))) + + graph_def, inputs, outputs = tf_loader.from_graphdef(model_path=pb_path, + input_names=list(graph_input_names), + output_names=list(graph_output_names)) + + with tf.Graph().as_default() as tf_graph: + tf.import_graph_def(graph_def, name='') + + if 160 <= TF2ONNX_VERSION: + with tf_loader.tf_session(graph=tf_graph): + onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph, + input_names=inputs, + output_names=outputs, + opset=11) + else: + with tf.Session(graph=tf_graph): + onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph, + input_names=inputs, + output_names=outputs, + opset=11) + + # Optimize with tf2onnx.optimizer + onnx_graph = tf2onnx.optimizer.optimize_graph(onnx_graph) + model_proto = onnx_graph.make_model(model_name) + + # Make tf2onnx output compatible with the spec. of other.polish_model + replacing.replace_initializer_with_Constant(model_proto.graph) + model_proto = other.polish_model(model_proto) + + else: + raise Exception('expect .pb file as input, but got "' + str(pb_path) + '"') + + # rename + m = model_proto + + m = combo.preprocess(m) + m = combo.common_optimization(m) + m = combo.tensorflow_optimization(m) + m = combo.postprocess(m) + + if not test_mode: + g = m.graph + eliminating.eliminate_shape_changing_after_input(g) + + m = other.polish_model(m) + return m + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Convert tensorflow pb file to onnx file and optimized onnx file. Or just optimize tensorflow onnx file.') + parser.add_argument('in_file', help='input file') + parser.add_argument('out_file', help='output optimized model file') + parser.add_argument('-t', '--test_mode', default=False, help='test mode will not eliminate shape changes after input') + + args = parser.parse_args() + logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.INFO) + m = tf2onnx_flow(args.in_file, args.test_mode) + onnx.save(m, args.out_file) + logging.info('Save Optimized ONNX: %s', args.out_file) diff --git a/tools/optimizer_scripts/tflite_vs_onnx.py b/tools/optimizer_scripts/tflite_vs_onnx.py new file mode 100644 index 0000000..ffeecea --- /dev/null +++ b/tools/optimizer_scripts/tflite_vs_onnx.py @@ -0,0 +1,68 @@ +import argparse +import numpy as np +import tensorflow as tf +import onnx +import onnxruntime + +from tools import helper + +def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10): + # Setup onnx session and get meta data + onnx_session = onnxruntime.InferenceSession(onnx_file, None) + onnx_outputs = onnx_session.get_outputs() + assert len(onnx_outputs) == 1, "The onnx model has more than one output" + onnx_model = onnx.load(onnx_file) + onnx_graph = onnx_model.graph + onnx_inputs = onnx_graph.input + assert len(onnx_inputs) == 1, "The onnx model has more than one input" + _, onnx_input_shape = helper.find_size_shape_from_value(onnx_inputs[0]) + # Setup TFLite sessio and get meta data + tflite_session = tf.lite.Interpreter(model_path=tflite_file) + tflite_session.allocate_tensors() + tflite_inputs = tflite_session.get_input_details() + tflite_outputs = tflite_session.get_output_details() + tflite_input_shape = tflite_inputs[0]['shape'] + # Compare input shape + assert(len(onnx_input_shape) == len(tflite_input_shape)), "TFLite and ONNX shape unmatch." + assert(onnx_input_shape == [tflite_input_shape[0], tflite_input_shape[3], tflite_input_shape[1], tflite_input_shape[2]]), "TFLite and ONNX shape unmatch." + # Generate random number and run + tflite_results = [] + onnx_results = [] + for _ in range(total_times): + # Generate input + tflite_input_data = np.array(np.random.random_sample(tflite_input_shape), dtype=np.float32) + onnx_input_data = np.transpose(tflite_input_data, [0, 3, 1, 2]) + # Run tflite + tflite_session.set_tensor(tflite_inputs[0]['index'], tflite_input_data) + tflite_session.invoke() + tflite_results.append(tflite_session.get_tensor(tflite_outputs[0]['index'])) + # Run onnx + onnx_input_dict = {onnx_inputs[0].name: onnx_input_data} + onnx_results.append(onnx_session.run([], onnx_input_dict)[0]) + + return tflite_results, onnx_results + + +if __name__ == '__main__': + # Argument parser. + parser = argparse.ArgumentParser(description="Compare a TFLite model and an ONNX model to check if they have the same output.") + parser.add_argument('tflite_file', help='input tflite file') + parser.add_argument('onnx_file', help='input ONNX file') + + args = parser.parse_args() + + results_a, results_b = compare_tflite_and_onnx(args.tflite_file, args.onnx_file, total_times=10) + ra_flat = helper.flatten_with_depth(results_a, 0) + rb_flat = helper.flatten_with_depth(results_b, 0) + shape_a = [item[1] for item in ra_flat] + shape_b = [item[1] for item in rb_flat] + assert shape_a == shape_b, 'two results data shape doesn\'t match' + ra_raw = [item[0] for item in ra_flat] + rb_raw = [item[0] for item in rb_flat] + + try: + np.testing.assert_almost_equal(ra_raw, rb_raw, 8) + print('Two models have the same behaviour.') + except Exception as mismatch: + print(mismatch) + exit(1) \ No newline at end of file diff --git a/tools/optimizer_scripts/tools/__init__.py b/tools/optimizer_scripts/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/optimizer_scripts/tools/combo.py b/tools/optimizer_scripts/tools/combo.py new file mode 100644 index 0000000..adadecb --- /dev/null +++ b/tools/optimizer_scripts/tools/combo.py @@ -0,0 +1,258 @@ +"""Combo functions that are usually called together. +""" + +import logging +import onnx.utils +try: + from onnx import optimizer +except ImportError: + import onnxoptimizer as optimizer + +from . import helper +from . import other +from . import replacing +from . import eliminating +from . import fusing +from . import constant_folding +from . import removing_transpose +from . import modhelper +from .common_pattern import torch_pattern_match, tf_pattern_match +from .helper import logger + +def preprocess(model_proto, disable_fuse_bn=False, duplicate_shared_weights=True): + """The most common used functions before other processing. + + Args: + model_proto: the original model input + duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True. + + Return: + the new model after preprocessing + + It includes: + + - inference shapes + - optimize model by ONNX library + - give names to the nodes + - replace initializer with Constant node + - replace -1 batch size with 1 + - eliminate dropout and identity + - eliminate no children inputs + - topological sort + + The optimizations provided by ONNX: + + - eliminate_identity + - eliminate_nop_dropout + - eliminate_nop_transpose + - eliminate_nop_pad + - eliminate_unused_initializer + - eliminate_deadend + - fuse_consecutive_squeezes + - fuse_consecutive_transposes + - fuse_add_bias_into_conv + - fuse_transpose_into_gemm + - fuse_matmul_add_bias_into_gemm + - fuse_bn_into_conv + - fuse_pad_into_conv + + """ + logger.info("Preprocessing the model...") + helper.setup_current_opset_version(model_proto) + eliminating.eliminate_empty_value_infos(model_proto.graph) + other.add_name_to_node(model_proto.graph) + other.rename_all_node_name(model_proto.graph) + replacing.replace_initializer_with_Constant(model_proto.graph) + other.topological_sort(model_proto.graph) + m = other.polish_model(model_proto) + passes = ['extract_constant_to_initializer', + 'eliminate_nop_dropout', + 'eliminate_deadend', + 'fuse_matmul_add_bias_into_gemm', + 'fuse_pad_into_conv'] + if not disable_fuse_bn: + passes.append('fuse_bn_into_conv') + m = optimizer.optimize(m, passes) + g = m.graph + # Add name again since onnx optimizer higher than 1.7 may remove node names. + other.add_name_to_node(g) + if duplicate_shared_weights: + replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=True) + other.duplicate_param_shared_constant(g) + else: + replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=False) + other.topological_sort(g) + m = other.polish_model(m) + g = m.graph + eliminating.eliminate_consecutive_Cast(m.graph) + eliminating.eliminate_Cast_after_input(m.graph) + eliminating.eliminate_nop_pads(g) + eliminating.eliminate_nop_cast(g) + eliminating.eliminate_Identify_and_Dropout(g) + eliminating.eliminate_trivial_maxpool(g) + eliminating.eliminate_no_children_input(g) + other.format_value_info_shape(g) + other.topological_sort(g) + m = other.inference_shapes(m) + g = m.graph + replacing.replace_split_with_slices(g) + other.topological_sort(g) + + return m + + +def common_optimization(m): + """Common optimizations can be used in most cases. + + :param m: the original model input\\ + :return: the new model after preprocessing + + It includes: + + - transpose B in Gemm + - fuse BN into Gemm + - fuse consecutive Gemm + - replace AveragePool with GAP + - replace Squeeze/Unsqueeze with Reshape + - replace Reshape with Flatten + """ + logger.info("Doing nodes fusion and replacement... ") + m = other.polish_model(m) + g = m.graph + other.transpose_B_in_Gemm(g) + fusing.fuse_BN_into_Gemm(g) + fusing.fuse_BN_with_Reshape_into_Gemm(g) + fusing.fuse_Gemm_into_Gemm(g) + fusing.fuse_consecutive_reducemean(g) + fusing.fuse_slice_nodes_into_conv(g) + fusing.fuse_relu_min_into_clip(g) + other.duplicate_shared_Flatten(g) + replacing.replace_average_pool_with_GAP(g) + + m = other.polish_model(m) + g = m.graph + + replacing.replace_Squeeze_with_Reshape(g) + replacing.replace_Unsqueeze_with_Reshape(g) + replacing.replace_Reshape_with_Flatten(g) + replacing.replace_ReduceMean_with_GlobalAveragePool(g) + replacing.replace_Sum_with_Adds(g) + replacing.replace_constant_input_concat_with_pad(g) + other.topological_sort(g) + return m + + +def pytorch_constant_folding(m): + """Constant folding needed by Pytorch exported models. It should be done + before using onnx optimizers since the dynamic shape structure may affect + the optimizations. + + :param m: the original model input\\ + :return: the new model after preprocessing + """ + logger.info("Working on constant folding.") + replacing.replace_shape_with_constant(m.graph) + replacing.replace_ConstantOfShape_with_constant(m.graph) + + # constant_folding + m = other.inference_shapes(m) + while constant_folding.constant_folding(m.graph): + logging.debug("After constant folding jobs.") + other.topological_sort(m.graph) + while len(m.graph.value_info) != 0: + m.graph.value_info.pop() + + m = other.inference_shapes(m) + replacing.replace_shape_with_constant(m.graph) + other.topological_sort(m.graph) + m = torch_pattern_match(m) + m = optimizer.optimize(m, ['eliminate_deadend']) + return m + + +def tensorflow_optimization(m): + """Optimizations for tf models can be used in most cases. + + :param m: the original model input\\ + :return: the new model after preprocessing + + It includes: + + - eliminate shape change after input + - eliminate Reshape cast + - eliminate Squeeze before Reshape + - fuse Transpose into Constant + - replace Shape with Constant + """ + + fusing.fuse_Transpose_into_Constant(m.graph) + fusing.fuse_MatMul_and_Add_into_Gemm(m.graph) + other.topological_sort(m.graph) + + m = other.polish_model(m) + + # constant folding + replacing.replace_shape_with_constant(m.graph) + + # constant_folding + m = other.inference_shapes(m) + while constant_folding.constant_folding(m.graph): + logging.debug("After constant folding jobs.") + other.topological_sort(m.graph) + while len(m.graph.value_info) != 0: + m.graph.value_info.pop() + + m = other.inference_shapes(m) + replacing.replace_shape_with_constant(m.graph) + other.topological_sort(m.graph) + m = tf_pattern_match(m) + m = optimizer.optimize(m, ['eliminate_deadend']) + + eliminating.eliminate_consecutive_reshape(m.graph) + eliminating.eliminate_Squeeze_before_Reshape(m.graph) + other.topological_sort(m.graph) + return m + + +def postprocess(m): + """Inference the shape and prepare for export. + + :param m: the original model input\\ + :return: the new model after preprocessing + """ + logger.info("Postprocessing the model...") + while len(m.graph.value_info) > 0: + m.graph.value_info.pop() + m = other.polish_model(m) + eliminating.eliminate_single_input_Concat(m.graph) + eliminating.eliminate_nop_Maxpool_and_AveragePool(m.graph) + eliminating.eliminate_trivial_elementwise_calculation(m.graph) + m = other.polish_model(m) + + replacing.replace_depthwise_1x1_with_bn(m.graph) + m = other.polish_model(m) + + # removing transpose + m = removing_transpose.eliminate_transposes(m) + m = other.polish_model(m) + removing_transpose.remove_trivial_transpose(m.graph) + removing_transpose.fuse_Transpose_into_Gemm_weight(m.graph) + + # fuse some nodes + fusing.fuse_mul_and_add_into_bn(m.graph) + m = other.polish_model(m) + fusing.fuse_mul_and_add_into_gemm(m.graph) + m = other.polish_model(m) + fusing.fuse_conv_and_add_into_conv(m.graph) + m = other.polish_model(m) + replacing.replace_mul_to_bn(m.graph) + replacing.replace_div_to_bn(m.graph) + replacing.replace_add_to_bn(m.graph) + replacing.replace_sub_to_bn(m.graph) + replacing.replace_sub_with_bn_and_add(m.graph) + m = other.polish_model(m) + + other.add_output_to_value_info(m.graph) + m = optimizer.optimize(m, ['eliminate_deadend']) + m.producer_name = 'kneron_formatter' + return m diff --git a/tools/optimizer_scripts/tools/common_pattern.py b/tools/optimizer_scripts/tools/common_pattern.py new file mode 100644 index 0000000..b65d5bd --- /dev/null +++ b/tools/optimizer_scripts/tools/common_pattern.py @@ -0,0 +1,157 @@ +from collections import defaultdict +import numpy as np +import onnx.helper +import onnx.utils + +from . import modhelper +from . import helper +from . import other + +def torch_pattern_match(m): + # Create a map from optype to the nodes. + optype2node = defaultdict(list) + for node in m.graph.node: + optype2node[node.op_type].append(node) + for matmul_node in optype2node['MatMul']: + pattern_matmul_mul_add(m.graph, matmul_node) + for resize_node in optype2node['Resize']: + # torch nn.UpsamplingBilinear2d will be given us 4 input: "X, roi, scales, sizes" + if len(resize_node.input) != 4: + continue + make_UpsamplingBilinear2d_value_info(m.graph, resize_node.name) + m = onnx.shape_inference.infer_shapes(m) + polish_RESIZE_input_param_node(m.graph, resize_node.name) + m = other.polish_model(m) + return m + +def tf_pattern_match(m): + # Create a map from optype to the nodes. + optype2node = defaultdict(list) + for node in m.graph.node: + optype2node[node.op_type].append(node) + for matmul_node in optype2node['MatMul']: + pattern_matmul_mul_add(m.graph, matmul_node) + for resize_node in optype2node['Resize']: + # In tensorflow2onnx, ReizeXXX will be given us 4 input: "X, roi, scales, sizes" + # and node output name will be given the "node name + :0" + if len(resize_node.input) != 4: + continue + make_UpsamplingBilinear2d_value_info(m.graph, resize_node.name) + m = onnx.shape_inference.infer_shapes(m) + polish_RESIZE_input_param_node(m.graph, resize_node.name) + m = other.polish_model(m) + return m + +def pattern_matmul_mul_add(g, matmul_node): + # Check node match - Mul node + next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0]) + if len(next_nodes) != 1: + return + if next_nodes[0].op_type != 'Mul': + return + mul_node = next_nodes[0] + # Check node match - Add node + next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0]) + if len(next_nodes) != 1: + return + if next_nodes[0].op_type != 'Add': + return + add_node = next_nodes[0] + # Check Mul weight + mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1]) + if mul_weight_node.op_type != 'Constant': + return + weight_size, mul_weight = helper.constant_to_list(mul_weight_node) + for i in mul_weight: + if i != 1: + return + channel = weight_size[0] + # Check Add weight + add_weight_node = helper.find_node_by_output_name(g, add_node.input[1]) + if add_weight_node.op_type != 'Constant': + return + # Check MatMul weight to see if it need weight broadcast + matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1]) + matmul_weight = helper.constant_to_numpy(matmul_weight_node) + if matmul_weight.shape[1] == 1: + # Weight broadcast + new_matmul_weight = np.tile(matmul_weight, channel) + new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight) + g.node.remove(matmul_weight_node) + g.node.extend([new_matmul_weight_node]) + value = helper.find_value_by_name(g, matmul_weight_node.output[0]) + if value is not None: + g.value_info.remove(value) + # Remove Mul node + g.node.remove(mul_weight_node) + value = helper.find_value_by_name(g, mul_weight_node.output[0]) + if value is not None: + g.value_info.remove(value) + g.node.remove(mul_node) + value = helper.find_value_by_name(g, mul_node.output[0]) + if value is not None: + g.value_info.remove(value) + # Fuse Matmul and Add + gemm_node = onnx.helper.make_node( + 'Gemm', + [matmul_node.input[0], matmul_node.input[1], add_node.input[1]], + [add_node.output[0]], + name = matmul_node.name, + alpha = 1.0, + beta = 1.0, + transA = 0, + transB = 0 + ) + g.node.extend([gemm_node]) + # Clean up + g.node.remove(matmul_node) + g.node.remove(add_node) + value = helper.find_value_by_name(g, matmul_node.output[0]) + if value is not None: + g.value_info.remove(value) + other.topological_sort(g) + +def make_UpsamplingBilinear2d_value_info(g, resize_node_name): + resize_node = helper.find_node_by_node_name(g, resize_node_name) + + shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3]) + shape_data = helper.constant_to_numpy(shape_data_node).astype(int) + l_shape_data = list(shape_data) + if l_shape_data[0] == 0: + l_shape_data[0] = 1 + l_shape_data[0] + shape_data = np.array(l_shape_data) + + new_output_value_info = onnx.helper.make_tensor_value_info( + resize_node.output[0], + onnx.helper.TensorProto.FLOAT, + shape_data.tolist() + ) + + g.value_info.extend([new_output_value_info]) + +def polish_RESIZE_input_param_node(g, resize_node_name): + resize_node = helper.find_node_by_node_name(g, resize_node_name) + + shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3]) + shape_data = helper.constant_to_numpy(shape_data_node).astype(int) + + # handle 0 batch size which is invalid + if shape_data[0] == 0: + shape_data[0] = 1 + + pre_node_output_value_info = helper.find_value_by_name(g, resize_node.input[0]) + ori_shape = np.array([pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value, + pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value, + pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value, + pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value]) + + resize_node.input.remove(resize_node.input[3]) + + + resize_scales = np.array(shape_data/ori_shape).astype(float) + resize_scale_node = helper.list_to_constant('resize_scales_node_' + resize_node.name, resize_scales.shape, resize_scales, data_type=onnx.helper.TensorProto.FLOAT) + + resize_node.input[2] = resize_scale_node.name + g.node.extend([resize_scale_node]) + + other.topological_sort(g) diff --git a/tools/optimizer_scripts/tools/constant_folding.py b/tools/optimizer_scripts/tools/constant_folding.py new file mode 100644 index 0000000..8149628 --- /dev/null +++ b/tools/optimizer_scripts/tools/constant_folding.py @@ -0,0 +1,995 @@ +import onnx.utils +import onnx +import numpy as np +import logging +import traceback + +from . import helper +from .general_graph import Graph, Node +from .other import topological_sort +from .replacing import replace_shape_with_constant +from .helper import logger + +def are_all_inputs_Constant_with_one_child(g, node): + for input_name in node.input: + input_node = helper.find_node_by_output_name(g, input_name) + if input_node is None or input_node.op_type != 'Constant': + return False + relative_outputs = helper.find_nodes_by_input_name(g, input_name) + if len(relative_outputs) > 1: + return False + return True + + +def constant_folding(g): + """ Do constant folding until nothing more can be done. + + :param g: The onnx GraphProto\\ + :return: If any node is folded, return True. Otherwise, return False. + """ + keep_folding = True # Keep the while loop + folded = False # Return value + try: + # Before constant folding, duplicate the constant nodes. + duplicate_constant_node(g) + while keep_folding: + keep_folding = False + for node in g.node: + # Check if the node is foldable + if node.op_type not in constant_folding_nodes.keys(): + continue + # Check if the parents of the node are all single follower constant node. + if not are_all_inputs_Constant_with_one_child(g, node): + continue + # Constant folding for the specific node + if constant_folding_nodes[node.op_type](g, node): + logging.debug("Constant nodes and %s %s are folded.", + node.op_type, node.name) + folded = True + keep_folding = True + else: + logging.debug( + "Constant nodes and %s %s are skipped.", node.op_type, node.name) + except Exception as e: + logger.error("An exception is raised while constant folding.") + logger.error(traceback.format_exc()) + return folded + + + +def duplicate_constant_node(g): + """ Duplicate the constant node if its following nodes contain constant folding + nodes. Create and link the new constant nodes to the constant folding nodes. + """ + for node in g.node: + # Find a valid constant node + if node.op_type != 'Constant': + continue + output_val_info = helper.find_value_by_name(g, node.output[0]) + if output_val_info is None: + print("Cannot inference the shape of Const node output: " + + node.output[0]) + exit(1) + data_shape = helper.get_shape_from_value_info(output_val_info) + output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + + # For constant that has only one following node, no need to duplicate + if len(output_nodes) < 2: + continue + + # Check if its following nodes are foldable + foldable_output_nodes = list(filter(lambda n: n.op_type in + constant_folding_nodes.keys(), output_nodes)) + if not foldable_output_nodes: + continue + + # Duplicate the node needed by foldable nodes + for i in range(len(foldable_output_nodes)): + logging.debug("Found constant %s and %s %s are availble for folding. Duplicate constant.", + node.name, foldable_output_nodes[i].op_type, foldable_output_nodes[i].name) + output_name = node.output[0] + '_dup_' + str(i) + new_constant_node = onnx.helper.make_node( + 'Constant', + [], + [output_name], + name=output_name, + value=node.attribute[0].t + ) + new_val_info = onnx.helper.make_tensor_value_info( + output_name, + node.attribute[0].t.data_type, + data_shape + ) + input_ind = list(foldable_output_nodes[i].input).index( + node.output[0]) + foldable_output_nodes[i].input[input_ind] = output_name + + g.node.extend([new_constant_node]) + g.value_info.extend([new_val_info]) + + # If all following nodes are foldable node, delete the original node. + if len(foldable_output_nodes) == len(output_nodes): + g.node.remove(node) + g.value_info.remove(output_val_info) + + topological_sort(g) + + return + +def slice_constant_folding(g, node): + op_version = helper.get_current_opset_version() + # only support opset 9 & 11 + if op_version == 11: + return slice_constant_folding_Opset_11(g, node) + elif op_version == 9: + return slice_constant_folding_Opset_9(g, node) + +def slice_constant_folding_Opset_11(g, node): + """ Fold constant and slice nodes to a single constant node. + """ + pre_node = helper.find_node_by_output_name(g, node.input[0]) + pre_shape, data_list = helper.constant_to_list(pre_node) + + starts_node = helper.find_node_by_output_name(g, node.input[1]) + _, starts = helper.constant_to_list(starts_node) + + ends_node = helper.find_node_by_output_name(g, node.input[2]) + _, ends = helper.constant_to_list(ends_node) + + + axes_node = None if len(node.input) <= 3 else helper.find_node_by_output_name(g, node.input[3]) + if not axes_node: + axes = list(range(len(helper.get_shape(data_list)))) + else: + _, axes = helper.constant_to_list(axes_node) + + steps_node = None if len(node.input) <= 4 else helper.find_node_by_output_name(g, node.input[4]) + if not steps_node: + steps = [1]*len(helper.get_shape(data_list)) + else: + _, steps = helper.constant_to_list(steps_node) + + + data_list = list(map(int, data_list)) + starts = list(map(int, starts)) + ends = list(map(int, ends)) + axes = list(map(int, axes)) + steps = list(map(int, steps)) + + data_list = np.reshape(data_list, pre_shape) + + new_data = None + for idx, _ in enumerate(axes): + new_data = np.apply_along_axis( lambda x: x[starts[idx] : ends[idx] : steps[idx]], idx, data_list ) + + new_node = helper.list_to_constant(node.output[0], helper.get_shape( + new_data), helper.flatten_to_list(new_data)) + g.node.extend([new_node]) + value_info = helper.find_value_by_name(g, pre_node.output[0]) + if value_info is not None: + g.value_info.remove(value_info) + g.node.remove(node) + g.node.remove(pre_node) + + return True + +def slice_constant_folding_Opset_9(g, node): + """ Fold constant and slice nodes to a single constant node. + """ + pre_node = helper.find_node_by_output_name(g, node.input[0]) + pre_shape, data_list = helper.constant_to_list(pre_node) + + data_list = np.reshape(data_list, pre_shape) + axes = helper.get_attribute_by_name(node, 'axes') + ends = list(helper.get_attribute_by_name(node, 'ends').ints) + starts = list(helper.get_attribute_by_name(node, 'starts').ints) + + if not axes: + axes = list(range(len(helper.get_shape(data_list)))) + else: + axes = list(axes.ints) + + new_data = helper.slice_data(data_list, starts, ends, axes) + new_node = helper.list_to_constant(node.output[0], helper.get_shape( + new_data), helper.flatten_to_list(new_data)) + g.node.extend([new_node]) + value_info = helper.find_value_by_name(g, pre_node.output[0]) + if value_info is not None: + g.value_info.remove(value_info) + g.node.remove(node) + g.node.remove(pre_node) + + return True + +def cast_constant_folding(g, node): + """ Fold constant and cast node to a single constant node. + """ + pre_node = helper.find_node_by_output_name(g, node.input[0]) + shape, data = helper.constant_to_list(pre_node) + data_type = node.attribute[0].i + if data_type in (6, 7): + data = list(map(int, data)) + elif data_type == onnx.helper.TensorProto.FLOAT: + data = list(map(float, data)) + else: + raise RuntimeError('data type not supported') + + if shape == 1: + tensor = onnx.helper.make_tensor( + name=pre_node.attribute[0].name, + data_type=data_type, + dims=[], + vals=data + ) + else: + tensor = onnx.helper.make_tensor( + name=pre_node.attribute[0].name, + data_type=data_type, + dims=shape, + vals=helper.flatten_to_list(data) + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=tensor + ) + g.node.extend([new_node]) + + value_info = helper.find_value_by_name(g, pre_node.output[0]) + if value_info is not None: + g.value_info.remove(value_info) + value_info = helper.find_value_by_name(g, node.output[0]) + if value_info is not None: + g.value_info.remove(value_info) + g.node.remove(pre_node) + g.node.remove(node) + + return True + + +def reduceprod_constant_folding(g, node): + """ Fold constant and reduceprod nodes to a single constant node. + """ + pre_node = helper.find_node_by_output_name(g, node.input[0]) + shape, data_set = helper.constant_to_list(pre_node) + tensor = pre_node.attribute[0].t + + data_set = np.reshape(data_set, shape) + for att in node.attribute: + if att.name == 'axes': + axes = list(att.ints) + else: + keepdims = int(att.i) + + new_data = np.prod(data_set, axis=tuple(axes), keepdims=keepdims == 1) + new_shape = helper.get_shape(new_data) + new_flat_data = helper.flatten_to_list(new_data) + new_tensor = onnx.helper.make_tensor( + name=node.output[0], + data_type=tensor.data_type, + dims=new_shape, + vals=new_flat_data + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + g.node.extend([new_node]) + value_info = None + for item in g.value_info: + if item.name == pre_node.output[0]: + value_info = item + if value_info is not None: + g.value_info.remove(value_info) + g.node.remove(pre_node) + g.node.remove(node) + + return True + + +def reshape_constant_input_folding(g, node): + """ Fold constant and reshape nodes to a single constant node. + """ + pre_data_node = helper.find_node_by_output_name(g, node.input[0]) + pre_shape_node = helper.find_node_by_output_name(g, node.input[1]) + + data = helper.constant_to_numpy(pre_data_node) + _, shape = helper.constant_to_list(pre_shape_node) + new_data = np.reshape(data, shape) + + new_tensor = onnx.helper.make_tensor( + name=node.output[0], + data_type=pre_data_node.attribute[0].t.data_type, + dims=new_data.shape, + vals=helper.flatten_to_list(new_data) + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + g.node.extend([new_node]) + + data_val_info = helper.find_value_by_name(g, pre_data_node.output[0]) + shape_val_info = helper.find_value_by_name(g, pre_shape_node.output[0]) + + g.value_info.remove(data_val_info) + g.value_info.remove(shape_val_info) + + g.node.remove(node) + g.node.remove(pre_data_node) + g.node.remove(pre_shape_node) + + return True + + +def concat_constant_folding(g, node): + """ Fold constant and concat nodes to a single constant node. + """ + node_to_del = [] + valid_inputs = True + for input_name in node.input: + input_node = helper.find_node_by_output_name(g, input_name) + input_node_output = helper.find_nodes_by_input_name(g, input_name) + if len(input_node_output) > 1: + valid_inputs = False + break + if input_node.op_type != 'Constant': + valid_inputs = False + break + + if not valid_inputs: + return False + + input_data = [] + input_shapes = [] + for input_name in node.input: + input_node = helper.find_node_by_output_name(g, input_name) + s, d = helper.constant_to_list(input_node) + d = np.reshape(d, s) + input_data.append(d) + input_shapes.append(s) + node_to_del.append(input_node) + + concat_data = np.concatenate(input_data, axis=node.attribute[0].i) + node_data_type = input_node.attribute[0].t.data_type + if concat_data.dtype in [np.int32, np.int64]: + node_data_type = onnx.helper.TensorProto.INT64 + elif concat_data.dtype in [np.float32, np.float64]: + node_data_type = onnx.helper.TensorProto.FLOAT + + new_node = helper.list_to_constant( + node.output[0], + helper.get_shape(concat_data), + helper.flatten_to_list(concat_data), + data_type=node_data_type + ) + g.node.extend([new_node]) + node_to_del.append(node) + + for input_name in node.input: + val_info = helper.find_value_by_name(g, input_name) + if val_info: + g.value_info.remove(val_info) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def transpose_constant_folding(g, node): + """Fold constant and transpose nodes to a single constant node. + """ + node_to_del = [] + pre_node = helper.find_node_by_output_name(g, node.input[0]) + shape, data = helper.constant_to_list(pre_node) + np_data = np.reshape(data, shape) + permutation = list(node.attribute[0].ints) + + new_data = np.transpose(np_data, permutation) + new_shape = new_data.shape + new_node = helper.list_to_constant( + node.output[0], + new_shape, + new_data.flatten().tolist(), + data_type=pre_node.attribute[0].t.data_type + ) + + g.node.extend([new_node]) + node_to_del.extend([node, pre_node]) + + pre_val_info = helper.find_value_by_name(g, node.input[0]) + g.value_info.remove(pre_val_info) + + next_val_info = helper.find_value_by_name(g, node.output[0]) + g.value_info.remove(next_val_info) + + new_val_info = onnx.helper.make_tensor_value_info( + node.output[0], + pre_node.attribute[0].t.data_type, + new_shape + ) + g.value_info.extend([new_val_info]) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + folded = True + + return folded + + +def unsqueeze_constant_folding(g, node): + """Fold constant and unsqueeze nodes to a single constant node. + """ + node_to_del = [] + pre_node = helper.find_node_by_output_name(g, node.input[0]) + shape, data = helper.constant_to_list(pre_node) + if type(shape) == int: + np_data = data[0] + else: + np_data = np.reshape(data, shape) + axes = list(node.attribute[0].ints) + axes.sort() + + for dim in axes: + np_data = np.expand_dims(np_data, axis=dim) + new_shape = np_data.shape + new_node = helper.list_to_constant( + node.output[0], + new_shape, + np_data.flatten().tolist(), + data_type=pre_node.attribute[0].t.data_type + ) + g.node.extend([new_node]) + node_to_del.extend([node, pre_node]) + + pre_val_info = helper.find_value_by_name(g, node.input[0]) + next_val_info = helper.find_value_by_name(g, node.output[0]) + if pre_val_info is not None: + g.value_info.remove(pre_val_info) + else: + print(node.name) + if next_val_info is not None: + g.value_info.remove(next_val_info) + + new_val_info = onnx.helper.make_tensor_value_info( + node.output[0], + pre_node.attribute[0].t.data_type, + new_shape + ) + g.value_info.extend([new_val_info]) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def gather_constant_folding(g, node): + """Fold constant and gather nodes to a single constant node. + """ + node_to_del = [] + + pre_data_node = helper.find_node_by_output_name(g, node.input[0]) + pre_indices_node = helper.find_node_by_output_name(g, node.input[1]) + + shape, data = helper.constant_to_list(pre_data_node) + indice_shape, indices = helper.constant_to_list(pre_indices_node) + if type(indice_shape) == int: + indices = indices[0] + + np_data = np.reshape(data, shape) + if len(node.attribute) < 1: + axis = 0 + else: + axis = node.attribute[0].i + + new_data = np.take(np_data, indices, axis=axis) + new_shape = new_data.shape + new_node = helper.list_to_constant( + node.output[0], + new_shape, + new_data.flatten().tolist(), + data_type=pre_data_node.attribute[0].t.data_type + ) + + node_to_del.extend([node, pre_data_node, pre_indices_node]) + g.node.extend([new_node]) + + val_info_1 = helper.find_value_by_name(g, node.input[0]) + val_info_2 = helper.find_value_by_name(g, node.input[1]) + val_info_3 = helper.find_value_by_name(g, node.output[0]) + new_val_info = onnx.helper.make_tensor_value_info( + new_node.output[0], + pre_data_node.attribute[0].t.data_type, + new_shape + ) + + if val_info_1 is not None: + g.value_info.remove(val_info_1) + if val_info_2 is not None: + g.value_info.remove(val_info_2) + if val_info_3 is not None: + g.value_info.remove(val_info_3) + g.value_info.extend([new_val_info]) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def add_constant_folding(g, node): + """Fold constant and add nodes to a single constant node. + """ + node_to_del = [] + pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) + pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) + if not pre_node_1 or not pre_node_2: + return False + + shape1, data1 = helper.constant_to_list(pre_node_1) + shape2, data2 = helper.constant_to_list(pre_node_2) + np_data1 = np.reshape(data1, shape1) + np_data2 = np.reshape(data2, shape2) + try: + new_data = np.add(np_data1, np_data2) + except: + raise RuntimeError('can\'t broadcast and add two data sets') + + new_node = helper.list_to_constant( + node.output[0], + new_data.shape, + new_data.flatten().tolist(), + data_type=pre_node_1.attribute[0].t.data_type + ) + + g.node.extend([new_node]) + node_to_del.extend([node, pre_node_1, pre_node_2]) + g.value_info.remove(helper.find_value_by_name(g, pre_node_1.output[0])) + g.value_info.remove(helper.find_value_by_name(g, pre_node_2.output[0])) + folded = True + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return folded + + +def sqrt_constant_folding(g, node): + """ Fold constant and sqrt nodes to a single node. + """ + node_to_del = [] + pre_node = helper.find_node_by_output_name(g, node.input[0]) + shape, data = helper.constant_to_list(pre_node) + np_data = np.sqrt(np.reshape(data, shape)) + output_val_info = helper.find_value_by_name(g, node.output[0]) + input_val_info = helper.find_value_by_name(g, node.input[0]) + data_type = output_val_info.type.tensor_type.elem_type + + new_tensor = onnx.helper.make_tensor( + name=node.output[0]+'_data', + data_type=data_type, + dims=shape, + vals=np_data.flatten().tolist() + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + g.value_info.remove(input_val_info) + node_to_del.extend([pre_node, node]) + g.node.extend([new_node]) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def reciprocal_constant_folding(g, node): + """ Fold constant and reciprocal nodes to a single constant node. + """ + node_to_del = [] + + pre_node = helper.find_node_by_output_name(g, node.input[0]) + shape, data = helper.constant_to_list(pre_node) + data = list(map(lambda x: x if abs(x) > 1.e-8 else 1.e-8, data)) + np_data = np.reshape(data, shape) + np_data = np.reciprocal(np_data) + + input_val_info = helper.find_value_by_name(g, node.input[0]) + output_val_info = helper.find_value_by_name(g, node.output[0]) + data_type = output_val_info.type.tensor_type.elem_type + + new_tensor = onnx.helper.make_tensor( + name=node.output[0]+'_data', + data_type=data_type, + dims=shape, + vals=np_data.flatten().tolist() + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + node_to_del.extend([node, pre_node]) + g.node.extend([new_node]) + + g.value_info.remove(input_val_info) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def mul_constant_folding(g, node): + """ Fold constant and mul nodes to a single constant node. + """ + node_to_del = [] + pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) + pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) + + pre_value_info1 = helper.find_value_by_name(g, node.input[0]) + pre_value_info2 = helper.find_value_by_name(g, node.input[1]) + if pre_value_info1 is None or pre_value_info2 is None: + return False + + shape1, data1 = helper.constant_to_list(pre_node_1) + shape2, data2 = helper.constant_to_list(pre_node_2) + np_data1 = np.reshape(data1, shape1) + np_data2 = np.reshape(data2, shape2) + + try: + new_data = np.multiply(np_data1, np_data2) + except: + raise RuntimeError('can not broadcast and multiply two data sets') + + # Special shape for single element. + if shape1 == 1 and shape2 == 1: + new_shape = [] + else: + new_shape = new_data.shape + + new_tensor = onnx.helper.make_tensor( + name=node.output[0]+'_data', + data_type=pre_node_1.attribute[0].t.data_type, + dims=new_shape, + vals=new_data.flatten().tolist() + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + node_to_del.extend([node, pre_node_1, pre_node_2]) + g.node.extend([new_node]) + + g.value_info.remove(pre_value_info1) + g.value_info.remove(pre_value_info2) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def div_constant_folding(g, node): + """ Fold constant and mul nodes to a single constant node. + """ + node_to_del = [] + pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) + pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) + + pre_value_info1 = helper.find_value_by_name(g, node.input[0]) + pre_value_info2 = helper.find_value_by_name(g, node.input[1]) + if pre_value_info1 is None or pre_value_info2 is None: + return False + + shape1, data1 = helper.constant_to_list(pre_node_1) + shape2, data2 = helper.constant_to_list(pre_node_2) + np_data1 = np.reshape(data1, shape1) + np_data2 = np.reshape(data2, shape2) + + try: + new_data = np.divide(np_data1, np_data2) + except: + raise RuntimeError('can not broadcast and multiply two data sets') + + # Special shape for single element. + if shape1 == 1 and shape2 == 1: + new_shape = [] + else: + new_shape = new_data.shape + + # Check data type if it is int + if pre_node_1.attribute[0].t.data_type == 7: + new_data = new_data.astype('int64') + + new_tensor = onnx.helper.make_tensor( + name=node.output[0]+'_data', + data_type=pre_node_1.attribute[0].t.data_type, + dims=new_shape, + vals=new_data.flatten().tolist() + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + node_to_del.extend([node, pre_node_1, pre_node_2]) + g.node.extend([new_node]) + + g.value_info.remove(pre_value_info1) + g.value_info.remove(pre_value_info2) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def sub_constant_folding(g, node): + """ Fold constant and sub nodes to a single node. + """ + node_to_del = [] + pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) + pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) + pre_val_info_1 = helper.find_value_by_name(g, node.input[0]) + pre_val_info_2 = helper.find_value_by_name(g, node.input[1]) + + shape1, data1 = helper.constant_to_list(pre_node_1) + shape2, data2 = helper.constant_to_list(pre_node_2) + + new_data = np.subtract(data1, data2) + # Special shape for single element. + if shape1 == 1 and shape2 == 1: + new_shape = [] + else: + new_shape = new_data.shape + + new_tensor = onnx.helper.make_tensor( + name=node.output[0]+'_data', + data_type=pre_node_1.attribute[0].t.data_type, + dims=new_shape, + vals=helper.flatten_to_list(new_data) + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + g.node.extend([new_node]) + node_to_del.extend([node, pre_node_1, pre_node_2]) + + g.value_info.remove(pre_val_info_1) + g.value_info.remove(pre_val_info_2) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def neg_constant_folding(g, node): + node_to_del = [] + pre_node = helper.find_node_by_output_name(g, node.input[0]) + + shape, data_list = helper.constant_to_list(pre_node) + new_data_list = [-num for num in data_list] + + new_tensor = onnx.helper.make_tensor( + name=pre_node.name+'_neg_tensor', + data_type=pre_node.attribute[0].t.data_type, + dims=shape, + vals=new_data_list + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + g.node.extend([new_node]) + node_to_del.extend([pre_node, node]) + g.value_info.remove(helper.find_value_by_name(g, node.input[0])) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + return True + + +def floor_constant_folding(g, node): + node_to_del = [] + pre_node = helper.find_node_by_output_name(g, node.input[0]) + + shape, data = helper.constant_to_list(pre_node) + new_data = np.floor(data).flatten().tolist() + + if shape == 1: + new_shape = [] + else: + new_shape = shape + + new_tensor = onnx.helper.make_tensor( + name=node.output[0]+'_data', + data_type=pre_node.attribute[0].t.data_type, + dims=new_shape, + vals=helper.flatten_to_list(new_data) + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + g.node.extend([new_node]) + node_to_del.extend([pre_node, node]) + old_value = helper.find_value_by_name(g, node.input[0]) + if old_value is not None: + g.value_info.remove(old_value) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + return True + + +def bn_constant_folding(g, node): + """ Fold constant and mul nodes to a single constant node. + """ + # Prepare data + node_to_del = [] + input_node = helper.find_node_by_output_name(g, node.input[0]) + scale_node = helper.find_node_by_output_name(g, node.input[1]) + bias_node = helper.find_node_by_output_name(g, node.input[2]) + mean_node = helper.find_node_by_output_name(g, node.input[3]) + var_node = helper.find_node_by_output_name(g, node.input[4]) + + input_value_info = [] + for i in range(5): + input_value_info.append(helper.find_value_by_name(g, node.input[i])) + + if input_value_info[0] is None: + return False + + input_data = helper.constant_to_numpy(input_node) + scale_data = helper.constant_to_numpy(scale_node) + bias_data = helper.constant_to_numpy(bias_node) + mean_data = helper.constant_to_numpy(mean_node) + var_data = helper.constant_to_numpy(var_node) + + epsilon = helper.get_var_attribute_by_name(node, 'epsilon', 'float') + if epsilon is None: + epsilon = 0.00001 + + # Calculate new node + new_data = scale_data * (input_data - mean_data) / np.sqrt(var_data + epsilon) + bias_data + + new_node = helper.numpy_to_constant(node.output[0], new_data) + + # Reconnect the graph + node_to_del.extend([node, input_node, scale_node, bias_node, mean_node, var_node]) + g.node.extend([new_node]) + + for value in input_value_info: + if value is not None: + g.value_info.remove(value) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def DequantizeLinear_constant_folding(g, node): + """ Fold constant and mul nodes to a single constant node. + """ + # Prepare data + node_to_del = [] + x_node = helper.find_node_by_output_name(g, node.input[0]) + x_scale_node = helper.find_node_by_output_name(g, node.input[1]) + if len(node.input) > 2: + x_zero_point_node = helper.find_node_by_output_name(g, node.input[2]) + else: + x_zero_point_node = None + + input_value_info = [] + for i in range(len(node.input)): + input_value_info.append(helper.find_value_by_name(g, node.input[i])) + + if input_value_info[0] is None: + return False + + x_data = helper.constant_to_numpy(x_node) + x_scale_data = helper.constant_to_numpy(x_scale_node) + if x_zero_point_node is not None: + x_zero_point_data = helper.constant_to_numpy(x_zero_point_node) + else: + x_zero_point_data = np.array([0.0]) + + # Calculate new node + new_data = (x_data.astype(np.float32) - x_zero_point_data.astype(np.float32)) * x_scale_data + + new_node = helper.numpy_to_constant(node.output[0], new_data) + + # Reconnect the graph + node_to_del.extend([node, x_node, x_scale_node]) + if x_zero_point_node is not None: + node_to_del.append(x_zero_point_node) + g.node.extend([new_node]) + + for value in input_value_info: + if value is not None: + g.value_info.remove(value) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +# Available constant folding names to function map. +constant_folding_nodes = { + 'Add': add_constant_folding, + 'BatchNormalization': bn_constant_folding, + 'Cast': cast_constant_folding, + 'Concat': concat_constant_folding, + 'DequantizeLinear': DequantizeLinear_constant_folding, + 'Div': div_constant_folding, + 'Floor': floor_constant_folding, + 'Gather': gather_constant_folding, + 'Mul': mul_constant_folding, + 'Reciprocal': reciprocal_constant_folding, + 'ReduceProd': reduceprod_constant_folding, + 'Reshape': reshape_constant_input_folding, + 'Slice': slice_constant_folding, + 'Sqrt': sqrt_constant_folding, + 'Transpose': transpose_constant_folding, + 'Unsqueeze': unsqueeze_constant_folding, + 'Sub': sub_constant_folding, + 'Neg': neg_constant_folding +} diff --git a/tools/optimizer_scripts/tools/eliminating.py b/tools/optimizer_scripts/tools/eliminating.py new file mode 100644 index 0000000..bc22b2e --- /dev/null +++ b/tools/optimizer_scripts/tools/eliminating.py @@ -0,0 +1,669 @@ +import collections +import struct +import onnx +import numpy as np +from . import other +from . import helper +from . import modhelper +from .general_graph import Graph + +def eliminate_Identify_and_Dropout(g): + """ + Eliminate Identify layers + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Identity' and node.op_type != 'Dropout': + continue + # If this node is the last node, leave it to `eliminate_useless_last node` + if helper.find_output_by_name(g, node.output[0]) is not None: + continue + # Replace the parents in all the following nodes + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + # Delete value info + value_between = helper.find_value_by_name(g, node.output[0]) + try: + g.value_info.remove(value_between) + except: + print("No value info to delete while eliminating identity layers.") + # Node is waiting for elimination + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) + +# Remove last useless nodes +def remove_useless_last_nodes(g): + """Remove useless nodes from the tail of the graph + """ + USELESS = ["Reshape", "Identity", "Transpose", "Flatten", "Dropout", "Mystery", "Constant", "Squeeze", "Unsqueeze", 'Softmax'] + graph = Graph(g) + todo = collections.deque() + for node in graph.output_nodes: + if len(node.children) == 0: + todo.append(node) + node_to_remove = [] + while todo: + # BFS find nodes to remove + cur_node = todo.popleft() + if cur_node.proto is None: + continue + if cur_node.proto.op_type not in USELESS: + continue + # Find the output + cur_node_output = helper.find_output_by_name(g, cur_node.proto.output[0]) + for cur_input in cur_node.parents: + cur_input.children.remove(cur_node) + if len(cur_input.children) == 0: + todo.append(cur_input) + if cur_node_output is not None: + cur_input_output = helper.find_value_by_name(g, cur_input.proto.output[0]) + cur_input_output_in_output = helper.find_output_by_name(g, cur_input.proto.output[0]) + if cur_input_output is not None and cur_input_output_in_output is None: + g.output.extend([cur_input_output]) + node_to_remove.append(cur_node.proto) + try: + g.value_info.remove(helper.find_value_by_name(g, cur_node.proto.output[0])) + except ValueError: + pass + if cur_node_output is not None: + g.output.remove(cur_node_output) + cur_node.proto = None + cur_node.parents.clear() + for node in node_to_remove: + g.node.remove(node) + +###################################### +# TF only optimization passes # +###################################### + +def eliminate_shape_changing_after_input(g): + """ + Eliminate the Reshape node after input and reshape the input + + :param g: the onnx graph + """ + node_to_remove = [] + REMOVE_LIST = ["Reshape", "Transpose", "Flatten", "Dropout", "Squeeze", "Unsqueeze"] + for node in g.node: + # Find an input and the shape node + if node.op_type not in REMOVE_LIST: + continue + old_input = helper.find_input_by_name(g, node.input[0]) + if old_input is None: + continue + # If the input is used by multiple nodes, skip. + counter = 0 + for tnode in g.node: + if old_input.name in tnode.input: + counter += 1 + if counter > 1: + continue + # Remove Weight if any. + output_val_info = helper.find_value_by_name(g, node.output[0]) + + if node.op_type == 'Reshape': + shape_node = helper.find_node_by_output_name(g, node.input[1]) + if shape_node.op_type != 'Constant': + continue + + # manuelly set the input shape + shape_info = helper.find_value_by_name(g, shape_node.output[0]) + old_size, old_shape = helper.find_size_shape_from_value(shape_info) + + _, new_shape = helper.constant_to_list(shape_node) + for i in range(len(new_shape)): + if new_shape[i] == -1: + dim = int(old_size//np.prod(new_shape)*(-1)) + new_shape[i] = dim + new_input = onnx.helper.make_tensor_value_info( + output_val_info.name, + output_val_info.type.tensor_type.elem_type, + new_shape + ) + + node_to_remove.append(node) + + shape_outputs = helper.find_nodes_by_input_name(g, shape_node.output[0]) + if len(shape_outputs) == 1: + node_to_remove.append(shape_node) + g.value_info.remove(helper.find_value_by_name(g, shape_node.output[0])) + + g.input.remove(old_input) + g.input.extend([new_input]) + g.value_info.remove(output_val_info) + elif node.op_type == 'Transpose': + permutation = list(node.attribute[0].ints) + pre_shape = helper.get_shape_from_value_info(old_input) + new_shape = [pre_shape[i] for i in permutation] + + new_input = onnx.helper.make_tensor_value_info( + output_val_info.name, + output_val_info.type.tensor_type.elem_type, + new_shape + ) + + node_to_remove.append(node) + + g.input.remove(old_input) + g.input.extend([new_input]) + g.value_info.remove(output_val_info) + elif node.op_type == 'Flatten': + axis = node.attribute[0].int + pre_shape = helper.get_shape_from_value_info(old_input) + dim_1, dim_2 = 1, 1 + if axis == 0: + dim_1 = 1 + dim_2 = np.prod(pre_shape) + else: + dim_1 = np.prod(pre_shape[:axis]).astype(int) + dim_2 = np.prod(pre_shape[axis:]).astype(int) + new_shape = [dim_1, dim_2] + + new_input = onnx.helper.make_tensor_value_info( + output_val_info.name, + output_val_info.type.tensor_type.elem_type, + new_shape + ) + + node_to_remove.append(node) + + g.input.remove(old_input) + g.input.extend([new_input]) + g.value_info.remove(output_val_info) + elif node.op_type == 'Dropout': + g.input.remove(old_input) + g.input.extend([output_val_info]) + g.value_info.remove(output_val_info) + + node_to_remove.append(node) + elif node.op_type == 'Squeeze': + axis = list(node.attribute[0].ints) + pre_shape = helper.get_shape_from_value_info(old_input) + for pos in sorted(axis)[::-1]: + if pre_shape[pos] != 1: + raise RuntimeError('invalid axis for squeeze') + else: + pre_shape.pop(pos) + new_shape = pre_shape + + new_input = onnx.helper.make_tensor_value_info( + output_val_info.name, + output_val_info.type.tensor_type.elem_type, + new_shape + ) + + node_to_remove.append(node) + + g.input.remove(old_input) + g.input.extend([new_input]) + g.value_info.remove(output_val_info) + elif node.op_type == 'Unsqueeze': + axis = list(node.attribute[0].ints) + pre_shape = helper.get_shape_from_value_info(old_input) + new_shape = pre_shape + for pos in axis: + new_shape.insert(pos, 1) + new_input = onnx.helper.make_tensor_value_info( + output_val_info.name, + output_val_info.type.tensor_type.elem_type, + new_shape + ) + node_to_remove.append(node) + + g.input.remove(old_input) + g.input.extend([new_input]) + g.value_info.remove(output_val_info) + else: + pass + + for node in node_to_remove: + g.node.remove(node) + + other.topological_sort(g) + + +def eliminate_Reshape_Cast(g): + """Eliminate the cast layer for shape of Reshape layer + + :param g: the onnx graph + """ + #Find all reshape layers + node_to_remove = [] + for node in g.node: + if node.op_type != 'Reshape': + continue + prev_node = helper.find_node_by_output_name(g, node.input[1]) + if prev_node.op_type != 'Cast': + continue + # Now we find the cast weight pattern. Cast the weight, delete the cast. + reshape_node = node + cast_node = prev_node + weight_node = helper.find_node_by_output_name(g, cast_node.input[0]) + if weight_node is None: + raise RuntimeError("Unexpected None before Cast-Reshape.") + weight_node.attribute[0].t.data_type = 7 + if weight_node.attribute[0].t.raw_data: + raw_data = weight_node.attribute[0].t.raw_data + int_data = [i[0] for i in struct.iter_unpack('i', raw_data)] + raw_data = struct.pack('q' * len(int_data), *int_data) + elif len(weight_node.attribute[0].t.int64_data) > 0\ + or len(weight_node.attribute[0].t.int32_data) > 0: + # It's already int. Do nothing + pass + else: + raise NotImplementedError() + # Change Value info + origin_weight_out = helper.find_value_by_name(g, weight_node.output[0]) + weight_node.output.pop() + weight_node.output.extend([reshape_node.input[1]]) + # Delete + g.value_info.remove(origin_weight_out) + g.node.remove(cast_node) + +def eliminate_Cast_after_input(g): + """Eliminate the cast layer right after the input + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Cast': + continue + old_input = helper.find_input_by_name(g, node.input[0]) + if old_input is None: + continue + next_val_info = helper.find_value_by_name(g, node.output[0]) + shape = helper.get_shape_from_value_info(next_val_info) + new_val_info = onnx.helper.make_tensor_value_info( + next_val_info.name, + node.attribute[0].i, + shape + ) + # Delete old value_info + g.input.remove(old_input) + g.value_info.remove(next_val_info) + # Append nodes to node_to_remove + node_to_remove.append(node) + # Add new input + g.input.extend([new_val_info]) + for node in node_to_remove: + g.node.remove(node) + +def eliminate_consecutive_Cast(g): + """If two cast is next to each other, remove the first cast + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Cast': + continue + first_node = helper.find_node_by_output_name(g, node.input[0]) + if first_node is None or first_node.op_type != 'Cast': + continue + # Here we have two consecutive Cast Node + # Reset the input of the later node + node.input[0] = first_node.input[0] + # Remove the first node and its output value info + node_to_remove.append(first_node) + first_output = helper.find_value_by_name(g, first_node.output[0]) + g.value_info.remove(first_output) + for node in node_to_remove: + g.node.remove(node) + +def eliminate_Squeeze_before_Reshape(g): + """If Squeeze and Reshape is next to each other, remove the first node + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Reshape': + continue + first_node = helper.find_node_by_output_name(g, node.input[0]) + if not first_node: + continue + if first_node.op_type != 'Squeeze': + continue + # Here we have two consecutive Cast Node + # Reset the input of the later node + node.input[0] = first_node.input[0] + # Remove the first node and its output value info + node_to_remove.append(first_node) + first_output = helper.find_value_by_name(g, first_node.output[0]) + g.value_info.remove(first_output) + for node in node_to_remove: + g.node.remove(node) + +def eliminate_no_children_input(g): + """Eliminate inputs with no children at all. + """ + # Create a set of input names + input_names = set([i.name for i in g.input]) + # If a name is used in any node, remove this name from the set. + for n in g.node: + for i in n.input: + input_names.discard(i) + # Remove the inputs with the left names. + for i in input_names: + info = helper.find_input_by_name(g, i) + g.input.remove(info) + +def eliminate_consecutive_reshape(g): + """Replace consecutive reshape nodes by a single node. + """ + node_to_del = [] + for node in g.node: + if node.op_type != 'Reshape': + continue + pre_data_node = helper.find_node_by_output_name(g, node.input[0]) + pre_shape_node = helper.find_node_by_output_name(g, node.input[1]) + if not pre_data_node or not pre_shape_node: + continue + if pre_shape_node.op_type != 'Constant': + continue + if pre_data_node.op_type != 'Reshape': + continue + + pre_pre_shape_node = helper.find_node_by_output_name(g, pre_data_node.input[1]) + if pre_pre_shape_node.op_type != 'Constant': + continue + + new_reshape_node = onnx.helper.make_node( + 'Reshape', + [pre_data_node.input[0], node.input[1]], + [node.output[0]], + name = node.output[0] + ) + + g.node.extend([new_reshape_node]) + node_to_del.append(node) + node_to_del.append(pre_data_node) + node_to_del.append(pre_pre_shape_node) + + val_info_to_del1 = helper.find_value_by_name(g, node.input[0]) + val_info_to_del2 = helper.find_value_by_name(g, pre_data_node.input[1]) + g.value_info.remove(val_info_to_del1) + g.value_info.remove(val_info_to_del2) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + +def eliminate_single_input_Concat(g): + """ + Eliminate single input Concat layers + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Concat': + continue + # If this node has more than 1 input, continue. + if len(node.input) > 1: + continue + # If this node is the output node, set its previous node as output nodes. + if helper.find_output_by_name(g, node.output[0]) is not None: + todel_output = helper.find_output_by_name(g, node.output[0]) + the_input_value = helper.find_value_by_name(g, node.input[0]) + g.output.remove(todel_output) + g.output.extend([the_input_value]) + node_to_remove.append(node) + continue + # Replace the parents in all the following nodes + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + # Delete value info + value_between = helper.find_value_by_name(g, node.output[0]) + try: + g.value_info.remove(value_between) + except: + print("No value info to delete while eliminating identity layers.") + # Node is waiting for elimination + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) + +def eliminate_nop_Maxpool_and_AveragePool(g): + """ + Eliminate do nothing MaxPool and AveragePool layers. + Those layers have valid padding, 1x1 kernel and [1,1] strides. + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'MaxPool' and node.op_type != 'AveragePool': + continue + # If this node is actually working, continue. + kernel = helper.get_list_attribute_by_name(node, "kernel_shape", "int") + pads = helper.get_list_attribute_by_name(node, "pads", "int") + strides = helper.get_list_attribute_by_name(node, "strides", "int") + if kernel != [1, 1] or pads != [0, 0, 0, 0] or strides != [1, 1]: + continue + # If this node is the output node, set its previous node as output nodes. + if helper.find_output_by_name(g, node.output[0]) is not None: + todel_output = helper.find_output_by_name(g, node.output[0]) + the_input_value = helper.find_value_by_name(g, node.input[0]) + g.output.remove(todel_output) + g.output.extend([the_input_value]) + node_to_remove.append(node) + continue + # Replace the parents in all the following nodes + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + # Delete value info + value_between = helper.find_value_by_name(g, node.output[0]) + try: + g.value_info.remove(value_between) + except: + print("No value info to delete while eliminating identity layers.") + # Node is waiting for elimination + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) + + +def eliminate_trivial_maxpool(g): + node_to_del = [] + for node in g.node: + if node.op_type != 'MaxPool': + continue + pads = None + strides = None + dilation = None + kernel_shape = None + for att in node.attribute: + if att.name == 'pads': + pads = list(att.ints) + elif att.name == 'strides': + strides = list(att.ints) + elif att.name == 'kernel_shape': + kernel_shape = list(att.ints) + elif att.name == 'dilation': + dilation = list(att.ints) + else: + pass + if pads and any([pad != 0 for pad in pads]): + continue + if strides and any([stride != 1 for stride in strides]): + continue + if dilation and any([dila != 1 for dila in dilation]): + continue + if any([dim != 1 for dim in kernel_shape]): + continue + + node_to_del.append(node) + + next_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + + if next_nodes[0] == None: + output_value = helper.find_output_by_name(g, node.output[0]) + if not output_value: + continue + else: + pre_val_info = helper.find_value_by_name(g, node.input[0]) + g.output.extend([pre_val_info]) + g.output.remove(output_value) + + for next_node in next_nodes: + modhelper.replace_node_input(next_node, node.output[0], node.input[0]) + + next_val_info = helper.find_value_by_name(g, node.output[0]) + g.value_info.remove(next_val_info) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + other.topological_sort(g) + +def eliminate_empty_value_infos(g): + to_remove = [] + for value_info in g.value_info: + if len(value_info.type.tensor_type.shape.dim) == 0: + to_remove.append(value_info) + for value_info in to_remove: + g.value_info.remove(value_info) + +def eliminate_nop_pads(g): + node_to_remove = [] + for node in g.node: + if node.op_type != 'Pad': + continue + # Check if the Pad is empty or not + pads_node = helper.find_node_by_output_name(g, node.input[1]) + pads_np = helper.constant_to_numpy(pads_node) + all_zero = True + for value in pads_np: + if value != 0: + all_zero = False + if not all_zero: + continue + # Check if it has the constant_value_node + constant_value_node = None + if len(node.input) > 2: + constant_value_node = helper.find_node_by_output_name(g, node.input[2]) + # If this node is the output node, set its previous node as output nodes. + if helper.find_output_by_name(g, node.output[0]) is not None: + todel_output = helper.find_output_by_name(g, node.output[0]) + g.output.remove(todel_output) + if helper.find_output_by_name(g, node.input[0]) is None: + the_input_value = helper.find_value_by_name(g, node.input[0]) + if the_input_value is not None: + g.output.extend([the_input_value]) + # Replace the parents in all the following nodes + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + # Delete value info + value_between = helper.find_value_by_name(g, node.output[0]) + try: + g.value_info.remove(value_between) + except: + helper.logger.info("No value info to delete while eliminating identity layers.") + # Node is waiting for elimination + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) + +def eliminate_trivial_elementwise_calculation(g): + """Eliminate Add, Sub, Mul, Sub nodes which do nothing. + """ + node_to_remove = [] + for node in g.node: + weight_node = None + if node.op_type == 'Add' or node.op_type == 'Sub': + # For add and sub, check if the weights are 0s. + weight_node = helper.find_node_by_output_name(g, node.input[1]) + if weight_node is None or weight_node.op_type != 'Constant': + continue + weight_np = helper.constant_to_numpy(weight_node) + if np.any(weight_np): + continue + elif node.op_type == 'Mul' or node.op_type == 'Div': + # For Mul and Div, check if the weights are 1s. + weight_node = helper.find_node_by_output_name(g, node.input[1]) + if weight_node is None or weight_node.op_type != 'Constant': + continue + weight_np = helper.constant_to_numpy(weight_node) + weight_np = weight_np - 1 + if np.any(weight_np): + continue + else: + # For other nodes, just skip + continue + # Remove the node + node_to_remove.append(node) + output_value_info = helper.find_value_by_name(g, node.output[0]) + if output_value_info is not None: + g.value_info.remove(output_value_info) + # Replace next node input if any. + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + todel_output = helper.find_output_by_name(g, node.output[0]) + if todel_output is not None: + g.output.remove(todel_output) + previous_output = helper.find_output_by_name(g, node.input[0]) + if previous_output is None: + the_input_value = helper.find_value_by_name(g, node.input[0]) + g.output.extend([the_input_value]) + # Delete the constant node if it is not used by other nodes + constant_following_nodes = helper.find_following_nodes_by_input_value_name(g, weight_node.output[0]) + if len(constant_following_nodes) == 1: + node_to_remove.append(weight_node) + output_value_info = helper.find_value_by_name(g, weight_node.output[0]) + if output_value_info is not None: + g.value_info.remove(output_value_info) + for node in node_to_remove: + g.node.remove(node) + +def eliminate_nop_cast(g): + """Eliminate do nothing Cast nodes. + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Cast': + continue + # Get input value_info + input_value = helper.find_value_by_name(g, node.input[0]) + if input_value is None: + helper.logger.debug(f"Cannot find the input value_info for Cast node {node.name}. Skip elimination check.") + continue + # Get output value_info + output_value = helper.find_value_by_name(g, node.output[0]) + if output_value is None: + output_value = helper.find_output_by_name(g, node.output[0]) + if output_value is None: + helper.logger.debug(f"Cannot find the output value_info for Cast node {node.name}. Skip elimination check.") + continue + # Compare the type. + if input_value.type.tensor_type.elem_type != output_value.type.tensor_type.elem_type: + continue + # If this node is the output node, set its previous node as output nodes. + if helper.find_output_by_name(g, node.output[0]) is not None: + todel_output = helper.find_output_by_name(g, node.output[0]) + g.output.remove(todel_output) + if helper.find_output_by_name(g, node.input[0]) is None: + the_input_value = helper.find_value_by_name(g, node.input[0]) + if the_input_value is not None: + g.output.extend([the_input_value]) + # Replace the parents in all the following nodes + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + # Delete value info + value_between = helper.find_value_by_name(g, node.output[0]) + if value_between is not None: + g.value_info.remove(value_between) + # Node is waiting for elimination + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) diff --git a/tools/optimizer_scripts/tools/fusing.py b/tools/optimizer_scripts/tools/fusing.py new file mode 100644 index 0000000..202a4c2 --- /dev/null +++ b/tools/optimizer_scripts/tools/fusing.py @@ -0,0 +1,1064 @@ +import onnx.helper +import numpy as np +from . import helper +from .other import topological_sort +from .modhelper import delete_value_with_name_if_exists, replace_node_input + +def fuse_Transpose_into_Constant(g): + """ + Fuse Transpose layers into the Constant layers before + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Transpose': + continue + prev_node = helper.find_node_by_output_name(g, node.input[0]) + if prev_node is None or prev_node.op_type != 'Constant': + continue + + pre_shape, data_list = helper.constant_to_list(prev_node) + w = np.reshape(data_list, pre_shape) + w = w.transpose(node.attribute[0].ints) + new_shape = w.shape + w = w.flatten() + + new_tensor = onnx.helper.make_tensor( + name=prev_node.name+'_data', + data_type=prev_node.attribute[0].t.data_type, + dims=new_shape, + vals=w.tolist() + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + value_between = helper.find_value_by_name(g, prev_node.output[0]) + value_type = value_between.type.tensor_type.elem_type + g.value_info.remove(value_between) + + g.node.extend([new_node]) + node_to_remove.append(node) + node_to_remove.append(prev_node) + + if new_node.output[0] not in [i.name for i in g.value_info]: + new_value = onnx.helper.make_tensor_value_info( + name=new_node.output[0], + elem_type=value_type, + shape=new_shape + ) + g.value_info.extend([new_value]) + if new_node.output[0]: + val_info_to_del = helper.find_value_by_name(g, new_node.output[0]) + g.value_info.remove(val_info_to_del) + + for node in node_to_remove: + g.node.remove(node) + + topological_sort(g) + +def fuse_Add_into_Conv(g): + """ + Fuse Transpose layers into the Constant layers before + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Add': + continue + conv_node = helper.find_node_by_output_name(g, node.input[0]) + cons_node = helper.find_node_by_output_name(g, node.input[1]) + if conv_node is None or cons_node is None: + continue + if conv_node.op_type != 'Conv' or cons_node.op_type != 'Constant': + continue + if len(conv_node.input) > 2: + continue + # This layer should be fused. Connect constant node into convolution node. + add_node = node + conv_node.input.extend([cons_node.output[0]]) + old_value = helper.find_value_by_name(g, conv_node.output[0]) + conv_node.output[0] = add_node.output[0] + # Remove origin conv_node_output + g.value_info.remove(old_value) + # Remove current node + node_to_remove.append(add_node) + # Apply changes to the model + for node in node_to_remove: + g.node.remove(node) + +def fuse_BN_into_Gemm(g): + """Fuse the following BN into the previous Gemm. + + :param g: the graph + """ + node_to_remove = [] + for node in g.node: + # Check for BN and Gemm + if node.op_type != 'BatchNormalization': + continue + gemm_node = helper.find_node_by_output_name(g, node.input[0]) + if gemm_node is None: + continue + if gemm_node.op_type != 'Gemm': + continue + if len(helper.find_following_nodes_by_input_value_name(g, gemm_node.output[0])) > 1: + continue + bn_node = node + # Get original weights + gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1]) + gemm_b = helper.constant_to_numpy(gemm_b_node) + gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2]) + gemm_c = helper.constant_to_numpy(gemm_c_node) + bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1]) + bn_scale = helper.constant_to_numpy(bn_scale_node) + bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2]) + bn_bias = helper.constant_to_numpy(bn_bias_node) + bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3]) + bn_mean = helper.constant_to_numpy(bn_mean_node) + bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4]) + bn_var = helper.constant_to_numpy(bn_var_node) + # Apply attributes + # epsilon + epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') + if epsilon is None: + epsilon = 0.00001 + else: + epsilon = epsilon.f + bn_var = bn_var + epsilon + # alpha + alpha = helper.get_attribute_by_name(gemm_node, 'alpha') + if alpha is None: + alpha = 1 + else: + alpha = alpha.f + gemm_b = gemm_b * alpha + # beta + beta = helper.get_attribute_by_name(gemm_node, 'beta') + if beta is None: + beta = 1 + else: + beta = beta.f + gemm_c = gemm_c * beta + # transA + transA = helper.get_attribute_by_name(gemm_node, 'transA') + if transA is not None and transA.i == 1: + raise RuntimeError("Do not support transA") + # transB + transB = helper.get_attribute_by_name(gemm_node, 'transB') + if transB is not None and transB.i == 1: + gemm_b = gemm_b.transpose() + # Calculate new weights + new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) + new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias + # Replace original weights + new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) + new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) + g.node.extend([new_gemm_b_node, new_gemm_c_node]) + node_to_remove.extend([gemm_b_node, + gemm_c_node, + bn_node, + bn_scale_node, + bn_bias_node, + bn_mean_node, + bn_var_node]) + # Modify attributes + # alpha + alpha = helper.get_attribute_by_name(gemm_node, 'alpha') + if alpha is not None: + alpha.f = 1.0 + # beta + beta = helper.get_attribute_by_name(gemm_node, 'beta') + if beta is not None: + beta.f = 1.0 + # transB + transB = helper.get_attribute_by_name(gemm_node, 'transB') + if transB is not None: + transB.i = 0 + # Connect the new graph + gemm_node.input[1] = new_gemm_b_node.output[0] + gemm_node.input[2] = new_gemm_c_node.output[0] + gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0]) + gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0]) + gemm_b_value.name = new_gemm_b_node.output[0] + gemm_c_value.name = new_gemm_c_node.output[0] + gemm_value = helper.find_value_by_name(g, gemm_node.output[0]) + g.value_info.remove(gemm_value) + gemm_node.output[0] = bn_node.output[0] + for i in range(1, 5): + value = helper.find_value_by_name(g, bn_node.input[i]) + g.value_info.remove(value) + # Remove useless nodes + for node in node_to_remove: + g.node.remove(node) + topological_sort(g) + +def fuse_BN_with_Reshape_into_Gemm(g): + """Fuse the following BN into the previous Gemm, even with Reshape or \\ + Squeeze and Unsqueeze surrounding. + + :param g: the graph + """ + node_to_remove = [] + for node in g.node: + # Check for BN and Gemm pattern: Gemm A BN B + # Find BatchNorm Node + if node.op_type != 'BatchNormalization': + continue + bn_node = node + # Find A Node + a_node = helper.find_node_by_output_name(g, node.input[0]) + if a_node is None or len(a_node.input) == 0: + continue + # Find Gemm Node + gemm_node = helper.find_node_by_output_name(g, a_node.input[0]) + if gemm_node is None or gemm_node.op_type != 'Gemm': + continue + # Find B Node + b_node_list = helper.find_following_nodes_by_input_value_name(g, bn_node.output[0]) + if len(b_node_list) == 0: + the_output = helper.find_output_by_name(g, bn_node.output[0]) + if the_output is None: + continue + b_node = None + elif len(b_node_list) > 1: + continue + else: + b_node = b_node_list[0] + # Check for branches + if len(helper.find_following_nodes_by_input_value_name(g, gemm_node.output[0])) > 1: + continue + if len(helper.find_following_nodes_by_input_value_name(g, a_node.output[0])) > 1: + continue + # Check type of A + if a_node.op_type == 'Unsqueeze': + axes = helper.get_attribute_by_name(a_node, 'axes') + if axes.ints != [2]: + continue + elif a_node.op_type == 'Reshape': + a = helper.constant_to_list(helper.find_node_by_output_name(g, a_node.input[1]))[1] + if len(a) != 3 or a[2] != 1: + continue + else: + continue + # Check type of B + if b_node is None: + pass + elif b_node.op_type == 'Flatten': + pass + elif b_node.op_type == 'Squeeze': + axes = helper.get_attribute_by_name(a_node, 'axes') + if axes.ints != [2]: + continue + elif b_node.op_type == 'Reshape': + a = helper.constant_to_list(helper.find_node_by_output_name(g, b_node.input[1]))[1] + if len(a) != 2: + continue + else: + continue + # Construct new Nodes + # Get original weights + gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1]) + gemm_b = helper.constant_to_numpy(gemm_b_node) + gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2]) + gemm_c = helper.constant_to_numpy(gemm_c_node) + bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1]) + bn_scale = helper.constant_to_numpy(bn_scale_node) + bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2]) + bn_bias = helper.constant_to_numpy(bn_bias_node) + bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3]) + bn_mean = helper.constant_to_numpy(bn_mean_node) + bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4]) + bn_var = helper.constant_to_numpy(bn_var_node) + # Apply attributes + # epsilon + epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') + if epsilon is None: + epsilon = 0.00001 + else: + epsilon = epsilon.f + bn_var = bn_var + epsilon + # alpha + alpha = helper.get_attribute_by_name(gemm_node, 'alpha') + if alpha is None: + alpha = 1 + else: + alpha = alpha.f + gemm_b = gemm_b * alpha + # beta + beta = helper.get_attribute_by_name(gemm_node, 'beta') + if beta is None: + beta = 1 + else: + beta = beta.f + gemm_c = gemm_c * beta + # transA + transA = helper.get_attribute_by_name(gemm_node, 'transA') + if transA is not None and transA.i == 1: + raise RuntimeError("Do not support transA") + # transB + transB = helper.get_attribute_by_name(gemm_node, 'transB') + if transB is not None and transB.i == 1: + gemm_b = gemm_b.transpose() + # Calculate new weights + new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) + new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias + # Replace original weights + new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) + new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) + g.node.extend([new_gemm_b_node, new_gemm_c_node]) + # Modify attributes + # alpha + alpha = helper.get_attribute_by_name(gemm_node, 'alpha') + if alpha is not None: + alpha.f = 1.0 + # beta + beta = helper.get_attribute_by_name(gemm_node, 'beta') + if beta is not None: + beta.f = 1.0 + # transB + transB = helper.get_attribute_by_name(gemm_node, 'transB') + if transB is not None: + transB.i = 0 + # Remove useless nodes + node_to_remove.extend([gemm_b_node, + gemm_c_node, + bn_node, + bn_scale_node, + bn_bias_node, + bn_mean_node, + bn_var_node, + a_node]) + if a_node.op_type == 'Reshape': + node_to_remove.append(helper.find_node_by_output_name(g, a_node.input[1])) + if b_node is not None: + node_to_remove.append(b_node) + if b_node.op_type == 'Reshape': + node_to_remove.append(helper.find_node_by_output_name(g, b_node.input[1])) + # Delete useless value infos + value = helper.find_value_by_name(g, a_node.output[0]) + g.value_info.remove(value) + if a_node.op_type == 'Reshape': + value = helper.find_value_by_name(g, a_node.input[1]) + g.value_info.remove(value) + for i in range(1, 5): + value = helper.find_value_by_name(g, bn_node.input[i]) + g.value_info.remove(value) + value = helper.find_value_by_name(g, bn_node.output[0]) + if value is not None: + g.value_info.remove(value) + if b_node is not None: + value = helper.find_value_by_name(g, gemm_node.output[0]) + g.value_info.remove(value) + if b_node.op_type == 'Reshape': + value = helper.find_value_by_name(g, b_node.input[1]) + g.value_info.remove(value) + # Connect the new graph + # Connect Gemm new weights + gemm_node.input[1] = new_gemm_b_node.output[0] + gemm_node.input[2] = new_gemm_c_node.output[0] + gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0]) + gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0]) + gemm_b_value.name = new_gemm_b_node.output[0] + gemm_b_value.type.tensor_type.shape.dim[0].dim_value = new_gemm_b.shape[0] + gemm_b_value.type.tensor_type.shape.dim[1].dim_value = new_gemm_b.shape[1] + gemm_c_value.name = new_gemm_c_node.output[0] + if b_node is None: + # If b node is None, set the Gemm output as the graph output + output_value = helper.find_output_by_name(g, bn_node.output[0]) + g.output.remove(output_value) + g.output.extend([helper.find_value_by_name(g, gemm_node.output[0])]) + else: + # Else, set node B output as gemm output + gemm_node.output[0] = b_node.output[0] + # Remove useless nodes + for node in node_to_remove: + g.node.remove(node) + topological_sort(g) + + +def fuse_Gemm_into_Gemm(g): + """Fuse the previous Gemm into the following Gemm. + + :param g: the graph + """ + node_to_remove = [] + for node in g.node: + # Check for Gemm and Gemm + if node.op_type != 'Gemm': + continue + prev_node = helper.find_node_by_output_name(g, node.input[0]) + if prev_node is None: + continue + if prev_node.op_type != 'Gemm': + continue + # Get original weights + prev_b_node = helper.find_node_by_output_name(g, prev_node.input[1]) + prev_b = helper.constant_to_numpy(prev_b_node) + prev_c_node = helper.find_node_by_output_name(g, prev_node.input[2]) + prev_c = helper.constant_to_numpy(prev_c_node) + b_node = helper.find_node_by_output_name(g, node.input[1]) + b = helper.constant_to_numpy(b_node) + c_node = helper.find_node_by_output_name(g, node.input[2]) + c = helper.constant_to_numpy(c_node) + # Apply attributes + # alpha + alpha = helper.get_attribute_by_name(node, 'alpha') + if alpha is None: + alpha = 1 + else: + alpha = alpha.f + b = b * alpha + alpha = helper.get_attribute_by_name(prev_node, 'alpha') + if alpha is None: + alpha = 1 + else: + alpha = alpha.f + prev_b = prev_b * alpha + # beta + beta = helper.get_attribute_by_name(node, 'beta') + if beta is None: + beta = 1 + else: + beta = beta.f + c = c * beta + beta = helper.get_attribute_by_name(prev_node, 'beta') + if beta is None: + beta = 1 + else: + beta = beta.f + prev_c = prev_c * beta + # transA + transA = helper.get_attribute_by_name(node, 'transA') + if transA is not None and transA.i == 1: + raise RuntimeError("Do not support transA") + transA = helper.get_attribute_by_name(prev_node, 'transA') + if transA is not None and transA.i == 1: + raise RuntimeError("Do not support transA") + # transB + transB = helper.get_attribute_by_name(node, 'transB') + if transB is not None and transB.i == 1: + b = b.transpose() + transB = helper.get_attribute_by_name(prev_node, 'transB') + if transB is not None and transB.i == 1: + prev_b = prev_b.transpose() + # Calculate new weights + new_b = prev_b.dot(b) + new_c = prev_c.dot(b) + c + # Replace original weights + new_b_node = helper.numpy_to_constant(b_node.name + '_fused', new_b) + new_c_node = helper.numpy_to_constant(c_node.name + '_fused', new_c) + g.node.extend([new_b_node, new_c_node]) + node_to_remove.extend([b_node, + c_node, + prev_b_node, + prev_c_node, + prev_node]) + # Modify attributes + # alpha + alpha = helper.get_attribute_by_name(node, 'alpha') + if alpha is not None: + alpha.f = 1.0 + # beta + beta = helper.get_attribute_by_name(node, 'beta') + if beta is not None: + beta.f = 1.0 + # transB + transB = helper.get_attribute_by_name(node, 'transB') + if transB is not None: + transB.i = 0 + # Connect the new graph + node.input[0] = prev_node.input[0] + delete_value_with_name_if_exists(g, prev_node.output[0]) + for i in range(1, 3): + delete_value_with_name_if_exists(g, prev_node.input[i]) + delete_value_with_name_if_exists(g, node.input[i]) + node.input[1] = new_b_node.output[0] + node.input[2] = new_c_node.output[0] + # Remove useless nodes + for node in node_to_remove: + g.node.remove(node) + topological_sort(g) + +def fuse_MatMul_and_Add_into_Gemm(g): + """ + Fuse MatMul and Add layers into a new Gemm layers. + + :param g: the onnx graph + :raises ValueError: MatMul must be followed by an Add node + """ + node_to_remove = [] + node_to_add = [] + for node in g.node: + if node.op_type != 'MatMul': + continue + add_node = None + for i in g.node: + if not i.input: + continue + if i.input[0] == node.output[0]: + add_node = i + break + value_to_remove = helper.find_value_by_name(g, node.output[0]) + if add_node is None or value_to_remove is None or add_node.op_type != 'Add': + continue + input_list = node.input + input_list.append(add_node.input[1]), + new_node = onnx.helper.make_node( + "Gemm", + input_list, + add_node.output, + name=node.name, + alpha=1.0, + beta=1.0, + transA=0, + transB=0 + ) + node_to_add.append(new_node) + node_to_remove.append(node) + node_to_remove.append(add_node) + g.value_info.remove(value_to_remove) + for node in node_to_remove: + g.node.remove(node) + g.node.extend(node_to_add) + +def fuse_consecutive_transposes(g): + node_to_del = [] + for node in g.node: + if node.op_type != 'Transpose': + continue + pre_node = helper.find_node_by_output_name(g, node.input[0]) + if pre_node.op_type != 'Transpose': + continue + + pre_permutation = list(pre_node.attribute[0].ints) + cur_permutation = list(node.attribute[0].ints) + if len(pre_permutation) != len(cur_permutation): + continue + + new_permutation = [] + for ind in cur_permutation: + new_permutation.append(pre_permutation[ind]) + + new_trans_node = onnx.helper.make_node( + 'Transpose', + [pre_node.input[0]], + [node.output[0]], + name=node.name, + perm=new_permutation + ) + + g.node.extend([new_trans_node]) + node_to_del.extend([pre_node, node]) + + mid_val_info = helper.find_value_by_name(g, node.input[0]) + if mid_val_info: + g.value_info.remove(mid_val_info) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + topological_sort(g) + +def fuse_mul_and_add_into_bn(g): + node_to_del = [] + for node in g.node: + if node.op_type != 'Add': + continue + add_node = node + input_nodes_add = [helper.find_node_by_output_name(g, input_name) for input_name in add_node.input] + if any([n == None for n in input_nodes_add]): + continue + mul_node, const_add = None, None + for input_node_add in input_nodes_add: + if input_node_add.op_type == 'Mul': + mul_node = input_node_add + elif input_node_add.op_type == 'Constant': + const_add = input_node_add + else: + pass + if not mul_node or not const_add: + continue + data_input_name, const_mul = None, None + for input_name in mul_node.input: + input_node = helper.find_node_by_output_name(g, input_name) + if not input_node: + data_input_name = input_name + elif input_node.op_type == 'Constant': + if not const_mul: + const_mul = input_node + else: + data_input_name = input_name + else: + data_input_name = input_name + + if not const_mul: + continue + + scale_shape, scale_data = helper.constant_to_list(const_mul) + bias_shape, __ = helper.constant_to_list(const_add) + c_dim = len(scale_data) + if scale_shape != bias_shape: + continue + + data_input_value = helper.find_value_by_name(g, data_input_name) + if data_input_value is None: + data_input_value = helper.find_input_by_name(g, data_input_name) + _ , previous_node_output_shape = helper.find_size_shape_from_value(data_input_value) + # only allow 4 dim data input due to the hardware limitation + if previous_node_output_shape is None or len(previous_node_output_shape) != 4: + continue + + # check if mul's dim and input channel dimension are matched + if previous_node_output_shape[1] != c_dim: + continue + + if scale_shape == [1, c_dim, 1, 1]: + # remove all '1' + for _ in range(3): + const_add.attribute[0].t.dims.remove(1) + const_mul.attribute[0].t.dims.remove(1) + elif scale_shape == [1, c_dim]: + # remove all '1' + const_add.attribute[0].t.dims.remove(1) + const_mul.attribute[0].t.dims.remove(1) + elif scale_shape == 1 and c_dim == 1: + # Single value weight + const_add.attribute[0].t.dims.append(1) + const_mul.attribute[0].t.dims.append(1) + else: + continue + + bn_name = add_node.output[0] + const_mean = helper.list_to_constant(bn_name+'_mean', [c_dim], [0.0 for _ in range(c_dim)]) + const_var = helper.list_to_constant(bn_name+'_var', [c_dim], [1.0 for _ in range(c_dim)]) + + bn_node = onnx.helper.make_node( + 'BatchNormalization', + [data_input_name, const_mul.output[0], const_add.output[0],\ + const_mean.output[0], const_var.output[0]], + [add_node.output[0]], + name=bn_name, + epsilon=0.00000001 + ) + + mid_val_info = helper.find_value_by_name(g, mul_node.output[0]) + scale_val_info = helper.find_value_by_name(g, const_mul.output[0]) + bais_val_info = helper.find_value_by_name(g, const_add.output[0]) + g.value_info.remove(mid_val_info) + g.value_info.remove(scale_val_info) + g.value_info.remove(bais_val_info) + + new_scale_val_info = onnx.helper.make_tensor_value_info( + const_mul.output[0], + const_mul.attribute[0].t.data_type, + [c_dim] + ) + new_bais_val_info = onnx.helper.make_tensor_value_info( + const_add.output[0], + const_add.attribute[0].t.data_type, + [c_dim] + ) + mean_val_info = onnx.helper.make_tensor_value_info( + const_mean.output[0], + const_mean.attribute[0].t.data_type, + [c_dim] + ) + var_val_info = onnx.helper.make_tensor_value_info( + const_var.output[0], + const_var.attribute[0].t.data_type, + [c_dim] + ) + + g.value_info.extend([new_scale_val_info]) + g.value_info.extend([new_bais_val_info]) + g.value_info.extend([mean_val_info]) + g.value_info.extend([var_val_info]) + g.node.extend([bn_node]) + g.node.extend([const_mean]) + g.node.extend([const_var]) + node_to_del.extend([mul_node, add_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + + +def fuse_mul_and_add_into_gemm(g): + node_to_del = [] + for node in g.node: + if node.op_type != 'Add': + continue + add_node = node + mul_node = helper.find_node_by_output_name(g, add_node.input[0]) + if not mul_node or mul_node.op_type != 'Mul': + continue + mul_const = helper.find_node_by_output_name(g, mul_node.input[1]) + if not mul_const or mul_const.op_type != 'Constant': + continue + add_const = helper.find_node_by_output_name(g, add_node.input[1]) + if not add_const or add_const.op_type != 'Constant': + continue + + input_val = helper.find_value_by_name(g, mul_node.input[0]) + if not input_val: + input_val = helper.find_input_by_name(g, mul_node.input[0]) + if not input_val: + continue + + _, input_shape = helper.find_size_shape_from_value(input_val) + if not input_shape: + continue + + dim = int(np.prod(input_shape)) + if input_shape != [1, dim]: + continue + + mul_const_shape, mul_const_data = helper.constant_to_list(mul_const) + add_const_shape, __ = helper.constant_to_list(add_const) + + if len(mul_const_shape) != 1 or mul_const_shape[0] != dim: + continue + if len(add_const_shape) != 1 or add_const_shape[0] != dim: + continue + + b_data = np.zeros([dim, dim]) + for i in range(dim): + b_data[i][i] = mul_const_data[i] + b_data = b_data.flatten().tolist() + b_tensor = onnx.helper.make_tensor( + name=mul_const.name+'_tensor', + data_type=mul_const.attribute[0].t.data_type, + dims=[dim, dim], + vals=b_data + ) + b_const_node = onnx.helper.make_node( + 'Constant', + [], + [mul_const.output[0]], + value=b_tensor, + name=mul_const.output[0] + ) + + add_const.attribute[0].t.dims.insert(0, 1) + + gemm_node = onnx.helper.make_node( + 'Gemm', + [mul_node.input[0], b_const_node.output[0], add_const.output[0]], + [add_node.output[0]], + name=add_node.output[0] + ) + + g.node.extend([gemm_node, b_const_node]) + node_to_del.extend([mul_const, mul_node, add_node]) + + val_info_mid = helper.find_value_by_name(g, mul_node.output[0]) + val_info_mul_const = helper.find_value_by_name(g, mul_const.output[0]) + val_info_add_const = helper.find_value_by_name(g, add_const.output[0]) + if val_info_mid: + g.value_info.remove(val_info_mid) + if val_info_mul_const: + g.value_info.remove(val_info_mul_const) + if val_info_add_const: + g.value_info.remove(val_info_add_const) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + +def fuse_conv_and_add_into_conv(g): + node_to_del = [] + for node in g.node: + # Check if two nodes can be fused + if node.op_type != 'Add': + continue + add_node = node + add_const = helper.find_node_by_output_name(g, add_node.input[1]) + if not add_const or add_const.op_type != 'Constant': + continue + + conv_node = helper.find_node_by_output_name(g, add_node.input[0]) + if not conv_node or conv_node.op_type != 'Conv': + continue + weight_node = helper.find_node_by_output_name(g, conv_node.input[1]) + if not weight_node or weight_node.op_type != 'Constant': + continue + + m_dim = weight_node.attribute[0].t.dims[0] + if add_const.attribute[0].t.dims != [1, m_dim, 1, 1]: + continue + for _ in range(3): + add_const.attribute[0].t.dims.remove(1) + + # Link the add weight to constant. + conv_node.input.extend([add_const.output[0]]) + + # Remove the node + node_to_del.append(node) + output_value_info = helper.find_value_by_name(g, add_node.output[0]) + if output_value_info is not None: + g.value_info.remove(output_value_info) + add_weight_value_info = helper.find_value_by_name(g, add_const.output[0]) + if add_weight_value_info is not None: + g.value_info.remove(add_weight_value_info) + # Replace next node input if any. + following_nodes = helper.find_following_nodes_by_input_value_name(g, add_node.output[0]) + for following_node in following_nodes: + replace_node_input(following_node, add_node.output[0], add_node.input[0]) + # Replace output if any + todel_output = helper.find_output_by_name(g, add_node.output[0]) + if todel_output is not None: + g.output.remove(todel_output) + previous_output = helper.find_output_by_name(g, add_node.input[0]) + if previous_output is None: + the_input_value = helper.find_value_by_name(g, add_node.input[0]) + g.output.extend([the_input_value]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + + +def fuse_consecutive_reducemean(g): + node_to_del = [] + for node in g.node: + # Find consecutive ReduceMean + if node.op_type != 'ReduceMean': + continue + pre_node = helper.find_node_by_output_name(g, node.input[0]) + if pre_node is None or pre_node.op_type != 'ReduceMean': + continue + # Check attributes + pre_keepdims = helper.get_var_attribute_by_name(pre_node, 'keepdims', 'int') + pre_axes = helper.get_list_attribute_by_name(pre_node, 'axes', 'int') + cur_keepdims = helper.get_var_attribute_by_name(node, 'keepdims', 'int') + cur_axes = helper.get_list_attribute_by_name(node, 'axes', 'int') + if pre_keepdims != 0 or cur_keepdims != 0: + continue + axes = sorted(pre_axes + cur_axes) + if axes != [2, 3]: + continue + # Merge two ReduceMean into GlobalAveragePool. + new_gap_node = onnx.helper.make_node( + 'GlobalAveragePool', + [pre_node.input[0]], + [node.output[0] + '_intermedia'], + name = node.name + '_gap' + ) + new_flatten_node = onnx.helper.make_node( + 'Flatten', + [node.output[0] + '_intermedia'], + [node.output[0]], + name = node.name + '_flatten', + axis = 1 + ) + + # Clean up + g.node.extend([new_gap_node, new_flatten_node]) + node_to_del.extend([pre_node, node]) + mid_val_info = helper.find_value_by_name(g, node.input[0]) + if mid_val_info: + g.value_info.remove(mid_val_info) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + topological_sort(g) + +def fuse_slice_nodes_into_conv(g): + # define pattern checker + def check_is_slice(node): + if node.op_type == 'Concat': + return True + if node.op_type != 'Slice': + return False + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + if len(following_nodes) != 1: + return False + # also check attributes + if len(node.input) != 5: + return False + # starts should be 0 or 1 + starts_node = helper.find_node_by_output_name(g, node.input[1]) + if starts_node.op_type != 'Constant': + return False + _, starts_list = helper.constant_to_list(starts_node) + for num in starts_list: + if num != 0 and num != 1: + return False + # ends + ends_node = helper.find_node_by_output_name(g, node.input[2]) + if ends_node.op_type != 'Constant': + return False + # axes should be 2 or 3 + axes_node = helper.find_node_by_output_name(g, node.input[3]) + if axes_node.op_type != 'Constant': + return False + _, axes_list = helper.constant_to_list(axes_node) + for num in axes_list: + if num != 2 and num != 3: + return False + # Steps can only be 2 + steps_node = helper.find_node_by_output_name(g, node.input[4]) + if steps_node.op_type != 'Constant': + return False + _, steps_list = helper.constant_to_list(steps_node) + for num in steps_list: + if num != 2: + return False + # Recursion + return check_is_slice(following_nodes[0]) + # defind concat finder + def find_concat_node(node): + while node.op_type != 'Concat': + node = helper.find_following_nodes_by_input_value_name(g, node.output[0])[0] + return node + # define remove node function. + def remove_nodes(input_name): + following_nodes = helper.find_following_nodes_by_input_value_name(g, input_name) + # Remove concat directly + if len(following_nodes) == 1 and following_nodes[0].op_type == 'Concat': + g.node.remove(following_nodes[0]) + return + for following_node in following_nodes: + # Recursion first + remove_nodes(following_node.output[0]) + # Remove weights + for i in range(1, len(following_node.input)): + if len(helper.find_following_nodes_by_input_value_name(g, following_node.input[i])) > 1: + # More than one following nodes. Skip. + continue + input_weight = helper.find_node_by_output_name(g, following_node.input[i]) + g.node.remove(input_weight) + # Remove Slice nodes + g.node.remove(following_node) + # define remove value_info function + def remove_value_infos(input_name): + following_nodes = helper.find_following_nodes_by_input_value_name(g, input_name) + if following_nodes[0].op_type == 'Concat': + return + for following_node in following_nodes: + output_value = helper.find_value_by_name(g, following_node.output[0]) + # Remove output values + if output_value is not None: + g.value_info.remove(output_value) + # Remove weight values + for i in range(1, len(following_node.input)): + input_value = helper.find_value_by_name(g, following_node.input[i]) + if input_value is not None: + g.value_info.remove(input_value) + # Recursion + remove_value_infos(following_node.output[0]) + # define get slice position + def get_slice_position(final_slice_output): + slice_position = [0, 0] + prev_node = helper.find_node_by_output_name(g, final_slice_output) + while prev_node is not None: + starts_np = helper.constant_to_numpy(helper.find_node_by_output_name(g, prev_node.input[1])) + axes_np = helper.constant_to_numpy(helper.find_node_by_output_name(g, prev_node.input[3])) + for i in range(len(axes_np)): + if axes_np[i] == 2: + slice_position[0] = starts_np[i] + elif axes_np[i] == 3: + slice_position[1] = starts_np[i] + prev_node = helper.find_node_by_output_name(g, prev_node.input[0]) + return slice_position + # Check pattern from each input + for input_value in g.input: + nodes_after_input = helper.find_following_nodes_by_input_value_name(g, input_value.name) + pattern_matched = True + for following_node in nodes_after_input: + if following_node.op_type != 'Slice': + pattern_matched = False + break + else: + pattern_matched = check_is_slice(following_node) + if not pattern_matched: + continue + # Pattern found. Check limitation + # Currently only support 2D + if len(nodes_after_input) != 4: + continue + # Get the concat node + concat_node = find_concat_node(nodes_after_input[0]) + # Get basic information + input_shape = helper.get_shape_from_value_info(input_value) + channel_num = input_shape[1] + # Construct weight + weight_np = np.zeros((input_shape[1] * 4, input_shape[1], 3, 3), dtype=np.float32) + for i in range(4): + # Check each branch + slice_position = get_slice_position(concat_node.input[i]) + for j in range(channel_num): + weight_np[i * channel_num + j, j, slice_position[0], slice_position[1]] = 1 + weight_node = helper.numpy_to_constant(concat_node.name + '_weight', weight_np) + # Construct Conv node + new_conv = onnx.helper.make_node( + 'Conv', + [input_value.name, concat_node.name + '_weight'], + [concat_node.output[0]], + name = concat_node.name + '_fused', + dilations = [1, 1], + group = 1, + kernel_shape = [3, 3], + strides = [2, 2], + pads = [0, 0, 2, 2] + ) + # Delete old nodes, weights and value_infos + remove_value_infos(input_value.name) + remove_nodes(input_value.name) + # Replace node + g.node.append(weight_node) + g.node.append(new_conv) + + +def fuse_relu_min_into_clip(g): + node_to_del = [] + for node in g.node: + # Check Min node + if node.op_type != 'Min': + continue + min_node = node + # Check Constant node + min_const = helper.find_node_by_output_name(g, min_node.input[1]) + if not min_const or min_const.op_type != 'Constant': + continue + min_shape, min_value = helper.constant_to_list(min_const) + if min_shape != 1: + continue + # Check Relu node + relu_node = helper.find_node_by_output_name(g, min_node.input[0]) + if not relu_node or relu_node.op_type != 'Relu': + continue + + # Create Clip node + relu_min_const_node = helper.list_to_constant(relu_node.name+'_min_value', [], [0.0]) + clip_node = onnx.helper.make_node( + "Clip", + [relu_node.input[0], relu_min_const_node.output[0], min_const.output[0]], + [min_node.output[0]], + name=min_node.name + ) + + node_to_del.extend([relu_node, min_node]) + + old_relu_const_val_info = helper.find_value_by_name(g, min_node.input[0]) + if old_relu_const_val_info: + g.value_info.remove(old_relu_const_val_info) + g.node.extend([relu_min_const_node, clip_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) \ No newline at end of file diff --git a/tools/optimizer_scripts/tools/general_graph.py b/tools/optimizer_scripts/tools/general_graph.py new file mode 100644 index 0000000..352445b --- /dev/null +++ b/tools/optimizer_scripts/tools/general_graph.py @@ -0,0 +1,83 @@ +from collections import deque + +class Node: + """A Node which maps a node proto. It has pointers to its parents and + children. + """ + def __init__(self, onnx_node): + """Initialize a node. This initialization only set up the mapping to + node proto. The pointers should be set up by outside. + """ + self.name = None + self.parents = [] + self.children = [] + self.proto = None + self.output_value = None + if onnx_node is not None: + self.name = onnx_node.name + self.proto = onnx_node + +class Graph: + """A graph which is constructed from the onnx proto. + """ + def __init__(self, onnx_graph): + """Construct the graph from onnx. + """ + self.input_nodes = [] + self.output_nodes = [] + self.name2node = {} + self.output2node = {} + self.proto = onnx_graph + # Add input nodes + for value in onnx_graph.input: + input_node = Node(None) + input_node.name = "Input_" + value.name + input_node.output_value = value + self.name2node[input_node.name] = input_node + self.output2node[value.name] = input_node + self.input_nodes.append(input_node) + output_value_names = [value.name for value in onnx_graph.output] + # Add regular nodes + for onnx_node in onnx_graph.node: + node = Node(onnx_node) + self.name2node[node.name] = node + self.output2node[onnx_node.output[0]] = node + for value_name in onnx_node.input: + node.parents.append(self.output2node[value_name]) + self.output2node[value_name].children.append(node) + if onnx_node.output[0] in output_value_names: + self.output_nodes.append(node) + # Add value infos + for value in onnx_graph.value_info: + node = self.output2node[value.name] + node.output_value = value + def get_sorted_node_list(self): + """Return a node list in topological order. + """ + visited = set() + todo = deque() + result = [] + for node in self.input_nodes: + todo.append(node) + visited.add(node) + for onnx_node in self.proto.node: + if onnx_node.op_type == "Constant": + node = self.name2node[onnx_node.name] + todo.append(node) + visited.add(node) + while todo: + node = todo.popleft() + result.append(node) + for child in node.children: + if child in visited: + continue + ready = True + for child_parent in child.parents: + if child_parent in visited: + continue + ready = False + break + if ready: + todo.append(child) + visited.add(child) + return result diff --git a/tools/optimizer_scripts/tools/helper.py b/tools/optimizer_scripts/tools/helper.py new file mode 100644 index 0000000..18bc1e3 --- /dev/null +++ b/tools/optimizer_scripts/tools/helper.py @@ -0,0 +1,621 @@ +"""This module contains helper functions that do not modify the graph. +""" +import onnx +import onnx.helper +import struct +import numpy as np +import logging + +__ONNX_VERSION__ = -1 + +logger = logging.getLogger("optimizer_scripts") + +def setup_current_opset_version(m): + global __ONNX_VERSION__ + __ONNX_VERSION__ = m.opset_import[0].version + if __ONNX_VERSION__ not in [11]: + raise RuntimeError('Only support opset 11, but got ' + str(__ONNX_VERSION__)) + +def get_current_opset_version(): + if __ONNX_VERSION__ == -1: + raise RuntimeError('do setup_current_opset_version first please') + return __ONNX_VERSION__ + +def find_nodes_by_input_name(g, name): + nodes = [] + for node in g.node: + if name in node.input: + nodes.append(node) + return nodes + +def find_node_by_output_name(g, name): + """ + Find a node in the graph by its output name + + :param g: the onnx graph\\ + :param name: the target node output name\\ + :returns: the node find by name + """ + for i in g.node: + if name in i.output: + return i + return None + +def find_node_by_node_name(g, name): + """ + Find a node in the graph by its output name + + :param g: the onnx graph\\ + :param name: the target node output name\\ + :returns: the node find by name + """ + for i in g.node: + if i.name == name: + return i + return None + +def find_following_nodes_by_input_value_name(g, name): + """ Find the following nodes of a specific value. + + :param g: the onnx graph. \\ + :param name: the value name. \\ + :return: a list of following nodes. + """ + return find_nodes_by_input_name(g, name) + +def find_value_by_name(g, name): + """ + Find a value_info in the graph by name + + :param g: the onnx graph\\ + :param name: the target value_info name\\ + :returns: the value_info find by name + """ + for i in g.value_info: + if i.name == name: + return i + return None + +def find_output_by_name(g, name): + """ + Find a value_info in the graph by name + + :param g: the onnx graph\\ + :param name: the target value_info name\\ + :returns: the value_info find by name + """ + for i in g.output: + if i.name == name: + return i + return None + +def find_input_by_name(g, name): + """ + Find a input in the graph by name + + :param g: the onnx graph\\ + :param name: the target input name\\ + :returns: the input find by name + """ + for i in g.input: + if i.name == name: + return i + return None + +def list_to_constant(name, shape, data, data_type=None): + """Generate a constant node using the given infomation. + + :name: the node name and the output value name\\ + :shape: the data shape\\ + :data: the data itself\\ + :returns: the generated onnx constant node + """ + if not data_type: + if isinstance(data, int): + data_type = onnx.helper.TensorProto.INT64 + elif isinstance(data, float): + data_type = onnx.helper.TensorProto.FLOAT + elif len(data) > 0 and isinstance(data[0], int): + data_type = onnx.helper.TensorProto.INT64 + else: + data_type = onnx.helper.TensorProto.FLOAT + tensor = onnx.helper.make_tensor( + name, + data_type, + shape, + data + ) + new_w_node = onnx.helper.make_node( + "Constant", + [], + [name], + name = name, + value = tensor + ) + return new_w_node + + +def scaler_to_constant(name, data, data_type=None): + """Generate a constant node using the given infomation. + + :name: the node name and the output value name\\ + :shape: the data shape\\ + :data: the data itself\\ + :returns: the generated onnx constant node + """ + if not data_type: + if isinstance(data, int): + data_type = onnx.helper.TensorProto.INT64 + elif isinstance(data, float): + data_type = onnx.helper.TensorProto.FLOAT + else: + logger.error("Cannot create scaler constant with a list.") + exit(1) + tensor = onnx.helper.make_tensor( + name, + data_type, + None, + [data] + ) + new_w_node = onnx.helper.make_node( + "Constant", + [], + [name], + name = name, + value = tensor + ) + return new_w_node + + +def numpy_to_constant(name, np_array): + return list_to_constant(name, np_array.shape, np_array.flatten().tolist()) + +def constant_to_list(node): + """Generate a list from the constant node + + :node: the Constant node\\ + :returns: the shape of the constant node, the data of the constant node + """ + tensor = node.attribute[0].t + # 1. check data type + # 2. get data from raw or data + # 3. get shape from dim + if tensor.data_type == onnx.helper.TensorProto.INT32: + if len(tensor.int32_data) != 0: + data = list(tensor.int32_data) + else: + data = [i[0] for i in struct.iter_unpack('i', tensor.raw_data)] + elif tensor.data_type == onnx.helper.TensorProto.INT64: + if len(tensor.int64_data) != 0: + data = list(tensor.int64_data) + else: + data = [i[0] for i in struct.iter_unpack('q', tensor.raw_data)] + elif tensor.data_type == onnx.helper.TensorProto.INT8: + if len(tensor.int32_data) != 0: + data = list(tensor.int32_data) + else: + data = [i[0] for i in struct.iter_unpack('b', tensor.raw_data)] + elif tensor.data_type == onnx.helper.TensorProto.FLOAT: + if len(tensor.float_data) != 0: + data = list(tensor.float_data) + else: + data = [i[0] for i in struct.iter_unpack('f', tensor.raw_data)] + elif tensor.data_type == onnx.helper.TensorProto.DOUBLE: + if len(tensor.double_data) != 0: + data = list(tensor.double_data) + else: + data = [i[0] for i in struct.iter_unpack('d', tensor.raw_data)] + else: + print("Not supported data type {}".format(tensor.data_type)) + raise RuntimeError + if len(tensor.dims) == 0: + shape = len(data) + else: + shape = list(tensor.dims) + return shape, data + +def constant_to_numpy(node): + """Generate a numpy array from the constant node + + :node: the Constant node\\ + :returns: the numpy array + """ + shape, data = constant_to_list(node) + return np.array(data).reshape(shape) + +def all_constant_input(node): + """Find the inputs of the given node. If the inputs of this node are all\\ + constant nodes, return True. Otherwise, return False. + + :param node: the input node which has a Node structure\\ + :return: whether the node of this node are all constant + """ + if node.proto is None: + return False + isConstant = True + for parent in node.parents: + if parent.proto is None or parent.proto.op_type != 'Constant': + isConstant = False + break + return isConstant + +def get_padding(size, kernel_size, strides): + """ Calculate the padding array for same padding in the Tensorflow fashion.\\ + See https://www.tensorflow.org/api_guides/python/nn#Convolution for more. + """ + if size[0] % strides[0] == 0: + pad_h = max(kernel_size[0] - strides[0], 0) + else: + pad_h = max(kernel_size[0] - (size[0] % strides[0]), 0) + if size[1] % strides[1] == 0: + pad_w = max(kernel_size[1] - strides[1], 0) + else: + pad_w = max(kernel_size[1] - (size[1] % strides[1]), 0) + return [pad_h//2, pad_w//2, pad_h-pad_h//2, pad_w-pad_w//2] + +def get_shape_from_value_info(value): + """Get shape from a value info. + + :param value: the value_info proto\\ + :return: list of the shape + """ + return [d.dim_value for d in value.type.tensor_type.shape.dim] + +def find_size_shape_from_value(value): + ''' + Find the size of data within the value_info object. + :param value: value_info + :return: int size and list shape of the data in the value_info + ''' + if not value: + return None, None + if not value.type.tensor_type.shape.dim: + return 0, [] + size = 1 + shape = [] + for i in range(len(value.type.tensor_type.shape.dim)): + size *= max(1, value.type.tensor_type.shape.dim[i].dim_value) + shape.append(max(1, value.type.tensor_type.shape.dim[i].dim_value)) + + return size, shape + + +def get_attribute_by_name(node, attr_name): + """Get attribute proto with specific name in the given node proto. + + :param node: the node proto.\\ + :param attr_name: a str for the name of the target.\\ + :return: if found, return the attribute_proto. Else, return None. + """ + for attr in node.attribute: + if attr.name == attr_name: + return attr + return None + +def get_list_attribute_by_name(node, attr_name: str, attr_type: str): + """Get list attribute with specific name in the given node proto. + + :param node: the node proto.\\ + :param attr_name: a str for the name of the target.\\ + :param attr_type: a str which should be "float" or "int".\\ + :return: if found, return the list. Else, return None. + """ + attr_proto = get_attribute_by_name(node, attr_name) + if attr_proto is None: + return None + if attr_type == "int": + if len(attr_proto.ints) == 0: + return None + else: + return list(attr_proto.ints) + elif attr_type == "float": + if len(attr_proto.ints) == 0: + return None + else: + return list(attr_proto.floats) + else: + print("Warning: undefined type for list attribute extraction") + return None + +def get_var_attribute_by_name(node, attr_name: str, attr_type: str): + """Get variable attribute with specific name in the given node proto. + + :param node: the node proto.\\ + :param attr_name: str for the name of the target.\\ + :param attr_type: str which should be "float", "int", "string" or "tensor".\\ + :return: if found, return the variable. Else, return None. + """ + attr_proto = get_attribute_by_name(node, attr_name) + if attr_proto is None: + return None + if attr_type == "int": + return attr_proto.i + elif attr_type == "float": + return attr_proto.f + elif attr_type == "string": + if type(attr_proto.s) == type(b'abc'): + return attr_proto.s.decode("utf-8") + else: + return attr_proto.s + elif attr_type == "tensor": + return attr_proto.t + else: + print("Warning: undefined type for variable attribute extraction") + return None + +def flatten_with_depth(data, depth): + output = [] + if type(data) not in [type(np.array([1])), type([1])]: + return [[data, 0]] + for item in data: + if type(item) not in [type(np.array([1])), type([1])]: + output.append([item, depth+1]) + else: + output += flatten_with_depth(item, depth+1) + return output + +def flatten_to_list(data): + flatten_depth = flatten_with_depth(data, 0) + flat_data = [item[0] for item in flatten_depth] + return flat_data + +def get_shape(data): + shape = [] + if type(data) not in [type(np.array([1])), type([1])]: + return [] + sub_data = data[0] + shape.append(len(data)) + while type(sub_data) in [type(np.array([1])), type([1])]: + shape.append(len(sub_data)) + sub_data = sub_data[0] + return shape + + +def slice_data(data, starts, ends, axes): + flat_data = [item[0] for item in flatten_with_depth(data, 0)] + shape = get_shape(data) + + starts_updated = [] + ends_updated = [] + for i in range(len(starts)): + start_updated = min(starts[i], shape[i]-1) % shape[i] + starts_updated.append(start_updated) + for j in range(len(starts)): + if ends[j] >= shape[j]: + end_updated = shape[j] + else: + end_updated = min(ends[j], shape[j]) % shape[j] + ends_updated.append(end_updated) + + index_slices = [] + for i in range(len(shape)): + if i not in axes: + index_slices.append(list(range(shape[i]))) + else: + axe_ind = axes.index(i) + index_slices.append(list(range(starts_updated[axe_ind], ends_updated[axe_ind]))) + + indices = [1] + for i in range(len(shape)-1, -1, -1): + step = np.prod(shape[i+1:]) + temp_pos = indices + new_indices = [] + for n in index_slices[i]: + for pos in temp_pos: + new_indices.append(int(n*step+pos)) + indices = new_indices + + sliced_data = [flat_data[k-1] for k in indices] + + # reshape to correct shape. + new_shape = [] + for i in range(len(shape)): + if i not in axes: + new_shape.append(shape[i]) + else: + axe_ind = axes.index(i) + new_shape.append(ends_updated[axe_ind]-starts_updated[axe_ind]) + if any([dim < 1 for dim in new_shape]): + raise RuntimeError('Invalid starts ends.') + + sliced_data = np.reshape(sliced_data, new_shape) + + return sliced_data + +def concatenate(data_sets, axis): + # check shapes + shapes = [] + shapes_ = [] + for data_set in data_sets: + shape = get_shape(data_set) + shapes.append(list(shape)) + shape.pop(axis) + shapes_.append(shape) + if not all([s == shapes_[0] for s in shapes_]): + raise RuntimeError('data sets shapes do not match') + + new_dim = sum([s[axis] for s in shapes]) + new_shape = list(shapes[0]) + new_shape[axis] = new_dim + + flat_data_sets = [] + for data_set in data_sets: + flat_data_sets.append(flatten_to_list(data_set)) + + sub_block_size = 1 + for i in range(axis+1, len(shapes[0])): + sub_block_size *= shapes[0][i] + + split_num = 1 + for i in range(axis): + split_num *= shapes[0][i] + + total_flat_data = [] + for i in range(split_num): + for j in range(len(shapes)): + block_size = sub_block_size*shapes[j][axis] + total_flat_data.extend(flat_data_sets[j][i*block_size:(i+1)*block_size]) + + new_data = np.reshape(total_flat_data, new_shape) + + return new_data + + +def broadcast_data_sets(data_set_1, data_set_2): + shape1 = get_shape(data_set_1) + shape2 = get_shape(data_set_2) + + # compare shapes and get broadcasted shape + list_a, list_b = (shape1, shape2) if len(shape1) > len(shape2) else (shape2, shape1) + while len(list_a) > len(list_b): + list_b.insert(0, 0) + broadcasted_shape = [] + for i in range(len(list_a)): + if list_b[i] == 0: + broadcasted_shape.append(list_a[i]) + elif list_b[i] == 1: + broadcasted_shape.append(list_a[i]) + elif list_a[i] == 1: + broadcasted_shape.append(list_b[i]) + elif list_a[i] == list_b[i]: + broadcasted_shape.append(list_a[i]) + else: + raise RuntimeError('Can not broadcast two data sets') + + # prepare data for broadcasting. + shape1 = list(map(lambda x:x if x != 0 else 1, shape1)) + shape2 = list(map(lambda x:x if x != 0 else 1, shape2)) + data_1 = np.reshape(data_set_1, shape1) + data_2 = np.reshape(data_set_2, shape2) + + for i in range(len(shape1)): + if shape1[i] != broadcasted_shape[i]: + new_data_total = [list(data_1) for _ in range(broadcasted_shape[i])] + data_1 = concatenate(new_data_total, axis=i) + for i in range(len(shape2)): + if shape2[i] != broadcasted_shape[i]: + new_data_total = [list(data_2) for _ in range(broadcasted_shape[i])] + data_2 = concatenate(new_data_total, axis=i) + + return data_1, data_2 + + +def add(data_set_1, data_set_2): + broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2) + + flat_data_1 = flatten_to_list(broadcasted_data_1) + flat_data_2 = flatten_to_list(broadcasted_data_2) + shape = get_shape(broadcasted_data_1) + res = [] + for i in range(len(flat_data_1)): + res.append(flat_data_1[i]+flat_data_2[i]) + + res = np.reshape(res, shape) + + return res + + +def reduceprod(data_set, axis, keepdims=1): + flat_data = flatten_to_list(data_set) + old_shape = get_shape(data_set) + + temp_shape = old_shape + temp_flat_data = flat_data + for ax in axis: + split_num = 1 + step = 1 + for i in range(ax): + split_num *= temp_shape[i] + for i in range(ax+1, len(temp_shape)): + step *= temp_shape[i] + + block_size = len(temp_flat_data)//split_num + new_flat_data = [] + for j in range(split_num): + block_data = temp_flat_data[j*block_size:(j+1)*block_size] + reduced_block_data = [] + for k in range(step): + val = block_data[k] + for l in range(1, block_size//step): + val *= block_data[k+l*step] + reduced_block_data.append(val) + new_flat_data.extend(reduced_block_data) + temp_flat_data = new_flat_data + temp_shape[ax] = 1 + + new_flat_data = temp_flat_data + new_shape = temp_shape + if not keepdims: + axis = sorted(list(axis)) + for pos in axis[::-1]: + new_shape.pop(pos) + + return np.reshape(new_flat_data, new_shape) + + +def transpose(data_set, permutation): + # find series of local swaps + data_set = list(data_set) + perm = list(permutation) + shape = get_shape(data_set) + flat_data = flatten_to_list(data_set) + assert set(perm) == set(range(len(shape))), 'invalid permutation' + + new_shape = [shape[i] for i in perm] + swaps = [] + bubbled = True + while bubbled: + bubbled = False + for i in range(len(new_shape)-1): + if perm[i] > perm[i+1]: + swaps.append([i, i+1]) + p_1, p_2 = perm[i], perm[i+1] + perm[i], perm[i+1] = p_2, p_1 + bubbled = True + + # apply local swaps + current_shape = list(shape) + temp_flat_data = flat_data + + for swap in swaps[::-1]: + ind_1, ind_2 = swap[0], swap[1] + dim_1 = current_shape[ind_1] + dim_2 = current_shape[ind_2] + split_num = 1 + block_size = 1 + + for i in range(ind_1): + split_num *= current_shape[i] + for i in range(ind_2+1, len(current_shape)): + block_size *= current_shape[i] + + data_blocks = np.reshape(temp_flat_data, [-1, block_size]) + flat_data_1 = [] + for k in range(split_num): + block = [] + for m in range(dim_2): + for n in range(dim_1): + block_pos = k*dim_1*dim_2 + n*dim_2+m + block.extend(data_blocks[block_pos]) + flat_data_1.extend(block) + + temp_flat_data = flat_data_1 + current_shape[ind_1] = dim_2 + current_shape[ind_2] = dim_1 + + return np.reshape(temp_flat_data, current_shape) + +def subtract(data_set_1, data_set_2): + broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2) + + shape = get_shape(broadcasted_data_1) + flat_data_1 = flatten_to_list(broadcasted_data_1) + flat_data_2 = flatten_to_list(broadcasted_data_2) + + substracted_data = [flat_data_1[i] - flat_data_2[i] for i in range(len(flat_data_1))] + + new_data = np.reshape(substracted_data, shape) + + return new_data + + \ No newline at end of file diff --git a/tools/optimizer_scripts/tools/modhelper.py b/tools/optimizer_scripts/tools/modhelper.py new file mode 100644 index 0000000..5e8302f --- /dev/null +++ b/tools/optimizer_scripts/tools/modhelper.py @@ -0,0 +1,78 @@ +"""This module contains helper functions that do graph modifications. +""" + +import onnx +from . import helper + + +def replace_node_input(node, old_input, new_input): + for i, input_name in enumerate(node.input): + if input_name == old_input: + node.input[i] = new_input + +def delete_nodes(g, node_list): + node_to_delete = [] + #Find target nodes + for node in g.node: + if node.name not in node_list: + continue + else: + node_to_delete.append(node) + if len(node_list) != len(node_to_delete): + print("Some nodes do not exist in the graph. Skipping them.") + for node in node_to_delete: + # Check the node whether if it is valid to delete + if len(node.input) == 0: + print("Deleting an Constant node. Please make sure you also delete all its following nodes") + elif len(node.input) > 1: + print("Warning: Node {} has more than one input. This script cannot delete merge nodes.".format(node.name)) + # Connect the nodes around the target node. + # Set the following node input as the previous node output. + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + if len(node.input) == 0: + for following_node in following_nodes: + following_node.input.remove(node.output[0]) + elif len(following_nodes) > 0 and len(node.input) == 1 and helper.find_input_by_name(g, node.input[0]) is not None: + # The node input is an input + new_input = helper.find_value_by_name(g, node.output[0]) + g.input.append(new_input) + g.input.remove(helper.find_input_by_name(g, node.input[0])) + g.value_info.remove(new_input) + elif len(following_nodes) > 0: + for following_node in following_nodes: + replace_node_input(following_node, node.output[0], node.input[0]) + else: + # If the node is the output, replace the output with the previous input. + value = helper.find_value_by_name(g, node.input[0]) + output_values = [] + while len(g.output): + output_values.append(g.output.pop()) + while output_values: + output_value = output_values.pop() + if output_value.name == node.output[0]: + g.output.extend([value]) + else: + g.output.extend([output_value]) + # Remove the node and value info. + g.node.remove(node) + +def delete_input(g, target_list): + for name in target_list: + input_value = helper.find_input_by_name(g, name) + if input_value is None: + print("Cannot find input {}".format(name)) + continue + g.input.remove(input_value) + +def delete_output(g, target_list): + for name in target_list: + output_value = helper.find_output_by_name(g, name) + if output_value is None: + print("Cannot find output {}".format(name)) + continue + g.output.remove(output_value) + +def delete_value_with_name_if_exists(g, name): + value = helper.find_value_by_name(g, name) + if value is not None: + g.value_info.remove(value) diff --git a/tools/optimizer_scripts/tools/other.py b/tools/optimizer_scripts/tools/other.py new file mode 100644 index 0000000..171179e --- /dev/null +++ b/tools/optimizer_scripts/tools/other.py @@ -0,0 +1,1200 @@ +"""Optimization functions that are not fusing, eliminating or replacing. In most +cases, these are the modifications on the original nodes. +""" +import struct +import collections +import numpy as np +import onnx.helper +import onnxoptimizer as optimizer +import math +import logging +from . import helper +from .modhelper import replace_node_input +import copy +from .helper import logger + + +def polish_model(model): + ''' + This function combines several useful utility functions together. + ''' + onnx.checker.check_model(model) + onnx.helper.strip_doc_string(model) + model = onnx.shape_inference.infer_shapes(model) + model = optimizer.optimize(model) + onnx.checker.check_model(model) + return model + + +def format_value_info_shape(g): + """ + Replace -1 and 0 batch size in value info + + :param g: the onnx graph + """ + for value in g.input: + if len(value.type.tensor_type.shape.dim) > 0 and\ + (value.type.tensor_type.shape.dim[0].dim_value <= 0 or\ + not isinstance(value.type.tensor_type.shape.dim[0].dim_value, int)): + value.type.tensor_type.shape.dim[0].dim_value = 1 + for value in g.output: + if len(value.type.tensor_type.shape.dim) > 0 and\ + (value.type.tensor_type.shape.dim[0].dim_value <= 0 or\ + not isinstance(value.type.tensor_type.shape.dim[0].dim_value, int)): + value.type.tensor_type.shape.dim[0].dim_value = 1 + for value in g.value_info: + if len(value.type.tensor_type.shape.dim) > 0 and\ + (value.type.tensor_type.shape.dim[0].dim_value < 0 or\ + not isinstance(value.type.tensor_type.shape.dim[0].dim_value, int)): + value.type.tensor_type.shape.dim[0].dim_value = 1 + +def add_name_to_node(g): + """ + If no name presents, give a name based on output name. + + :param g: the onnx graph + """ + for node in g.node: + if len(node.name) == 0: + node.name = node.output[0] + +def rename_all_node_name(g): + """ + rename all nodes if the node name is a number: + + new_name = old_name + "_kn" + + :param g: the onnx graph + """ + + for node in g.node: + if not node.name.isdigit(): + # Skip not number names + continue + new_node_name = node.name + "_kn" + new_node_output0_name = node.output[0] + "_kn" + + # in order to keep same output node name, skip if it is output node. + output_value_info = helper.find_output_by_name(g, node.output[0]) + if output_value_info != None: + continue + + # rename the input of all the following nodes + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + replace_node_input(following_node, node.output[0], new_node_output0_name ) + + # rename value info + value_info = helper.find_value_by_name(g, node.output[0]) + if value_info != None: + value_info.name = new_node_output0_name + + # rename node + node.output[0] = new_node_output0_name + node.name = new_node_name + +def add_output_to_value_info(g): + """ + If output does not present in value_info, copy one + + :param g: the onnx graph + """ + for output in g.output: + if helper.find_value_by_name(g, output.name) is None: + g.value_info.extend([output]) + +def find_first_sequential_output(g, node): + for value_name in node.output: + value = helper.find_output_by_name(g, value_name) + if value is not None: + return value + next_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + if len(next_nodes) == 0: + # No following nodes + return None + return find_first_sequential_output(g, next_nodes[0]) + +def remove_nodes(g, cut_nodes=[], cut_types=[]): + node_to_delete = [] + #Find target nodes + for node in g.node: + if node.name not in cut_nodes and node.op_type not in cut_types: + continue + else: + node_to_delete.append(node) + # Mapping originnal outputs to new outputs. This mapping is to keep the output order. + output_mapping = {} + new_output = set() + for node in node_to_delete: + original_output = find_first_sequential_output(g, node) + if original_output.name not in output_mapping: + output_mapping[original_output.name] = [] + for input_name in node.input: + value = helper.find_value_by_name(g, input_name) + if value is not None and helper.find_output_by_name(g, input_name) is None and value.name not in new_output: + output_mapping[original_output.name].append(value) + new_output.add(value.name) + # Remove them + while node_to_delete: + g.node.remove(node_to_delete.pop()) + # Remove unreachable nodes + visited_values = set() + unused_constant_map = {} + for input_value in g.input: + visited_values.add(input_value.name) + for node in g.node: + if node.op_type == 'Constant': + visited_values.add(node.output[0]) + unused_constant_map[node.output[0]] = node + continue + can_reach = True + for input_name in node.input: + if input_name not in visited_values: + can_reach = False + break + if can_reach: + for output_name in node.output: + visited_values.add(output_name) + else: + node_to_delete.append(node) + # Mapping outputs again + for node in node_to_delete: + original_output = find_first_sequential_output(g, node) + if original_output is None: + continue + if original_output.name not in output_mapping: + output_mapping[original_output.name] = [] + for input_name in node.input: + value = helper.find_value_by_name(g, input_name) + if value is not None and helper.find_output_by_name(g, input_name) is None and value.name not in new_output: + output_mapping[original_output.name].append(value) + new_output.add(value.name) + # Remove them + while node_to_delete: + g.node.remove(node_to_delete.pop()) + #Remove unused constants + for node in g.node: + for input_name in node.input: + if input_name in unused_constant_map: + del unused_constant_map[input_name] + for node in unused_constant_map.values(): + g.node.remove(node) + #Remove unreachable value infos + reachable_values = set() + for input_value in g.input: + reachable_values.add(input_value.name) + for node in g.node: + for input_name in node.input: + reachable_values.add(input_name) + for output_name in node.output: + reachable_values.add(output_name) + value_to_remove = [] + for value_info in g.value_info: + if value_info.name not in reachable_values: + value_to_remove.append(value_info) + while value_to_remove: + value_info = value_to_remove.pop() + g.value_info.remove(value_info) + # Reorder output + output_values = [] + while len(g.output): + output_values.append(g.output.pop()) + while output_values: + output_value = output_values.pop() + if output_value.name in reachable_values: + logger.info("Keep output {}".format(output_value.name)) + g.output.extend([output_value]) + elif output_value.name in output_mapping: + real_outputs = [i for i in output_mapping[output_value.name] if i.name in reachable_values] + logger.info("Replace output {} with {}".format(output_value.name, [i.name for i in real_outputs])) + g.output.extend(real_outputs) + else: + logger.info("Abandon output {}".format(output_value.name)) + continue + +def transpose_B_in_Gemm(g): + """ + If transB is set in Gemm, transpose it + + :param g: the onnx graph + """ + for node in g.node: + if node.op_type != 'Gemm': + continue + do_it = False + for attr in node.attribute: + if attr.name == "transB": + if attr.i == 1: + attr.i = 0 + do_it = True + break + if not do_it: + continue + # Transpose the weight and its output value + w_node = helper.find_node_by_output_name(g, node.input[1]) + w_output = helper.find_value_by_name(g, node.input[1]) + dim_0 = w_output.type.tensor_type.shape.dim[0].dim_value + dim_1 = w_output.type.tensor_type.shape.dim[1].dim_value + w_output.type.tensor_type.shape.dim[0].dim_value = dim_1 + w_output.type.tensor_type.shape.dim[1].dim_value = dim_0 + w_node.attribute[0].t.dims[0] = dim_1 + w_node.attribute[0].t.dims[1] = dim_0 + if w_node.attribute[0].t.raw_data: + raw_data = w_node.attribute[0].t.raw_data + fl_data = [i[0] for i in struct.iter_unpack('f', raw_data)] + else: + fl_data = w_node.attribute[0].t.float_data + w = np.reshape(fl_data, (dim_0, dim_1)) + w = w.transpose((1, 0)).flatten() + if w_node.attribute[0].t.raw_data: + buf = struct.pack('%sf' % len(w), *w) + w_node.attribute[0].t.raw_data = buf + else: + for i in range(len(fl_data)): + w_node.attribute[0].t.float_data[i] = w[i] + +def topological_sort(g): + """ + Topological sort all the layers. + Assume a node do not take the same value as more than one inputs. + + :param g: the onnx graph + """ + # TODO: Topological sort on the same branch + # Map from node name to its input degree + in_degree = {} + # Map from value info name to the nodes using it as input + output_nodes = collections.defaultdict(list) + # Map from node name to node object + node_map = {} + to_add = collections.deque() + # init + length = len(g.node) + for _ in range(length): + node = g.node.pop() + node_map[node.name] = node + if len([i for i in node.input if i != '']) == 0: + to_add.append(node.name) + else: + in_degree[node.name] = len([i for i in node.input if i != '']) + for input_name in node.input: + if input_name == '': + continue + output_nodes[input_name].append(node.name) + # sort + # deal with input first + for value_info in g.input: + input_name = value_info.name + for node_name in output_nodes[input_name]: + in_degree[node_name] -= 1 + if in_degree[node_name] == 0: + to_add.append(node_name) + del in_degree[node_name] + # main sort loop + sorted_nodes = [] + while to_add: + node_name = to_add.pop() + node = node_map[node_name] + del node_map[node_name] + sorted_nodes.append(node) + # Expect only one output name for each node + next_node_names = [] + for output_name in node.output: + next_node_names.extend(output_nodes[output_name]) + for next_node_name in next_node_names: + in_degree[next_node_name] -= 1 + if in_degree[next_node_name] == 0: + to_add.append(next_node_name) + del in_degree[next_node_name] + g.node.extend(sorted_nodes) + if in_degree: + raise RuntimeError("Unreachable nodes exist: {}".format(in_degree.keys())) + if node_map: + raise RuntimeError("Unused nodes exist: {}".format(node_map.keys())) + +def remove_zero_value_info(g): + value_info_list = list(g.value_info) + for vi in value_info_list: + if not vi.type.tensor_type.shape.dim: + g.value_info.remove(vi) + + for dim in vi.type.tensor_type.shape.dim: + if dim.dim_value == 0: + g.value_info.remove(vi) + break + +def inference_shapes(m): + while len(m.graph.value_info) > 0: + m.graph.value_info.pop() + g = m.graph + inferencing_shapes = True + while inferencing_shapes: + inferencing_shapes = False + if inference_cov_shape(g): + inferencing_shapes = True + if inference_upsample_shape(g): + inferencing_shapes = True + if inference_resize_shape(g): + inferencing_shapes = True + if inference_split_shape(g): + inferencing_shapes = True + if inferencing_shapes: + topological_sort(g) + m = polish_model(m) + g = m.graph + remove_zero_value_info(g) + m = polish_model(m) + return m + +def inference_resize_shape(g): + for node in g.node: + if node.op_type != 'Resize': + continue + + output_value = helper.find_value_by_name(g, node.output[0]) + output_value = helper.find_output_by_name(g, node.output[0]) if output_value is None else output_value + if output_value is not None: + continue + + if len(node.input) == 4: # input: X, roi, scales, sizes + shape_node = helper.find_node_by_output_name(g, node.input[3]) + if shape_node.op_type != 'Constant': + continue + + _, shape_value = helper.constant_to_list(shape_node) + output_value = onnx.helper.make_tensor_value_info( + node.output[0], + onnx.TensorProto.FLOAT, + [int(v) for v in shape_value]) + g.value_info.extend([output_value]) + return True + else: + # If output shape is not given, inference from scales + # Get the input shape + input_value = helper.find_value_by_name(g, node.input[0]) + if input_value is None: + continue + shape_value = helper.get_shape_from_value_info(input_value) + scales_node = helper.find_node_by_output_name(g, node.input[2]) + if scales_node.op_type != 'Constant': + continue + _, scales_value = helper.constant_to_list(scales_node) + for i in range(len(shape_value)): + shape_value[i] *= scales_value[i] + output_value = onnx.helper.make_tensor_value_info( + node.output[0], + onnx.TensorProto.FLOAT, + [int(v) for v in shape_value]) + g.value_info.extend([output_value]) + return True + return False + +def inference_upsample_shape(g): + """For onnx v1.4.1+, onnx cannot inference upsample output shape. Let's\\ + do it ourselves. This function only inference the next upsample without\\ + output shape each time. + + :param g: the graph\\ + :return: True if any Upsample shape is generated. Otherwise, False. + """ + for node in g.node: + if node.op_type != 'Upsample': + continue + output_value = helper.find_value_by_name(g, node.output[0]) + if output_value is None: + output_value = helper.find_output_by_name(g, node.output[0]) + if output_value and helper.get_shape_from_value_info(output_value): + continue + # Get input shape + input_value = helper.find_value_by_name(g, node.input[0]) + if input_value is None: + continue + #raise RuntimeError("Shape for {} has not been generated.".format(node.input[0])) + if not helper.get_shape_from_value_info(input_value): + continue + #raise RuntimeError("Shape for {} is empty.".format(node.input[0])) + input_shape = helper.get_shape_from_value_info(input_value) + # Get upsample weight + weight_node = helper.find_node_by_output_name(g, node.input[1]) + weight_shape, weight = helper.constant_to_list(weight_node) + if len(input_shape) != weight_shape[0]: + raise RuntimeError("Unmatch input shape and weight shape: {} vs {}".format(input_shape, weight_shape)) + # Calculate shape + output_shape = list(input_shape) + for i in range(len(output_shape)): + output_shape[i] = int(input_shape[i] * weight[i]) + output_value = onnx.helper.make_tensor_value_info( + node.output[0], + input_value.type.tensor_type.elem_type, + output_shape) + g.value_info.extend([output_value]) + return True + return False + +def inference_cov_shape(g): + processed = False + for node in g.node: + # Check for Conv output shape need to be inferrenced. + if node.op_type != 'Conv': + continue + # Input shape is not ready yet. Skip. + input_value_info = helper.find_value_by_name(g, node.input[0]) + if not input_value_info: + input_value_info = helper.find_input_by_name(g, node.input[0]) + if not input_value_info: + continue + _, input_shape = helper.find_size_shape_from_value(input_value_info) + if not input_shape: + continue + # Output shape is already there. Skip. + output_value_info = helper.find_value_by_name(g, node.output[0]) + if not output_value_info: + output_value_info = helper.find_output_by_name(g, node.output[0]) + if output_value_info and \ + helper.get_shape_from_value_info(output_value_info): + continue + + # Now start the inference. + # Check kernel shape + kernel_value_info = helper.find_value_by_name(g, node.input[1]) + _, kernel_shape = helper.find_size_shape_from_value(kernel_value_info) + if not kernel_shape: + continue + # If auto_pad is set, use the auto_pad. + auto_pad = helper.get_var_attribute_by_name(node, 'auto_pad', 'string') + pads = None + if auto_pad is not None and auto_pad != 'NOTSET': + if auto_pad == 'SAME_LOWER' or auto_pad == 'SAME_UPPER': + new_output_value_info = onnx.helper.make_tensor_value_info( + node.output[0], + input_value_info.type.tensor_type.elem_type, + [input_shape[0], kernel_shape[0], input_shape[2], input_shape[3]] + ) + if output_value_info: + g.value_info.remove(output_value_info) + g.value_info.extend([new_output_value_info]) + processed = True + continue + elif auto_pad == 'VALID': + pads = [0, 0, 0, 0] + else: + logger.error("Unrecognized auto_pad value: " + str(auto_pad)) + exit(1) + + strides = helper.get_attribute_by_name(node, 'strides').ints + if not pads: + pads = helper.get_attribute_by_name(node, 'pads').ints + dilation = helper.get_attribute_by_name(node, 'dilations').ints + + # Pytorch model has the case where strides only have one number + if len(strides) == 1: + strides.append(strides[0]) + if len(dilation) == 1: + dilation.append(dilation[0]) + + H = math.floor((input_shape[2]+pads[0]+pads[2]-\ + dilation[0]*(kernel_shape[2]-1)-1)/strides[0]+1) + W = math.floor((input_shape[3]+pads[1]+pads[3]-\ + dilation[1]*(kernel_shape[3]-1)-1)/strides[1]+1) + output_shape = [input_shape[0], kernel_shape[0], H, W] + + new_output_value_info = onnx.helper.make_tensor_value_info( + node.output[0], + input_value_info.type.tensor_type.elem_type, + output_shape + ) + + processed = True + + if output_value_info: + g.value_info.remove(output_value_info) + g.value_info.extend([new_output_value_info]) + + return processed + + +def inference_split_shape(g): + processed = False + for node in g.node: + if node.op_type != 'Split': + continue + + input_val_info = helper.find_value_by_name(g, node.input[0]) + if not input_val_info: + input_val_info = helper.find_input_by_name(g, node.input[0]) + if not input_val_info: + continue + + _, input_shape = helper.find_size_shape_from_value(input_val_info) + if not input_shape: + continue + + output_val_names = list(node.output) + output_vals = [helper.find_value_by_name(g, val_name) for val_name in output_val_names] + + output_shapes = [helper.find_size_shape_from_value(output_val)[1] for output_val in output_vals] + if not any([len(s) == 0 for s in output_shapes]): + continue + + for att in node.attribute: + if att.name == 'axis': + axis = att.i + else: + split = list(att.ints) + + new_output_vals = [] + for i in range(len(output_val_names)): + new_shape = list(input_shape) + new_shape[axis] = split[i] + new_output_val = onnx.helper.make_tensor_value_info( + output_val_names[i], + input_val_info.type.tensor_type.elem_type, + new_shape + ) + new_output_vals.append(new_output_val) + + for val in output_vals: + if val is not None: + g.value_info.remove(val) + g.value_info.extend(new_output_vals) + + processed = True + + return processed + + +def parse_shape_change_input(s: str): + """The input should be like 'input 1 1 224 224'. + """ + s_list = s.split(' ') + if len(s_list) < 2: + print("Cannot parse the shape change input: {}".format(s)) + return None + shape = [] + for i in range(1, len(s_list)): + shape.append(int(s_list[i])) + return s_list[0], shape + +def change_input_shape(g, target_list): + for target in target_list: + try: + name, shape = parse_shape_change_input(target) + input_value = helper.find_input_by_name(g, name) + if input_value is None: + print("Cannot find input {}".format(name)) + continue + if len(shape) != len(input_value.type.tensor_type.shape.dim): + print("The dimension doesn't match for input {}".format(name)) + continue + for i in range(len(shape)): + input_value.type.tensor_type.shape.dim[i].dim_value = shape[i] + except TypeError: + # This happens when the parser function returns None. + continue + except ValueError: + # This happens when the input cannot be converter into int + print("Cannot parse {} into name and int".format(target)) + continue + +def change_output_shape(g, target_list): + for target in target_list: + try: + name, shape = parse_shape_change_input(target) + output_value = helper.find_output_by_name(g, name) + if output_value is None: + print("Cannot find output {}".format(name)) + continue + if len(shape) != len(output_value.type.tensor_type.shape.dim): + print("The dimension doesn't match for output {}".format(name)) + continue + for i in range(len(shape)): + output_value.type.tensor_type.shape.dim[i].dim_value = shape[i] + except TypeError: + # This happens when the parser function returns None. + continue + except ValueError: + # This happens when the input cannot be converter into int + print("Cannot parse {} into name and int".format(target)) + continue + +def add_nop_conv_after(g, value_names): + """Add do-nothing depthwise Conv nodes after the given value info. It will\\ + take the given names as the inputs of the new node and replace the inputs\\ + of the following nodes. + + :param g: the graph\\ + :param value_names: a list of string which are the names of value_info. + """ + for value_name in value_names: + # Find the value first + value = helper.find_value_by_name(g, value_name) + if value is None: + value = helper.find_input_by_name(g, value_name) + if value is None: + value = helper.find_output_by_name(g, value_name) + if value is None: + print("Cannot find an value_info named {}".format(value_name)) + continue + # Get the channel number from value info + shape = helper.get_shape_from_value_info(value) + channel = shape[1] + # Construct 4 weights + node_name = value_name + "_nop_conv" + ones = [1.0] * channel + weight_node = helper.list_to_constant(node_name + "_weight", [channel, 1, 1, 1], ones) + # Construct BN node + conv_node = onnx.helper.make_node( + "Conv", + [value_name, + weight_node.output[0]], + [node_name], + name = node_name, + dilations = [1, 1], + group = channel, + kernel_shape = [1, 1], + pads = [0, 0, 0, 0], + strides = [1, 1] + ) + # Reconnect the graph + following_nodes = helper.find_following_nodes_by_input_value_name(g, value_name) + if len(following_nodes) > 0: + for following_node in following_nodes: + replace_node_input(following_node, value_name, node_name) + else: + # If the node is the output, replace the output with the previous input. + new_value = onnx.helper.make_tensor_value_info( + node_name, + value.type.tensor_type.elem_type, + shape + ) + output_values = [] + while len(g.output): + output_values.append(g.output.pop()) + while output_values: + output_value = output_values.pop() + if output_value.name == value_name: + g.output.extend([new_value]) + else: + g.output.extend([output_value]) + # Add node to the graph + g.node.extend([conv_node, weight_node]) + topological_sort(g) + +def add_nop_bn_after(g, value_names): + """Add do-nothing BatchNormalization nodes after the given value info. It will\\ + take the given names as the inputs of the new node and replace the inputs\\ + of the following nodes. + + :param g: the graph\\ + :param value_names: a list of string which are the names of value_info. + """ + for value_name in value_names: + # Find the value first + value = helper.find_value_by_name(g, value_name) + if value is None: + value = helper.find_input_by_name(g, value_name) + if value is None: + value = helper.find_output_by_name(g, value_name) + if value is None: + print("Cannot find an value_info named {}".format(value_name)) + continue + # Get the channel number from value info + shape = helper.get_shape_from_value_info(value) + channel = shape[1] + # Construct 4 weights + node_name = value_name + "_nop_bn" + ones = [1.0] * channel + zeros = [0.0] * channel + scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) + bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) + mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) + var_node = helper.list_to_constant(node_name + "_var", [channel], ones) + # Construct BN node + bn_node = onnx.helper.make_node( + "BatchNormalization", + [value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0]], + [node_name], + name = node_name + ) + # Reconnect the graph + following_nodes = helper.find_following_nodes_by_input_value_name(g, value_name) + if len(following_nodes) > 0: + for following_node in following_nodes: + replace_node_input(following_node, value_name, node_name) + else: + # If the node is the output, replace the output with the previous input. + new_value = onnx.helper.make_tensor_value_info( + node_name, + value.type.tensor_type.elem_type, + shape + ) + output_values = [] + while len(g.output): + output_values.append(g.output.pop()) + while output_values: + output_value = output_values.pop() + if output_value.name == value_name: + g.output.extend([new_value]) + else: + g.output.extend([output_value]) + # Add node to the graph + g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) + topological_sort(g) + +def add_bias_scale_bn_after(g, value_name, channel_bias, channel_scale): + """Add do-nothing BatchNormalization nodes after the given value info. It will\\ + take the given names as the inputs of the new node and replace the inputs\\ + of the following nodes. + + :param g: the graph\\ + :param value_name: a list of string which are the name of value_info. + """ + # Find the value first + value = helper.find_value_by_name(g, value_name) + if value is None: + value = helper.find_input_by_name(g, value_name) + if value is None: + value = helper.find_output_by_name(g, value_name) + if value is None: + print("Cannot find an value_info named {}".format(value_name)) + return + # Get the channel number from value info + shape = helper.get_shape_from_value_info(value) + channel = shape[1] + # Construct 4 weights + node_name = value_name + "_scale_shift_bn" + ones = [1.0] * channel + zeros = [0.0] * channel + scale_node = helper.list_to_constant(node_name + "_scale", [len(channel_scale)], channel_scale) + bias_node = helper.list_to_constant(node_name + "_bias", [len(channel_bias)], channel_bias) + mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) + var_node = helper.list_to_constant(node_name + "_var", [channel], ones) + # Construct BN node + bn_node = onnx.helper.make_node( + "BatchNormalization", + [value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0]], + [node_name], + name = node_name + ) + # Reconnect the graph + following_nodes = helper.find_following_nodes_by_input_value_name(g, value_name) + if len(following_nodes) > 0: + for following_node in following_nodes: + replace_node_input(following_node, value_name, node_name) + else: + # If the node is the output, replace the output with the previous input. + new_value = onnx.helper.make_tensor_value_info( + node_name, + value.type.tensor_type.elem_type, + shape + ) + output_values = [] + while len(g.output): + output_values.append(g.output.pop()) + while output_values: + output_value = output_values.pop() + if output_value.name == value_name: + g.output.extend([new_value]) + else: + g.output.extend([output_value]) + # Add node to the graph + g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) + topological_sort(g) + +def duplicate_shared_Flatten(g): + """To feed our compiler, bind Flatten with Gemm. If the output of one\\ + Flatten goes to two Gemm nodes, duplicate the Flatten. + + :param g: the graph + """ + for node in g.node: + # Find a Flatten node + if node.op_type != 'Flatten': + continue + # Check Flatten outputs. Get following Gemm + output_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + if len(output_nodes) < 2: + continue + gemm_nodes = [] + for output_node in output_nodes: + if output_node.op_type == 'Gemm': + gemm_nodes.append(output_node) + if len(gemm_nodes) < 2: + continue + # Process all the Gemm nodes except for the first one. + for i in range(1, len(gemm_nodes)): + # Duplicate + new_flatten_name = node.name + "_copy" + str(i) + new_flatten_node = onnx.helper.make_node( + "Flatten", + node.input, + [new_flatten_name], + name=new_flatten_name, + axis=1 + ) + # Connect new graph + replace_node_input(gemm_nodes[i], node.output[0], new_flatten_name) + g.node.extend([new_flatten_node]) + topological_sort(g) + +def deconv_to_conv_info_extraction(input_size, node_proto): + """Extract the information needed for deconv split. + + :param input_size: input shape of the deconv node.\\ + :param node_proto: the deconv node proto.\\ + :return: a dictionary of extracted params. + """ + attr = dict() + # Get attributes from Deconv node + attr["auto_pad"] = helper.get_var_attribute_by_name(node_proto, "auto_pad", "string") + attr["dilations"] = helper.get_list_attribute_by_name(node_proto, "dilations", "int") + attr["group"] = helper.get_var_attribute_by_name(node_proto, "group", "int") + attr["kernel_shape"] = helper.get_list_attribute_by_name(node_proto, "kernel_shape", "int") + attr["output_padding"] = helper.get_list_attribute_by_name(node_proto, "output_padding", "int") + attr["pads"] = helper.get_list_attribute_by_name(node_proto, "pads", "int") + attr["strides"] = helper.get_list_attribute_by_name(node_proto, "strides", "int") + # Get output_padding + if attr["output_padding"] is None: + if attr["auto_pad"] == "SAME_LOWER" or attr["auto_pad"] == "SAME_UPPER": + attr["output_padding"] = [attr["strides"][0] - 1, attr["strides"][1]] + else: + attr["output_padding"] = [max(attr["strides"][0] - attr["kernel_shape"][0], 0), + max(attr["strides"][1] - attr["kernel_shape"][1], 0)] + # Calculate conv_padding + if attr["auto_pad"] == "SAME_LOWER" or attr["auto_pad"] == "SAME_UPPER": + pad1_h = attr["kernel_shape"][0] - (attr["kernel_shape"][0] - 1) // 2 - 1 + pad1_w = attr["kernel_shape"][1] - (attr["kernel_shape"][1] - 1) // 2 - 1 + head_h = min(attr["kernel_shape"][0] // 2, (attr["output_padding"][0] + 1) // 2) + head_w = min(attr["kernel_shape"][1] // 2, (attr["output_padding"][1] + 1) // 2) + tail_h = attr["output_padding"][0] - head_h + tail_w = attr["output_padding"][1] - head_w + attr["conv_pads"] = [pad1_h + head_h, pad1_w + head_w, pad1_h + tail_h, pad1_w + tail_w] + elif attr["pads"] is not None: + sum_of_pads = sum(attr["pads"]) + if sum_of_pads == 0: + # Valid padding + pad1_h = attr["kernel_shape"][0] - 0 - 1 + pad1_w = attr["kernel_shape"][1] - 0 - 1 + head_h = 0 + head_w = 0 + tail_h = attr["output_padding"][0] - head_h + tail_w = attr["output_padding"][1] - head_w + attr["conv_pads"] = [pad1_h + head_h, pad1_w + head_w, pad1_h + tail_h, pad1_w + tail_w] + else: + # Calculate output shape + tmp_output_shape = [0, 0] + tmp_output_shape[0] = attr["strides"][0] * (input_size[2] - 1) + attr["output_padding"][0] + attr["kernel_shape"][0] - attr["pads"][0] - attr["pads"][2] + tmp_output_shape[1] = attr["strides"][1] * (input_size[3] - 1) + attr["output_padding"][1] + attr["kernel_shape"][1] - attr["pads"][1] - attr["pads"][3] + # Calculate real conv output shape + tmp_center_shape = [0, 0] + tmp_center_shape[0] = (input_size[2] - 1) * attr["strides"][0] + 1 + tmp_center_shape[1] = (input_size[3] - 1) * attr["strides"][1] + 1 + # Calculate padding + total_padding = [0, 0] + total_padding[0] = tmp_output_shape[0] - tmp_center_shape[0] + attr["kernel_shape"][0] - 1 + total_padding[1] = tmp_output_shape[1] - tmp_center_shape[1] + attr["kernel_shape"][1] - 1 + if total_padding[0] < 0 or total_padding[1] < 0: + raise RuntimeError(node_proto.name + " cannot infer conv padding.") + conv_pads_ = [0] * 4 + conv_pads_[0] = total_padding[0] // 2 + conv_pads_[1] = total_padding[1] // 2 + conv_pads_[2] = total_padding[0] - total_padding[0] // 2 + conv_pads_[3] = total_padding[1] - total_padding[1] // 2 + attr["conv_pads"] = conv_pads_ + else: + pad1_h = attr["kernel_shape"][0] - 0 - 1 + pad1_w = attr["kernel_shape"][1] - 0 - 1 + head_h = 0 + head_w = 0 + tail_h = attr["output_padding"][0] - head_h + tail_w = attr["output_padding"][1] - head_w + attr["conv_pads"] = [pad1_h + head_h, pad1_w + head_w, pad1_h + tail_h, pad1_w + tail_w] + return attr + +def split_ConvTranspose(model): + """To feed our compiler, split ConvTranspose into Upsample and Conv. + + :param model: the model + """ + node_to_delete = [] + # Change model properties for upsample. + if model.ir_version < 3: + print("Warning: Current model IR version is not fully supported.") + model.ir_version = 4 + model.opset_import[0].version = 9 + g = model.graph + # Get a Convtranspose layer + for node in g.node: + # Find a Flatten node + if node.op_type != 'ConvTranspose': + continue + # Check auto_pad + auto_pad_proto = helper.get_attribute_by_name(node, "auto_pad") + if auto_pad_proto is not None: + print("Currently not split auto_pad ConvTranspose") + continue + # Check output_shape + output_shape_proto = helper.get_attribute_by_name(node, "output_shape") + if output_shape_proto is not None: + print("Currently not split output_shape ConvTranspose") + continue + # Get input shape + input_value = helper.find_value_by_name(g, node.input[0]) + if input_value is None: + input_value = helper.find_input_by_name(g, node.input[0]) + if input_value is None: + print("Cannot get value info named {}.".format(node.input[0])) + exit(1) + input_shape = helper.get_shape_from_value_info(input_value) + # Get attrbutes + attr = deconv_to_conv_info_extraction(input_shape, node) + # Generate Upsample scales + upsample_output_shape = list(input_shape) + upsample_output_shape[2] = (input_shape[2] - 1) * attr["strides"][0] + 1 + upsample_output_shape[3] = (input_shape[3] - 1) * attr["strides"][1] + 1 + upsample_node_name = node.name + "_inner_upsample" + upsample_scale_name = upsample_node_name + "_scales" + scales_np = np.ones([4]).astype('float32') + scales_np[2] = float(upsample_output_shape[2]) / input_shape[2] + scales_np[3] = float(upsample_output_shape[3]) / input_shape[3] + scales_node = helper.numpy_to_constant(upsample_scale_name, scales_np) + # Generate a Upsample layer and an internal value info + upsample_node = onnx.helper.make_node( + "Upsample", + [node.input[0], upsample_scale_name], + [upsample_node_name], + name=upsample_node_name, + mode="zeros" + ) + upsample_value_info = onnx.helper.make_tensor_value_info( + upsample_node_name, + input_value.type.tensor_type.elem_type, + upsample_output_shape + ) + # Check the weight layer, it may need a transpose + if attr["group"] != input_shape[1]: + weight_node = helper.find_node_by_output_name(g, node.input[1]) + weight_np = helper.constant_to_numpy(weight_node) + new_weight_np = np.transpose(weight_np, [1, 0, 2, 3]) + new_weight_node = helper.numpy_to_constant(node.input[1], new_weight_np) + node_to_delete.append(weight_node) + g.node.extend([new_weight_node]) + value = helper.find_value_by_name(g, node.input[1]) + g.value_info.remove(value) + # Generate a Conv layer + conv_node_name = node.name + "_inner_conv" + conv_node_input = [upsample_node_name] + conv_node_input.extend(node.input[1:]) + conv_node = onnx.helper.make_node( + "Conv", + conv_node_input, + [node.output[0]], + name=conv_node_name, + pads=[int(i) for i in attr["conv_pads"]], + dilations=[int(i) for i in attr["dilations"]], + group=int(attr["group"]), + kernel_shape=[int(i) for i in attr["kernel_shape"]], + strides=[int(1), int(1)] + ) + # Reconnect the graph + g.node.extend([scales_node, upsample_node, conv_node]) + g.value_info.extend([upsample_value_info]) + node_to_delete.append(node) + # Delete useless nodes + for node in node_to_delete: + g.node.remove(node) + topological_sort(g) + +def add_bn_on_skip_branch(g): + for n in g.node: + # Find merge node (Add) + if n.op_type != 'Add': + continue + if len(n.input) != 2: + continue + # TODO: Still need to consider more cases + # Check if skip branch exist + input_node_a = helper.find_node_by_output_name(g, n.input[0]) + output_of_input_node_a = helper.find_nodes_by_input_name(g, input_node_a.output[0]) + input_node_b = helper.find_node_by_output_name(g, n.input[1]) + output_of_input_node_b = helper.find_nodes_by_input_name(g, input_node_b.output[0]) + if len(output_of_input_node_a) == 1 and len(output_of_input_node_b) == 1: + continue + if len(output_of_input_node_a) == 2: + split_node = input_node_a + elif len(output_of_input_node_b) == 2: + split_node = input_node_b + else: + continue + # Get the channel number from value info + value_name = split_node.output[0] + value = helper.find_value_by_name(g, value_name) + shape = helper.get_shape_from_value_info(value) + channel = shape[1] + # Construct 4 weights + node_name = value_name + "_nop_bn" + ones = [1.0] * channel + zeros = [0.0] * channel + scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) + bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) + mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) + var_node = helper.list_to_constant(node_name + "_var", [channel], ones) + # Construct BN node + bn_node = onnx.helper.make_node( + "BatchNormalization", + [value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0]], + [node_name], + name = node_name + ) + # Reconnect the graph + replace_node_input(n, value_name, node_name) + # Add node to the graph + g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) + topological_sort(g) + +def add_bn_before_add(g): + for n in g.node: + # Find merge node (Add) + if n.op_type != 'Add': + continue + if len(n.input) != 2: + continue + # Get two inputs + input_node_a = helper.find_node_by_output_name(g, n.input[0]) + input_node_b = helper.find_node_by_output_name(g, n.input[1]) + # Skip constant input add + if input_node_a is None or input_node_a.op_type == 'Constant': + continue + if input_node_b is None or input_node_b.op_type == 'Constant': + continue + def add_bn_after(prev_node): + # Get the channel number from value info + value_name = prev_node.output[0] + value = helper.find_value_by_name(g, value_name) + shape = helper.get_shape_from_value_info(value) + channel = shape[1] + # Construct 4 weights + node_name = value_name + "_nop_bn" + ones = [1.0] * channel + zeros = [0.0] * channel + scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) + bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) + mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) + var_node = helper.list_to_constant(node_name + "_var", [channel], ones) + # Construct BN node + bn_node = onnx.helper.make_node( + "BatchNormalization", + [value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0]], + [node_name], + name = node_name, + epsilon=0.00000001 + ) + # Reconnect the graph + replace_node_input(n, value_name, node_name) + # Add node to the graph + g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) + if not input_node_a.op_type == 'BatchNormalization' or len(helper.find_following_nodes_by_input_value_name(g, input_node_a.output[0])) > 1: + add_bn_after(input_node_a) + if not input_node_b.op_type == 'BatchNormalization' or len(helper.find_following_nodes_by_input_value_name(g, input_node_b.output[0])) > 1: + add_bn_after(input_node_b) + topological_sort(g) + +def add_bn_before_activation(g): + activation_nodes = set(['Relu', 'Clip', 'PRelu', 'LeakyRelu']) + previous_nodes = set(['Conv', 'BatchNormalization']) + for n in g.node: + # Find activation node + if n.op_type not in activation_nodes: + continue + # Get input + input_node = helper.find_node_by_output_name(g, n.input[0]) + if input_node is None or input_node.op_type in previous_nodes: + continue + def add_bn_after(prev_node): + # Get the channel number from value info + value_name = prev_node.output[0] + value = helper.find_value_by_name(g, value_name) + shape = helper.get_shape_from_value_info(value) + channel = shape[1] + # Construct 4 weights + node_name = value_name + "_nop_bn" + ones = [1.0] * channel + zeros = [0.0] * channel + scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) + bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) + mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) + var_node = helper.list_to_constant(node_name + "_var", [channel], ones) + # Construct BN node + bn_node = onnx.helper.make_node( + "BatchNormalization", + [value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0]], + [node_name], + name = node_name, + epsilon=0.00000001 + ) + # Reconnect the graph + replace_node_input(n, value_name, node_name) + # Add node to the graph + g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) + add_bn_after(input_node) + topological_sort(g) + +def rename_output_name(g, original_name, new_name): + # Output + output_value = helper.find_output_by_name(g, original_name) + if output_value is None: + logging.error("Cannot find output value named " + original_name) + return + output_value.name = new_name + # Value Info + value_info = helper.find_value_by_name(g, original_name) + if value_info is not None: + value_info.name = new_name + # Node output + node = helper.find_node_by_output_name(g, original_name) + node.output[0] = new_name + # Node input + nodes = helper.find_nodes_by_input_name(g, original_name) + for node in nodes: + replace_node_input(node, original_name, new_name) + +def duplicate_param_shared_constant(g): + for node in g.node: + input_names = set() + for n, input_node_name in enumerate(node.input): + param_data_node = helper.find_node_by_output_name(g, input_node_name) + if param_data_node is None or param_data_node.op_type != 'Constant': + continue + if param_data_node.name not in input_names: + input_names.add(input_node_name) + continue + + new_node_name = param_data_node.name + '_' + str(n) + helper.logger.debug(f"Duplicating weight: {param_data_node.name} -> {new_node_name}") + duplicated_node = copy.deepcopy(param_data_node) + + duplicated_node.name = new_node_name + duplicated_node.output[0] = new_node_name + + node.input[n] = new_node_name + g.node.extend([duplicated_node]) diff --git a/tools/optimizer_scripts/tools/removing_transpose.py b/tools/optimizer_scripts/tools/removing_transpose.py new file mode 100644 index 0000000..d0b7882 --- /dev/null +++ b/tools/optimizer_scripts/tools/removing_transpose.py @@ -0,0 +1,317 @@ +from . import helper +from . import other +from . import modhelper +from . import fusing +import numpy as np +import onnx +import onnx.utils + +def eliminate_transposes(m): + g = m.graph + keep_eliminating = True + while keep_eliminating: + while swap_transpose_with_single_next_node(g): + pass + splitted = split_transpose_for_multiple_next_nodes(g) + annihilated = annihilate_transposes(g) + multiple_trans_swapped = swap_multiple_transposes_with_node(g) + keep_eliminating = splitted or annihilated or multiple_trans_swapped + + if keep_eliminating: + m = other.polish_model(m) + g = m.graph + + return m + + +def swap_transpose_with_single_next_node(g): + swapped = False + passable_nodes = set(['Relu', 'Neg', 'LeakyRelu', 'Sqrt', 'Reciprocal', 'Add', 'Mul', 'Tanh']) + for node in g.node: + trans_node = node + # Check for transpose node + if trans_node.op_type != 'Transpose': + continue + next_nodes = helper.find_nodes_by_input_name(g, trans_node.output[0]) + if len(next_nodes) != 1: + continue + next_node = next_nodes[0] + # Check if the next node is the type can be swapped + if next_node.op_type not in passable_nodes: + continue + + input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in next_node.input] + + # Check if the node has nonconstant input other than the Transpose node itself + nonconstant_input = False + for input_node in input_nodes: + if input_node == None: + nonconstant_input = True + break + if input_node.name == trans_node.name: + continue + elif input_node.op_type == 'Constant': + continue + else: + nonconstant_input = True + break + if nonconstant_input: + continue + + for input_node in input_nodes: + if input_node.name == trans_node.name: + # if the input is just the transpose node + next_value_info = helper.find_value_by_name(g, next_node.output[0]) + mid_value_info = helper.find_value_by_name(g, trans_node.output[0]) + + output_nodes = helper.find_nodes_by_input_name(g, next_node.output[0]) + for out_node in output_nodes: + modhelper.replace_node_input(out_node, next_node.output[0], trans_node.name) + + next_node.input[0] = trans_node.input[0] + next_node.output[0] = next_node.name + trans_node.input[0] = next_node.name + trans_node.output[0] = trans_node.name + + if next_value_info: + next_value_info.name = trans_node.name + if mid_value_info: + g.value_info.remove(mid_value_info) + else: + # if the input is a constant node + old_tensor = input_node.attribute[0].t + old_shape, data = helper.constant_to_list(input_node) + # If the constant node is a scaler, no action is needed + if type(old_shape) == int: + old_shape = [old_shape] + permutation = list(trans_node.attribute[0].ints) + while len(old_shape) < len(permutation): + old_shape.insert(0, 1) + np_data = np.reshape(data, old_shape) + reverse_perm = [] + for i in range(len(permutation)): + reverse_perm.append(permutation.index(i)) + np_data = np.transpose(np_data, reverse_perm) + new_shape = np_data.shape + new_tensor = onnx.helper.make_tensor( + name=old_tensor.name, + data_type=old_tensor.data_type, + dims=new_shape, + vals=np_data.flatten().tolist() + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [input_node.output[0]], + name=input_node.name, + value=new_tensor + ) + g.node.extend([new_node]) + + g.value_info.remove(helper.find_value_by_name(g, input_node.output[0])) + g.node.remove(input_node) + + swapped = True + + other.topological_sort(g) + return swapped + + +def swap_multiple_transposes_with_node(g): + # here only consider same input transposes + swapped = False + passable_nodes = set(['Add', 'Mul']) + node_to_del = [] + for node in g.node: + if node.op_type not in passable_nodes: + continue + input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in node.input] + if any([input_node == None for input_node in input_nodes]): + continue + if any([input_node.op_type != 'Transpose' for input_node in input_nodes]): + continue + + permutation = list(input_nodes[0].attribute[0].ints) + if any([list(input_node.attribute[0].ints) != permutation for input_node in input_nodes]): + continue + + for input_name in node.input: + input_node = helper.find_node_by_output_name(g, input_name) + modhelper.replace_node_input(node, input_name, input_node.input[0]) + + node_to_del.extend(input_nodes) + for input_node in input_nodes: + input_val_info = helper.find_value_by_name(g, input_node.output[0]) + if input_val_info is not None: + g.value_info.remove(input_val_info) + output_val_info = helper.find_value_by_name(g, node.output[0]) + if output_val_info is not None: + g.value_info.remove(output_val_info) + + output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + for i in range(len(output_nodes)): + new_trans_node_name = node.name+'_trans_'+str(i) + new_trans_node = onnx.helper.make_node( + 'Transpose', + [node.output[0]], + [new_trans_node_name], + name=new_trans_node_name, + perm=permutation + ) + modhelper.replace_node_input(output_nodes[i], node.output[0], new_trans_node_name) + + g.node.extend([new_trans_node]) + + swapped = True + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + other.topological_sort(g) + return swapped + + +def annihilate_transposes(g): + node_to_del = [] + annihilated = False + for node in g.node: + if node.op_type != 'Transpose': + continue + pre_node = helper.find_node_by_output_name(g, node.input[0]) + if not pre_node or pre_node.op_type != 'Transpose': + continue + nodes_from_top_transpose = helper.find_nodes_by_input_name(g, pre_node.output[0]) + if len(nodes_from_top_transpose) > 1: + continue + + perm_1 = list(pre_node.attribute[0].ints) + perm_2 = list(node.attribute[0].ints) + if perm_1 != perm_2: + continue + + out_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + for out_node in out_nodes: + modhelper.replace_node_input(out_node, node.output[0], pre_node.input[0]) + + node_to_del.extend([node, pre_node]) + mid_value_info = helper.find_value_by_name(g, pre_node.output[0]) + out_value_info = helper.find_value_by_name(g, node.output[0]) + g.value_info.remove(mid_value_info) + g.value_info.remove(out_value_info) + + annihilated = True + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return annihilated + + +def split_transpose_for_multiple_next_nodes(g): + splitted = False + node_to_del = [] + for node in g.node: + if node.op_type != 'Transpose': + continue + output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + if len(output_nodes) < 2: + continue + for i in range(len(output_nodes)): + output_node = output_nodes[i] + new_trans_node_name = node.name + '_' + str(i) + new_trans_node = onnx.helper.make_node( + 'Transpose', + [node.input[0]], + [new_trans_node_name], + name=new_trans_node_name, + perm=list(node.attribute[0].ints) + ) + modhelper.replace_node_input(output_node, node.output[0], new_trans_node.output[0]) + g.node.extend([new_trans_node]) + + node_to_del.append(node) + val_info = helper.find_value_by_name(g, node.output[0]) + g.value_info.remove(val_info) + + splitted = True + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + other.topological_sort(g) + return splitted + +def remove_trivial_transpose(g): + node_to_del = [] + for node in g.node: + if node.op_type != 'Transpose': + continue + permutation = list(node.attribute[0].ints) + if permutation != list(range(len(permutation))): + continue + + next_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + if not next_nodes: + input_val_info = helper.find_value_by_name(g, node.input[0]) + out_val_info = helper.find_output_by_name(g, node.output[0]) + if not input_val_info: + input_val_info = helper.find_input_by_name(g, node.input[0]) + g.output.remove(out_val_info) + g.output.extend([input_val_info]) + else: + out_val_info = helper.find_value_by_name(g, node.output[0]) + for next_node in next_nodes: + modhelper.replace_node_input(next_node, node.output[0], node.input[0]) + g.value_info.remove(out_val_info) + + node_to_del.append(node) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + other.topological_sort(g) + +def fuse_Transpose_into_Gemm_weight(g): + node_to_del = [] + for node in g.node: + # Check pattern + if node.op_type != 'Gemm': + continue + prev_node = helper.find_node_by_output_name(g, node.input[0]) + if prev_node is None or prev_node.op_type != 'Flatten': + continue + transpose_node = helper.find_node_by_output_name(g, prev_node.input[0]) + if transpose_node.op_type != 'Transpose': + continue + # Check attribute + perm = helper.get_list_attribute_by_name(transpose_node, 'perm', 'int') + if perm != [0, 2, 3, 1]: + continue + transB = helper.get_var_attribute_by_name(node, 'transB', 'int') + if transB is not None and transB == 1: + continue + # Get the original weight + origin_weight = helper.find_node_by_output_name(g, node.input[1]) + origin_np = helper.constant_to_numpy(origin_weight) + # Calculate a new weight + shape = helper.get_shape_from_value_info(helper.find_value_by_name(g, prev_node.input[0])) + shape.append(-1) + new_np = np.reshape(origin_np, shape) + new_np = np.transpose(new_np, [0, 3, 1, 2, 4]) + new_np = np.reshape(new_np, [-1, new_np.shape[-1]]) + new_weight = helper.numpy_to_constant(origin_weight.output[0], new_np) + # Replace and eliminate + prev_node.input[0] = transpose_node.input[0] + node_to_del.append(transpose_node) + node_to_del.append(origin_weight) + g.value_info.remove(helper.find_value_by_name(g, transpose_node.output[0])) + g.node.extend([new_weight]) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + other.topological_sort(g) diff --git a/tools/optimizer_scripts/tools/replacing.py b/tools/optimizer_scripts/tools/replacing.py new file mode 100644 index 0000000..091e571 --- /dev/null +++ b/tools/optimizer_scripts/tools/replacing.py @@ -0,0 +1,1171 @@ +"""Optimizations that replace one node with another. +""" +from os import dup +import struct +import copy +import logging +import onnx.helper +import numpy as np +from . import helper +from . import modhelper +from .other import topological_sort + +def replace_initializer_with_Constant(g, duplicate_shared_weights=True): + """ + Replace initializers with Constant and a corresponding value_info + If the initializer has related input, remove it. + + :param g: the onnx graph + """ + + input_map = {i.name: i for i in g.input} + for tensor in g.initializer: + # Check for the initializer related input and remove it + if tensor.name in input_map: + value_info = input_map[tensor.name] + g.input.remove(value_info) + following_nodes = helper.find_nodes_by_input_name(g, tensor.name) + if duplicate_shared_weights and len(following_nodes) >= 2: + for i, node in enumerate(following_nodes): + new_name = tensor.name + "_duplicated_No" + str(i) if i > 0 else tensor.name + helper.logger.debug(f"Duplicating weight: {tensor.name} -> {new_name}") + modhelper.replace_node_input(node, tensor.name, new_name) + new_node = onnx.helper.make_node( + "Constant", + [], + [new_name], + name=new_name, + value=tensor + ) + # Add node to lists + g.node.extend([new_node]) + else: + new_name = tensor.name + new_node = onnx.helper.make_node( + "Constant", + [], + [new_name], + name=new_name, + value=tensor + ) + # Add node to lists + g.node.extend([new_node]) + + # if value info already exists, remove it as well. + value_info = helper.find_value_by_name(g, tensor.name) + if value_info is not None: + g.value_info.remove(value_info) + + # Remove original initializer + while len(g.initializer) != 0: + g.initializer.pop() + + topological_sort(g) + +def replace_Reshape_with_Flatten(g): + """ + Replace Reshape node into Flatten node if applicable. + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Reshape': + continue + found_Gemm = False + # Flatten could be followed by Gemm + for i in g.node: + if len(i.input) == 0 or i.input[0] != node.output[0]: + continue + if i.op_type == 'Gemm': + found = True + break + # Check weight + shape_node = helper.find_node_by_output_name(g, node.input[1]) + if shape_node.op_type != 'Constant': + continue + shape_value = helper.constant_to_numpy(shape_node) + if (shape_value.size != 2 or shape_value[0] != 1) and not found_Gemm: + continue + # Replace it + node.op_type = "Flatten" + for _ in range(len(node.attribute)): + node.attribute.pop() + shape_value = helper.find_value_by_name(g, shape_node.output[0]) + node.input.pop() + node_to_remove.append(shape_node) + # If found shape value_info, remove it + if shape_value != None: + g.value_info.remove(shape_value) + + for node in node_to_remove: + g.node.remove(node) + +def replace_Squeeze_with_Reshape(g): + """ + Replace Squeeze nodes with Reshape node. + + :param g: the input graph + """ + node_to_remove = [] + for node in g.node: + # Find Squeeze node + if node.op_type != 'Squeeze': + continue + # Get the shape and Construct the shape + output_value = helper.find_value_by_name(g, node.output[0]) + if output_value is None: + output_value = helper.find_output_by_name(g, node.output[0]) + if output_value is None: + raise RuntimeError("Cannot get shape for Squeeze") + shape = [dim.dim_value for dim in output_value.type.tensor_type.shape.dim] + const_node = helper.list_to_constant(node.name + "_shape", [len(shape)], shape) + # Construct the Reshape layer with same input, output and name. + new_node = onnx.helper.make_node( + "Reshape", + [node.input[0], node.name + "_shape"], + node.output, + name=node.name + ) + # Append constructed nodes and append old node to remove_list + g.node.extend([const_node, new_node]) + node_to_remove.append(node) + # Remove old nodes + for node in node_to_remove: + g.node.remove(node) + # Topological sort + topological_sort(g) + +def replace_Unsqueeze_with_Reshape(g): + """ + Replace Unsqueeze nodes with Reshape node. + + :param g: the input graph + """ + node_to_remove = [] + for node in g.node: + # Find Squeeze node + if node.op_type != 'Unsqueeze': + continue + # Get the shape and Construct the shape + output_value = helper.find_value_by_name(g, node.output[0]) + if output_value is None: + output_value = helper.find_output_by_name(g, node.output[0]) + if output_value is None: + raise RuntimeError("Cannot get shape for Unsqueeze") + shape = [dim.dim_value for dim in output_value.type.tensor_type.shape.dim] + + const_node = helper.list_to_constant(node.name + "_shape", [len(shape)], shape) + # Construct the Reshape layer with same input, output and name. + new_node = onnx.helper.make_node( + "Reshape", + [node.input[0], node.name + "_shape"], + node.output, + name=node.name + ) + # Append constructed nodes and append old node to remove_list + g.node.extend([const_node, new_node]) + node_to_remove.append(node) + # Remove old nodes + for node in node_to_remove: + g.node.remove(node) + # Topological sort + topological_sort(g) + +def replace_average_pool_with_GAP(g): + """ + Replace AveragePool nodes with GlobalAveragePool node when available. + + :param g: the input graph + """ + node_to_remove = [] + for node in g.node: + # Find a average pool layer + if node.op_type != 'AveragePool': + continue + # Check attributes + not_replace = False + for attr in node.attribute: + if attr.name == 'pads': + if list(attr.ints) != [0, 0, 0, 0]: + not_replace = True + break + if attr.name == 'kernel_shape': + kernel_shape = list(attr.ints) + value_info = helper.find_value_by_name(g, node.input[0]) + if value_info is None: + not_replace = True + break + input_shape = [] + for dim in value_info.type.tensor_type.shape.dim: + input_shape.append(dim.dim_value) + if input_shape[-2:] != kernel_shape: + not_replace = True + break + if not_replace: + continue + # Replace it with GlobalAveragePool + new_node = onnx.helper.make_node( + "GlobalAveragePool", + node.input, + node.output, + name=node.name + ) + g.node.extend([new_node]) + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) + topological_sort(g) + +def replace_dilated_conv(g): + """ + If the dilation of a convolution is not (1, 1), replace it with a regular + convolution with an expanded kernel. + + :param g: the input graph + """ + node_to_remove = [] + for node in g.node: + # Check if this is a conv layer + if node.op_type != 'Conv': + continue + # Check if this has dilation + has_dilations = False + has_strides = False + for attr in node.attribute: + if attr.name == "dilations": + dilations = list(attr.ints) + if dilations != [1, 1]: + has_dilations = True + if attr.name == "strides": + strides = list(attr.ints) + if strides != [1, 1]: + has_strides = True + if has_dilations and has_strides: + print("Warning: Both strides and dilations are set in ", node.name) + continue + if not has_dilations: + continue + # Construct new kernel + w_node = helper.find_node_by_output_name(g, node.input[1]) + w_output = helper.find_value_by_name(g, node.input[1]) + shape = list(w_node.attribute[0].t.dims) + # get original weight from float_data or raw data + weight = list(w_node.attribute[0].t.float_data) + if len(weight) == 0: + # Unpack from raw data + raw_data = w_node.attribute[0].t.raw_data + weight = [i[0] for i in struct.iter_unpack('f', raw_data)] + weight = np.array(weight) + weight = np.reshape(weight ,shape) + new_shape = copy.copy(shape) + new_shape[2] = 1 + (shape[2] - 1) * dilations[0] + new_shape[3] = 1 + (shape[3] - 1) * dilations[1] + new_weight = np.zeros(new_shape) + for batch in range(shape[0]): + for ch in range(shape[1]): + for h in range(shape[2]): + nh = h * dilations[0] + for w in range(shape[3]): + nw = w * dilations[1] + new_weight[batch, ch, nh, nw] = weight[batch, ch, h, w] + tensor = onnx.helper.make_tensor( + w_node.attribute[0].t.name, + w_node.attribute[0].t.data_type, + new_shape, + new_weight.ravel() + ) + new_w_node = onnx.helper.make_node( + "Constant", + [], + list(w_node.output), + name=w_node.name, + value=tensor + ) + g.node.extend([new_w_node]) + node_to_remove.append(w_node) + # Modify attributes and value info shapes + w_output.type.tensor_type.shape.dim[2].dim_value = new_shape[2] + w_output.type.tensor_type.shape.dim[3].dim_value = new_shape[3] + for attr in node.attribute: + if attr.name == "kernel_shape": + attr.ints[0] = new_shape[2] + attr.ints[1] = new_shape[3] + if attr.name == "dilations": + attr.ints[0] = 1 + attr.ints[1] = 1 + # Remove old weight nodes + for node in node_to_remove: + g.node.remove(node) + +def replace_depthwise_1x1_with_bn(g): + """Replace 1x1 DepthwiseConv node into BN node if applicable. + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + # Check op_type + if node.op_type != 'Conv': + continue + # Check attributes + attr_map = {attr.name: attr for attr in node.attribute} + if "group" not in attr_map or attr_map["group"].i == 1: + continue + if attr_map["kernel_shape"].ints[0] != 1 or attr_map["kernel_shape"].ints[1] != 1: + continue + if "pads" in attr_map and sum(attr_map["pads"].ints) != 0: + continue + # Check scale + scale_node = helper.find_node_by_output_name(g, node.input[1]) + if scale_node is None or scale_node.attribute[0].t.dims[1] != 1: + continue + scale_node.attribute[0].t.dims.pop() + scale_node.attribute[0].t.dims.pop() + scale_node.attribute[0].t.dims.pop() + scale_info = helper.find_value_by_name(g, node.input[1]) + if scale_info is not None: + scale_info.type.tensor_type.shape.dim.pop() + scale_info.type.tensor_type.shape.dim.pop() + scale_info.type.tensor_type.shape.dim.pop() + # Check bias + if len(node.input) == 3: + bias_name = node.input[2] + else: + bias_name = node.name + "_bias" + bias_node = helper.list_to_constant(bias_name, [attr_map["group"].i], [0.0] * attr_map["group"].i) + g.node.extend([bias_node]) + # Construct mean and vars + mean_name = node.name + "_mean" + mean_node = helper.list_to_constant(mean_name, [attr_map["group"].i], [0.0] * attr_map["group"].i) + var_name = node.name + "_var" + var_node = helper.list_to_constant(var_name, [attr_map["group"].i], [1.0] * attr_map["group"].i) + g.node.extend([mean_node, var_node]) + # Convert + bn_node = onnx.helper.make_node( + op_type='BatchNormalization', + inputs=[node.input[0], node.input[1], bias_name, mean_name, var_name], + outputs=node.output, + name=node.name, + epsilon=0.00001, + momentum=0.9 + ) + g.node.extend([bn_node]) + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) + topological_sort(g) + +def replace_shape_with_constant(g): + """Replace Shape with Constant.\\ + This is the first step of reshape constant folding. + + :param g: the input graph\\ + :return: if anything modified, return true. + """ + node_to_remove = [] + for node in g.node: + # Find a Shape + if node.op_type != 'Shape': + continue + # Check its input + input_value = helper.find_value_by_name(g, node.input[0]) + if input_value is None: + input_value = helper.find_input_by_name(g, node.input[0]) + if input_value is None or len(input_value.type.tensor_type.shape.dim) == 0: + continue + # Check for case where dimension could be 0 or -1 + tmp = True + for d in input_value.type.tensor_type.shape.dim: + tmp = tmp and (d.dim_value > 0) + if not tmp: + continue + # Repalce it + input_shape = [ + d.dim_value for d in input_value.type.tensor_type.shape.dim] + node_name = node.output[0] + new_node = helper.list_to_constant( + node_name, [len(input_shape)], input_shape) + g.node.extend([new_node]) + node_to_remove.append(node) + + # if the input value_info is not used by other node + # delete this input value_info + val_info_used = sum([input_value.name in node.input for node in g.node]) + if val_info_used == 1: + g.value_info.remove(input_value) + + replaced = True if len(node_to_remove) > 0 else False + + for node in node_to_remove: + g.node.remove(node) + + topological_sort(g) + + return replaced + +def replace_ConstantOfShape_with_constant(g): + """Replace Shape with Constant.\\ + This is the first step of reshape constant folding. + + :param g: the input graph\\ + :return: if anything modified, return true. + """ + node_to_remove = [] + for node in g.node: + # Find a Shape + if node.op_type != 'ConstantOfShape': + continue + # Check input + input_value = helper.find_value_by_name(g, node.input[0]) + if input_value is None: + input_value = helper.find_input_by_name(g, node.input[0]) + if input_value is None or len(input_value.type.tensor_type.shape.dim) == 0: + continue + + # Replace to constant node + pre_node = helper.find_node_by_output_name(g, node.input[0]) + _, target_shape = helper.constant_to_list(pre_node) + + value = helper.get_attribute_by_name(node, 'value').i + + node_name = node.output[0] + new_node = helper.list_to_constant( + node_name, [target_shape[0]], [value] * target_shape[0]) + + g.node.extend([new_node]) + + # remove old node + node_to_remove.append(node) + + # delete value_info + val_info_used = sum([input_value.name in node.input for node in g.node]) + if val_info_used == 1: + g.value_info.remove(input_value) + + replaced = True if len(node_to_remove) > 0 else False + + for node in node_to_remove: + g.node.remove(node) + + topological_sort(g) + + return replaced + +def replace_split_with_slices(g): + """Replace split node with slice nodes. + :param g: input graph. + :return: + """ + node_to_remove = [] + for node in g.node: + # Find a Split + if node.op_type != 'Split': + continue + + input_value = helper.find_value_by_name(g, node.input[0]) + if not input_value: + input_value = helper.find_input_by_name(g, node.input[0]) + _, shape = helper.find_size_shape_from_value(input_value) + if len(shape) == 0: + continue + + output_val_names = list(node.output) + + axis = 0 + split = [] + for item in node.attribute: + if item.name == 'axis': + axis = item.i + if item.name == 'split': + split = item.ints + + # For opset 11, axis could be negative. + if axis < 0: + axis = len(shape) + axis + + length = input_value.type.tensor_type.shape.dim[axis].dim_value + if len(split) > 0: + n_out = len(split) + pos = 0 + for i in range(n_out): + pos += split[i] + new_node_name = output_val_names[i] + # Construct starts, ends, axes + starts_name = new_node_name + '_starts_' + str(i) + ends_name = new_node_name + '_ends_' + str(i) + axes_name = new_node_name + '_axes_' + str(i) + starts_node = helper.list_to_constant(starts_name, (1, ), [int(pos-split[i])]) + ends_node = helper.list_to_constant(ends_name, (1, ), [int(pos)]) + axes_node = helper.list_to_constant(axes_name, (1, ), [int(axis)]) + # Construtc node + new_node = onnx.helper.make_node( + op_type='Slice', + inputs=[node.input[0], starts_name, ends_name, axes_name], + outputs=[node.output[i]], + name=new_node_name + ) + g.node.extend([starts_node, ends_node, axes_node, new_node]) + node_to_remove.append(node) + else: + n_out = len(output_val_names) + width = length//n_out + for i in range(n_out): + new_node_name = output_val_names[i] + # Construct starts, ends, axes + starts_name = new_node_name + '_starts_' + str(i) + ends_name = new_node_name + '_ends_' + str(i) + axes_name = new_node_name + '_axes_' + str(i) + starts_node = helper.list_to_constant(starts_name, (1, ), [int(i*width)]) + ends_node = helper.list_to_constant(ends_name, (1, ), [int((1+i)*width)]) + axes_node = helper.list_to_constant(axes_name, (1, ), [int(axis)]) + # Construtc node + new_node = onnx.helper.make_node( + op_type='Slice', + inputs=[node.input[0], starts_name, ends_name, axes_name], + outputs=[node.output[i]], + name=new_node_name + ) + g.node.extend([starts_node, ends_node, axes_node, new_node]) + node_to_remove.append(node) + + for old_node in node_to_remove: + g.node.remove(old_node) + topological_sort(g) + + +def replace_ReduceMean_with_GlobalAveragePool(g): + """ + Replace ReduceMean with GlobalAveragePool node when available. + + If there is preceeded Transpose, check the Transpose and the ReduceMean + together. If the keep_dims is set to 0, add a Flatten. + + :param g: the input graph + """ + node_to_remove = [] + for node in g.node: + # Find a ReduceMean layer + if node.op_type != 'ReduceMean': + continue + # Find if it have previous Transpose and its attribute meet the need. + prev_node = helper.find_node_by_output_name(g, node.input[0]) + if prev_node is not None and prev_node.op_type != 'Transpose': + prev_node = None + if prev_node is not None: + perm = helper.get_list_attribute_by_name(prev_node, 'perm', 'int') + if perm != [0, 2, 3, 1]: + prev_node = None + # Check attributes + axes = helper.get_list_attribute_by_name(node, 'axes', 'int') + keepdims = helper.get_var_attribute_by_name(node, 'keepdims', 'int') + if axes is None: + continue + if prev_node is None and axes != [2, 3]: + continue + if prev_node is not None and axes != [1, 2]: + continue + if keepdims is None: + keepdims = 1 + # Replace it with GlobalAveragePool + if prev_node: + input_list = prev_node.input + else: + input_list = node.input + if keepdims == 1: + output_list = node.output + else: + output_list = [node.output[0] + '_before_flatten'] + flatten_node = onnx.helper.make_node( + "Flatten", + output_list, + node.output, + name = node.name + "_flatten", + axis = 1 + ) + g.node.extend([flatten_node]) + new_node = onnx.helper.make_node( + "GlobalAveragePool", + input_list, + output_list, + name=node.name + ) + g.node.extend([new_node]) + node_to_remove.append(node) + if prev_node: + value = helper.find_value_by_name(g, prev_node.output[0]) + if value: + g.value_info.remove(value) + node_to_remove.append(prev_node) + for node in node_to_remove: + g.node.remove(node) + topological_sort(g) + +def replace_mul_to_bn(g): + """Replace single Mul node with Batchnorm node. + :param g: input graph. + :return: + """ + node_to_del = [] + for node in g.node: + if node.op_type != 'Mul': + continue + + mul_op_node = node + + # only support one input node + if len(mul_op_node.input) != 2: # OP node and value node + continue + + input_op_node_name = mul_op_node.input[0] + mul_value_node = helper.find_node_by_output_name(g, mul_op_node.input[1]) + if not mul_value_node or mul_value_node.op_type != 'Constant': + continue + + prev_shape_value_info = helper.find_value_by_name(g, input_op_node_name) + prev_shape_value_info = helper.find_input_by_name(g, input_op_node_name) if prev_shape_value_info is None else prev_shape_value_info + if prev_shape_value_info is None: + continue + + _ , previous_node_output_shape = helper.find_size_shape_from_value(prev_shape_value_info) + scale_shape, scale_data = helper.constant_to_list(mul_value_node) + + # channel dimension + c_dim = previous_node_output_shape[1] if len(previous_node_output_shape) > 1 else 1 + + # only allow channelwise mul or const mul + if scale_shape == [1, c_dim, 1, 1]: + muls = scale_data + elif scale_shape == [c_dim, 1, 1]: + muls = scale_data + elif scale_shape == 1: + muls = scale_data * c_dim + else: + continue + + ones = [1.0] * c_dim + zeros = [0.0] * c_dim + bn_name = mul_op_node.output[0] + mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) + variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) + bias_value_node = helper.list_to_constant(bn_name+'_add', np.array(zeros).shape, zeros) + new_mul_value_node = helper.list_to_constant(bn_name+'_mul', np.array(muls).shape, muls) + + bn_node = onnx.helper.make_node( + 'BatchNormalization', + [input_op_node_name, + new_mul_value_node.output[0], + bias_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0]], + [mul_op_node.output[0]], + name=bn_name, + epsilon=0.00000001 + ) + + scale_val_info = helper.find_value_by_name(g, mul_value_node.output[0]) + g.value_info.remove(scale_val_info) + + g.node.extend([bn_node]) + g.node.extend([mean_value_node]) + g.node.extend([variance_value_node]) + g.node.extend([bias_value_node]) + g.node.extend([new_mul_value_node]) + + node_to_del.extend([mul_op_node]) + node_to_del.extend([mul_value_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + +def replace_div_to_bn(g): + """Replace single Div node with Batchnorm node. + :param g: input graph. + :return: + """ + node_to_del = [] + for node in g.node: + if node.op_type != 'Div': + continue + + div_op_node = node + + # only support one input node + if len(div_op_node.input) != 2: # OP node and value node + continue + + input_op_node_name = div_op_node.input[0] + div_value_node = helper.find_node_by_output_name(g, div_op_node.input[1]) + if not div_value_node or div_value_node.op_type != 'Constant': + continue + + prev_shape_value_info = helper.find_value_by_name(g, input_op_node_name) + prev_shape_value_info = helper.find_input_by_name(g, input_op_node_name) if prev_shape_value_info is None else prev_shape_value_info + if prev_shape_value_info is None: + continue + + _ , previous_node_output_shape = helper.find_size_shape_from_value(prev_shape_value_info) + scale_shape, scale_data = helper.constant_to_list(div_value_node) + + # channel dimension + c_dim = previous_node_output_shape[1] if len(previous_node_output_shape) > 1 else 1 + + # only allow channelwise div or const div + if scale_shape == [1, c_dim, 1, 1]: + muls = scale_data + elif scale_shape == [c_dim, 1, 1]: + muls = scale_data + elif scale_shape == 1: + muls = scale_data * c_dim + else: + continue + + ones = [1.0] * c_dim + zeros = [0.0] * c_dim + muls = (1 / np.array(muls)).tolist() + bn_name = div_op_node.output[0] + mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) + variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) + bias_value_node = helper.list_to_constant(bn_name+'_add', np.array(zeros).shape, zeros) + new_mul_value_node = helper.list_to_constant(bn_name+'_mul', np.array(muls).shape, muls) + + bn_node = onnx.helper.make_node( + 'BatchNormalization', + [input_op_node_name, + new_mul_value_node.output[0], + bias_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0]], + [div_op_node.output[0]], + name=bn_name, + epsilon=0.00000001 + ) + + scale_val_info = helper.find_value_by_name(g, div_value_node.output[0]) + g.value_info.remove(scale_val_info) + + g.node.extend([bn_node]) + g.node.extend([mean_value_node]) + g.node.extend([variance_value_node]) + g.node.extend([bias_value_node]) + g.node.extend([new_mul_value_node]) + + node_to_del.extend([div_op_node]) + node_to_del.extend([div_value_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + + +def replace_add_to_bn(g): + """Replace single Add node with Batchnorm node. + :param g: input graph. + :return: + """ + node_to_del = [] + for node in g.node: + if node.op_type != 'Add': + continue + + add_op_node = node + + # only support one input node + if len(add_op_node.input) != 2: # OP node and value node + continue + + input_op_node_name = add_op_node.input[0] + add_value_node = helper.find_node_by_output_name(g, add_op_node.input[1]) + if not add_value_node or add_value_node.op_type != 'Constant': + continue + + prev_shape_value_info = helper.find_value_by_name(g, input_op_node_name) + prev_shape_value_info = helper.find_input_by_name(g, input_op_node_name) if prev_shape_value_info is None else prev_shape_value_info + if prev_shape_value_info is None: + continue + + _ , previous_node_output_shape = helper.find_size_shape_from_value(prev_shape_value_info) + bias_shape, bias_data = helper.constant_to_list(add_value_node) + + # channel dimension + c_dim = previous_node_output_shape[1] if len(previous_node_output_shape) > 1 else 1 + + # only allow channelwise add or const add + if bias_shape == [1, c_dim, 1, 1]: + bias = bias_data + elif bias_shape == [c_dim, 1, 1]: + bias = bias_data + elif bias_shape == 1: + bias = bias_data * c_dim + else: + continue + + ones = [1.0] * c_dim + zeros = [0.0] * c_dim + bn_name = add_op_node.output[0] + mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) + variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) + scale_value_node = helper.list_to_constant(bn_name+'_mul', np.array(ones).shape, ones) + new_add_value_node = helper.list_to_constant(bn_name+'_add', np.array(bias).shape, bias) + + bn_node = onnx.helper.make_node( + 'BatchNormalization', + [input_op_node_name, + scale_value_node.output[0], + new_add_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0]], + [add_op_node.output[0]], + name=bn_name, + epsilon=0.00000001 + ) + + add_val_info = helper.find_value_by_name(g, add_value_node.output[0]) + g.value_info.remove(add_val_info) + + g.node.extend([bn_node]) + g.node.extend([mean_value_node]) + g.node.extend([variance_value_node]) + g.node.extend([scale_value_node]) + g.node.extend([new_add_value_node]) + + node_to_del.extend([add_op_node]) + node_to_del.extend([add_value_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + +def replace_sub_to_bn(g): + """Replace single Sub node with BatchNorm node. + :param g: input graph. + :return: + """ + node_to_del = [] + for node in g.node: + if node.op_type != 'Sub': + continue + + sub_op_node = node + + # only support one input node + if len(sub_op_node.input) != 2: # OP node and value node + continue + + # Check the input type + input_1st_name = sub_op_node.input[0] + input_2nd_name = sub_op_node.input[1] + input_1st_node = helper.find_node_by_output_name(g, input_1st_name) + input_2nd_node = helper.find_node_by_output_name(g, input_2nd_name) + if input_1st_node is not None and input_1st_node.op_type == 'Constant': + real_input_name = input_2nd_name + reverse = True + constant_node = input_1st_node + elif input_2nd_node is not None and input_2nd_node.op_type == 'Constant': + real_input_name = input_1st_name + reverse = False + constant_node = input_2nd_node + else: + continue + + # Get shapes + prev_shape_value_info = helper.find_value_by_name(g, real_input_name) + prev_shape_value_info = helper.find_input_by_name(g, real_input_name) if prev_shape_value_info is None else prev_shape_value_info + if prev_shape_value_info is None: + continue + + _ , previous_node_output_shape = helper.find_size_shape_from_value(prev_shape_value_info) + bias_shape, bias_data = helper.constant_to_list(constant_node) + + # channel dimension + c_dim = previous_node_output_shape[1] if len(previous_node_output_shape) > 1 else 1 + + # only allow channelwise sub or const sub + if bias_shape == [1, c_dim, 1, 1]: + bias = bias_data + elif bias_shape == [c_dim, 1, 1]: + bias = bias_data + elif bias_shape == 1: + bias = bias_data * c_dim + else: + continue + + ones = [1.0] * c_dim + zeros = [0.0] * c_dim + # If reversed provide special scaler + if reverse: + scale = [-1.0] * c_dim + else: + scale = ones + bias *= -1 + bn_name = sub_op_node.output[0] + mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) + variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) + scale_value_node = helper.list_to_constant(bn_name+'_mul', np.array(scale).shape, scale) + new_add_value_node = helper.list_to_constant(bn_name+'_add', np.array(bias).shape, bias) + + bn_node = onnx.helper.make_node( + 'BatchNormalization', + [real_input_name, + scale_value_node.output[0], + new_add_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0]], + [sub_op_node.output[0]], + name=bn_name, + epsilon=0.00000001 + ) + + add_val_info = helper.find_value_by_name(g, constant_node.output[0]) + g.value_info.remove(add_val_info) + + g.node.extend([bn_node]) + g.node.extend([mean_value_node]) + g.node.extend([variance_value_node]) + g.node.extend([scale_value_node]) + g.node.extend([new_add_value_node]) + + node_to_del.extend([sub_op_node]) + node_to_del.extend([constant_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + +def replace_sub_with_bn_and_add(g): + """Replace two input Sub node with BN and Add: A - B = A + (-1) * B + :param g: input graph. + :return: + """ + for node in g.node: + if node.op_type != 'Sub': + continue + + sub_op_node = node + + # only support one input node + if len(sub_op_node.input) != 2: # OP node and value node + continue + + # Check the input type + input_1st_name = sub_op_node.input[0] + input_2nd_name = sub_op_node.input[1] + input_1st_node = helper.find_node_by_output_name(g, input_1st_name) + input_2nd_node = helper.find_node_by_output_name(g, input_2nd_name) + if input_1st_node is not None and input_1st_node.op_type == 'Constant': + continue + elif input_2nd_node is not None and input_2nd_node.op_type == 'Constant': + continue + + # Get shapes + input_2nd_value_info = helper.find_value_by_name(g, input_2nd_name) + if input_2nd_value_info is None: + input_2nd_value_info = helper.find_input_by_name(g, input_2nd_name) + if input_2nd_value_info is None: + continue + + # Get channel dimension + _ , input_2nd_shape = helper.find_size_shape_from_value(input_2nd_value_info) + if len(input_2nd_shape) < 2: + helper.logger.debug(f"{sub_op_node.name} cannot be replaced due to the input shape.") + c_dim = input_2nd_shape[1] + + # Create * -1 bn node. + ones = [1.0] * c_dim + zeros = [0.0] * c_dim + scale = [-1.0] * c_dim + bn_name = input_2nd_name + '_neg_for_' + node.name + mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) + variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) + scale_value_node = helper.list_to_constant(bn_name+'_mul', np.array(scale).shape, scale) + bias_value_node = helper.list_to_constant(bn_name+'_add', np.array(zeros).shape, zeros) + bn_node = onnx.helper.make_node( + 'BatchNormalization', + [input_2nd_name, + scale_value_node.output[0], + bias_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0]], + [bn_name], + name=bn_name, + epsilon=0.00000001 + ) + + # Change sub to add + sub_op_node.op_type = "Add" + # Replace add input + modhelper.replace_node_input(sub_op_node, input_2nd_name, bn_name) + + g.node.extend([scale_value_node, bias_value_node, mean_value_node, variance_value_node, bn_node]) + + topological_sort(g) + +def replace_Sum_with_Adds(g): + node_to_del = [] + + for node in g.node: + # Check for sum + if node.op_type != 'Sum': + continue + # Check for input number + if len(node.input) == 1: + # If input number is 1, delete the sum node. + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + node_to_del.append(node) + if helper.find_value_by_name(node.output[0]) is not None: + g.value_info.remove(helper.find_value_by_name(node.output[0])) + elif len(node.input) == 2: + # If input number is 2, replace it with add. + node.op_type = 'Add' + continue + elif len(node.input) > 2: + # If input number is larger than 2, replace it with n-1 add. + input_count = len(node.input) + # First node has 2 inputs + first_node = onnx.helper.make_node( + "Add", + [node.input[0], node.input[1]], + [node.output[0] + '_replacement_1'], + name=node.name + '_replacement_1' + ) + # Last node has the same output as the original sum node + last_node = onnx.helper.make_node( + "Add", + [node.output[0] + '_replacement_' + str(input_count - 2), node.input[input_count - 1]], + [node.output[0]], + name=node.name + ) + g.node.extend([first_node, last_node]) + for i in range(2, input_count - 1): + new_node = onnx.helper.make_node( + "Add", + [node.output[0] + '_replacement_' + str(i - 1), node.input[i]], + [node.output[0] + '_replacement_' + str(i)], + name=node.name + '_replacement_' + str(i) + ) + g.node.extend([new_node]) + node_to_del.append(node) + else: + logging.error("Sum node must have at least 1 input.") + quit(1) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + + +def replace_constant_input_concat_with_pad(g): + """If single input is concating with constant node of same number. Replace it with pad. Currently only support 2-3 inputs. + :param g: input graph. + :return: + """ + node_to_del = [] + for node in g.node: + # Check for Concat node + if node.op_type != 'Concat': + continue + + # Check concat node input + mode = None + value = 0 + real_input_name = None + if len(node.input) == 2: + input_1st_node = helper.find_node_by_output_name(g, node.input[0]) + input_2nd_node = helper.find_node_by_output_name(g, node.input[1]) + if input_1st_node is not None and input_1st_node.op_type == 'Constant': + mode = 'left' + constant_value = helper.constant_to_numpy(input_1st_node) + real_input_name = node.input[1] + value = constant_value.flatten()[0] + # Check if the values are all the same. + if np.any(constant_value - value): + continue + elif input_2nd_node is not None and input_2nd_node.op_type == 'Constant': + mode = 'right' + constant_value = helper.constant_to_numpy(input_2nd_node) + real_input_name = node.input[0] + value = constant_value.flatten()[0] + # Check if the values are all the same. + if np.any(constant_value - value): + continue + else: + # No constant input case + continue + elif len(node.input) == 3: + # For 3 inputs concat node, the 1st and the 3rd input should be constant with the same value. + input_1st_node = helper.find_node_by_output_name(g, node.input[0]) + input_2nd_node = helper.find_node_by_output_name(g, node.input[1]) + input_3rd_node = helper.find_node_by_output_name(g, node.input[2]) + if input_1st_node is None or input_1st_node.op_type != 'Constant' or \ + input_3rd_node is None or input_3rd_node.op_type != 'Constant': + continue + mode = 'both' + real_input_name = node.input[1] + input_1st_value = helper.constant_to_numpy(input_1st_node) + input_3rd_value = helper.constant_to_numpy(input_3rd_node) + value = input_1st_value.flatten()[0] + # Check if all the values are all the same + if np.any(input_1st_value - value): + continue + elif np.any(input_3rd_value - value): + continue + else: + # Too many inputs case. + continue + # Make weight nodes + input_value_info = helper.find_value_by_name(g, real_input_name) + input_shape = helper.get_shape_from_value_info(input_value_info) + pads = [0] * (len(input_shape) * 2) + axis = helper.get_var_attribute_by_name(node, 'axis', 'int') + if axis < 0: + axis = len(input_shape) - axis + if mode == 'left': + left_value_info = helper.find_value_by_name(g, node.input[0]) + left_input_shape = helper.get_shape_from_value_info(left_value_info) + pads[axis] = left_input_shape[axis] + elif mode == 'right': + right_value_info = helper.find_value_by_name(g, node.input[1]) + right_input_shape = helper.get_shape_from_value_info(right_value_info) + pads[axis + len(input_shape)] = right_input_shape[axis] + else: + # mode shoule be both + left_value_info = helper.find_value_by_name(g, node.input[0]) + left_input_shape = helper.get_shape_from_value_info(left_value_info) + pads[axis] = left_input_shape[axis] + right_value_info = helper.find_value_by_name(g, node.input[2]) + right_input_shape = helper.get_shape_from_value_info(right_value_info) + pads[axis + len(input_shape)] = right_input_shape[axis] + pads_node = helper.list_to_constant( + node.name + '_pads', + (len(pads), ), + pads + ) + constant_value_node = helper.scaler_to_constant( + node.name + '_constant_value', + value + ) + # Create new Pad node + new_pad_node = onnx.helper.make_node( + "Pad", + [real_input_name, pads_node.name, constant_value_node.name], + [node.output[0]], + name = node.name, + mode = "constant" + ) + # Replace + node_to_del.append(node) + g.node.extend([pads_node, constant_value_node, new_pad_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + diff --git a/tools/optimizer_scripts/tools/special.py b/tools/optimizer_scripts/tools/special.py new file mode 100644 index 0000000..38de4f5 --- /dev/null +++ b/tools/optimizer_scripts/tools/special.py @@ -0,0 +1,423 @@ +"""Special operations on model. +""" +import logging +import onnx.helper +import numpy as np +from . import helper +from . import other +from . import modhelper + +def change_first_conv_from_bgr_to_rgb(m): + """For input channel format BGR model, use this function to change the first + conv weight to adapt the input into RGB. + + :param m: the model proto + """ + # Check for first node. + g = m.graph + input_name = g.input[0].name + first_nodes = helper.find_following_nodes_by_input_value_name(g, input_name) + if len(first_nodes) > 1: + return False + first_node = first_nodes[0] + # Now we have the first node. Check this first node. + if first_node.op_type != 'Conv': + return False + weight_value = helper.find_value_by_name(g, first_node.input[1]) + weight_shape = helper.get_shape_from_value_info(weight_value) + if weight_shape[1] != 3: + return False + # Do weight shuffle + weight_node = helper.find_node_by_output_name(g, weight_value.name) + weight_np = helper.constant_to_numpy(weight_node) + b_channel = np.expand_dims(weight_np[:, 0, :, :], axis=1) + g_channel = np.expand_dims(weight_np[:, 1, :, :], axis=1) + r_channel = np.expand_dims(weight_np[:, 2, :, :], axis=1) + new_np = np.concatenate((r_channel, g_channel, b_channel), axis=1) + new_node = helper.numpy_to_constant(weight_value.name, new_np) + # Replace the weight and topological sort + g.node.remove(weight_node) + g.node.extend([new_node]) + other.topological_sort(g) + return True + +def change_input_from_bgr_to_rgb(m): + """For input channel format BGR model, use this function to modify the model + to accepct RGB image.If the first node is a non-group Conv. Modify weight to + adapt the input into RGB. Otherwise create a new node. + + :param m: the model proto + """ + g = m.graph + if len(g.input) > 1: + print("This model has multiple inputs. Cannot change to RGB input.") + return + input_shape = helper.get_shape_from_value_info(g.input[0]) + if len(input_shape) != 4 or input_shape[1] != 3: + print("The input shape is invalid for bgr conversion.") + return + # Try change conv weight first + if change_first_conv_from_bgr_to_rgb(m): + return + # Otherwise, create a special conv node and replace the input + # Construct weight + weight_np = np.zeros((3, 3, 3, 3)).astype('float32') + weight_np[0, 2, 1, 1] = 1.0 + weight_np[1, 1, 1, 1] = 1.0 + weight_np[2, 0, 1, 1] = 1.0 + new_weight = helper.numpy_to_constant("bgr_shuffle_weight", weight_np) + # Construct Conv + new_conv = onnx.helper.make_node( + 'Conv', + ['rgb_input', "bgr_shuffle_weight"], + [g.input[0].name], + name='bgr_shuffle', + dilations=[1, 1], + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[1, 1] + ) + # Connect the graph + old_input_value = g.input.pop() + new_input_value = onnx.helper.make_tensor_value_info( + 'rgb_input', + old_input_value.type.tensor_type.elem_type, + input_shape + ) + g.input.extend([new_input_value]) + g.node.extend([new_weight, new_conv]) + # topological sort + other.topological_sort(g) + +def add_0_5_to_normalized_input(m): + """For normalized input between -0.5 ~ 0.5, add 0.5 to the input to keep it + between 0 ~ 1. + + :param m: the model proto + """ + g = m.graph + if len(g.input) > 1: + print("This model has multiple inputs. Cannot normalize input.") + return + input_shape = helper.get_shape_from_value_info(g.input[0]) + if len(input_shape) != 4: + print("The input shape is not BCHW. Cannot normalize input.") + return + # Construct weight + ch = input_shape[1] + weight_np = np.zeros((ch, ch, 3, 3)).astype('float32') + for i in range(ch): + weight_np[i, i, 1, 1] = 1.0 + new_weight = helper.numpy_to_constant("input_norm_weight", weight_np) + # Construct bias + bias_np = np.array([0.5] * ch).astype('float32') + new_bias = helper.numpy_to_constant("input_norm_bias", bias_np) + # Construct Conv + new_conv = onnx.helper.make_node( + 'Conv', + ['origin_input', "input_norm_weight", "input_norm_bias"], + [g.input[0].name], + name='input_norm', + dilations=[1, 1], + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[1, 1] + ) + # Construct value_infos + old_input_value = g.input.pop() + weight_value = onnx.helper.make_tensor_value_info( + 'input_norm_weight', + old_input_value.type.tensor_type.elem_type, + [3, 3, 3, 3] + ) + bias_value = onnx.helper.make_tensor_value_info( + 'input_norm_bias', + old_input_value.type.tensor_type.elem_type, + [3] + ) + # Connect the graph + new_input_value = onnx.helper.make_tensor_value_info( + 'origin_input', + old_input_value.type.tensor_type.elem_type, + input_shape + ) + g.input.extend([new_input_value]) + g.node.extend([new_weight, new_bias, new_conv]) + g.value_info.extend([weight_value, bias_value, old_input_value]) + # topological sort + other.topological_sort(g) + +def add_rgb2yynn_node(m): + """Add a conv layer which can convert rgb to yynn input. + """ + g = m.graph + if len(g.input) > 1: + print("This model has multiple inputs. Cannot change to rgb input.") + return + input_shape = helper.get_shape_from_value_info(g.input[0]) + if len(input_shape) != 4: + print("The input shape is not BCHW. Cannot normalize input.") + return + # Construct weight + ch = input_shape[1] + weight_np = np.zeros((3, 3, 4, 4)).astype('float32') + weight_np[1, 1, :3, :2] = np.array([[[[0.299], + [0.587], + [0.114]]]]) + weight_np[1, 1, 3, 2:] = 1. + weight_np = np.transpose(weight_np, (3, 2, 0, 1)) + new_weight = helper.numpy_to_constant("input_rgb2yynn_weight", weight_np) + # Construct conv node + new_conv = onnx.helper.make_node( + 'Conv', + ['new_input', "input_rgb2yynn_weight"], + [g.input[0].name], + name='input_rgba2yynn', + dilations=[1, 1], + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[1, 1] + ) + # Construct value_infos + old_input_value = g.input.pop() + weight_value = onnx.helper.make_tensor_value_info( + 'input_rgb2yynn_weight', + old_input_value.type.tensor_type.elem_type, + [4, 4, 3, 3] + ) + # Connect the graph + new_input_value = onnx.helper.make_tensor_value_info( + 'new_input', + old_input_value.type.tensor_type.elem_type, + input_shape + ) + g.input.extend([new_input_value]) + g.node.extend([new_weight, new_conv]) + g.value_info.extend([weight_value, old_input_value]) + # topological sort + other.topological_sort(g) + +def swap_MatMul_inputs(g, original_matmul_node): + # Create Transpose nodes + input_a_value = helper.find_value_by_name(g, original_matmul_node.input[0]) + input_a_shape = helper.get_shape_from_value_info(input_a_value) + if len(input_a_shape) == 2: + perm = [1, 0] + else: + perm = [0, 2, 1] + new_input_b_node = onnx.helper.make_node( + 'Transpose', + inputs = [input_a_value.name], + outputs = [input_a_value.name + '_transposed'], + name = f"{input_a_value.name}_transposed_for_{original_matmul_node.name}", + perm = perm + ) + input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1]) + input_b_shape = helper.get_shape_from_value_info(input_b_value) + if len(input_b_shape) == 3: + perm = [0, 2, 1] + else: + perm = [0, 1, 3, 2] + new_input_a_node = onnx.helper.make_node( + 'Transpose', + inputs = [input_b_value.name], + outputs = [input_b_value.name + '_transposed'], + name = f'{input_b_value.name}_transposed_for_{original_matmul_node.name}', + perm = perm + ) + # Create new MatMul node + new_matmul_node = onnx.helper.make_node( + 'MatMul', + inputs = [new_input_a_node.output[0], new_input_b_node.output[0]], + outputs = [original_matmul_node.output[0] + '_transposed'], + name = original_matmul_node.name + '_transposed' + ) + # Create final Transpose node + output_value = helper.find_value_by_name(g, original_matmul_node.output[0]) + output_shape = helper.get_shape_from_value_info(output_value) + if len(output_shape) == 3: + perm = [0, 2, 1] + else: + perm = [0, 1, 3, 2] + new_final_transpose_node = onnx.helper.make_node( + 'Transpose', + inputs = [new_matmul_node.output[0]], + outputs = [original_matmul_node.output[0]], + name = original_matmul_node.name + '_final_transpose', + perm = perm + ) + # Add new nodes + g.node.extend([new_input_a_node, new_input_b_node, new_matmul_node, new_final_transpose_node]) + # Delete original nodes + g.node.remove(original_matmul_node) + +def split_MatMul_batch_then_concat(g, original_matmul_node): + new_nodes = [] + final_concat_inputs = [] + # Get the batch count + input_a_value = helper.find_value_by_name(g, original_matmul_node.input[0]) + input_a_shape = helper.get_shape_from_value_info(input_a_value) + input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1]) + input_b_shape = helper.get_shape_from_value_info(input_b_value) + if len(input_a_shape) == 3: + batch_count = input_a_shape[0] + else: + batch_count = input_a_shape[1] + for i in range(batch_count): + # Create Split nodes for input A + starts_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_starts", (1, ), [i]) + ends_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_ends", (1, ), [i+1]) + axes_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_axes", (1, ), [len(input_a_shape) - 3]) + new_sliced_a_node = onnx.helper.make_node( + 'Slice', + inputs = [input_a_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]], + outputs = [f"{input_a_value.name}_sliced_{i}"], + name = f"{input_a_value.name}_sliced_{i}_for_{original_matmul_node.name}" + ) + new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_a_node]) + # Create Split nodes for input B + starts_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_starts", (1, ), [i]) + ends_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_ends", (1, ), [i+1]) + axes_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_axes", (1, ), [len(input_b_shape) - 3]) + new_sliced_b_node = onnx.helper.make_node( + 'Slice', + inputs = [input_b_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]], + outputs = [f"{input_b_value.name}_sliced_{i}"], + name = f"{input_b_value.name}_sliced_{i}_for_{original_matmul_node.name}" + ) + new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_b_node]) + # Create MatMul nodes + new_matmul_node = onnx.helper.make_node( + 'MatMul', + inputs = [new_sliced_a_node.output[0], new_sliced_b_node.output[0]], + outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"], + name = f"{original_matmul_node.name}_sliced_{i}" + ) + new_nodes.append(new_matmul_node) + final_concat_inputs.append(new_matmul_node.output[0]) + # Create Concat nodes + output_value = helper.find_value_by_name(g, original_matmul_node.output[0]) + if output_value is None: + output_value = helper.find_output_by_name(g, original_matmul_node.output[0]) + if output_value is None: + helper.logger.error(f"Cannot find value_info for {original_matmul_node.output[0]}") + output_shape = helper.get_shape_from_value_info(output_value) + new_concat_node = onnx.helper.make_node( + "Concat", + inputs = final_concat_inputs, + outputs = [original_matmul_node.output[0]], + name = f"{original_matmul_node.name}_final_concat", + axis = len(output_shape) - 3 + ) + new_nodes.append(new_concat_node) + # Add new nodes + g.node.extend(new_nodes) + # Delete original nodes + g.node.remove(original_matmul_node) + + +def split_MatMul_Constant_input_then_concat(g, original_matmul_node): + new_nodes = [] + final_concat_inputs = [] + # Get the batch count + input_b_node = helper.find_node_by_output_name(g, original_matmul_node.input[1]) + input_b_np = helper.constant_to_numpy(input_b_node) + if len(input_b_np.shape) == 3: + batch_count = input_b_np.shape[0] + else: + batch_count = input_b_np.shape[1] + for i in range(batch_count): + # Create new constant node + if len(input_b_np.shape) == 3: + new_np = input_b_np[i:i+1, ...] + else: + new_np = input_b_np[:, i:i+1, ...] + new_weight = helper.numpy_to_constant(f"{input_b_node.name}_sliced_{i}", new_np) + new_nodes.append(new_weight) + # Create MatMul nodes + new_matmul_node = onnx.helper.make_node( + 'MatMul', + inputs = [original_matmul_node.input[0], new_weight.output[0]], + outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"], + name = f"{original_matmul_node.name}_sliced_{i}" + ) + new_nodes.append(new_matmul_node) + final_concat_inputs.append(new_matmul_node.output[0]) + # Create Concat nodes + output_value = helper.find_value_by_name(g, original_matmul_node.output[0]) + output_shape = helper.get_shape_from_value_info(output_value) + new_concat_node = onnx.helper.make_node( + "Concat", + inputs = final_concat_inputs, + outputs = [original_matmul_node.output[0]], + name = f"{original_matmul_node.name}_final_concat", + axis = len(output_shape) - 3 + ) + new_nodes.append(new_concat_node) + # Add new nodes + g.node.extend(new_nodes) + # Delete original value info + input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1]) + if input_b_value is not None: + g.value_info.remove(input_b_value) + # Delete original nodes + g.node.remove(original_matmul_node) + g.node.remove(input_b_node) + + +def special_MatMul_process(g): + for node in g.node: + if node.op_type != 'MatMul': + continue + input_a_name = node.input[0] + input_a_value = helper.find_value_by_name(g, input_a_name) + input_b_name = node.input[1] + input_b_value = helper.find_value_by_name(g, input_b_name) + if input_a_value is None or input_b_value is None: + continue + input_a_shape = helper.get_shape_from_value_info(input_a_value) + input_b_shape = helper.get_shape_from_value_info(input_b_value) + # Check shapes and choose the process + # Normal case, Skip + if len(input_b_shape) == 2: + continue + # Too many dimensions or too few dimensions. Not supported. Skip + if len(input_a_shape) > 4 or len(input_b_shape) > 4: + helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have too many dimensions.") + continue + if len(input_a_shape) < 2 or len(input_b_shape) < 2: + helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have two few dimensions.") + continue + # For 4 dimension, check the first dimension (should be 1) and treated as 3 dimensions. + extra_dim = None + if len(input_a_shape) == 4: + extra_dim = input_a_shape[0] + input_a_shape = input_a_shape[1:] + if len(input_b_shape) == 4: + if input_b_shape[0] != extra_dim: + helper.logger.warning(f"Cannot optimize MatMul {node.name}: input dimension batch sizes does not match ({extra_dim} vs {input_b_shape[0]}).") + continue + input_b_shape = input_b_shape[1:] + # Check input B dimension + # If B is 1 x W x V, it is the same as normal case. + if input_b_shape[0] == 1: + continue + # If B is B x W x V, but B is a constant. + input_b_node = helper.find_node_by_output_name(g, input_b_name) + if input_b_node is not None and input_b_node.op_type == 'Constant': + # Constant input + helper.logger.debug(f"Optimizing MatMul node {node.name}: split constant input.") + split_MatMul_Constant_input_then_concat(g, node) + # If B is B x W x V and A is 1 x H x W, do the swap. + elif len(input_a_shape) == 2 or (input_a_shape[0] == 1 and (extra_dim is None or extra_dim == 1)): + helper.logger.debug(f"Optimizing MatMul node {node.name}: swap input.") + swap_MatMul_inputs(g, node) + # If B is B x W x V and A is B x H x W, do the split. + elif input_b_shape[0] == input_a_shape[0]: + helper.logger.debug(f"Optimizing MatMul node {node.name}: split input batch.") + split_MatMul_batch_then_concat(g, node) + # Other cases are not supported: If B is B x W x V but A is X x H x W. + else: + helper.logger.warning(f"Cannot optimize MatMul {node.name}: unknown reason. Might be shape mismatch.") + continue + other.topological_sort(g) \ No newline at end of file diff --git a/tools/pytorch2onnx_kneron.py b/tools/pytorch2onnx_kneron.py index be373e3..8602c32 100644 --- a/tools/pytorch2onnx_kneron.py +++ b/tools/pytorch2onnx_kneron.py @@ -2,6 +2,7 @@ # Original: tools/pytorch2onnx.py, modified by Kneron import argparse +import onnx import mmcv import numpy as np import onnxruntime as rt @@ -18,6 +19,10 @@ from mmseg.apis.inference import LoadImage from mmseg.datasets.pipelines import Compose from mmseg.models import build_segmentor +from optimizer_scripts.pytorch_exported_onnx_preprocess import ( + torch_exported_onnx_flow, +) + torch.manual_seed(3) @@ -117,10 +122,12 @@ def pytorch2onnx(model, dynamic_axes=None) print(f'Successfully exported ONNX model: {output_file}') model.forward = origin_forward + # NOTE: optimizing onnx for kneron inference + m = onnx.load(output_file) + m = torch_exported_onnx_flow(m, disable_fuse_bn=False) + onnx.save(m, output_file) if verify: - # check by onnx - import onnx onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model)