feat: integrate kneron optimizer script to pytorch2onnx_kneron.py
This commit is contained in:
parent
b83243c5f4
commit
5b99260c9b
1
tools/optimizer_scripts/.clang-format
Normal file
1
tools/optimizer_scripts/.clang-format
Normal file
@ -0,0 +1 @@
|
||||
BasedOnStyle: Google
|
||||
7
tools/optimizer_scripts/.gitignore
vendored
Normal file
7
tools/optimizer_scripts/.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
__pycache__
|
||||
.vscode
|
||||
*.pyc
|
||||
models.py
|
||||
temp.py
|
||||
.ssh/
|
||||
docker/test_models/
|
||||
189
tools/optimizer_scripts/README.md
Normal file
189
tools/optimizer_scripts/README.md
Normal file
@ -0,0 +1,189 @@
|
||||
# Converter Scripts
|
||||
|
||||
[](http://192.168.200.1:8088/jiyuan/converter_scripts/commits/master)
|
||||
|
||||
This project collects various optimization scripts and converter scritps for
|
||||
Kneron toolchain. This collection does not include the Keras to ONNX converter
|
||||
and the Caffe to ONNX converter. They are in seperate projects.
|
||||
|
||||
**The scripts not listed below are used as libraries and cannot be used
|
||||
directly.**
|
||||
|
||||
## onnx2onnx.py
|
||||
|
||||
### 1.1. Description
|
||||
|
||||
General optimizations on ONNX model for Kneron toolchain. Though Kneron
|
||||
toolchains are designed to take ONNX models as input, they have some
|
||||
restrictions on the models (e.g. inferenced shapes for all value_info). Thus, we
|
||||
have this tool to do some general optimization and conversion on ONNX models.
|
||||
**Notice that this script should take an valid ONNX model as input.** It cannot
|
||||
turn an invalid ONNX model into a valid one.
|
||||
|
||||
### 1.2. Basic Usage
|
||||
|
||||
```bash
|
||||
python onnx2onnx.py input.onnx -o output.onnx
|
||||
```
|
||||
|
||||
### 1.3. Optimizations Included
|
||||
|
||||
* Fusing BN into Conv.
|
||||
* Fusing BN into Gemm.
|
||||
* Fusing consecutive Gemm.
|
||||
* Eliminating Identify layers and Dropout layers.
|
||||
* Eliminating last shape changing nodes.
|
||||
* Replacing initializers into Constant nodes.
|
||||
* Replacing global AveragePool with GAP.
|
||||
* Replacing Squeeze and Unsqueeze with Reshape.
|
||||
* Replacing 1x1 depthwise with BN.
|
||||
* Inferencing Upsample shapes.
|
||||
* Transposing B in Gemm.
|
||||
|
||||
## pytorch2onnx.py
|
||||
|
||||
### 2.1. Description
|
||||
|
||||
Convert Pytorch models or Pytorch generated ONNX models into Kneron toolchain
|
||||
compatible ONNX files. This script include most of the optimizations in
|
||||
`onnx2onnx.py`. It also includes some optimizations for Pytorch model only.
|
||||
|
||||
### 2.2. Basic Usage
|
||||
|
||||
```bash
|
||||
# Take Pytorch model name, input channel number, input height, input width
|
||||
python pytorch2onnx.py input.pth output.onnx --input-size 3 224 224
|
||||
# Or take Pytorch exported ONNX.
|
||||
python pytorch2onnx.py input.onnx output.onnx
|
||||
```
|
||||
|
||||
### 2.3. Optimizations Included
|
||||
|
||||
* Adding name to nodes.
|
||||
* Unsqueeze nodes constant folding.
|
||||
* Reshape nodes constant folding.
|
||||
* Optimizations in `onnx2onnx.py`.
|
||||
|
||||
## editor.py
|
||||
|
||||
### 3.1. Description
|
||||
|
||||
This is an simple ONNX editor which achieves the following functions:
|
||||
|
||||
* Add nop BN or Conv nodes.
|
||||
* Delete specific nodes or inputs.
|
||||
* Cut the graph from certain node (Delete all the nodes following the node).
|
||||
* Reshape inputs and outputs
|
||||
|
||||
### 3.2 Usage
|
||||
|
||||
```
|
||||
usage: editor.py [-h] [-c CUT_NODE [CUT_NODE ...]]
|
||||
[--cut-type CUT_TYPE [CUT_TYPE ...]]
|
||||
[-d DELETE_NODE [DELETE_NODE ...]]
|
||||
[--delete-input DELETE_INPUT [DELETE_INPUT ...]]
|
||||
[-i INPUT_CHANGE [INPUT_CHANGE ...]]
|
||||
[-o OUTPUT_CHANGE [OUTPUT_CHANGE ...]]
|
||||
[--add-conv ADD_CONV [ADD_CONV ...]]
|
||||
[--add-bn ADD_BN [ADD_BN ...]]
|
||||
in_file out_file
|
||||
|
||||
Edit an ONNX model. The processing sequense is 'delete nodes/values' -> 'add
|
||||
nodes' -> 'change shapes'. Cutting cannot be done with other operations
|
||||
together
|
||||
|
||||
positional arguments:
|
||||
in_file input ONNX FILE
|
||||
out_file ouput ONNX FILE
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
-c CUT_NODE [CUT_NODE ...], --cut CUT_NODE [CUT_NODE ...]
|
||||
remove nodes from the given nodes(inclusive)
|
||||
--cut-type CUT_TYPE [CUT_TYPE ...]
|
||||
remove nodes by type from the given nodes(inclusive)
|
||||
-d DELETE_NODE [DELETE_NODE ...], --delete DELETE_NODE [DELETE_NODE ...]
|
||||
delete nodes by names and only those nodes
|
||||
--delete-input DELETE_INPUT [DELETE_INPUT ...]
|
||||
delete inputs by names
|
||||
-i INPUT_CHANGE [INPUT_CHANGE ...], --input INPUT_CHANGE [INPUT_CHANGE ...]
|
||||
change input shape (e.g. -i 'input_0 1 3 224 224')
|
||||
-o OUTPUT_CHANGE [OUTPUT_CHANGE ...], --output OUTPUT_CHANGE [OUTPUT_CHANGE ...]
|
||||
change output shape (e.g. -o 'input_0 1 3 224 224')
|
||||
--add-conv ADD_CONV [ADD_CONV ...]
|
||||
add nop conv using specific input
|
||||
--add-bn ADD_BN [ADD_BN ...]
|
||||
add nop bn using specific input
|
||||
```
|
||||
|
||||
### 3.3. Example
|
||||
|
||||
Here is an example of when and how to use the editor.py.
|
||||
|
||||
```bash
|
||||
# In the `res` folder, there is a vdsr model from tensorflow.
|
||||
# We need to convert this model firstly.
|
||||
./tf2onnx.sh res/vdsr_41_20layer_1.pb res/tmp.onnx images:0 output:0
|
||||
# This onnx file seems valid. But, it's channel last for the input and output.
|
||||
# It is using Traspose to convert to channel first, affacting the performance.
|
||||
# Thus, here we use the editor to delete these Transpose and reset the shapes.
|
||||
python editor.py debug.onnx new.onnx -d Conv2D__6 Conv2D_19__84 -i 'images:0 1 3 41 41' -o 'output:0 1 3 41 41'
|
||||
# Now, it has no Transpose and take channel first inputs directly.
|
||||
```
|
||||
|
||||
## test_models_opt.py
|
||||
|
||||
### 4.1. Description
|
||||
Compare all original and optimized onnx models under a specified directory.
|
||||
Using different endings to locate original and optimized model paths. Apply
|
||||
onnxruntime inference to the models, and compare the results from original
|
||||
and optimized models. Calculate basic statistics and store to a csv file.
|
||||
|
||||
### 4.2. Usage
|
||||
|
||||
```bash
|
||||
python DIR ending1 ending2 csv_out_file -p=Y/N
|
||||
|
||||
# csv_out_file is file path for the stats data.
|
||||
# -p --plot is the plot option, if Y, stats plots will be generated.
|
||||
```
|
||||
|
||||
### 4.3. Statistics
|
||||
* max_rel_diff
|
||||
* max_abs_diff
|
||||
* mean_rel_diff
|
||||
* mean_abs_diff
|
||||
* std_rel_diff
|
||||
* std_abs_diff
|
||||
* acc_with_diff_precision
|
||||
* percentile
|
||||
|
||||
### 4.4. Plots
|
||||
* Max Relative Difference Histogram
|
||||
* Max Absolute Difference Histogram
|
||||
* Rel_diff Percentiles of Raw and Optimized Models
|
||||
* Abs_diff Percentiles of Raw and Optimized Models
|
||||
* Accuracies with Different Precisions
|
||||
|
||||
## tensorflow2onnx.py
|
||||
|
||||
### 5.1. Description
|
||||
Convert and optimize tensorflow models. If input file is frozen tensorflow .pb model,
|
||||
convert to onnx model and do the custmized optimization afterwards. If input model is already
|
||||
onnx model, apply optimization and save optimized model.
|
||||
|
||||
### 5.2 Dependency
|
||||
|
||||
This scripts depends on the tensorflow-onnx project. Please [check and install it](https://github.com/onnx/tensorflow-onnx/tree/r1.5) before using this script. We currently support up to version 1.5.5. For other versions, you may need to try it our yourself.
|
||||
|
||||
### 5.3. Basic Usage
|
||||
```bash
|
||||
python tensorflow2onnx.py in_file out_file -t=True/False
|
||||
|
||||
# -t --test, is the option for test mode, if True, shape change after input will not be eliminated.
|
||||
```
|
||||
|
||||
### 5.4. Model Save Paths
|
||||
`in_file` is the input model path, `out_file` specifies output optimized model path.
|
||||
If input file is `.pb` model, an unoptimized onnx model will be saved to the output directory as well.
|
||||
|
||||
59
tools/optimizer_scripts/consecutive_conv_opt.py
Normal file
59
tools/optimizer_scripts/consecutive_conv_opt.py
Normal file
@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
import onnx
|
||||
import sys
|
||||
|
||||
from tools.other import topological_sort
|
||||
from tools import helper
|
||||
|
||||
def fuse_bias_in_consecutive_1x1_conv(g):
|
||||
for second in g.node:
|
||||
# Find two conv
|
||||
if second.op_type != 'Conv':
|
||||
continue
|
||||
first = helper.find_node_by_output_name(g, second.input[0])
|
||||
if first is None or first.op_type != 'Conv':
|
||||
continue
|
||||
# Check if the first one has only one folloing node
|
||||
if len(helper.find_following_nodes_by_input_value_name(g, first.output[0])) != 1:
|
||||
continue
|
||||
# If first node has no bias, continue
|
||||
if len(first.input) == 2:
|
||||
continue
|
||||
# Check their kernel size
|
||||
first_kernel_shape = helper.get_list_attribute_by_name(first, 'kernel_shape', 'int')
|
||||
second_kernel_shape = helper.get_list_attribute_by_name(second, 'kernel_shape', 'int')
|
||||
prod = first_kernel_shape[0] * first_kernel_shape[1] * second_kernel_shape[0] * second_kernel_shape[1]
|
||||
if prod != 1:
|
||||
continue
|
||||
print('Found: ', first.name, ' ', second.name)
|
||||
# Get bias of the nodes
|
||||
first_bias_node = helper.find_node_by_output_name(g, first.input[2])
|
||||
second_weight_node = helper.find_node_by_output_name(g, second.input[1])
|
||||
second_bias_node = helper.find_node_by_output_name(g, second.input[2])
|
||||
first_bias = helper.constant_to_numpy(first_bias_node)
|
||||
second_weight = helper.constant_to_numpy(second_weight_node)
|
||||
second_bias = helper.constant_to_numpy(second_bias_node)
|
||||
# Calculate the weight for second node
|
||||
first_bias = np.reshape(first_bias, (1, first_bias.size))
|
||||
second_weight = np.reshape(second_weight, (second_weight.shape[0], second_weight.shape[1]))
|
||||
second_weight = np.transpose(second_weight)
|
||||
new_second_bias = second_bias + np.matmul(first_bias, second_weight)
|
||||
new_second_bias = np.reshape(new_second_bias, (new_second_bias.size,))
|
||||
# Generate new weight
|
||||
new_first_bias = np.reshape(first_bias, (first_bias.size, ))
|
||||
for i in range(new_first_bias.shape[0]):
|
||||
new_first_bias[i] = 0.0
|
||||
new_first_bias_node = helper.numpy_to_constant(first_bias_node.output[0], new_first_bias)
|
||||
new_second_bias_node = helper.numpy_to_constant(second_bias_node.output[0], new_second_bias)
|
||||
# Delete old weight and add new weights
|
||||
g.node.remove(first_bias_node)
|
||||
g.node.remove(second_bias_node)
|
||||
g.node.extend([new_first_bias_node, new_second_bias_node])
|
||||
topological_sort(g)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
exit(1)
|
||||
m = onnx.load(sys.argv[1])
|
||||
fuse_bias_in_consecutive_1x1_conv(m.graph)
|
||||
onnx.save(m, sys.argv[2])
|
||||
24
tools/optimizer_scripts/docker/Dockerfile
Normal file
24
tools/optimizer_scripts/docker/Dockerfile
Normal file
@ -0,0 +1,24 @@
|
||||
FROM continuumio/miniconda3:latest
|
||||
LABEL maintainer="jiyuan@kneron.us"
|
||||
|
||||
# Install python packages
|
||||
RUN conda update -y conda && \
|
||||
conda install -y python=3.6 && \
|
||||
conda install -y -c intel caffe && \
|
||||
conda install -y -c pytorch pytorch=1.3.1 torchvision=0.4.2 cpuonly && \
|
||||
conda install -y -c conda-forge tensorflow=1.5.1 keras=2.2.4 && \
|
||||
pip install onnx==1.4.1 onnxruntime==1.1.0 tf2onnx==1.5.4 && \
|
||||
ln -s /opt/conda/lib/libgflags.so.2.2.2 /opt/conda/lib/libgflags.so.2
|
||||
|
||||
# Install git lfs packages
|
||||
RUN apt-get update && apt-get install -y curl apt-utils && \
|
||||
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \
|
||||
apt-get install -y git-lfs
|
||||
|
||||
RUN conda clean -a -y && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# copy the test data
|
||||
COPY ./test_models /test_models
|
||||
|
||||
# Clean the environment and finalize the process
|
||||
WORKDIR /root
|
||||
118
tools/optimizer_scripts/editor.py
Normal file
118
tools/optimizer_scripts/editor.py
Normal file
@ -0,0 +1,118 @@
|
||||
import onnx
|
||||
import onnx.utils
|
||||
try:
|
||||
from onnx import optimizer
|
||||
except ImportError:
|
||||
import onnxoptimizer as optimizer
|
||||
import argparse
|
||||
|
||||
import tools.modhelper as helper
|
||||
import tools.other as other
|
||||
import tools.replacing as replacing
|
||||
# Main process
|
||||
# Argument parser
|
||||
parser = argparse.ArgumentParser(description="Edit an ONNX model.\nThe processing sequense is 'delete nodes/values' -> 'add nodes' -> 'change shapes'.\nCutting cannot be done with other operations together")
|
||||
parser.add_argument('in_file', type=str, help='input ONNX FILE')
|
||||
parser.add_argument('out_file', type=str, help="ouput ONNX FILE")
|
||||
parser.add_argument('-c', '--cut', dest='cut_node', type=str, nargs='+', help="remove nodes from the given nodes(inclusive)")
|
||||
parser.add_argument('--cut-type', dest='cut_type', type=str, nargs='+', help="remove nodes by type from the given nodes(inclusive)")
|
||||
parser.add_argument('-d', '--delete', dest='delete_node', type=str, nargs='+', help="delete nodes by names and only those nodes")
|
||||
parser.add_argument('--delete-input', dest='delete_input', type=str, nargs='+', help="delete inputs by names")
|
||||
parser.add_argument('--delete-output', dest='delete_output', type=str, nargs='+', help="delete outputs by names")
|
||||
parser.add_argument('-i', '--input', dest='input_change', type=str, nargs='+', help="change input shape (e.g. -i 'input_0 1 3 224 224')")
|
||||
parser.add_argument('-o', '--output', dest='output_change', type=str, nargs='+', help="change output shape (e.g. -o 'input_0 1 3 224 224')")
|
||||
parser.add_argument('--add-conv', dest='add_conv', type=str, nargs='+', help='add nop conv using specific input')
|
||||
parser.add_argument('--add-bn', dest='add_bn', type=str, nargs='+', help='add nop bn using specific input')
|
||||
parser.add_argument('--rename-output', dest='rename_output', type=str, nargs='+', help='Rename the specific output(e.g. --rename-output old_name new_name)')
|
||||
parser.add_argument('--pixel-bias-value', dest='pixel_bias_value', type=str, nargs='+', help='(per channel) set pixel value bias bn layer at model front for normalization( e.g. --pixel_bias_value "[104.0, 117.0, 123.0]" )')
|
||||
parser.add_argument('--pixel-scale-value', dest='pixel_scale_value', type=str, nargs='+', help='(per channel) set pixel value scale bn layer at model front for normalization( e.g. --pixel_scale_value "[0.0078125, 0.0078125, 0.0078125]" )')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load model and polish
|
||||
m = onnx.load(args.in_file)
|
||||
m = other.polish_model(m)
|
||||
g = m.graph
|
||||
replacing.replace_initializer_with_Constant(g)
|
||||
other.topological_sort(g)
|
||||
|
||||
# Remove nodes according to the given arguments.
|
||||
if args.delete_node is not None:
|
||||
helper.delete_nodes(g, args.delete_node)
|
||||
|
||||
if args.delete_input is not None:
|
||||
helper.delete_input(g, args.delete_input)
|
||||
|
||||
if args.delete_output is not None:
|
||||
helper.delete_output(g, args.delete_output)
|
||||
|
||||
# Add do-nothing Conv node
|
||||
if args.add_conv is not None:
|
||||
other.add_nop_conv_after(g, args.add_conv)
|
||||
other.topological_sort(g)
|
||||
|
||||
# Add do-nothing BN node
|
||||
if args.add_bn is not None:
|
||||
other.add_nop_bn_after(g, args.add_bn)
|
||||
other.topological_sort(g)
|
||||
|
||||
# Add bias scale BN node
|
||||
if args.pixel_bias_value is not None or args.pixel_scale_value is not None:
|
||||
|
||||
if len(g.input) > 1:
|
||||
raise ValueError(" '--pixel-bias-value' and '--pixel-scale-value' only support one input node model currently")
|
||||
|
||||
i_n = g.input[0]
|
||||
|
||||
pixel_bias_value = [0] * i_n.type.tensor_type.shape.dim[1].dim_value
|
||||
pixel_scale_value = [1] * i_n.type.tensor_type.shape.dim[1].dim_value
|
||||
|
||||
if args.pixel_bias_value is not None and len(args.pixel_bias_value) == 1:
|
||||
pixel_bias_value = [float(n) for n in args.pixel_bias_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')]
|
||||
|
||||
if args.pixel_scale_value is not None and len(args.pixel_scale_value) == 1:
|
||||
pixel_scale_value = [float(n) for n in args.pixel_scale_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')]
|
||||
|
||||
|
||||
if i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_bias_value) or i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_scale_value):
|
||||
raise ValueError("--pixel-bias-value (" + str(pixel_bias_value) + ") and --pixel-scale-value (" + str(pixel_scale_value) + ") should be same as input dimension:" + str(i_n.type.tensor_type.shape.dim[1].dim_value) )
|
||||
other.add_bias_scale_bn_after(g, i_n.name, pixel_bias_value, pixel_scale_value)
|
||||
|
||||
# Change input and output shapes as requested
|
||||
if args.input_change is not None:
|
||||
other.change_input_shape(g, args.input_change)
|
||||
if args.output_change is not None:
|
||||
other.change_output_shape(g, args.output_change)
|
||||
|
||||
# Cutting nodes according to the given arguments.
|
||||
if args.cut_node is not None or args.cut_type is not None:
|
||||
if args.cut_node is None:
|
||||
other.remove_nodes(g, cut_types=args.cut_type)
|
||||
elif args.cut_type is None:
|
||||
other.remove_nodes(g, cut_nodes=args.cut_node)
|
||||
else:
|
||||
other.remove_nodes(g, cut_nodes=args.cut_node, cut_types=args.cut_type)
|
||||
other.topological_sort(g)
|
||||
|
||||
# Rename nodes
|
||||
if args.rename_output:
|
||||
if len(args.rename_output) % 2 != 0:
|
||||
print("Rename output should be paires of names.")
|
||||
else:
|
||||
for i in range(0, len(args.rename_output), 2):
|
||||
other.rename_output_name(g, args.rename_output[i], args.rename_output[i + 1])
|
||||
|
||||
# Remove useless nodes
|
||||
if args.delete_node or args.delete_input or args.input_change or args.output_change:
|
||||
# If shape changed during the modification, redo shape inference.
|
||||
while(len(g.value_info) > 0):
|
||||
g.value_info.pop()
|
||||
passes = ['extract_constant_to_initializer']
|
||||
m = optimizer.optimize(m, passes)
|
||||
g = m.graph
|
||||
replacing.replace_initializer_with_Constant(g)
|
||||
other.topological_sort(g)
|
||||
# Polish and output
|
||||
m = other.polish_model(m)
|
||||
other.add_output_to_value_info(m.graph)
|
||||
onnx.save(m, args.out_file)
|
||||
52
tools/optimizer_scripts/norm_on_scaled_onnx.py
Normal file
52
tools/optimizer_scripts/norm_on_scaled_onnx.py
Normal file
@ -0,0 +1,52 @@
|
||||
import onnx
|
||||
import sys
|
||||
import json
|
||||
|
||||
from tools import special
|
||||
|
||||
if len(sys.argv) != 3:
|
||||
print("python norm_on_scaled_onnx.py input.onnx input.json")
|
||||
exit(1)
|
||||
|
||||
# Modify onnx
|
||||
m = onnx.load(sys.argv[1])
|
||||
special.add_0_5_to_normalized_input(m)
|
||||
onnx.save(m, sys.argv[1][:-4] + 'norm.onnx')
|
||||
|
||||
# Change input node
|
||||
origin_file = open(sys.argv[2], 'r')
|
||||
origin_json = json.load(origin_file)
|
||||
origin_json["input_node"]["output_datapath_radix"] = [8]
|
||||
new_json_str = json.dumps(origin_json)
|
||||
|
||||
# Modify json
|
||||
file = open(sys.argv[1][:-4] + 'norm.onnx' + '.json', 'w')
|
||||
s = """{{
|
||||
\"{0}\" :
|
||||
{{
|
||||
\"bias_bitwidth\" : 16,
|
||||
\"{0}_bias\" : [15],
|
||||
\"{0}_weight\" : [3,3,3],
|
||||
\"conv_coarse_shift\" : [-4,-4,-4],
|
||||
\"conv_fine_shift\" : [0,0,0],
|
||||
\"conv_total_shift\" : [-4,-4,-4],
|
||||
\"cpu_mode\" : false,
|
||||
\"delta_input_bitwidth\" : [0],
|
||||
\"delta_output_bitwidth\" : 8,
|
||||
\"flag_radix_bias_eq_output\" : true,
|
||||
\"input_scale\" : [[1.0,1.0,1.0]],
|
||||
\"output_scale\" : [1.0, 1.0, 1.0],
|
||||
\"psum_bitwidth\" : 16,
|
||||
\"weight_bitwidth\" : 8,
|
||||
\"input_datapath_bitwidth\" : [8],
|
||||
\"input_datapath_radix\" : [8],
|
||||
\"working_input_bitwidth\" : 8,
|
||||
\"working_input_radix\" : [8],
|
||||
\"working_output_bitwidth\" : 16,
|
||||
\"working_output_radix\" : 15,
|
||||
\"output_datapath_bitwidth\" : 8,
|
||||
\"output_datapath_radix\" : 7
|
||||
}},\n""".format('input_norm')
|
||||
file.write(s + new_json_str[1:])
|
||||
file.close()
|
||||
origin_file.close()
|
||||
135
tools/optimizer_scripts/onnx1_3to1_4.py
Normal file
135
tools/optimizer_scripts/onnx1_3to1_4.py
Normal file
@ -0,0 +1,135 @@
|
||||
# ref http://192.168.200.1:8088/jiyuan/converter_scripts.git
|
||||
|
||||
import sys
|
||||
import onnx
|
||||
import numpy as np
|
||||
from onnx import numpy_helper
|
||||
from tools import other, helper
|
||||
|
||||
"""
|
||||
Change onnx model from version 1.3 to version 1.4.
|
||||
Modify the BN node by removing the spatial attribute
|
||||
Modify the Upsample node by removing the 'scales' attribute, and adding a constant node instead.
|
||||
Model's ir_version and opset_import are updated.
|
||||
"""
|
||||
|
||||
def remove_BN_spatial(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'BatchNormalization':
|
||||
continue
|
||||
for att in node.attribute:
|
||||
if att.name == 'spatial':
|
||||
node.attribute.remove(att)
|
||||
|
||||
|
||||
def upsample_attribute_to_const(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'Upsample':
|
||||
continue
|
||||
scales_exist = False
|
||||
for att in node.attribute:
|
||||
if att.name == 'scales':
|
||||
scales_exist = True
|
||||
break
|
||||
if not scales_exist:
|
||||
continue
|
||||
|
||||
shape = [len(att.floats)]
|
||||
node.attribute.remove(att)
|
||||
new_node = helper.list_to_constant(node.name+'_input', shape, att.floats)
|
||||
|
||||
g.node.extend([new_node])
|
||||
value_info = onnx.helper.make_tensor_value_info(node.name+'_input', onnx.TensorProto.FLOAT, shape)
|
||||
node.input.extend([node.name+'_input'])
|
||||
g.value_info.extend([value_info])
|
||||
|
||||
def relu6_to_clip(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'Relu':
|
||||
continue
|
||||
max_val = helper.get_var_attribute_by_name(node, 'max', 'float')
|
||||
if max_val is None:
|
||||
continue
|
||||
new_node = onnx.helper.make_node(
|
||||
"Clip",
|
||||
node.input,
|
||||
node.output,
|
||||
name=node.name,
|
||||
max=max_val,
|
||||
min=0.0
|
||||
)
|
||||
g.node.remove(node)
|
||||
g.node.extend([new_node])
|
||||
|
||||
def PRelu_weight_reshape(g):
|
||||
# For PRelu with single dimension weight. Expand it to 1, x, 1, 1
|
||||
for node in g.node:
|
||||
if node.op_type != "PRelu":
|
||||
continue
|
||||
slope = helper.find_node_by_output_name(g, node.input[1])
|
||||
if slope is not None:
|
||||
# Constant node
|
||||
if len(slope.attribute[0].t.dims) != 1:
|
||||
continue
|
||||
slope.attribute[0].t.dims.append(slope.attribute[0].t.dims[0])
|
||||
slope.attribute[0].t.dims[0] = 1
|
||||
slope.attribute[0].t.dims.append(1)
|
||||
slope.attribute[0].t.dims.append(1)
|
||||
else:
|
||||
# Initializer
|
||||
for i in g.initializer:
|
||||
if i.name == node.input[1]:
|
||||
slope = i
|
||||
break
|
||||
if len(slope.dims) != 1:
|
||||
continue
|
||||
slope.dims.append(slope.dims[0])
|
||||
slope.dims[0] = 1
|
||||
slope.dims.append(1)
|
||||
slope.dims.append(1)
|
||||
input_value = helper.find_input_by_name(g, node.input[1])
|
||||
new_input = onnx.helper.make_tensor_value_info(
|
||||
node.input[1],
|
||||
input_value.type.tensor_type.elem_type,
|
||||
(1, slope.dims[1], 1, 1))
|
||||
g.input.remove(input_value)
|
||||
g.input.append(new_input)
|
||||
value_info = helper.find_value_by_name(g, node.input[1])
|
||||
if value_info is not None:
|
||||
g.value_info.remove(value_info)
|
||||
|
||||
def do_convert(m):
|
||||
graph = m.graph
|
||||
|
||||
# Modify the nodes.
|
||||
remove_BN_spatial(graph)
|
||||
upsample_attribute_to_const(graph)
|
||||
relu6_to_clip(graph)
|
||||
PRelu_weight_reshape(graph)
|
||||
other.topological_sort(graph)
|
||||
|
||||
# Change model properties.
|
||||
m.ir_version = 4
|
||||
m.opset_import[0].version = 9
|
||||
return m
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print("Usage:{} file_in file_out".format(sys.argv[0]))
|
||||
exit(1)
|
||||
|
||||
model = onnx.load(sys.argv[1])
|
||||
graph = model.graph
|
||||
|
||||
# Modify the nodes.
|
||||
remove_BN_spatial(graph)
|
||||
upsample_attribute_to_const(graph)
|
||||
relu6_to_clip(graph)
|
||||
PRelu_weight_reshape(graph)
|
||||
other.topological_sort(graph)
|
||||
|
||||
# Change model properties.
|
||||
model.ir_version = 4
|
||||
model.opset_import[0].version = 9
|
||||
|
||||
onnx.save(model, sys.argv[2])
|
||||
184
tools/optimizer_scripts/onnx1_4to1_6.py
Normal file
184
tools/optimizer_scripts/onnx1_4to1_6.py
Normal file
@ -0,0 +1,184 @@
|
||||
# ref http://192.168.200.1:8088/jiyuan/converter_scripts.git
|
||||
|
||||
import sys
|
||||
import onnx
|
||||
import onnx.utils
|
||||
import numpy as np
|
||||
from onnx import numpy_helper
|
||||
from tools import other, helper, replacing
|
||||
|
||||
"""
|
||||
Change onnx model from version 1.4 to version 1.6.
|
||||
"""
|
||||
|
||||
def replace_all_attribute_to_const_node_in_pad_node(g):
|
||||
node_to_remove = []
|
||||
node_to_extend = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Pad':
|
||||
continue
|
||||
|
||||
pad_loc_node = None # must have
|
||||
pad_mode = 'constant'
|
||||
pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [0.0]) # need scalar
|
||||
for att in node.attribute:
|
||||
if att.name == 'mode':
|
||||
pad_mode = helper.get_var_attribute_by_name(node, 'mode', 'string')
|
||||
if att.name == 'pads':
|
||||
pad_loc_node = helper.list_to_constant(node.name+'_pad_loc', [len(att.ints)], att.ints)
|
||||
if att.name == 'value':
|
||||
pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [att.f])
|
||||
|
||||
new_node = onnx.helper.make_node(
|
||||
"Pad",
|
||||
[node.input[0], pad_loc_node.name, pad_value_node.name],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
mode=pad_mode,
|
||||
)
|
||||
node_to_remove.append(node)
|
||||
node_to_extend.append(new_node)
|
||||
node_to_extend.append(pad_loc_node)
|
||||
node_to_extend.append(pad_value_node)
|
||||
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
for node in node_to_extend:
|
||||
g.node.extend([node])
|
||||
|
||||
|
||||
def upsampling_to_resize(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'Upsample':
|
||||
continue
|
||||
upsampling_mode = helper.get_var_attribute_by_name(node, 'mode', 'string')
|
||||
|
||||
scale_value_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
if scale_value_node.op_type != "Constant":
|
||||
raise TypeError('seems there is a dynamic "scales" param in Upsampling node: ' + node.name + ' , you might need to do constant folding first')
|
||||
|
||||
roi_node = helper.list_to_constant(node.name+'_roi_value', [0], [])
|
||||
|
||||
new_node = onnx.helper.make_node(
|
||||
"Resize",
|
||||
[node.input[0], roi_node.name, scale_value_node.name],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
mode=upsampling_mode,
|
||||
coordinate_transformation_mode = 'asymmetric'
|
||||
)
|
||||
|
||||
g.node.remove(node)
|
||||
g.node.extend([new_node])
|
||||
g.node.extend([roi_node])
|
||||
|
||||
|
||||
def replace_all_attribute_to_const_node_in_slice_node(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'Slice':
|
||||
continue
|
||||
|
||||
axes_const_node = None
|
||||
ends_const_node = None
|
||||
starts_const_node = None
|
||||
steps_const_node = None
|
||||
for att in node.attribute:
|
||||
if att.name == 'axes':
|
||||
axes_const_node = helper.list_to_constant(node.name+'_axes_value', [len(att.ints)], att.ints)
|
||||
|
||||
if att.name == 'ends':
|
||||
ends_const_node = helper.list_to_constant(node.name+'_ends_value', [len(att.ints)], att.ints)
|
||||
|
||||
if att.name == 'starts':
|
||||
starts_const_node = helper.list_to_constant(node.name+'_starts_value', [len(att.ints)], att.ints)
|
||||
|
||||
if att.name == 'steps':
|
||||
steps_const_node = helper.list_to_constant(node.name+'_steps_value',[ len(att.ints)], att.ints)
|
||||
|
||||
## pop out from back
|
||||
attr_len = len(node.attribute)
|
||||
for i in range(attr_len):
|
||||
node.attribute.remove(node.attribute[ attr_len -1 - i ])
|
||||
|
||||
## according the spec, we need to add node in specific order
|
||||
if starts_const_node != None:
|
||||
g.node.extend([starts_const_node])
|
||||
node.input.extend([starts_const_node.name])
|
||||
if ends_const_node != None:
|
||||
g.node.extend([ends_const_node])
|
||||
node.input.extend([ends_const_node.name])
|
||||
if axes_const_node != None:
|
||||
g.node.extend([axes_const_node])
|
||||
node.input.extend([axes_const_node.name])
|
||||
if steps_const_node != None:
|
||||
g.node.extend([steps_const_node])
|
||||
node.input.extend([steps_const_node.name])
|
||||
|
||||
|
||||
def replace_min_max_attribute_to_const_node_in_clip_node(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'Clip':
|
||||
continue
|
||||
|
||||
max_const_node = None
|
||||
min_const_node = None
|
||||
for att in node.attribute:
|
||||
if att.name == 'max':
|
||||
max_const_node = helper.list_to_constant(node.name+'_max_value', [], [att.f])
|
||||
|
||||
if att.name == 'min':
|
||||
min_const_node = helper.list_to_constant(node.name+'_min_value', [], [att.f])
|
||||
|
||||
## pop out from back
|
||||
node.attribute.remove(node.attribute[1])
|
||||
node.attribute.remove(node.attribute[0])
|
||||
|
||||
## according the spec, we need to add node in specific order
|
||||
g.node.extend([min_const_node])
|
||||
g.node.extend([max_const_node])
|
||||
node.input.extend([min_const_node.name])
|
||||
node.input.extend([max_const_node.name])
|
||||
|
||||
def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto:
|
||||
"""Update ir_version from 4 to 6 and update opset from 9 to 11.
|
||||
|
||||
Args:
|
||||
model (onnx.ModelProto): input onnx model.
|
||||
|
||||
Returns:
|
||||
onnx.ModelProto: updated onnx model.
|
||||
"""
|
||||
graph = model.graph
|
||||
|
||||
if model.opset_import[0].version == 11:
|
||||
print("(Stop) the input model is already opset 11, no need to upgrade")
|
||||
exit(1)
|
||||
|
||||
# deal with empty node name issue
|
||||
other.add_name_to_node(graph)
|
||||
# simplify the node param type from initializer to constant
|
||||
replacing.replace_initializer_with_Constant(graph)
|
||||
|
||||
# Modify the nodes.
|
||||
replace_min_max_attribute_to_const_node_in_clip_node(graph)
|
||||
replace_all_attribute_to_const_node_in_slice_node(graph)
|
||||
replace_all_attribute_to_const_node_in_pad_node(graph)
|
||||
upsampling_to_resize(graph)
|
||||
other.topological_sort(graph)
|
||||
|
||||
# Change model properties.
|
||||
model.ir_version = 6
|
||||
model.opset_import[0].version = 11
|
||||
|
||||
model = other.polish_model(model)
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print("Usage:{} file_in file_out".format(sys.argv[0]))
|
||||
exit(1)
|
||||
|
||||
model = onnx.load(sys.argv[1])
|
||||
model = onnx1_4to1_6(model)
|
||||
|
||||
onnx.save(model, sys.argv[2])
|
||||
136
tools/optimizer_scripts/onnx2onnx.py
Normal file
136
tools/optimizer_scripts/onnx2onnx.py
Normal file
@ -0,0 +1,136 @@
|
||||
import onnx
|
||||
import onnx.utils
|
||||
try:
|
||||
from onnx import optimizer
|
||||
except ImportError:
|
||||
import onnxoptimizer as optimizer
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from tools import eliminating
|
||||
from tools import fusing
|
||||
from tools import replacing
|
||||
from tools import other
|
||||
from tools import special
|
||||
from tools import combo
|
||||
from tools.helper import logger
|
||||
# from tools import temp
|
||||
|
||||
def onnx2onnx_flow(m: onnx.ModelProto,
|
||||
disable_fuse_bn=False,
|
||||
bn_on_skip=False,
|
||||
bn_before_add=False,
|
||||
bgr=False,
|
||||
norm=False,
|
||||
rgba2yynn=False,
|
||||
eliminate_tail=False,
|
||||
opt_matmul=False,
|
||||
duplicate_shared_weights=True) -> onnx.ModelProto:
|
||||
"""Optimize the onnx.
|
||||
|
||||
Args:
|
||||
m (ModelProto): the input onnx ModelProto
|
||||
disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False.
|
||||
bn_on_skip (bool, optional): add BN operator on skip branches. Defaults to False.
|
||||
bn_before_add (bool, optional): add BN before Add node on every branches. Defaults to False.
|
||||
bgr (bool, optional): add an Conv layer to convert rgb input to bgr. Defaults to False.
|
||||
norm (bool, optional): add an Conv layer to add 0.5 tp the input. Defaults to False.
|
||||
rgba2yynn (bool, optional): add an Conv layer to convert rgb input to yynn . Defaults to False.
|
||||
eliminate_tail (bool, optional): remove the trailing NPU unsupported nodes. Defaults to False.
|
||||
opt_matmul(bool, optional): optimize the MatMul layers according to the NPU limit. Defaults to False.
|
||||
duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True.
|
||||
|
||||
Returns:
|
||||
ModelProto: the optimized onnx model object.
|
||||
"""
|
||||
# temp.weight_broadcast(m.graph)
|
||||
m = combo.preprocess(m, disable_fuse_bn, duplicate_shared_weights)
|
||||
# temp.fuse_bias_in_consecutive_1x1_conv(m.graph)
|
||||
|
||||
# Add BN on skip branch
|
||||
if bn_on_skip:
|
||||
other.add_bn_on_skip_branch(m.graph)
|
||||
elif bn_before_add:
|
||||
other.add_bn_before_add(m.graph)
|
||||
other.add_bn_before_activation(m.graph)
|
||||
|
||||
# My optimization
|
||||
m = combo.common_optimization(m)
|
||||
# Special options
|
||||
if bgr:
|
||||
special.change_input_from_bgr_to_rgb(m)
|
||||
if norm:
|
||||
special.add_0_5_to_normalized_input(m)
|
||||
if rgba2yynn:
|
||||
special.add_rgb2yynn_node(m)
|
||||
|
||||
# Remove useless last node
|
||||
if eliminate_tail:
|
||||
eliminating.remove_useless_last_nodes(m.graph)
|
||||
|
||||
# Postprocessing
|
||||
m = combo.postprocess(m)
|
||||
|
||||
# Put matmul after postprocess to avoid transpose moving downwards
|
||||
if opt_matmul:
|
||||
special.special_MatMul_process(m.graph)
|
||||
m = other.polish_model(m)
|
||||
|
||||
return m
|
||||
|
||||
# Main process
|
||||
if __name__ == "__main__":
|
||||
# Argument parser
|
||||
parser = argparse.ArgumentParser(description="Optimize an ONNX model for Kneron compiler")
|
||||
parser.add_argument('in_file', help='input ONNX FILE')
|
||||
parser.add_argument('-o', '--output', dest='out_file', type=str, help="ouput ONNX FILE")
|
||||
parser.add_argument('--log', default='i', type=str, help="set log level")
|
||||
parser.add_argument('--bgr', action='store_true', default=False, help="set if the model is trained in BGR mode")
|
||||
parser.add_argument('--norm', action='store_true', default=False, help="set if you have the input -0.5~0.5")
|
||||
parser.add_argument('--rgba2yynn', action='store_true', default=False, help="set if the model has yynn input but you want to take rgba images")
|
||||
parser.add_argument('--add-bn-on-skip', dest='bn_on_skip', action='store_true', default=False,
|
||||
help="set if you only want to add BN on skip branches")
|
||||
parser.add_argument('--add-bn', dest='bn_before_add', action='store_true', default=False,
|
||||
help="set if you want to add BN before Add")
|
||||
parser.add_argument('-t', '--eliminate-tail-unsupported', dest='eliminate_tail', action='store_true', default=False,
|
||||
help='whether remove the last unsupported node for hardware')
|
||||
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
|
||||
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
|
||||
parser.add_argument('--opt-matmul', dest='opt_matmul', action='store_true', default=False,
|
||||
help="set if you want to optimize the MatMul operations for the kneron hardware.")
|
||||
parser.add_argument('--no-duplicate-shared-weights', dest='no_duplicate_shared_weights', action='store_true', default=False,
|
||||
help='do not duplicate shared weights. Defaults to False.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.out_file is None:
|
||||
outfile = args.in_file[:-5] + "_polished.onnx"
|
||||
else:
|
||||
outfile = args.out_file
|
||||
|
||||
if args.log == 'w':
|
||||
logging.basicConfig(level=logging.WARN)
|
||||
elif args.log == 'd':
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
elif args.log == 'e':
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# onnx Polish model includes:
|
||||
# -- nop
|
||||
# -- eliminate_identity
|
||||
# -- eliminate_nop_transpose
|
||||
# -- eliminate_nop_pad
|
||||
# -- eliminate_unused_initializer
|
||||
# -- fuse_consecutive_squeezes
|
||||
# -- fuse_consecutive_transposes
|
||||
# -- fuse_add_bias_into_conv
|
||||
# -- fuse_transpose_into_gemm
|
||||
|
||||
# Basic model organize
|
||||
m = onnx.load(args.in_file)
|
||||
|
||||
m = onnx2onnx_flow(m, args.disable_fuse_bn, args.bn_on_skip, args.bn_before_add, args.bgr, args.norm, args.rgba2yynn, args.eliminate_tail, args.opt_matmul, not args.no_duplicate_shared_weights)
|
||||
|
||||
onnx.save(m, outfile)
|
||||
134
tools/optimizer_scripts/onnx_vs_onnx.py
Normal file
134
tools/optimizer_scripts/onnx_vs_onnx.py
Normal file
@ -0,0 +1,134 @@
|
||||
import onnxruntime
|
||||
import onnx
|
||||
import argparse
|
||||
import numpy as np
|
||||
from tools import helper
|
||||
|
||||
|
||||
onnx2np_dtype = {0: 'float', 1: 'float32', 2: 'uint8', 3: 'int8', 4: 'uint16', 5: 'int16', 6: 'int32', 7: 'int64', 8: 'str', 9: 'bool', 10: 'float16', 11: 'double', 12: 'uint32', 13: 'uint64', 14: 'complex64', 15: 'complex128', 16: 'float'}
|
||||
|
||||
|
||||
def onnx_model_results(path_a, path_b, total_times=10):
|
||||
""" using onnxruntime to inference two onnx models' ouputs
|
||||
|
||||
:onnx model paths: two model paths
|
||||
:total_times: inference times, default to be 10
|
||||
:returns: inference results of two models
|
||||
"""
|
||||
# load model a and model b to runtime
|
||||
session_a = onnxruntime.InferenceSession(path_a, None)
|
||||
session_b = onnxruntime.InferenceSession(path_b, None)
|
||||
outputs_a = session_a.get_outputs()
|
||||
outputs_b = session_b.get_outputs()
|
||||
|
||||
# check outputs
|
||||
assert len(outputs_a) == len(outputs_b), 'Two models have different output numbers.'
|
||||
for i in range(len(outputs_a)):
|
||||
out_shape_a, out_shape_b = outputs_a[i].shape, outputs_b[i].shape
|
||||
out_shape_a = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_a))
|
||||
out_shape_b = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_b))
|
||||
assert out_shape_a == out_shape_b, 'Output {} has unmatched shapes'.format(i)
|
||||
|
||||
|
||||
# load onnx graph_a and graph_b, to find the initializer and inputs
|
||||
# then compare to remove the items in the inputs which will be initialized
|
||||
model_a, model_b = onnx.load(path_a), onnx.load(path_b)
|
||||
graph_a, graph_b = model_a.graph, model_b.graph
|
||||
inputs_a, inputs_b = graph_a.input, graph_b.input
|
||||
init_a, init_b = graph_a.initializer, graph_b.initializer
|
||||
|
||||
# remove initializer from raw inputs
|
||||
input_names_a, input_names_b = set([ele.name for ele in inputs_a]), set([ele.name for ele in inputs_b])
|
||||
init_names_a, init_names_b = set([ele.name for ele in init_a]), set([ele.name for ele in init_b])
|
||||
real_inputs_names_a, real_inputs_names_b = input_names_a - init_names_a, input_names_b - init_names_b
|
||||
|
||||
# prepare and figure out matching of real inputs a and real inputs b
|
||||
# try to keep original orders of each inputs
|
||||
real_inputs_a, real_inputs_b = [], []
|
||||
for item in inputs_a:
|
||||
if item.name in real_inputs_names_a:
|
||||
real_inputs_a.append(item)
|
||||
for item in inputs_b:
|
||||
if item.name in real_inputs_names_b:
|
||||
real_inputs_b.append(item)
|
||||
|
||||
# suppose there's only one real single input tensor for each model
|
||||
# find the real single inputs for model_a and model_b
|
||||
real_single_input_a = None
|
||||
real_single_input_b = None
|
||||
size_a, size_b = 0, 0
|
||||
shape_a, shape_b = [], []
|
||||
for item_a in real_inputs_a:
|
||||
size, shape = helper.find_size_shape_from_value(item_a)
|
||||
if size:
|
||||
assert real_single_input_a is None, 'Multiple inputs of first model, single input expected.'
|
||||
real_single_input_a = item_a
|
||||
size_a, shape_a = size, shape
|
||||
for item_b in real_inputs_b:
|
||||
size, shape = helper.find_size_shape_from_value(item_b)
|
||||
if size:
|
||||
assert real_single_input_b is None, 'Multiple inputs of second model, single input expected.'
|
||||
real_single_input_b = item_b
|
||||
size_b, shape_b = size, shape
|
||||
assert size_a == size_b, 'Sizes of two models do not match.'
|
||||
|
||||
|
||||
# construct inputs tensors
|
||||
input_data_type_a = real_single_input_a.type.tensor_type.elem_type
|
||||
input_data_type_b = real_single_input_b.type.tensor_type.elem_type
|
||||
input_data_type_a = onnx2np_dtype[input_data_type_a]
|
||||
input_data_type_b = onnx2np_dtype[input_data_type_b]
|
||||
|
||||
# run inference
|
||||
times = 0
|
||||
results_a = [[] for i in range(len(outputs_a))]
|
||||
results_b = [[] for i in range(len(outputs_b))]
|
||||
while times < total_times:
|
||||
# initialize inputs by random data, default to be uniform
|
||||
data = np.random.random(size_a)
|
||||
input_a = np.reshape(data, shape_a).astype(input_data_type_a)
|
||||
input_b = np.reshape(data, shape_b).astype(input_data_type_b)
|
||||
|
||||
input_dict_a = {}
|
||||
input_dict_b = {}
|
||||
for item_a in real_inputs_a:
|
||||
item_type_a = onnx2np_dtype[item_a.type.tensor_type.elem_type]
|
||||
input_dict_a[item_a.name] = np.array([]).astype(item_type_a) \
|
||||
if item_a.name != real_single_input_a.name else input_a
|
||||
for item_b in real_inputs_b:
|
||||
item_type_b = onnx2np_dtype[item_b.type.tensor_type.elem_type]
|
||||
input_dict_b[item_b.name] = np.array([]).astype(item_type_b) \
|
||||
if item_b.name != real_single_input_b.name else input_b
|
||||
|
||||
ra = session_a.run([], input_dict_a)
|
||||
rb = session_b.run([], input_dict_b)
|
||||
for i in range(len(outputs_a)):
|
||||
results_a[i].append(ra[i])
|
||||
results_b[i].append(rb[i])
|
||||
times += 1
|
||||
|
||||
return results_a, results_b
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Argument parser.
|
||||
parser = argparse.ArgumentParser(description="Compare two ONNX models to check if they have the same output.")
|
||||
parser.add_argument('in_file_a', help='input ONNX file a')
|
||||
parser.add_argument('in_file_b', help='input ONNX file b')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
results_a, results_b = onnx_model_results(args.in_file_a, args.in_file_b, total_times=10)
|
||||
ra_flat = helper.flatten_with_depth(results_a, 0)
|
||||
rb_flat = helper.flatten_with_depth(results_b, 0)
|
||||
shape_a = [item[1] for item in ra_flat]
|
||||
shape_b = [item[1] for item in rb_flat]
|
||||
assert shape_a == shape_b, 'two results data shape doesn\'t match'
|
||||
ra_raw = [item[0] for item in ra_flat]
|
||||
rb_raw = [item[0] for item in rb_flat]
|
||||
|
||||
try:
|
||||
np.testing.assert_almost_equal(ra_raw, rb_raw, 4)
|
||||
print('Two models have the same behaviour.')
|
||||
except Exception as mismatch:
|
||||
print(mismatch)
|
||||
exit(1)
|
||||
221
tools/optimizer_scripts/onnx_vs_onnx_opt.py
Normal file
221
tools/optimizer_scripts/onnx_vs_onnx_opt.py
Normal file
@ -0,0 +1,221 @@
|
||||
import onnx
|
||||
import argparse
|
||||
import glob
|
||||
import csv
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from tools import helper
|
||||
import onnx_vs_onnx as onnx_tester
|
||||
|
||||
def compare_results(results_a, results_b):
|
||||
""" compare onnx model inference results
|
||||
calculate basic statistical values
|
||||
results: results from inference multiple times
|
||||
returns: list of basic statistical values
|
||||
"""
|
||||
# input results data can be of nonuniform shape
|
||||
# get flatten data to compare
|
||||
ra_flat = helper.flatten_with_depth(results_a, 0)
|
||||
rb_flat = helper.flatten_with_depth(results_b, 0)
|
||||
shape_a = [item[1] for item in ra_flat]
|
||||
shape_b = [item[1] for item in rb_flat]
|
||||
assert shape_a == shape_b, 'two results data shape doesn\'t match'
|
||||
ra_raw = [item[0] for item in ra_flat]
|
||||
rb_raw = [item[0] for item in rb_flat]
|
||||
|
||||
# the statistical values
|
||||
max_rel_diff = 0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } )
|
||||
max_abs_diff = 0 # defined to be max( { abs(ra-rb) } )
|
||||
mean_rel_diff = 0
|
||||
mean_abs_diff = 0
|
||||
std_rel_diff = 0
|
||||
std_abs_diff = 0
|
||||
acc_with_diff_precision = []
|
||||
rel_diff = []
|
||||
abs_diff_percentiles = [] # rel_diff percentiles
|
||||
rel_diff_percentiles = [] # abs_diff precentiles
|
||||
|
||||
raw_diff = [ra_raw[i]-rb_raw[i] for i in range(len(ra_raw))]
|
||||
abs_diff = [abs(num) for num in raw_diff]
|
||||
for i in range(len(ra_raw)):
|
||||
divider = max([abs(ra_raw[i]), abs(rb_raw[i])])
|
||||
val = abs_diff[i]/divider if divider != 0 else 0
|
||||
rel_diff.append(val)
|
||||
|
||||
max_rel_diff = max(rel_diff)
|
||||
max_abs_diff = max(abs_diff)
|
||||
mean_rel_diff = np.average(rel_diff)
|
||||
mean_abs_diff = np.average(abs_diff)
|
||||
std_rel_diff = np.std(rel_diff)
|
||||
std_abs_diff = np.std(abs_diff)
|
||||
|
||||
# calculate accuracy with different precison
|
||||
for digit in range(8):
|
||||
correct = 0
|
||||
for i in range(len(ra_raw)):
|
||||
if format(ra_raw[i], '.'+str(digit)+'f')\
|
||||
== format(rb_raw[i], '.'+str(digit)+'f'):
|
||||
correct += 1
|
||||
acc_with_diff_precision.append([digit, float(format(correct/len(ra_raw), '.3f'))])
|
||||
|
||||
# analyze rel_diff distribution
|
||||
rel_diff.sort()
|
||||
abs_diff.sort()
|
||||
for i in range(20):
|
||||
rel_diff_percentiles.append(['{}%'.format(i*5), rel_diff[int((i/20)*len(rel_diff))]])
|
||||
abs_diff_percentiles.append(['{}%'.format(i*5), abs_diff[int((i/20)*len(abs_diff))]])
|
||||
|
||||
results = [
|
||||
['max_rel_diff', max_rel_diff],
|
||||
['max_abs_diff', max_abs_diff],
|
||||
['mean_rel_diff', mean_rel_diff],
|
||||
['mean_abs_diff', mean_abs_diff],
|
||||
['std_rel_diff', std_rel_diff],
|
||||
['std_abs_diff', std_abs_diff],
|
||||
['acc_with_diff_precision', acc_with_diff_precision],
|
||||
['rel_diff_percentiles', rel_diff_percentiles],
|
||||
['abs_diff_percentiles', abs_diff_percentiles]
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='test model optimization results')
|
||||
|
||||
parser.add_argument('dir', type=str, help='the directory that stores onnx models')
|
||||
parser.add_argument('ending1', type=str, help='model file name ending(eg, .onnx)')
|
||||
parser.add_argument('ending2', type=str, help='opt model file name ending(eg. _opt.onnx)')
|
||||
parser.add_argument('out_file', type=str, help='output csv file name')
|
||||
parser.add_argument('-p', '--plot', default='N', help='get plots (Y/N)')
|
||||
parser.add_argument('-i', '--iter_times', default=10, type=int, help='inference times')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
old_models_paths = glob.glob(args.dir+'*'+args.ending1)
|
||||
new_models_paths = glob.glob(args.dir+'*'+args.ending2)
|
||||
|
||||
stats_table = [[
|
||||
'Model',
|
||||
'max_rel_diff',
|
||||
'max_abs_diff',
|
||||
'mean_rel_diff',
|
||||
'mean_abs_diff',
|
||||
'std_rel_diff',
|
||||
'std_abs_diff',
|
||||
'acc_with_diff_precision',
|
||||
'rel_diff_percentiles',
|
||||
'abs_diff_percentiles'
|
||||
]]
|
||||
|
||||
for new_model_path in new_models_paths:
|
||||
old_model_path = new_model_path[:-len(args.ending2)] + args.ending1
|
||||
if old_model_path not in old_models_paths:
|
||||
continue
|
||||
|
||||
# run inference
|
||||
results_a, results_b = onnx_tester.onnx_model_results(old_model_path, new_model_path, total_times=args.iter_times)
|
||||
|
||||
# compare inference results
|
||||
comparision = compare_results(results_a, results_b)
|
||||
|
||||
new_line = [old_model_path.split('/')[-1]]
|
||||
for item in comparision:
|
||||
new_line.append(item[1])
|
||||
|
||||
stats_table.append(new_line)
|
||||
|
||||
# try to read existing file
|
||||
old_stats_table = []
|
||||
try:
|
||||
old_file = open(args.out_file, 'r')
|
||||
reader = csv.reader(old_file)
|
||||
old_header = reader.__next__()
|
||||
for row in reader:
|
||||
old_stats_table.append(row)
|
||||
old_file.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
# compare and merge possible old stat data file with new stat data file
|
||||
header = stats_table[0]
|
||||
stats_table = stats_table[1:]
|
||||
new_model_names = set([item[0] for item in stats_table])
|
||||
for row in old_stats_table:
|
||||
if row[0] not in new_model_names:
|
||||
stats_table.append(row)
|
||||
stats_table.insert(0, header)
|
||||
|
||||
# write a new stat data file, overwrite old file
|
||||
new_file = open(args.out_file, 'w', newline='')
|
||||
writer = csv.writer(new_file)
|
||||
for row in stats_table:
|
||||
writer.writerow(row)
|
||||
new_file.close()
|
||||
|
||||
# make some plots
|
||||
if args.plot == 'Y':
|
||||
if len(stats_table) < 2:
|
||||
exit(0)
|
||||
|
||||
sample_table = stats_table[1:] if len(stats_table) < 6 else stats_table[1:6]
|
||||
|
||||
max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]]
|
||||
plt.hist(max_rel_diffs, bins=15)
|
||||
plt.title('Max Relavtive Difference Histogram')
|
||||
plt.xlabel('Max Relative Difference')
|
||||
plt.ylabel('Counts')
|
||||
plt.savefig('max_rel_diff_hist.png')
|
||||
plt.close()
|
||||
|
||||
max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]]
|
||||
plt.hist(max_abs_diffs, bins=15)
|
||||
plt.title('Max Absolute Difference Histogram')
|
||||
plt.xlabel('Max Absolute Difference')
|
||||
plt.ylabel('Counts')
|
||||
plt.savefig('max_abs_diff_hist.png')
|
||||
plt.close()
|
||||
|
||||
for line in sample_table:
|
||||
model_name = line[0]
|
||||
percentiles = line[-2]
|
||||
x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))]
|
||||
y = [ele[1] for ele in percentiles]
|
||||
plt.plot(x, y, label=model_name)
|
||||
plt.title('Rel_diff Percentiles of Raw and Optimized Models')
|
||||
plt.xlabel('percentage')
|
||||
plt.ylabel('relative difference')
|
||||
plt.legend()
|
||||
plt.savefig('rel_diff_percentiles.png')
|
||||
plt.close()
|
||||
|
||||
for line in sample_table:
|
||||
model_name = line[0]
|
||||
percentiles = line[-1]
|
||||
x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))]
|
||||
y = [ele[1] for ele in percentiles]
|
||||
plt.plot(x, y, label=model_name)
|
||||
plt.title('Abs_diff Percentiles of Raw and Optimized Models')
|
||||
plt.xlabel('percentage')
|
||||
plt.ylabel('absolute difference')
|
||||
plt.legend()
|
||||
plt.savefig('abs_diff_percentiles.png')
|
||||
plt.close()
|
||||
|
||||
for line in sample_table:
|
||||
model_name = line[0]
|
||||
accuracies = line[-3]
|
||||
x = [acc[0] for acc in accuracies]
|
||||
y = [acc[1] for acc in accuracies]
|
||||
plt.plot(x, y, label=model_name)
|
||||
plt.title('Accuracies with Different Precisions')
|
||||
plt.xlabel('Decimals')
|
||||
plt.ylabel('Precision')
|
||||
plt.legend()
|
||||
plt.savefig('precisions.png')
|
||||
plt.close()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
81
tools/optimizer_scripts/pytorch2onnx.py
Normal file
81
tools/optimizer_scripts/pytorch2onnx.py
Normal file
@ -0,0 +1,81 @@
|
||||
import onnx
|
||||
import onnx.utils
|
||||
try:
|
||||
from onnx import optimizer
|
||||
except ImportError:
|
||||
import onnxoptimizer as optimizer
|
||||
import sys
|
||||
import numpy as np
|
||||
import struct
|
||||
import logging
|
||||
import argparse
|
||||
|
||||
from tools import eliminating
|
||||
from tools import fusing
|
||||
from tools import replacing
|
||||
from tools import other
|
||||
from tools import combo
|
||||
from tools import special
|
||||
from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow
|
||||
|
||||
# Debug use
|
||||
# logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
######################################
|
||||
# Generate a prototype onnx #
|
||||
######################################
|
||||
|
||||
parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler")
|
||||
parser.add_argument('in_file', help='input ONNX or PTH FILE')
|
||||
parser.add_argument('out_file', help="ouput ONNX FILE")
|
||||
parser.add_argument('--input-size', dest='input_size', nargs=3,
|
||||
help='if you using pth, please use this argument to set up the input size of the model. It should be in \'CH H W\' format, e.g. \'--input-size 3 256 512\'.')
|
||||
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
|
||||
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if len(args.in_file) <= 4:
|
||||
# When the filename is too short.
|
||||
logging.error("Invalid input file: {}".format(args.in_file))
|
||||
exit(1)
|
||||
elif args.in_file[-4:] == '.pth':
|
||||
# Pytorch pth case
|
||||
logging.warning("Converting from pth to onnx is not recommended.")
|
||||
onnx_in = args.out_file
|
||||
# Import pytorch libraries
|
||||
from torch.autograd import Variable
|
||||
import torch
|
||||
import torch.onnx
|
||||
# import torchvision
|
||||
# Standard ImageNet input - 3 channels, 224x224.
|
||||
# Values don't matter as we care about network structure.
|
||||
# But they can also be real inputs.
|
||||
if args.input_size is None:
|
||||
logging.error("\'--input-size\' is required for the pth input file.")
|
||||
exit(1)
|
||||
dummy_input = Variable(torch.randn(1, int(args.input_size[0]), int(args.input_size[1]), int(args.input_size[2])))
|
||||
# Obtain your model, it can be also constructed in your script explicitly.
|
||||
model = torch.load(sys.argv[1], map_location='cpu')
|
||||
# model = torchvision.models.resnet34(pretrained=True)
|
||||
# Invoke export.
|
||||
# torch.save(model, "resnet34.pth")
|
||||
torch.onnx.export(model, dummy_input, args.out_file, opset_version=11)
|
||||
elif args.in_file[-4:] == 'onnx':
|
||||
onnx_in = args.in_file
|
||||
else:
|
||||
# When the file is neither an onnx or a pytorch pth.
|
||||
logging.error("Invalid input file: {}".format(args.in_file))
|
||||
exit(1)
|
||||
|
||||
onnx_out = args.out_file
|
||||
|
||||
######################################
|
||||
# Optimize onnx #
|
||||
######################################
|
||||
|
||||
m = onnx.load(onnx_in)
|
||||
|
||||
m = torch_exported_onnx_flow(m, args.disable_fuse_bn)
|
||||
|
||||
onnx.save(m, onnx_out)
|
||||
80
tools/optimizer_scripts/pytorch_exported_onnx_preprocess.py
Normal file
80
tools/optimizer_scripts/pytorch_exported_onnx_preprocess.py
Normal file
@ -0,0 +1,80 @@
|
||||
import onnx
|
||||
import onnx.utils
|
||||
try:
|
||||
from onnx import optimizer
|
||||
except ImportError:
|
||||
import onnxoptimizer as optimizer
|
||||
import sys
|
||||
import numpy as np
|
||||
import struct
|
||||
import logging
|
||||
import argparse
|
||||
|
||||
from .tools import eliminating
|
||||
from .tools import fusing
|
||||
from .tools import replacing
|
||||
from .tools import other
|
||||
from .tools import combo
|
||||
from .tools import special
|
||||
|
||||
# Define general pytorch exported onnx optimize process
|
||||
def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx.ModelProto:
|
||||
"""Optimize the Pytorch exported onnx.
|
||||
|
||||
Args:
|
||||
m (ModelProto): the input onnx model
|
||||
disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ModelProto: the optimized onnx model
|
||||
"""
|
||||
m = combo.preprocess(m, disable_fuse_bn)
|
||||
m = combo.pytorch_constant_folding(m)
|
||||
m = combo.common_optimization(m)
|
||||
m = combo.postprocess(m)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
# Main Process
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler")
|
||||
parser.add_argument('in_file', help='input ONNX')
|
||||
parser.add_argument('out_file', help="ouput ONNX FILE")
|
||||
parser.add_argument('--log', default='i', type=str, help="set log level")
|
||||
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
|
||||
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.log == 'w':
|
||||
logging.basicConfig(level=logging.WARN)
|
||||
elif args.log == 'd':
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
elif args.log == 'e':
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
if len(args.in_file) <= 4:
|
||||
# When the filename is too short.
|
||||
logging.error("Invalid input file: {}".format(args.in_file))
|
||||
exit(1)
|
||||
elif args.in_file[-4:] == 'onnx':
|
||||
onnx_in = args.in_file
|
||||
else:
|
||||
# When the file is not an onnx file.
|
||||
logging.error("Invalid input file: {}".format(args.in_file))
|
||||
exit(1)
|
||||
|
||||
onnx_out = args.out_file
|
||||
|
||||
######################################
|
||||
# Optimize onnx #
|
||||
######################################
|
||||
|
||||
m = onnx.load(onnx_in)
|
||||
|
||||
m = torch_exported_onnx_flow(m, args.disable_fuse_bn)
|
||||
|
||||
onnx.save(m, onnx_out)
|
||||
27
tools/optimizer_scripts/res/first_insert_layer.json
Normal file
27
tools/optimizer_scripts/res/first_insert_layer.json
Normal file
@ -0,0 +1,27 @@
|
||||
{
|
||||
"LAYERNAME" :
|
||||
{
|
||||
"bias_bitwidth" : 16,
|
||||
"LAYERNAME_bias" : [15],
|
||||
"LAYERNAME_weight" : [3,3,3],
|
||||
"conv_coarse_shift" : [-4,-4,-4],
|
||||
"conv_fine_shift" : [0,0,0],
|
||||
"conv_total_shift" : [-4,-4,-4],
|
||||
"cpu_mode" : false,
|
||||
"delta_input_bitwidth" : [0],
|
||||
"delta_output_bitwidth" : 8,
|
||||
"flag_radix_bias_eq_output" : true,
|
||||
"input_scale" : [[1.0,1.0,1.0]],
|
||||
"output_scale" : [1.0, 1.0, 1.0],
|
||||
"psum_bitwidth" : 16,
|
||||
"weight_bitwidth" : 8,
|
||||
"input_datapath_bitwidth" : [8],
|
||||
"input_datapath_radix" : [7],
|
||||
"working_input_bitwidth" : 8,
|
||||
"working_input_radix" : [7],
|
||||
"working_output_bitwidth" : 16,
|
||||
"working_output_radix" : 15,
|
||||
"output_datapath_bitwidth" : 8,
|
||||
"output_datapath_radix" : 7
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
python onnx_tester.py /test_models/mobilenet_v2_224.onnx /test_models/mobilenet_v2_224.cut.onnx
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Those two model results should be different!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit 0
|
||||
BIN
tools/optimizer_scripts/res/vdsr_41_20layer_1.pb
Normal file
BIN
tools/optimizer_scripts/res/vdsr_41_20layer_1.pb
Normal file
Binary file not shown.
147
tools/optimizer_scripts/tensorflow2onnx.py
Normal file
147
tools/optimizer_scripts/tensorflow2onnx.py
Normal file
@ -0,0 +1,147 @@
|
||||
import tensorflow as tf
|
||||
import tf2onnx
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import onnx
|
||||
import onnx.utils
|
||||
from tensorflow.python.platform import gfile
|
||||
from tools import combo, eliminating, replacing, other
|
||||
|
||||
def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
|
||||
"""Convert frozen graph pb file into onnx
|
||||
|
||||
Args:
|
||||
pb_path (str): input pb file path
|
||||
test_mode (bool, optional): test mode. Defaults to False.
|
||||
|
||||
Raises:
|
||||
Exception: invalid input file
|
||||
|
||||
Returns:
|
||||
onnx.ModelProto: converted onnx
|
||||
"""
|
||||
TF2ONNX_VERSION = int(tf2onnx.version.version.replace('.', ''))
|
||||
|
||||
if 160 <= TF2ONNX_VERSION:
|
||||
from tf2onnx import tf_loader
|
||||
else:
|
||||
from tf2onnx import loader as tf_loader
|
||||
|
||||
if pb_path[-3:] == '.pb':
|
||||
model_name = pb_path.split('/')[-1][:-3]
|
||||
|
||||
# always reset tensorflow session at begin
|
||||
tf.reset_default_graph()
|
||||
|
||||
with tf.Session() as sess:
|
||||
with gfile.FastGFile(pb_path, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
sess.graph.as_default()
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
|
||||
if 160 <= int(tf2onnx.version.version.replace('.', '')):
|
||||
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions = tf2onnx.tf_utils.tflist_to_onnx(
|
||||
sess.graph,
|
||||
{})
|
||||
else:
|
||||
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tf2onnx.tfonnx.tflist_to_onnx(
|
||||
sess.graph.get_operations(),
|
||||
{})
|
||||
|
||||
for n in onnx_nodes:
|
||||
if len(n.output) == 0:
|
||||
onnx_nodes.remove(n)
|
||||
|
||||
# find inputs and outputs of graph
|
||||
nodes_inputs = set()
|
||||
nodes_outputs = set()
|
||||
|
||||
for n in onnx_nodes:
|
||||
if n.op_type == 'Placeholder':
|
||||
continue
|
||||
for input in n.input:
|
||||
nodes_inputs.add(input)
|
||||
for output in n.output:
|
||||
nodes_outputs.add(output)
|
||||
|
||||
graph_input_names = set()
|
||||
for input_name in nodes_inputs:
|
||||
if input_name not in nodes_outputs:
|
||||
graph_input_names.add(input_name)
|
||||
|
||||
graph_output_names = set()
|
||||
for n in onnx_nodes:
|
||||
if n.input and n.input[0] not in nodes_outputs:
|
||||
continue
|
||||
if len(n.output) == 0:
|
||||
n.output.append(n.name + ':0')
|
||||
graph_output_names.add(n.output[0])
|
||||
else:
|
||||
output_name = n.output[0]
|
||||
if (output_name not in nodes_inputs) and (0 < len(n.input)):
|
||||
graph_output_names.add(output_name)
|
||||
|
||||
logging.info('Model Inputs: %s', str(list(graph_input_names)))
|
||||
logging.info('Model Outputs: %s', str(list(graph_output_names)))
|
||||
|
||||
graph_def, inputs, outputs = tf_loader.from_graphdef(model_path=pb_path,
|
||||
input_names=list(graph_input_names),
|
||||
output_names=list(graph_output_names))
|
||||
|
||||
with tf.Graph().as_default() as tf_graph:
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
|
||||
if 160 <= TF2ONNX_VERSION:
|
||||
with tf_loader.tf_session(graph=tf_graph):
|
||||
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
|
||||
input_names=inputs,
|
||||
output_names=outputs,
|
||||
opset=11)
|
||||
else:
|
||||
with tf.Session(graph=tf_graph):
|
||||
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
|
||||
input_names=inputs,
|
||||
output_names=outputs,
|
||||
opset=11)
|
||||
|
||||
# Optimize with tf2onnx.optimizer
|
||||
onnx_graph = tf2onnx.optimizer.optimize_graph(onnx_graph)
|
||||
model_proto = onnx_graph.make_model(model_name)
|
||||
|
||||
# Make tf2onnx output compatible with the spec. of other.polish_model
|
||||
replacing.replace_initializer_with_Constant(model_proto.graph)
|
||||
model_proto = other.polish_model(model_proto)
|
||||
|
||||
else:
|
||||
raise Exception('expect .pb file as input, but got "' + str(pb_path) + '"')
|
||||
|
||||
# rename
|
||||
m = model_proto
|
||||
|
||||
m = combo.preprocess(m)
|
||||
m = combo.common_optimization(m)
|
||||
m = combo.tensorflow_optimization(m)
|
||||
m = combo.postprocess(m)
|
||||
|
||||
if not test_mode:
|
||||
g = m.graph
|
||||
eliminating.eliminate_shape_changing_after_input(g)
|
||||
|
||||
m = other.polish_model(m)
|
||||
return m
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Convert tensorflow pb file to onnx file and optimized onnx file. Or just optimize tensorflow onnx file.')
|
||||
parser.add_argument('in_file', help='input file')
|
||||
parser.add_argument('out_file', help='output optimized model file')
|
||||
parser.add_argument('-t', '--test_mode', default=False, help='test mode will not eliminate shape changes after input')
|
||||
|
||||
args = parser.parse_args()
|
||||
logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.INFO)
|
||||
m = tf2onnx_flow(args.in_file, args.test_mode)
|
||||
onnx.save(m, args.out_file)
|
||||
logging.info('Save Optimized ONNX: %s', args.out_file)
|
||||
68
tools/optimizer_scripts/tflite_vs_onnx.py
Normal file
68
tools/optimizer_scripts/tflite_vs_onnx.py
Normal file
@ -0,0 +1,68 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import onnx
|
||||
import onnxruntime
|
||||
|
||||
from tools import helper
|
||||
|
||||
def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
|
||||
# Setup onnx session and get meta data
|
||||
onnx_session = onnxruntime.InferenceSession(onnx_file, None)
|
||||
onnx_outputs = onnx_session.get_outputs()
|
||||
assert len(onnx_outputs) == 1, "The onnx model has more than one output"
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
onnx_graph = onnx_model.graph
|
||||
onnx_inputs = onnx_graph.input
|
||||
assert len(onnx_inputs) == 1, "The onnx model has more than one input"
|
||||
_, onnx_input_shape = helper.find_size_shape_from_value(onnx_inputs[0])
|
||||
# Setup TFLite sessio and get meta data
|
||||
tflite_session = tf.lite.Interpreter(model_path=tflite_file)
|
||||
tflite_session.allocate_tensors()
|
||||
tflite_inputs = tflite_session.get_input_details()
|
||||
tflite_outputs = tflite_session.get_output_details()
|
||||
tflite_input_shape = tflite_inputs[0]['shape']
|
||||
# Compare input shape
|
||||
assert(len(onnx_input_shape) == len(tflite_input_shape)), "TFLite and ONNX shape unmatch."
|
||||
assert(onnx_input_shape == [tflite_input_shape[0], tflite_input_shape[3], tflite_input_shape[1], tflite_input_shape[2]]), "TFLite and ONNX shape unmatch."
|
||||
# Generate random number and run
|
||||
tflite_results = []
|
||||
onnx_results = []
|
||||
for _ in range(total_times):
|
||||
# Generate input
|
||||
tflite_input_data = np.array(np.random.random_sample(tflite_input_shape), dtype=np.float32)
|
||||
onnx_input_data = np.transpose(tflite_input_data, [0, 3, 1, 2])
|
||||
# Run tflite
|
||||
tflite_session.set_tensor(tflite_inputs[0]['index'], tflite_input_data)
|
||||
tflite_session.invoke()
|
||||
tflite_results.append(tflite_session.get_tensor(tflite_outputs[0]['index']))
|
||||
# Run onnx
|
||||
onnx_input_dict = {onnx_inputs[0].name: onnx_input_data}
|
||||
onnx_results.append(onnx_session.run([], onnx_input_dict)[0])
|
||||
|
||||
return tflite_results, onnx_results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Argument parser.
|
||||
parser = argparse.ArgumentParser(description="Compare a TFLite model and an ONNX model to check if they have the same output.")
|
||||
parser.add_argument('tflite_file', help='input tflite file')
|
||||
parser.add_argument('onnx_file', help='input ONNX file')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
results_a, results_b = compare_tflite_and_onnx(args.tflite_file, args.onnx_file, total_times=10)
|
||||
ra_flat = helper.flatten_with_depth(results_a, 0)
|
||||
rb_flat = helper.flatten_with_depth(results_b, 0)
|
||||
shape_a = [item[1] for item in ra_flat]
|
||||
shape_b = [item[1] for item in rb_flat]
|
||||
assert shape_a == shape_b, 'two results data shape doesn\'t match'
|
||||
ra_raw = [item[0] for item in ra_flat]
|
||||
rb_raw = [item[0] for item in rb_flat]
|
||||
|
||||
try:
|
||||
np.testing.assert_almost_equal(ra_raw, rb_raw, 8)
|
||||
print('Two models have the same behaviour.')
|
||||
except Exception as mismatch:
|
||||
print(mismatch)
|
||||
exit(1)
|
||||
0
tools/optimizer_scripts/tools/__init__.py
Normal file
0
tools/optimizer_scripts/tools/__init__.py
Normal file
258
tools/optimizer_scripts/tools/combo.py
Normal file
258
tools/optimizer_scripts/tools/combo.py
Normal file
@ -0,0 +1,258 @@
|
||||
"""Combo functions that are usually called together.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import onnx.utils
|
||||
try:
|
||||
from onnx import optimizer
|
||||
except ImportError:
|
||||
import onnxoptimizer as optimizer
|
||||
|
||||
from . import helper
|
||||
from . import other
|
||||
from . import replacing
|
||||
from . import eliminating
|
||||
from . import fusing
|
||||
from . import constant_folding
|
||||
from . import removing_transpose
|
||||
from . import modhelper
|
||||
from .common_pattern import torch_pattern_match, tf_pattern_match
|
||||
from .helper import logger
|
||||
|
||||
def preprocess(model_proto, disable_fuse_bn=False, duplicate_shared_weights=True):
|
||||
"""The most common used functions before other processing.
|
||||
|
||||
Args:
|
||||
model_proto: the original model input
|
||||
duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True.
|
||||
|
||||
Return:
|
||||
the new model after preprocessing
|
||||
|
||||
It includes:
|
||||
|
||||
- inference shapes
|
||||
- optimize model by ONNX library
|
||||
- give names to the nodes
|
||||
- replace initializer with Constant node
|
||||
- replace -1 batch size with 1
|
||||
- eliminate dropout and identity
|
||||
- eliminate no children inputs
|
||||
- topological sort
|
||||
|
||||
The optimizations provided by ONNX:
|
||||
|
||||
- eliminate_identity
|
||||
- eliminate_nop_dropout
|
||||
- eliminate_nop_transpose
|
||||
- eliminate_nop_pad
|
||||
- eliminate_unused_initializer
|
||||
- eliminate_deadend
|
||||
- fuse_consecutive_squeezes
|
||||
- fuse_consecutive_transposes
|
||||
- fuse_add_bias_into_conv
|
||||
- fuse_transpose_into_gemm
|
||||
- fuse_matmul_add_bias_into_gemm
|
||||
- fuse_bn_into_conv
|
||||
- fuse_pad_into_conv
|
||||
|
||||
"""
|
||||
logger.info("Preprocessing the model...")
|
||||
helper.setup_current_opset_version(model_proto)
|
||||
eliminating.eliminate_empty_value_infos(model_proto.graph)
|
||||
other.add_name_to_node(model_proto.graph)
|
||||
other.rename_all_node_name(model_proto.graph)
|
||||
replacing.replace_initializer_with_Constant(model_proto.graph)
|
||||
other.topological_sort(model_proto.graph)
|
||||
m = other.polish_model(model_proto)
|
||||
passes = ['extract_constant_to_initializer',
|
||||
'eliminate_nop_dropout',
|
||||
'eliminate_deadend',
|
||||
'fuse_matmul_add_bias_into_gemm',
|
||||
'fuse_pad_into_conv']
|
||||
if not disable_fuse_bn:
|
||||
passes.append('fuse_bn_into_conv')
|
||||
m = optimizer.optimize(m, passes)
|
||||
g = m.graph
|
||||
# Add name again since onnx optimizer higher than 1.7 may remove node names.
|
||||
other.add_name_to_node(g)
|
||||
if duplicate_shared_weights:
|
||||
replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=True)
|
||||
other.duplicate_param_shared_constant(g)
|
||||
else:
|
||||
replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=False)
|
||||
other.topological_sort(g)
|
||||
m = other.polish_model(m)
|
||||
g = m.graph
|
||||
eliminating.eliminate_consecutive_Cast(m.graph)
|
||||
eliminating.eliminate_Cast_after_input(m.graph)
|
||||
eliminating.eliminate_nop_pads(g)
|
||||
eliminating.eliminate_nop_cast(g)
|
||||
eliminating.eliminate_Identify_and_Dropout(g)
|
||||
eliminating.eliminate_trivial_maxpool(g)
|
||||
eliminating.eliminate_no_children_input(g)
|
||||
other.format_value_info_shape(g)
|
||||
other.topological_sort(g)
|
||||
m = other.inference_shapes(m)
|
||||
g = m.graph
|
||||
replacing.replace_split_with_slices(g)
|
||||
other.topological_sort(g)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
def common_optimization(m):
|
||||
"""Common optimizations can be used in most cases.
|
||||
|
||||
:param m: the original model input\\
|
||||
:return: the new model after preprocessing
|
||||
|
||||
It includes:
|
||||
|
||||
- transpose B in Gemm
|
||||
- fuse BN into Gemm
|
||||
- fuse consecutive Gemm
|
||||
- replace AveragePool with GAP
|
||||
- replace Squeeze/Unsqueeze with Reshape
|
||||
- replace Reshape with Flatten
|
||||
"""
|
||||
logger.info("Doing nodes fusion and replacement... ")
|
||||
m = other.polish_model(m)
|
||||
g = m.graph
|
||||
other.transpose_B_in_Gemm(g)
|
||||
fusing.fuse_BN_into_Gemm(g)
|
||||
fusing.fuse_BN_with_Reshape_into_Gemm(g)
|
||||
fusing.fuse_Gemm_into_Gemm(g)
|
||||
fusing.fuse_consecutive_reducemean(g)
|
||||
fusing.fuse_slice_nodes_into_conv(g)
|
||||
fusing.fuse_relu_min_into_clip(g)
|
||||
other.duplicate_shared_Flatten(g)
|
||||
replacing.replace_average_pool_with_GAP(g)
|
||||
|
||||
m = other.polish_model(m)
|
||||
g = m.graph
|
||||
|
||||
replacing.replace_Squeeze_with_Reshape(g)
|
||||
replacing.replace_Unsqueeze_with_Reshape(g)
|
||||
replacing.replace_Reshape_with_Flatten(g)
|
||||
replacing.replace_ReduceMean_with_GlobalAveragePool(g)
|
||||
replacing.replace_Sum_with_Adds(g)
|
||||
replacing.replace_constant_input_concat_with_pad(g)
|
||||
other.topological_sort(g)
|
||||
return m
|
||||
|
||||
|
||||
def pytorch_constant_folding(m):
|
||||
"""Constant folding needed by Pytorch exported models. It should be done
|
||||
before using onnx optimizers since the dynamic shape structure may affect
|
||||
the optimizations.
|
||||
|
||||
:param m: the original model input\\
|
||||
:return: the new model after preprocessing
|
||||
"""
|
||||
logger.info("Working on constant folding.")
|
||||
replacing.replace_shape_with_constant(m.graph)
|
||||
replacing.replace_ConstantOfShape_with_constant(m.graph)
|
||||
|
||||
# constant_folding
|
||||
m = other.inference_shapes(m)
|
||||
while constant_folding.constant_folding(m.graph):
|
||||
logging.debug("After constant folding jobs.")
|
||||
other.topological_sort(m.graph)
|
||||
while len(m.graph.value_info) != 0:
|
||||
m.graph.value_info.pop()
|
||||
|
||||
m = other.inference_shapes(m)
|
||||
replacing.replace_shape_with_constant(m.graph)
|
||||
other.topological_sort(m.graph)
|
||||
m = torch_pattern_match(m)
|
||||
m = optimizer.optimize(m, ['eliminate_deadend'])
|
||||
return m
|
||||
|
||||
|
||||
def tensorflow_optimization(m):
|
||||
"""Optimizations for tf models can be used in most cases.
|
||||
|
||||
:param m: the original model input\\
|
||||
:return: the new model after preprocessing
|
||||
|
||||
It includes:
|
||||
|
||||
- eliminate shape change after input
|
||||
- eliminate Reshape cast
|
||||
- eliminate Squeeze before Reshape
|
||||
- fuse Transpose into Constant
|
||||
- replace Shape with Constant
|
||||
"""
|
||||
|
||||
fusing.fuse_Transpose_into_Constant(m.graph)
|
||||
fusing.fuse_MatMul_and_Add_into_Gemm(m.graph)
|
||||
other.topological_sort(m.graph)
|
||||
|
||||
m = other.polish_model(m)
|
||||
|
||||
# constant folding
|
||||
replacing.replace_shape_with_constant(m.graph)
|
||||
|
||||
# constant_folding
|
||||
m = other.inference_shapes(m)
|
||||
while constant_folding.constant_folding(m.graph):
|
||||
logging.debug("After constant folding jobs.")
|
||||
other.topological_sort(m.graph)
|
||||
while len(m.graph.value_info) != 0:
|
||||
m.graph.value_info.pop()
|
||||
|
||||
m = other.inference_shapes(m)
|
||||
replacing.replace_shape_with_constant(m.graph)
|
||||
other.topological_sort(m.graph)
|
||||
m = tf_pattern_match(m)
|
||||
m = optimizer.optimize(m, ['eliminate_deadend'])
|
||||
|
||||
eliminating.eliminate_consecutive_reshape(m.graph)
|
||||
eliminating.eliminate_Squeeze_before_Reshape(m.graph)
|
||||
other.topological_sort(m.graph)
|
||||
return m
|
||||
|
||||
|
||||
def postprocess(m):
|
||||
"""Inference the shape and prepare for export.
|
||||
|
||||
:param m: the original model input\\
|
||||
:return: the new model after preprocessing
|
||||
"""
|
||||
logger.info("Postprocessing the model...")
|
||||
while len(m.graph.value_info) > 0:
|
||||
m.graph.value_info.pop()
|
||||
m = other.polish_model(m)
|
||||
eliminating.eliminate_single_input_Concat(m.graph)
|
||||
eliminating.eliminate_nop_Maxpool_and_AveragePool(m.graph)
|
||||
eliminating.eliminate_trivial_elementwise_calculation(m.graph)
|
||||
m = other.polish_model(m)
|
||||
|
||||
replacing.replace_depthwise_1x1_with_bn(m.graph)
|
||||
m = other.polish_model(m)
|
||||
|
||||
# removing transpose
|
||||
m = removing_transpose.eliminate_transposes(m)
|
||||
m = other.polish_model(m)
|
||||
removing_transpose.remove_trivial_transpose(m.graph)
|
||||
removing_transpose.fuse_Transpose_into_Gemm_weight(m.graph)
|
||||
|
||||
# fuse some nodes
|
||||
fusing.fuse_mul_and_add_into_bn(m.graph)
|
||||
m = other.polish_model(m)
|
||||
fusing.fuse_mul_and_add_into_gemm(m.graph)
|
||||
m = other.polish_model(m)
|
||||
fusing.fuse_conv_and_add_into_conv(m.graph)
|
||||
m = other.polish_model(m)
|
||||
replacing.replace_mul_to_bn(m.graph)
|
||||
replacing.replace_div_to_bn(m.graph)
|
||||
replacing.replace_add_to_bn(m.graph)
|
||||
replacing.replace_sub_to_bn(m.graph)
|
||||
replacing.replace_sub_with_bn_and_add(m.graph)
|
||||
m = other.polish_model(m)
|
||||
|
||||
other.add_output_to_value_info(m.graph)
|
||||
m = optimizer.optimize(m, ['eliminate_deadend'])
|
||||
m.producer_name = 'kneron_formatter'
|
||||
return m
|
||||
157
tools/optimizer_scripts/tools/common_pattern.py
Normal file
157
tools/optimizer_scripts/tools/common_pattern.py
Normal file
@ -0,0 +1,157 @@
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import onnx.helper
|
||||
import onnx.utils
|
||||
|
||||
from . import modhelper
|
||||
from . import helper
|
||||
from . import other
|
||||
|
||||
def torch_pattern_match(m):
|
||||
# Create a map from optype to the nodes.
|
||||
optype2node = defaultdict(list)
|
||||
for node in m.graph.node:
|
||||
optype2node[node.op_type].append(node)
|
||||
for matmul_node in optype2node['MatMul']:
|
||||
pattern_matmul_mul_add(m.graph, matmul_node)
|
||||
for resize_node in optype2node['Resize']:
|
||||
# torch nn.UpsamplingBilinear2d will be given us 4 input: "X, roi, scales, sizes"
|
||||
if len(resize_node.input) != 4:
|
||||
continue
|
||||
make_UpsamplingBilinear2d_value_info(m.graph, resize_node.name)
|
||||
m = onnx.shape_inference.infer_shapes(m)
|
||||
polish_RESIZE_input_param_node(m.graph, resize_node.name)
|
||||
m = other.polish_model(m)
|
||||
return m
|
||||
|
||||
def tf_pattern_match(m):
|
||||
# Create a map from optype to the nodes.
|
||||
optype2node = defaultdict(list)
|
||||
for node in m.graph.node:
|
||||
optype2node[node.op_type].append(node)
|
||||
for matmul_node in optype2node['MatMul']:
|
||||
pattern_matmul_mul_add(m.graph, matmul_node)
|
||||
for resize_node in optype2node['Resize']:
|
||||
# In tensorflow2onnx, ReizeXXX will be given us 4 input: "X, roi, scales, sizes"
|
||||
# and node output name will be given the "node name + :0"
|
||||
if len(resize_node.input) != 4:
|
||||
continue
|
||||
make_UpsamplingBilinear2d_value_info(m.graph, resize_node.name)
|
||||
m = onnx.shape_inference.infer_shapes(m)
|
||||
polish_RESIZE_input_param_node(m.graph, resize_node.name)
|
||||
m = other.polish_model(m)
|
||||
return m
|
||||
|
||||
def pattern_matmul_mul_add(g, matmul_node):
|
||||
# Check node match - Mul node
|
||||
next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0])
|
||||
if len(next_nodes) != 1:
|
||||
return
|
||||
if next_nodes[0].op_type != 'Mul':
|
||||
return
|
||||
mul_node = next_nodes[0]
|
||||
# Check node match - Add node
|
||||
next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0])
|
||||
if len(next_nodes) != 1:
|
||||
return
|
||||
if next_nodes[0].op_type != 'Add':
|
||||
return
|
||||
add_node = next_nodes[0]
|
||||
# Check Mul weight
|
||||
mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1])
|
||||
if mul_weight_node.op_type != 'Constant':
|
||||
return
|
||||
weight_size, mul_weight = helper.constant_to_list(mul_weight_node)
|
||||
for i in mul_weight:
|
||||
if i != 1:
|
||||
return
|
||||
channel = weight_size[0]
|
||||
# Check Add weight
|
||||
add_weight_node = helper.find_node_by_output_name(g, add_node.input[1])
|
||||
if add_weight_node.op_type != 'Constant':
|
||||
return
|
||||
# Check MatMul weight to see if it need weight broadcast
|
||||
matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1])
|
||||
matmul_weight = helper.constant_to_numpy(matmul_weight_node)
|
||||
if matmul_weight.shape[1] == 1:
|
||||
# Weight broadcast
|
||||
new_matmul_weight = np.tile(matmul_weight, channel)
|
||||
new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight)
|
||||
g.node.remove(matmul_weight_node)
|
||||
g.node.extend([new_matmul_weight_node])
|
||||
value = helper.find_value_by_name(g, matmul_weight_node.output[0])
|
||||
if value is not None:
|
||||
g.value_info.remove(value)
|
||||
# Remove Mul node
|
||||
g.node.remove(mul_weight_node)
|
||||
value = helper.find_value_by_name(g, mul_weight_node.output[0])
|
||||
if value is not None:
|
||||
g.value_info.remove(value)
|
||||
g.node.remove(mul_node)
|
||||
value = helper.find_value_by_name(g, mul_node.output[0])
|
||||
if value is not None:
|
||||
g.value_info.remove(value)
|
||||
# Fuse Matmul and Add
|
||||
gemm_node = onnx.helper.make_node(
|
||||
'Gemm',
|
||||
[matmul_node.input[0], matmul_node.input[1], add_node.input[1]],
|
||||
[add_node.output[0]],
|
||||
name = matmul_node.name,
|
||||
alpha = 1.0,
|
||||
beta = 1.0,
|
||||
transA = 0,
|
||||
transB = 0
|
||||
)
|
||||
g.node.extend([gemm_node])
|
||||
# Clean up
|
||||
g.node.remove(matmul_node)
|
||||
g.node.remove(add_node)
|
||||
value = helper.find_value_by_name(g, matmul_node.output[0])
|
||||
if value is not None:
|
||||
g.value_info.remove(value)
|
||||
other.topological_sort(g)
|
||||
|
||||
def make_UpsamplingBilinear2d_value_info(g, resize_node_name):
|
||||
resize_node = helper.find_node_by_node_name(g, resize_node_name)
|
||||
|
||||
shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3])
|
||||
shape_data = helper.constant_to_numpy(shape_data_node).astype(int)
|
||||
l_shape_data = list(shape_data)
|
||||
if l_shape_data[0] == 0:
|
||||
l_shape_data[0] = 1 + l_shape_data[0]
|
||||
shape_data = np.array(l_shape_data)
|
||||
|
||||
new_output_value_info = onnx.helper.make_tensor_value_info(
|
||||
resize_node.output[0],
|
||||
onnx.helper.TensorProto.FLOAT,
|
||||
shape_data.tolist()
|
||||
)
|
||||
|
||||
g.value_info.extend([new_output_value_info])
|
||||
|
||||
def polish_RESIZE_input_param_node(g, resize_node_name):
|
||||
resize_node = helper.find_node_by_node_name(g, resize_node_name)
|
||||
|
||||
shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3])
|
||||
shape_data = helper.constant_to_numpy(shape_data_node).astype(int)
|
||||
|
||||
# handle 0 batch size which is invalid
|
||||
if shape_data[0] == 0:
|
||||
shape_data[0] = 1
|
||||
|
||||
pre_node_output_value_info = helper.find_value_by_name(g, resize_node.input[0])
|
||||
ori_shape = np.array([pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value,
|
||||
pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value,
|
||||
pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value,
|
||||
pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value])
|
||||
|
||||
resize_node.input.remove(resize_node.input[3])
|
||||
|
||||
|
||||
resize_scales = np.array(shape_data/ori_shape).astype(float)
|
||||
resize_scale_node = helper.list_to_constant('resize_scales_node_' + resize_node.name, resize_scales.shape, resize_scales, data_type=onnx.helper.TensorProto.FLOAT)
|
||||
|
||||
resize_node.input[2] = resize_scale_node.name
|
||||
g.node.extend([resize_scale_node])
|
||||
|
||||
other.topological_sort(g)
|
||||
995
tools/optimizer_scripts/tools/constant_folding.py
Normal file
995
tools/optimizer_scripts/tools/constant_folding.py
Normal file
@ -0,0 +1,995 @@
|
||||
import onnx.utils
|
||||
import onnx
|
||||
import numpy as np
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from . import helper
|
||||
from .general_graph import Graph, Node
|
||||
from .other import topological_sort
|
||||
from .replacing import replace_shape_with_constant
|
||||
from .helper import logger
|
||||
|
||||
def are_all_inputs_Constant_with_one_child(g, node):
|
||||
for input_name in node.input:
|
||||
input_node = helper.find_node_by_output_name(g, input_name)
|
||||
if input_node is None or input_node.op_type != 'Constant':
|
||||
return False
|
||||
relative_outputs = helper.find_nodes_by_input_name(g, input_name)
|
||||
if len(relative_outputs) > 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def constant_folding(g):
|
||||
""" Do constant folding until nothing more can be done.
|
||||
|
||||
:param g: The onnx GraphProto\\
|
||||
:return: If any node is folded, return True. Otherwise, return False.
|
||||
"""
|
||||
keep_folding = True # Keep the while loop
|
||||
folded = False # Return value
|
||||
try:
|
||||
# Before constant folding, duplicate the constant nodes.
|
||||
duplicate_constant_node(g)
|
||||
while keep_folding:
|
||||
keep_folding = False
|
||||
for node in g.node:
|
||||
# Check if the node is foldable
|
||||
if node.op_type not in constant_folding_nodes.keys():
|
||||
continue
|
||||
# Check if the parents of the node are all single follower constant node.
|
||||
if not are_all_inputs_Constant_with_one_child(g, node):
|
||||
continue
|
||||
# Constant folding for the specific node
|
||||
if constant_folding_nodes[node.op_type](g, node):
|
||||
logging.debug("Constant nodes and %s %s are folded.",
|
||||
node.op_type, node.name)
|
||||
folded = True
|
||||
keep_folding = True
|
||||
else:
|
||||
logging.debug(
|
||||
"Constant nodes and %s %s are skipped.", node.op_type, node.name)
|
||||
except Exception as e:
|
||||
logger.error("An exception is raised while constant folding.")
|
||||
logger.error(traceback.format_exc())
|
||||
return folded
|
||||
|
||||
|
||||
|
||||
def duplicate_constant_node(g):
|
||||
""" Duplicate the constant node if its following nodes contain constant folding
|
||||
nodes. Create and link the new constant nodes to the constant folding nodes.
|
||||
"""
|
||||
for node in g.node:
|
||||
# Find a valid constant node
|
||||
if node.op_type != 'Constant':
|
||||
continue
|
||||
output_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
if output_val_info is None:
|
||||
print("Cannot inference the shape of Const node output: " +
|
||||
node.output[0])
|
||||
exit(1)
|
||||
data_shape = helper.get_shape_from_value_info(output_val_info)
|
||||
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
|
||||
# For constant that has only one following node, no need to duplicate
|
||||
if len(output_nodes) < 2:
|
||||
continue
|
||||
|
||||
# Check if its following nodes are foldable
|
||||
foldable_output_nodes = list(filter(lambda n: n.op_type in
|
||||
constant_folding_nodes.keys(), output_nodes))
|
||||
if not foldable_output_nodes:
|
||||
continue
|
||||
|
||||
# Duplicate the node needed by foldable nodes
|
||||
for i in range(len(foldable_output_nodes)):
|
||||
logging.debug("Found constant %s and %s %s are availble for folding. Duplicate constant.",
|
||||
node.name, foldable_output_nodes[i].op_type, foldable_output_nodes[i].name)
|
||||
output_name = node.output[0] + '_dup_' + str(i)
|
||||
new_constant_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[output_name],
|
||||
name=output_name,
|
||||
value=node.attribute[0].t
|
||||
)
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
output_name,
|
||||
node.attribute[0].t.data_type,
|
||||
data_shape
|
||||
)
|
||||
input_ind = list(foldable_output_nodes[i].input).index(
|
||||
node.output[0])
|
||||
foldable_output_nodes[i].input[input_ind] = output_name
|
||||
|
||||
g.node.extend([new_constant_node])
|
||||
g.value_info.extend([new_val_info])
|
||||
|
||||
# If all following nodes are foldable node, delete the original node.
|
||||
if len(foldable_output_nodes) == len(output_nodes):
|
||||
g.node.remove(node)
|
||||
g.value_info.remove(output_val_info)
|
||||
|
||||
topological_sort(g)
|
||||
|
||||
return
|
||||
|
||||
def slice_constant_folding(g, node):
|
||||
op_version = helper.get_current_opset_version()
|
||||
# only support opset 9 & 11
|
||||
if op_version == 11:
|
||||
return slice_constant_folding_Opset_11(g, node)
|
||||
elif op_version == 9:
|
||||
return slice_constant_folding_Opset_9(g, node)
|
||||
|
||||
def slice_constant_folding_Opset_11(g, node):
|
||||
""" Fold constant and slice nodes to a single constant node.
|
||||
"""
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_shape, data_list = helper.constant_to_list(pre_node)
|
||||
|
||||
starts_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
_, starts = helper.constant_to_list(starts_node)
|
||||
|
||||
ends_node = helper.find_node_by_output_name(g, node.input[2])
|
||||
_, ends = helper.constant_to_list(ends_node)
|
||||
|
||||
|
||||
axes_node = None if len(node.input) <= 3 else helper.find_node_by_output_name(g, node.input[3])
|
||||
if not axes_node:
|
||||
axes = list(range(len(helper.get_shape(data_list))))
|
||||
else:
|
||||
_, axes = helper.constant_to_list(axes_node)
|
||||
|
||||
steps_node = None if len(node.input) <= 4 else helper.find_node_by_output_name(g, node.input[4])
|
||||
if not steps_node:
|
||||
steps = [1]*len(helper.get_shape(data_list))
|
||||
else:
|
||||
_, steps = helper.constant_to_list(steps_node)
|
||||
|
||||
|
||||
data_list = list(map(int, data_list))
|
||||
starts = list(map(int, starts))
|
||||
ends = list(map(int, ends))
|
||||
axes = list(map(int, axes))
|
||||
steps = list(map(int, steps))
|
||||
|
||||
data_list = np.reshape(data_list, pre_shape)
|
||||
|
||||
new_data = None
|
||||
for idx, _ in enumerate(axes):
|
||||
new_data = np.apply_along_axis( lambda x: x[starts[idx] : ends[idx] : steps[idx]], idx, data_list )
|
||||
|
||||
new_node = helper.list_to_constant(node.output[0], helper.get_shape(
|
||||
new_data), helper.flatten_to_list(new_data))
|
||||
g.node.extend([new_node])
|
||||
value_info = helper.find_value_by_name(g, pre_node.output[0])
|
||||
if value_info is not None:
|
||||
g.value_info.remove(value_info)
|
||||
g.node.remove(node)
|
||||
g.node.remove(pre_node)
|
||||
|
||||
return True
|
||||
|
||||
def slice_constant_folding_Opset_9(g, node):
|
||||
""" Fold constant and slice nodes to a single constant node.
|
||||
"""
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_shape, data_list = helper.constant_to_list(pre_node)
|
||||
|
||||
data_list = np.reshape(data_list, pre_shape)
|
||||
axes = helper.get_attribute_by_name(node, 'axes')
|
||||
ends = list(helper.get_attribute_by_name(node, 'ends').ints)
|
||||
starts = list(helper.get_attribute_by_name(node, 'starts').ints)
|
||||
|
||||
if not axes:
|
||||
axes = list(range(len(helper.get_shape(data_list))))
|
||||
else:
|
||||
axes = list(axes.ints)
|
||||
|
||||
new_data = helper.slice_data(data_list, starts, ends, axes)
|
||||
new_node = helper.list_to_constant(node.output[0], helper.get_shape(
|
||||
new_data), helper.flatten_to_list(new_data))
|
||||
g.node.extend([new_node])
|
||||
value_info = helper.find_value_by_name(g, pre_node.output[0])
|
||||
if value_info is not None:
|
||||
g.value_info.remove(value_info)
|
||||
g.node.remove(node)
|
||||
g.node.remove(pre_node)
|
||||
|
||||
return True
|
||||
|
||||
def cast_constant_folding(g, node):
|
||||
""" Fold constant and cast node to a single constant node.
|
||||
"""
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
shape, data = helper.constant_to_list(pre_node)
|
||||
data_type = node.attribute[0].i
|
||||
if data_type in (6, 7):
|
||||
data = list(map(int, data))
|
||||
elif data_type == onnx.helper.TensorProto.FLOAT:
|
||||
data = list(map(float, data))
|
||||
else:
|
||||
raise RuntimeError('data type not supported')
|
||||
|
||||
if shape == 1:
|
||||
tensor = onnx.helper.make_tensor(
|
||||
name=pre_node.attribute[0].name,
|
||||
data_type=data_type,
|
||||
dims=[],
|
||||
vals=data
|
||||
)
|
||||
else:
|
||||
tensor = onnx.helper.make_tensor(
|
||||
name=pre_node.attribute[0].name,
|
||||
data_type=data_type,
|
||||
dims=shape,
|
||||
vals=helper.flatten_to_list(data)
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=tensor
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
|
||||
value_info = helper.find_value_by_name(g, pre_node.output[0])
|
||||
if value_info is not None:
|
||||
g.value_info.remove(value_info)
|
||||
value_info = helper.find_value_by_name(g, node.output[0])
|
||||
if value_info is not None:
|
||||
g.value_info.remove(value_info)
|
||||
g.node.remove(pre_node)
|
||||
g.node.remove(node)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def reduceprod_constant_folding(g, node):
|
||||
""" Fold constant and reduceprod nodes to a single constant node.
|
||||
"""
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
shape, data_set = helper.constant_to_list(pre_node)
|
||||
tensor = pre_node.attribute[0].t
|
||||
|
||||
data_set = np.reshape(data_set, shape)
|
||||
for att in node.attribute:
|
||||
if att.name == 'axes':
|
||||
axes = list(att.ints)
|
||||
else:
|
||||
keepdims = int(att.i)
|
||||
|
||||
new_data = np.prod(data_set, axis=tuple(axes), keepdims=keepdims == 1)
|
||||
new_shape = helper.get_shape(new_data)
|
||||
new_flat_data = helper.flatten_to_list(new_data)
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0],
|
||||
data_type=tensor.data_type,
|
||||
dims=new_shape,
|
||||
vals=new_flat_data
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
)
|
||||
|
||||
g.node.extend([new_node])
|
||||
value_info = None
|
||||
for item in g.value_info:
|
||||
if item.name == pre_node.output[0]:
|
||||
value_info = item
|
||||
if value_info is not None:
|
||||
g.value_info.remove(value_info)
|
||||
g.node.remove(pre_node)
|
||||
g.node.remove(node)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def reshape_constant_input_folding(g, node):
|
||||
""" Fold constant and reshape nodes to a single constant node.
|
||||
"""
|
||||
pre_data_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_shape_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
|
||||
data = helper.constant_to_numpy(pre_data_node)
|
||||
_, shape = helper.constant_to_list(pre_shape_node)
|
||||
new_data = np.reshape(data, shape)
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0],
|
||||
data_type=pre_data_node.attribute[0].t.data_type,
|
||||
dims=new_data.shape,
|
||||
vals=helper.flatten_to_list(new_data)
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
|
||||
data_val_info = helper.find_value_by_name(g, pre_data_node.output[0])
|
||||
shape_val_info = helper.find_value_by_name(g, pre_shape_node.output[0])
|
||||
|
||||
g.value_info.remove(data_val_info)
|
||||
g.value_info.remove(shape_val_info)
|
||||
|
||||
g.node.remove(node)
|
||||
g.node.remove(pre_data_node)
|
||||
g.node.remove(pre_shape_node)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def concat_constant_folding(g, node):
|
||||
""" Fold constant and concat nodes to a single constant node.
|
||||
"""
|
||||
node_to_del = []
|
||||
valid_inputs = True
|
||||
for input_name in node.input:
|
||||
input_node = helper.find_node_by_output_name(g, input_name)
|
||||
input_node_output = helper.find_nodes_by_input_name(g, input_name)
|
||||
if len(input_node_output) > 1:
|
||||
valid_inputs = False
|
||||
break
|
||||
if input_node.op_type != 'Constant':
|
||||
valid_inputs = False
|
||||
break
|
||||
|
||||
if not valid_inputs:
|
||||
return False
|
||||
|
||||
input_data = []
|
||||
input_shapes = []
|
||||
for input_name in node.input:
|
||||
input_node = helper.find_node_by_output_name(g, input_name)
|
||||
s, d = helper.constant_to_list(input_node)
|
||||
d = np.reshape(d, s)
|
||||
input_data.append(d)
|
||||
input_shapes.append(s)
|
||||
node_to_del.append(input_node)
|
||||
|
||||
concat_data = np.concatenate(input_data, axis=node.attribute[0].i)
|
||||
node_data_type = input_node.attribute[0].t.data_type
|
||||
if concat_data.dtype in [np.int32, np.int64]:
|
||||
node_data_type = onnx.helper.TensorProto.INT64
|
||||
elif concat_data.dtype in [np.float32, np.float64]:
|
||||
node_data_type = onnx.helper.TensorProto.FLOAT
|
||||
|
||||
new_node = helper.list_to_constant(
|
||||
node.output[0],
|
||||
helper.get_shape(concat_data),
|
||||
helper.flatten_to_list(concat_data),
|
||||
data_type=node_data_type
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
node_to_del.append(node)
|
||||
|
||||
for input_name in node.input:
|
||||
val_info = helper.find_value_by_name(g, input_name)
|
||||
if val_info:
|
||||
g.value_info.remove(val_info)
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def transpose_constant_folding(g, node):
|
||||
"""Fold constant and transpose nodes to a single constant node.
|
||||
"""
|
||||
node_to_del = []
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
shape, data = helper.constant_to_list(pre_node)
|
||||
np_data = np.reshape(data, shape)
|
||||
permutation = list(node.attribute[0].ints)
|
||||
|
||||
new_data = np.transpose(np_data, permutation)
|
||||
new_shape = new_data.shape
|
||||
new_node = helper.list_to_constant(
|
||||
node.output[0],
|
||||
new_shape,
|
||||
new_data.flatten().tolist(),
|
||||
data_type=pre_node.attribute[0].t.data_type
|
||||
)
|
||||
|
||||
g.node.extend([new_node])
|
||||
node_to_del.extend([node, pre_node])
|
||||
|
||||
pre_val_info = helper.find_value_by_name(g, node.input[0])
|
||||
g.value_info.remove(pre_val_info)
|
||||
|
||||
next_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
g.value_info.remove(next_val_info)
|
||||
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
node.output[0],
|
||||
pre_node.attribute[0].t.data_type,
|
||||
new_shape
|
||||
)
|
||||
g.value_info.extend([new_val_info])
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
folded = True
|
||||
|
||||
return folded
|
||||
|
||||
|
||||
def unsqueeze_constant_folding(g, node):
|
||||
"""Fold constant and unsqueeze nodes to a single constant node.
|
||||
"""
|
||||
node_to_del = []
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
shape, data = helper.constant_to_list(pre_node)
|
||||
if type(shape) == int:
|
||||
np_data = data[0]
|
||||
else:
|
||||
np_data = np.reshape(data, shape)
|
||||
axes = list(node.attribute[0].ints)
|
||||
axes.sort()
|
||||
|
||||
for dim in axes:
|
||||
np_data = np.expand_dims(np_data, axis=dim)
|
||||
new_shape = np_data.shape
|
||||
new_node = helper.list_to_constant(
|
||||
node.output[0],
|
||||
new_shape,
|
||||
np_data.flatten().tolist(),
|
||||
data_type=pre_node.attribute[0].t.data_type
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
node_to_del.extend([node, pre_node])
|
||||
|
||||
pre_val_info = helper.find_value_by_name(g, node.input[0])
|
||||
next_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
if pre_val_info is not None:
|
||||
g.value_info.remove(pre_val_info)
|
||||
else:
|
||||
print(node.name)
|
||||
if next_val_info is not None:
|
||||
g.value_info.remove(next_val_info)
|
||||
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
node.output[0],
|
||||
pre_node.attribute[0].t.data_type,
|
||||
new_shape
|
||||
)
|
||||
g.value_info.extend([new_val_info])
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def gather_constant_folding(g, node):
|
||||
"""Fold constant and gather nodes to a single constant node.
|
||||
"""
|
||||
node_to_del = []
|
||||
|
||||
pre_data_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_indices_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
|
||||
shape, data = helper.constant_to_list(pre_data_node)
|
||||
indice_shape, indices = helper.constant_to_list(pre_indices_node)
|
||||
if type(indice_shape) == int:
|
||||
indices = indices[0]
|
||||
|
||||
np_data = np.reshape(data, shape)
|
||||
if len(node.attribute) < 1:
|
||||
axis = 0
|
||||
else:
|
||||
axis = node.attribute[0].i
|
||||
|
||||
new_data = np.take(np_data, indices, axis=axis)
|
||||
new_shape = new_data.shape
|
||||
new_node = helper.list_to_constant(
|
||||
node.output[0],
|
||||
new_shape,
|
||||
new_data.flatten().tolist(),
|
||||
data_type=pre_data_node.attribute[0].t.data_type
|
||||
)
|
||||
|
||||
node_to_del.extend([node, pre_data_node, pre_indices_node])
|
||||
g.node.extend([new_node])
|
||||
|
||||
val_info_1 = helper.find_value_by_name(g, node.input[0])
|
||||
val_info_2 = helper.find_value_by_name(g, node.input[1])
|
||||
val_info_3 = helper.find_value_by_name(g, node.output[0])
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
new_node.output[0],
|
||||
pre_data_node.attribute[0].t.data_type,
|
||||
new_shape
|
||||
)
|
||||
|
||||
if val_info_1 is not None:
|
||||
g.value_info.remove(val_info_1)
|
||||
if val_info_2 is not None:
|
||||
g.value_info.remove(val_info_2)
|
||||
if val_info_3 is not None:
|
||||
g.value_info.remove(val_info_3)
|
||||
g.value_info.extend([new_val_info])
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def add_constant_folding(g, node):
|
||||
"""Fold constant and add nodes to a single constant node.
|
||||
"""
|
||||
node_to_del = []
|
||||
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
|
||||
if not pre_node_1 or not pre_node_2:
|
||||
return False
|
||||
|
||||
shape1, data1 = helper.constant_to_list(pre_node_1)
|
||||
shape2, data2 = helper.constant_to_list(pre_node_2)
|
||||
np_data1 = np.reshape(data1, shape1)
|
||||
np_data2 = np.reshape(data2, shape2)
|
||||
try:
|
||||
new_data = np.add(np_data1, np_data2)
|
||||
except:
|
||||
raise RuntimeError('can\'t broadcast and add two data sets')
|
||||
|
||||
new_node = helper.list_to_constant(
|
||||
node.output[0],
|
||||
new_data.shape,
|
||||
new_data.flatten().tolist(),
|
||||
data_type=pre_node_1.attribute[0].t.data_type
|
||||
)
|
||||
|
||||
g.node.extend([new_node])
|
||||
node_to_del.extend([node, pre_node_1, pre_node_2])
|
||||
g.value_info.remove(helper.find_value_by_name(g, pre_node_1.output[0]))
|
||||
g.value_info.remove(helper.find_value_by_name(g, pre_node_2.output[0]))
|
||||
folded = True
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return folded
|
||||
|
||||
|
||||
def sqrt_constant_folding(g, node):
|
||||
""" Fold constant and sqrt nodes to a single node.
|
||||
"""
|
||||
node_to_del = []
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
shape, data = helper.constant_to_list(pre_node)
|
||||
np_data = np.sqrt(np.reshape(data, shape))
|
||||
output_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
input_val_info = helper.find_value_by_name(g, node.input[0])
|
||||
data_type = output_val_info.type.tensor_type.elem_type
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0]+'_data',
|
||||
data_type=data_type,
|
||||
dims=shape,
|
||||
vals=np_data.flatten().tolist()
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
)
|
||||
|
||||
g.value_info.remove(input_val_info)
|
||||
node_to_del.extend([pre_node, node])
|
||||
g.node.extend([new_node])
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def reciprocal_constant_folding(g, node):
|
||||
""" Fold constant and reciprocal nodes to a single constant node.
|
||||
"""
|
||||
node_to_del = []
|
||||
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
shape, data = helper.constant_to_list(pre_node)
|
||||
data = list(map(lambda x: x if abs(x) > 1.e-8 else 1.e-8, data))
|
||||
np_data = np.reshape(data, shape)
|
||||
np_data = np.reciprocal(np_data)
|
||||
|
||||
input_val_info = helper.find_value_by_name(g, node.input[0])
|
||||
output_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
data_type = output_val_info.type.tensor_type.elem_type
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0]+'_data',
|
||||
data_type=data_type,
|
||||
dims=shape,
|
||||
vals=np_data.flatten().tolist()
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
)
|
||||
|
||||
node_to_del.extend([node, pre_node])
|
||||
g.node.extend([new_node])
|
||||
|
||||
g.value_info.remove(input_val_info)
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def mul_constant_folding(g, node):
|
||||
""" Fold constant and mul nodes to a single constant node.
|
||||
"""
|
||||
node_to_del = []
|
||||
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
|
||||
|
||||
pre_value_info1 = helper.find_value_by_name(g, node.input[0])
|
||||
pre_value_info2 = helper.find_value_by_name(g, node.input[1])
|
||||
if pre_value_info1 is None or pre_value_info2 is None:
|
||||
return False
|
||||
|
||||
shape1, data1 = helper.constant_to_list(pre_node_1)
|
||||
shape2, data2 = helper.constant_to_list(pre_node_2)
|
||||
np_data1 = np.reshape(data1, shape1)
|
||||
np_data2 = np.reshape(data2, shape2)
|
||||
|
||||
try:
|
||||
new_data = np.multiply(np_data1, np_data2)
|
||||
except:
|
||||
raise RuntimeError('can not broadcast and multiply two data sets')
|
||||
|
||||
# Special shape for single element.
|
||||
if shape1 == 1 and shape2 == 1:
|
||||
new_shape = []
|
||||
else:
|
||||
new_shape = new_data.shape
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0]+'_data',
|
||||
data_type=pre_node_1.attribute[0].t.data_type,
|
||||
dims=new_shape,
|
||||
vals=new_data.flatten().tolist()
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
)
|
||||
|
||||
node_to_del.extend([node, pre_node_1, pre_node_2])
|
||||
g.node.extend([new_node])
|
||||
|
||||
g.value_info.remove(pre_value_info1)
|
||||
g.value_info.remove(pre_value_info2)
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def div_constant_folding(g, node):
|
||||
""" Fold constant and mul nodes to a single constant node.
|
||||
"""
|
||||
node_to_del = []
|
||||
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
|
||||
|
||||
pre_value_info1 = helper.find_value_by_name(g, node.input[0])
|
||||
pre_value_info2 = helper.find_value_by_name(g, node.input[1])
|
||||
if pre_value_info1 is None or pre_value_info2 is None:
|
||||
return False
|
||||
|
||||
shape1, data1 = helper.constant_to_list(pre_node_1)
|
||||
shape2, data2 = helper.constant_to_list(pre_node_2)
|
||||
np_data1 = np.reshape(data1, shape1)
|
||||
np_data2 = np.reshape(data2, shape2)
|
||||
|
||||
try:
|
||||
new_data = np.divide(np_data1, np_data2)
|
||||
except:
|
||||
raise RuntimeError('can not broadcast and multiply two data sets')
|
||||
|
||||
# Special shape for single element.
|
||||
if shape1 == 1 and shape2 == 1:
|
||||
new_shape = []
|
||||
else:
|
||||
new_shape = new_data.shape
|
||||
|
||||
# Check data type if it is int
|
||||
if pre_node_1.attribute[0].t.data_type == 7:
|
||||
new_data = new_data.astype('int64')
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0]+'_data',
|
||||
data_type=pre_node_1.attribute[0].t.data_type,
|
||||
dims=new_shape,
|
||||
vals=new_data.flatten().tolist()
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
)
|
||||
|
||||
node_to_del.extend([node, pre_node_1, pre_node_2])
|
||||
g.node.extend([new_node])
|
||||
|
||||
g.value_info.remove(pre_value_info1)
|
||||
g.value_info.remove(pre_value_info2)
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def sub_constant_folding(g, node):
|
||||
""" Fold constant and sub nodes to a single node.
|
||||
"""
|
||||
node_to_del = []
|
||||
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
|
||||
pre_val_info_1 = helper.find_value_by_name(g, node.input[0])
|
||||
pre_val_info_2 = helper.find_value_by_name(g, node.input[1])
|
||||
|
||||
shape1, data1 = helper.constant_to_list(pre_node_1)
|
||||
shape2, data2 = helper.constant_to_list(pre_node_2)
|
||||
|
||||
new_data = np.subtract(data1, data2)
|
||||
# Special shape for single element.
|
||||
if shape1 == 1 and shape2 == 1:
|
||||
new_shape = []
|
||||
else:
|
||||
new_shape = new_data.shape
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0]+'_data',
|
||||
data_type=pre_node_1.attribute[0].t.data_type,
|
||||
dims=new_shape,
|
||||
vals=helper.flatten_to_list(new_data)
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
)
|
||||
|
||||
g.node.extend([new_node])
|
||||
node_to_del.extend([node, pre_node_1, pre_node_2])
|
||||
|
||||
g.value_info.remove(pre_val_info_1)
|
||||
g.value_info.remove(pre_val_info_2)
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def neg_constant_folding(g, node):
|
||||
node_to_del = []
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
|
||||
shape, data_list = helper.constant_to_list(pre_node)
|
||||
new_data_list = [-num for num in data_list]
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=pre_node.name+'_neg_tensor',
|
||||
data_type=pre_node.attribute[0].t.data_type,
|
||||
dims=shape,
|
||||
vals=new_data_list
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
)
|
||||
|
||||
g.node.extend([new_node])
|
||||
node_to_del.extend([pre_node, node])
|
||||
g.value_info.remove(helper.find_value_by_name(g, node.input[0]))
|
||||
|
||||
while node_to_del:
|
||||
g.node.remove(node_to_del.pop())
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def floor_constant_folding(g, node):
|
||||
node_to_del = []
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
|
||||
shape, data = helper.constant_to_list(pre_node)
|
||||
new_data = np.floor(data).flatten().tolist()
|
||||
|
||||
if shape == 1:
|
||||
new_shape = []
|
||||
else:
|
||||
new_shape = shape
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0]+'_data',
|
||||
data_type=pre_node.attribute[0].t.data_type,
|
||||
dims=new_shape,
|
||||
vals=helper.flatten_to_list(new_data)
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
)
|
||||
|
||||
g.node.extend([new_node])
|
||||
node_to_del.extend([pre_node, node])
|
||||
old_value = helper.find_value_by_name(g, node.input[0])
|
||||
if old_value is not None:
|
||||
g.value_info.remove(old_value)
|
||||
|
||||
while node_to_del:
|
||||
g.node.remove(node_to_del.pop())
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def bn_constant_folding(g, node):
|
||||
""" Fold constant and mul nodes to a single constant node.
|
||||
"""
|
||||
# Prepare data
|
||||
node_to_del = []
|
||||
input_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
scale_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
bias_node = helper.find_node_by_output_name(g, node.input[2])
|
||||
mean_node = helper.find_node_by_output_name(g, node.input[3])
|
||||
var_node = helper.find_node_by_output_name(g, node.input[4])
|
||||
|
||||
input_value_info = []
|
||||
for i in range(5):
|
||||
input_value_info.append(helper.find_value_by_name(g, node.input[i]))
|
||||
|
||||
if input_value_info[0] is None:
|
||||
return False
|
||||
|
||||
input_data = helper.constant_to_numpy(input_node)
|
||||
scale_data = helper.constant_to_numpy(scale_node)
|
||||
bias_data = helper.constant_to_numpy(bias_node)
|
||||
mean_data = helper.constant_to_numpy(mean_node)
|
||||
var_data = helper.constant_to_numpy(var_node)
|
||||
|
||||
epsilon = helper.get_var_attribute_by_name(node, 'epsilon', 'float')
|
||||
if epsilon is None:
|
||||
epsilon = 0.00001
|
||||
|
||||
# Calculate new node
|
||||
new_data = scale_data * (input_data - mean_data) / np.sqrt(var_data + epsilon) + bias_data
|
||||
|
||||
new_node = helper.numpy_to_constant(node.output[0], new_data)
|
||||
|
||||
# Reconnect the graph
|
||||
node_to_del.extend([node, input_node, scale_node, bias_node, mean_node, var_node])
|
||||
g.node.extend([new_node])
|
||||
|
||||
for value in input_value_info:
|
||||
if value is not None:
|
||||
g.value_info.remove(value)
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def DequantizeLinear_constant_folding(g, node):
|
||||
""" Fold constant and mul nodes to a single constant node.
|
||||
"""
|
||||
# Prepare data
|
||||
node_to_del = []
|
||||
x_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
x_scale_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
if len(node.input) > 2:
|
||||
x_zero_point_node = helper.find_node_by_output_name(g, node.input[2])
|
||||
else:
|
||||
x_zero_point_node = None
|
||||
|
||||
input_value_info = []
|
||||
for i in range(len(node.input)):
|
||||
input_value_info.append(helper.find_value_by_name(g, node.input[i]))
|
||||
|
||||
if input_value_info[0] is None:
|
||||
return False
|
||||
|
||||
x_data = helper.constant_to_numpy(x_node)
|
||||
x_scale_data = helper.constant_to_numpy(x_scale_node)
|
||||
if x_zero_point_node is not None:
|
||||
x_zero_point_data = helper.constant_to_numpy(x_zero_point_node)
|
||||
else:
|
||||
x_zero_point_data = np.array([0.0])
|
||||
|
||||
# Calculate new node
|
||||
new_data = (x_data.astype(np.float32) - x_zero_point_data.astype(np.float32)) * x_scale_data
|
||||
|
||||
new_node = helper.numpy_to_constant(node.output[0], new_data)
|
||||
|
||||
# Reconnect the graph
|
||||
node_to_del.extend([node, x_node, x_scale_node])
|
||||
if x_zero_point_node is not None:
|
||||
node_to_del.append(x_zero_point_node)
|
||||
g.node.extend([new_node])
|
||||
|
||||
for value in input_value_info:
|
||||
if value is not None:
|
||||
g.value_info.remove(value)
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Available constant folding names to function map.
|
||||
constant_folding_nodes = {
|
||||
'Add': add_constant_folding,
|
||||
'BatchNormalization': bn_constant_folding,
|
||||
'Cast': cast_constant_folding,
|
||||
'Concat': concat_constant_folding,
|
||||
'DequantizeLinear': DequantizeLinear_constant_folding,
|
||||
'Div': div_constant_folding,
|
||||
'Floor': floor_constant_folding,
|
||||
'Gather': gather_constant_folding,
|
||||
'Mul': mul_constant_folding,
|
||||
'Reciprocal': reciprocal_constant_folding,
|
||||
'ReduceProd': reduceprod_constant_folding,
|
||||
'Reshape': reshape_constant_input_folding,
|
||||
'Slice': slice_constant_folding,
|
||||
'Sqrt': sqrt_constant_folding,
|
||||
'Transpose': transpose_constant_folding,
|
||||
'Unsqueeze': unsqueeze_constant_folding,
|
||||
'Sub': sub_constant_folding,
|
||||
'Neg': neg_constant_folding
|
||||
}
|
||||
669
tools/optimizer_scripts/tools/eliminating.py
Normal file
669
tools/optimizer_scripts/tools/eliminating.py
Normal file
@ -0,0 +1,669 @@
|
||||
import collections
|
||||
import struct
|
||||
import onnx
|
||||
import numpy as np
|
||||
from . import other
|
||||
from . import helper
|
||||
from . import modhelper
|
||||
from .general_graph import Graph
|
||||
|
||||
def eliminate_Identify_and_Dropout(g):
|
||||
"""
|
||||
Eliminate Identify layers
|
||||
|
||||
:param g: the onnx graph
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Identity' and node.op_type != 'Dropout':
|
||||
continue
|
||||
# If this node is the last node, leave it to `eliminate_useless_last node`
|
||||
if helper.find_output_by_name(g, node.output[0]) is not None:
|
||||
continue
|
||||
# Replace the parents in all the following nodes
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
for following_node in following_nodes:
|
||||
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
||||
# Delete value info
|
||||
value_between = helper.find_value_by_name(g, node.output[0])
|
||||
try:
|
||||
g.value_info.remove(value_between)
|
||||
except:
|
||||
print("No value info to delete while eliminating identity layers.")
|
||||
# Node is waiting for elimination
|
||||
node_to_remove.append(node)
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
# Remove last useless nodes
|
||||
def remove_useless_last_nodes(g):
|
||||
"""Remove useless nodes from the tail of the graph
|
||||
"""
|
||||
USELESS = ["Reshape", "Identity", "Transpose", "Flatten", "Dropout", "Mystery", "Constant", "Squeeze", "Unsqueeze", 'Softmax']
|
||||
graph = Graph(g)
|
||||
todo = collections.deque()
|
||||
for node in graph.output_nodes:
|
||||
if len(node.children) == 0:
|
||||
todo.append(node)
|
||||
node_to_remove = []
|
||||
while todo:
|
||||
# BFS find nodes to remove
|
||||
cur_node = todo.popleft()
|
||||
if cur_node.proto is None:
|
||||
continue
|
||||
if cur_node.proto.op_type not in USELESS:
|
||||
continue
|
||||
# Find the output
|
||||
cur_node_output = helper.find_output_by_name(g, cur_node.proto.output[0])
|
||||
for cur_input in cur_node.parents:
|
||||
cur_input.children.remove(cur_node)
|
||||
if len(cur_input.children) == 0:
|
||||
todo.append(cur_input)
|
||||
if cur_node_output is not None:
|
||||
cur_input_output = helper.find_value_by_name(g, cur_input.proto.output[0])
|
||||
cur_input_output_in_output = helper.find_output_by_name(g, cur_input.proto.output[0])
|
||||
if cur_input_output is not None and cur_input_output_in_output is None:
|
||||
g.output.extend([cur_input_output])
|
||||
node_to_remove.append(cur_node.proto)
|
||||
try:
|
||||
g.value_info.remove(helper.find_value_by_name(g, cur_node.proto.output[0]))
|
||||
except ValueError:
|
||||
pass
|
||||
if cur_node_output is not None:
|
||||
g.output.remove(cur_node_output)
|
||||
cur_node.proto = None
|
||||
cur_node.parents.clear()
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
######################################
|
||||
# TF only optimization passes #
|
||||
######################################
|
||||
|
||||
def eliminate_shape_changing_after_input(g):
|
||||
"""
|
||||
Eliminate the Reshape node after input and reshape the input
|
||||
|
||||
:param g: the onnx graph
|
||||
"""
|
||||
node_to_remove = []
|
||||
REMOVE_LIST = ["Reshape", "Transpose", "Flatten", "Dropout", "Squeeze", "Unsqueeze"]
|
||||
for node in g.node:
|
||||
# Find an input and the shape node
|
||||
if node.op_type not in REMOVE_LIST:
|
||||
continue
|
||||
old_input = helper.find_input_by_name(g, node.input[0])
|
||||
if old_input is None:
|
||||
continue
|
||||
# If the input is used by multiple nodes, skip.
|
||||
counter = 0
|
||||
for tnode in g.node:
|
||||
if old_input.name in tnode.input:
|
||||
counter += 1
|
||||
if counter > 1:
|
||||
continue
|
||||
# Remove Weight if any.
|
||||
output_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
|
||||
if node.op_type == 'Reshape':
|
||||
shape_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
if shape_node.op_type != 'Constant':
|
||||
continue
|
||||
|
||||
# manuelly set the input shape
|
||||
shape_info = helper.find_value_by_name(g, shape_node.output[0])
|
||||
old_size, old_shape = helper.find_size_shape_from_value(shape_info)
|
||||
|
||||
_, new_shape = helper.constant_to_list(shape_node)
|
||||
for i in range(len(new_shape)):
|
||||
if new_shape[i] == -1:
|
||||
dim = int(old_size//np.prod(new_shape)*(-1))
|
||||
new_shape[i] = dim
|
||||
new_input = onnx.helper.make_tensor_value_info(
|
||||
output_val_info.name,
|
||||
output_val_info.type.tensor_type.elem_type,
|
||||
new_shape
|
||||
)
|
||||
|
||||
node_to_remove.append(node)
|
||||
|
||||
shape_outputs = helper.find_nodes_by_input_name(g, shape_node.output[0])
|
||||
if len(shape_outputs) == 1:
|
||||
node_to_remove.append(shape_node)
|
||||
g.value_info.remove(helper.find_value_by_name(g, shape_node.output[0]))
|
||||
|
||||
g.input.remove(old_input)
|
||||
g.input.extend([new_input])
|
||||
g.value_info.remove(output_val_info)
|
||||
elif node.op_type == 'Transpose':
|
||||
permutation = list(node.attribute[0].ints)
|
||||
pre_shape = helper.get_shape_from_value_info(old_input)
|
||||
new_shape = [pre_shape[i] for i in permutation]
|
||||
|
||||
new_input = onnx.helper.make_tensor_value_info(
|
||||
output_val_info.name,
|
||||
output_val_info.type.tensor_type.elem_type,
|
||||
new_shape
|
||||
)
|
||||
|
||||
node_to_remove.append(node)
|
||||
|
||||
g.input.remove(old_input)
|
||||
g.input.extend([new_input])
|
||||
g.value_info.remove(output_val_info)
|
||||
elif node.op_type == 'Flatten':
|
||||
axis = node.attribute[0].int
|
||||
pre_shape = helper.get_shape_from_value_info(old_input)
|
||||
dim_1, dim_2 = 1, 1
|
||||
if axis == 0:
|
||||
dim_1 = 1
|
||||
dim_2 = np.prod(pre_shape)
|
||||
else:
|
||||
dim_1 = np.prod(pre_shape[:axis]).astype(int)
|
||||
dim_2 = np.prod(pre_shape[axis:]).astype(int)
|
||||
new_shape = [dim_1, dim_2]
|
||||
|
||||
new_input = onnx.helper.make_tensor_value_info(
|
||||
output_val_info.name,
|
||||
output_val_info.type.tensor_type.elem_type,
|
||||
new_shape
|
||||
)
|
||||
|
||||
node_to_remove.append(node)
|
||||
|
||||
g.input.remove(old_input)
|
||||
g.input.extend([new_input])
|
||||
g.value_info.remove(output_val_info)
|
||||
elif node.op_type == 'Dropout':
|
||||
g.input.remove(old_input)
|
||||
g.input.extend([output_val_info])
|
||||
g.value_info.remove(output_val_info)
|
||||
|
||||
node_to_remove.append(node)
|
||||
elif node.op_type == 'Squeeze':
|
||||
axis = list(node.attribute[0].ints)
|
||||
pre_shape = helper.get_shape_from_value_info(old_input)
|
||||
for pos in sorted(axis)[::-1]:
|
||||
if pre_shape[pos] != 1:
|
||||
raise RuntimeError('invalid axis for squeeze')
|
||||
else:
|
||||
pre_shape.pop(pos)
|
||||
new_shape = pre_shape
|
||||
|
||||
new_input = onnx.helper.make_tensor_value_info(
|
||||
output_val_info.name,
|
||||
output_val_info.type.tensor_type.elem_type,
|
||||
new_shape
|
||||
)
|
||||
|
||||
node_to_remove.append(node)
|
||||
|
||||
g.input.remove(old_input)
|
||||
g.input.extend([new_input])
|
||||
g.value_info.remove(output_val_info)
|
||||
elif node.op_type == 'Unsqueeze':
|
||||
axis = list(node.attribute[0].ints)
|
||||
pre_shape = helper.get_shape_from_value_info(old_input)
|
||||
new_shape = pre_shape
|
||||
for pos in axis:
|
||||
new_shape.insert(pos, 1)
|
||||
new_input = onnx.helper.make_tensor_value_info(
|
||||
output_val_info.name,
|
||||
output_val_info.type.tensor_type.elem_type,
|
||||
new_shape
|
||||
)
|
||||
node_to_remove.append(node)
|
||||
|
||||
g.input.remove(old_input)
|
||||
g.input.extend([new_input])
|
||||
g.value_info.remove(output_val_info)
|
||||
else:
|
||||
pass
|
||||
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
other.topological_sort(g)
|
||||
|
||||
|
||||
def eliminate_Reshape_Cast(g):
|
||||
"""Eliminate the cast layer for shape of Reshape layer
|
||||
|
||||
:param g: the onnx graph
|
||||
"""
|
||||
#Find all reshape layers
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Reshape':
|
||||
continue
|
||||
prev_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
if prev_node.op_type != 'Cast':
|
||||
continue
|
||||
# Now we find the cast weight pattern. Cast the weight, delete the cast.
|
||||
reshape_node = node
|
||||
cast_node = prev_node
|
||||
weight_node = helper.find_node_by_output_name(g, cast_node.input[0])
|
||||
if weight_node is None:
|
||||
raise RuntimeError("Unexpected None before Cast-Reshape.")
|
||||
weight_node.attribute[0].t.data_type = 7
|
||||
if weight_node.attribute[0].t.raw_data:
|
||||
raw_data = weight_node.attribute[0].t.raw_data
|
||||
int_data = [i[0] for i in struct.iter_unpack('i', raw_data)]
|
||||
raw_data = struct.pack('q' * len(int_data), *int_data)
|
||||
elif len(weight_node.attribute[0].t.int64_data) > 0\
|
||||
or len(weight_node.attribute[0].t.int32_data) > 0:
|
||||
# It's already int. Do nothing
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
# Change Value info
|
||||
origin_weight_out = helper.find_value_by_name(g, weight_node.output[0])
|
||||
weight_node.output.pop()
|
||||
weight_node.output.extend([reshape_node.input[1]])
|
||||
# Delete
|
||||
g.value_info.remove(origin_weight_out)
|
||||
g.node.remove(cast_node)
|
||||
|
||||
def eliminate_Cast_after_input(g):
|
||||
"""Eliminate the cast layer right after the input
|
||||
|
||||
:param g: the onnx graph
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Cast':
|
||||
continue
|
||||
old_input = helper.find_input_by_name(g, node.input[0])
|
||||
if old_input is None:
|
||||
continue
|
||||
next_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
shape = helper.get_shape_from_value_info(next_val_info)
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
next_val_info.name,
|
||||
node.attribute[0].i,
|
||||
shape
|
||||
)
|
||||
# Delete old value_info
|
||||
g.input.remove(old_input)
|
||||
g.value_info.remove(next_val_info)
|
||||
# Append nodes to node_to_remove
|
||||
node_to_remove.append(node)
|
||||
# Add new input
|
||||
g.input.extend([new_val_info])
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
def eliminate_consecutive_Cast(g):
|
||||
"""If two cast is next to each other, remove the first cast
|
||||
|
||||
:param g: the onnx graph
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Cast':
|
||||
continue
|
||||
first_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
if first_node is None or first_node.op_type != 'Cast':
|
||||
continue
|
||||
# Here we have two consecutive Cast Node
|
||||
# Reset the input of the later node
|
||||
node.input[0] = first_node.input[0]
|
||||
# Remove the first node and its output value info
|
||||
node_to_remove.append(first_node)
|
||||
first_output = helper.find_value_by_name(g, first_node.output[0])
|
||||
g.value_info.remove(first_output)
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
def eliminate_Squeeze_before_Reshape(g):
|
||||
"""If Squeeze and Reshape is next to each other, remove the first node
|
||||
|
||||
:param g: the onnx graph
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Reshape':
|
||||
continue
|
||||
first_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
if not first_node:
|
||||
continue
|
||||
if first_node.op_type != 'Squeeze':
|
||||
continue
|
||||
# Here we have two consecutive Cast Node
|
||||
# Reset the input of the later node
|
||||
node.input[0] = first_node.input[0]
|
||||
# Remove the first node and its output value info
|
||||
node_to_remove.append(first_node)
|
||||
first_output = helper.find_value_by_name(g, first_node.output[0])
|
||||
g.value_info.remove(first_output)
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
def eliminate_no_children_input(g):
|
||||
"""Eliminate inputs with no children at all.
|
||||
"""
|
||||
# Create a set of input names
|
||||
input_names = set([i.name for i in g.input])
|
||||
# If a name is used in any node, remove this name from the set.
|
||||
for n in g.node:
|
||||
for i in n.input:
|
||||
input_names.discard(i)
|
||||
# Remove the inputs with the left names.
|
||||
for i in input_names:
|
||||
info = helper.find_input_by_name(g, i)
|
||||
g.input.remove(info)
|
||||
|
||||
def eliminate_consecutive_reshape(g):
|
||||
"""Replace consecutive reshape nodes by a single node.
|
||||
"""
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Reshape':
|
||||
continue
|
||||
pre_data_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_shape_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
if not pre_data_node or not pre_shape_node:
|
||||
continue
|
||||
if pre_shape_node.op_type != 'Constant':
|
||||
continue
|
||||
if pre_data_node.op_type != 'Reshape':
|
||||
continue
|
||||
|
||||
pre_pre_shape_node = helper.find_node_by_output_name(g, pre_data_node.input[1])
|
||||
if pre_pre_shape_node.op_type != 'Constant':
|
||||
continue
|
||||
|
||||
new_reshape_node = onnx.helper.make_node(
|
||||
'Reshape',
|
||||
[pre_data_node.input[0], node.input[1]],
|
||||
[node.output[0]],
|
||||
name = node.output[0]
|
||||
)
|
||||
|
||||
g.node.extend([new_reshape_node])
|
||||
node_to_del.append(node)
|
||||
node_to_del.append(pre_data_node)
|
||||
node_to_del.append(pre_pre_shape_node)
|
||||
|
||||
val_info_to_del1 = helper.find_value_by_name(g, node.input[0])
|
||||
val_info_to_del2 = helper.find_value_by_name(g, pre_data_node.input[1])
|
||||
g.value_info.remove(val_info_to_del1)
|
||||
g.value_info.remove(val_info_to_del2)
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
def eliminate_single_input_Concat(g):
|
||||
"""
|
||||
Eliminate single input Concat layers
|
||||
|
||||
:param g: the onnx graph
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Concat':
|
||||
continue
|
||||
# If this node has more than 1 input, continue.
|
||||
if len(node.input) > 1:
|
||||
continue
|
||||
# If this node is the output node, set its previous node as output nodes.
|
||||
if helper.find_output_by_name(g, node.output[0]) is not None:
|
||||
todel_output = helper.find_output_by_name(g, node.output[0])
|
||||
the_input_value = helper.find_value_by_name(g, node.input[0])
|
||||
g.output.remove(todel_output)
|
||||
g.output.extend([the_input_value])
|
||||
node_to_remove.append(node)
|
||||
continue
|
||||
# Replace the parents in all the following nodes
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
for following_node in following_nodes:
|
||||
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
||||
# Delete value info
|
||||
value_between = helper.find_value_by_name(g, node.output[0])
|
||||
try:
|
||||
g.value_info.remove(value_between)
|
||||
except:
|
||||
print("No value info to delete while eliminating identity layers.")
|
||||
# Node is waiting for elimination
|
||||
node_to_remove.append(node)
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
def eliminate_nop_Maxpool_and_AveragePool(g):
|
||||
"""
|
||||
Eliminate do nothing MaxPool and AveragePool layers.
|
||||
Those layers have valid padding, 1x1 kernel and [1,1] strides.
|
||||
|
||||
:param g: the onnx graph
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'MaxPool' and node.op_type != 'AveragePool':
|
||||
continue
|
||||
# If this node is actually working, continue.
|
||||
kernel = helper.get_list_attribute_by_name(node, "kernel_shape", "int")
|
||||
pads = helper.get_list_attribute_by_name(node, "pads", "int")
|
||||
strides = helper.get_list_attribute_by_name(node, "strides", "int")
|
||||
if kernel != [1, 1] or pads != [0, 0, 0, 0] or strides != [1, 1]:
|
||||
continue
|
||||
# If this node is the output node, set its previous node as output nodes.
|
||||
if helper.find_output_by_name(g, node.output[0]) is not None:
|
||||
todel_output = helper.find_output_by_name(g, node.output[0])
|
||||
the_input_value = helper.find_value_by_name(g, node.input[0])
|
||||
g.output.remove(todel_output)
|
||||
g.output.extend([the_input_value])
|
||||
node_to_remove.append(node)
|
||||
continue
|
||||
# Replace the parents in all the following nodes
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
for following_node in following_nodes:
|
||||
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
||||
# Delete value info
|
||||
value_between = helper.find_value_by_name(g, node.output[0])
|
||||
try:
|
||||
g.value_info.remove(value_between)
|
||||
except:
|
||||
print("No value info to delete while eliminating identity layers.")
|
||||
# Node is waiting for elimination
|
||||
node_to_remove.append(node)
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
|
||||
def eliminate_trivial_maxpool(g):
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'MaxPool':
|
||||
continue
|
||||
pads = None
|
||||
strides = None
|
||||
dilation = None
|
||||
kernel_shape = None
|
||||
for att in node.attribute:
|
||||
if att.name == 'pads':
|
||||
pads = list(att.ints)
|
||||
elif att.name == 'strides':
|
||||
strides = list(att.ints)
|
||||
elif att.name == 'kernel_shape':
|
||||
kernel_shape = list(att.ints)
|
||||
elif att.name == 'dilation':
|
||||
dilation = list(att.ints)
|
||||
else:
|
||||
pass
|
||||
if pads and any([pad != 0 for pad in pads]):
|
||||
continue
|
||||
if strides and any([stride != 1 for stride in strides]):
|
||||
continue
|
||||
if dilation and any([dila != 1 for dila in dilation]):
|
||||
continue
|
||||
if any([dim != 1 for dim in kernel_shape]):
|
||||
continue
|
||||
|
||||
node_to_del.append(node)
|
||||
|
||||
next_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
|
||||
if next_nodes[0] == None:
|
||||
output_value = helper.find_output_by_name(g, node.output[0])
|
||||
if not output_value:
|
||||
continue
|
||||
else:
|
||||
pre_val_info = helper.find_value_by_name(g, node.input[0])
|
||||
g.output.extend([pre_val_info])
|
||||
g.output.remove(output_value)
|
||||
|
||||
for next_node in next_nodes:
|
||||
modhelper.replace_node_input(next_node, node.output[0], node.input[0])
|
||||
|
||||
next_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
g.value_info.remove(next_val_info)
|
||||
|
||||
while node_to_del:
|
||||
g.node.remove(node_to_del.pop())
|
||||
|
||||
other.topological_sort(g)
|
||||
|
||||
def eliminate_empty_value_infos(g):
|
||||
to_remove = []
|
||||
for value_info in g.value_info:
|
||||
if len(value_info.type.tensor_type.shape.dim) == 0:
|
||||
to_remove.append(value_info)
|
||||
for value_info in to_remove:
|
||||
g.value_info.remove(value_info)
|
||||
|
||||
def eliminate_nop_pads(g):
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Pad':
|
||||
continue
|
||||
# Check if the Pad is empty or not
|
||||
pads_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
pads_np = helper.constant_to_numpy(pads_node)
|
||||
all_zero = True
|
||||
for value in pads_np:
|
||||
if value != 0:
|
||||
all_zero = False
|
||||
if not all_zero:
|
||||
continue
|
||||
# Check if it has the constant_value_node
|
||||
constant_value_node = None
|
||||
if len(node.input) > 2:
|
||||
constant_value_node = helper.find_node_by_output_name(g, node.input[2])
|
||||
# If this node is the output node, set its previous node as output nodes.
|
||||
if helper.find_output_by_name(g, node.output[0]) is not None:
|
||||
todel_output = helper.find_output_by_name(g, node.output[0])
|
||||
g.output.remove(todel_output)
|
||||
if helper.find_output_by_name(g, node.input[0]) is None:
|
||||
the_input_value = helper.find_value_by_name(g, node.input[0])
|
||||
if the_input_value is not None:
|
||||
g.output.extend([the_input_value])
|
||||
# Replace the parents in all the following nodes
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
for following_node in following_nodes:
|
||||
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
||||
# Delete value info
|
||||
value_between = helper.find_value_by_name(g, node.output[0])
|
||||
try:
|
||||
g.value_info.remove(value_between)
|
||||
except:
|
||||
helper.logger.info("No value info to delete while eliminating identity layers.")
|
||||
# Node is waiting for elimination
|
||||
node_to_remove.append(node)
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
def eliminate_trivial_elementwise_calculation(g):
|
||||
"""Eliminate Add, Sub, Mul, Sub nodes which do nothing.
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
weight_node = None
|
||||
if node.op_type == 'Add' or node.op_type == 'Sub':
|
||||
# For add and sub, check if the weights are 0s.
|
||||
weight_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
if weight_node is None or weight_node.op_type != 'Constant':
|
||||
continue
|
||||
weight_np = helper.constant_to_numpy(weight_node)
|
||||
if np.any(weight_np):
|
||||
continue
|
||||
elif node.op_type == 'Mul' or node.op_type == 'Div':
|
||||
# For Mul and Div, check if the weights are 1s.
|
||||
weight_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
if weight_node is None or weight_node.op_type != 'Constant':
|
||||
continue
|
||||
weight_np = helper.constant_to_numpy(weight_node)
|
||||
weight_np = weight_np - 1
|
||||
if np.any(weight_np):
|
||||
continue
|
||||
else:
|
||||
# For other nodes, just skip
|
||||
continue
|
||||
# Remove the node
|
||||
node_to_remove.append(node)
|
||||
output_value_info = helper.find_value_by_name(g, node.output[0])
|
||||
if output_value_info is not None:
|
||||
g.value_info.remove(output_value_info)
|
||||
# Replace next node input if any.
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
for following_node in following_nodes:
|
||||
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
||||
todel_output = helper.find_output_by_name(g, node.output[0])
|
||||
if todel_output is not None:
|
||||
g.output.remove(todel_output)
|
||||
previous_output = helper.find_output_by_name(g, node.input[0])
|
||||
if previous_output is None:
|
||||
the_input_value = helper.find_value_by_name(g, node.input[0])
|
||||
g.output.extend([the_input_value])
|
||||
# Delete the constant node if it is not used by other nodes
|
||||
constant_following_nodes = helper.find_following_nodes_by_input_value_name(g, weight_node.output[0])
|
||||
if len(constant_following_nodes) == 1:
|
||||
node_to_remove.append(weight_node)
|
||||
output_value_info = helper.find_value_by_name(g, weight_node.output[0])
|
||||
if output_value_info is not None:
|
||||
g.value_info.remove(output_value_info)
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
def eliminate_nop_cast(g):
|
||||
"""Eliminate do nothing Cast nodes.
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Cast':
|
||||
continue
|
||||
# Get input value_info
|
||||
input_value = helper.find_value_by_name(g, node.input[0])
|
||||
if input_value is None:
|
||||
helper.logger.debug(f"Cannot find the input value_info for Cast node {node.name}. Skip elimination check.")
|
||||
continue
|
||||
# Get output value_info
|
||||
output_value = helper.find_value_by_name(g, node.output[0])
|
||||
if output_value is None:
|
||||
output_value = helper.find_output_by_name(g, node.output[0])
|
||||
if output_value is None:
|
||||
helper.logger.debug(f"Cannot find the output value_info for Cast node {node.name}. Skip elimination check.")
|
||||
continue
|
||||
# Compare the type.
|
||||
if input_value.type.tensor_type.elem_type != output_value.type.tensor_type.elem_type:
|
||||
continue
|
||||
# If this node is the output node, set its previous node as output nodes.
|
||||
if helper.find_output_by_name(g, node.output[0]) is not None:
|
||||
todel_output = helper.find_output_by_name(g, node.output[0])
|
||||
g.output.remove(todel_output)
|
||||
if helper.find_output_by_name(g, node.input[0]) is None:
|
||||
the_input_value = helper.find_value_by_name(g, node.input[0])
|
||||
if the_input_value is not None:
|
||||
g.output.extend([the_input_value])
|
||||
# Replace the parents in all the following nodes
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
for following_node in following_nodes:
|
||||
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
||||
# Delete value info
|
||||
value_between = helper.find_value_by_name(g, node.output[0])
|
||||
if value_between is not None:
|
||||
g.value_info.remove(value_between)
|
||||
# Node is waiting for elimination
|
||||
node_to_remove.append(node)
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
1064
tools/optimizer_scripts/tools/fusing.py
Normal file
1064
tools/optimizer_scripts/tools/fusing.py
Normal file
File diff suppressed because it is too large
Load Diff
83
tools/optimizer_scripts/tools/general_graph.py
Normal file
83
tools/optimizer_scripts/tools/general_graph.py
Normal file
@ -0,0 +1,83 @@
|
||||
from collections import deque
|
||||
|
||||
class Node:
|
||||
"""A Node which maps a node proto. It has pointers to its parents and
|
||||
children.
|
||||
"""
|
||||
def __init__(self, onnx_node):
|
||||
"""Initialize a node. This initialization only set up the mapping to
|
||||
node proto. The pointers should be set up by outside.
|
||||
"""
|
||||
self.name = None
|
||||
self.parents = []
|
||||
self.children = []
|
||||
self.proto = None
|
||||
self.output_value = None
|
||||
if onnx_node is not None:
|
||||
self.name = onnx_node.name
|
||||
self.proto = onnx_node
|
||||
|
||||
class Graph:
|
||||
"""A graph which is constructed from the onnx proto.
|
||||
"""
|
||||
def __init__(self, onnx_graph):
|
||||
"""Construct the graph from onnx.
|
||||
"""
|
||||
self.input_nodes = []
|
||||
self.output_nodes = []
|
||||
self.name2node = {}
|
||||
self.output2node = {}
|
||||
self.proto = onnx_graph
|
||||
# Add input nodes
|
||||
for value in onnx_graph.input:
|
||||
input_node = Node(None)
|
||||
input_node.name = "Input_" + value.name
|
||||
input_node.output_value = value
|
||||
self.name2node[input_node.name] = input_node
|
||||
self.output2node[value.name] = input_node
|
||||
self.input_nodes.append(input_node)
|
||||
output_value_names = [value.name for value in onnx_graph.output]
|
||||
# Add regular nodes
|
||||
for onnx_node in onnx_graph.node:
|
||||
node = Node(onnx_node)
|
||||
self.name2node[node.name] = node
|
||||
self.output2node[onnx_node.output[0]] = node
|
||||
for value_name in onnx_node.input:
|
||||
node.parents.append(self.output2node[value_name])
|
||||
self.output2node[value_name].children.append(node)
|
||||
if onnx_node.output[0] in output_value_names:
|
||||
self.output_nodes.append(node)
|
||||
# Add value infos
|
||||
for value in onnx_graph.value_info:
|
||||
node = self.output2node[value.name]
|
||||
node.output_value = value
|
||||
def get_sorted_node_list(self):
|
||||
"""Return a node list in topological order.
|
||||
"""
|
||||
visited = set()
|
||||
todo = deque()
|
||||
result = []
|
||||
for node in self.input_nodes:
|
||||
todo.append(node)
|
||||
visited.add(node)
|
||||
for onnx_node in self.proto.node:
|
||||
if onnx_node.op_type == "Constant":
|
||||
node = self.name2node[onnx_node.name]
|
||||
todo.append(node)
|
||||
visited.add(node)
|
||||
while todo:
|
||||
node = todo.popleft()
|
||||
result.append(node)
|
||||
for child in node.children:
|
||||
if child in visited:
|
||||
continue
|
||||
ready = True
|
||||
for child_parent in child.parents:
|
||||
if child_parent in visited:
|
||||
continue
|
||||
ready = False
|
||||
break
|
||||
if ready:
|
||||
todo.append(child)
|
||||
visited.add(child)
|
||||
return result
|
||||
621
tools/optimizer_scripts/tools/helper.py
Normal file
621
tools/optimizer_scripts/tools/helper.py
Normal file
@ -0,0 +1,621 @@
|
||||
"""This module contains helper functions that do not modify the graph.
|
||||
"""
|
||||
import onnx
|
||||
import onnx.helper
|
||||
import struct
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
__ONNX_VERSION__ = -1
|
||||
|
||||
logger = logging.getLogger("optimizer_scripts")
|
||||
|
||||
def setup_current_opset_version(m):
|
||||
global __ONNX_VERSION__
|
||||
__ONNX_VERSION__ = m.opset_import[0].version
|
||||
if __ONNX_VERSION__ not in [11]:
|
||||
raise RuntimeError('Only support opset 11, but got ' + str(__ONNX_VERSION__))
|
||||
|
||||
def get_current_opset_version():
|
||||
if __ONNX_VERSION__ == -1:
|
||||
raise RuntimeError('do setup_current_opset_version first please')
|
||||
return __ONNX_VERSION__
|
||||
|
||||
def find_nodes_by_input_name(g, name):
|
||||
nodes = []
|
||||
for node in g.node:
|
||||
if name in node.input:
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
def find_node_by_output_name(g, name):
|
||||
"""
|
||||
Find a node in the graph by its output name
|
||||
|
||||
:param g: the onnx graph\\
|
||||
:param name: the target node output name\\
|
||||
:returns: the node find by name
|
||||
"""
|
||||
for i in g.node:
|
||||
if name in i.output:
|
||||
return i
|
||||
return None
|
||||
|
||||
def find_node_by_node_name(g, name):
|
||||
"""
|
||||
Find a node in the graph by its output name
|
||||
|
||||
:param g: the onnx graph\\
|
||||
:param name: the target node output name\\
|
||||
:returns: the node find by name
|
||||
"""
|
||||
for i in g.node:
|
||||
if i.name == name:
|
||||
return i
|
||||
return None
|
||||
|
||||
def find_following_nodes_by_input_value_name(g, name):
|
||||
""" Find the following nodes of a specific value.
|
||||
|
||||
:param g: the onnx graph. \\
|
||||
:param name: the value name. \\
|
||||
:return: a list of following nodes.
|
||||
"""
|
||||
return find_nodes_by_input_name(g, name)
|
||||
|
||||
def find_value_by_name(g, name):
|
||||
"""
|
||||
Find a value_info in the graph by name
|
||||
|
||||
:param g: the onnx graph\\
|
||||
:param name: the target value_info name\\
|
||||
:returns: the value_info find by name
|
||||
"""
|
||||
for i in g.value_info:
|
||||
if i.name == name:
|
||||
return i
|
||||
return None
|
||||
|
||||
def find_output_by_name(g, name):
|
||||
"""
|
||||
Find a value_info in the graph by name
|
||||
|
||||
:param g: the onnx graph\\
|
||||
:param name: the target value_info name\\
|
||||
:returns: the value_info find by name
|
||||
"""
|
||||
for i in g.output:
|
||||
if i.name == name:
|
||||
return i
|
||||
return None
|
||||
|
||||
def find_input_by_name(g, name):
|
||||
"""
|
||||
Find a input in the graph by name
|
||||
|
||||
:param g: the onnx graph\\
|
||||
:param name: the target input name\\
|
||||
:returns: the input find by name
|
||||
"""
|
||||
for i in g.input:
|
||||
if i.name == name:
|
||||
return i
|
||||
return None
|
||||
|
||||
def list_to_constant(name, shape, data, data_type=None):
|
||||
"""Generate a constant node using the given infomation.
|
||||
|
||||
:name: the node name and the output value name\\
|
||||
:shape: the data shape\\
|
||||
:data: the data itself\\
|
||||
:returns: the generated onnx constant node
|
||||
"""
|
||||
if not data_type:
|
||||
if isinstance(data, int):
|
||||
data_type = onnx.helper.TensorProto.INT64
|
||||
elif isinstance(data, float):
|
||||
data_type = onnx.helper.TensorProto.FLOAT
|
||||
elif len(data) > 0 and isinstance(data[0], int):
|
||||
data_type = onnx.helper.TensorProto.INT64
|
||||
else:
|
||||
data_type = onnx.helper.TensorProto.FLOAT
|
||||
tensor = onnx.helper.make_tensor(
|
||||
name,
|
||||
data_type,
|
||||
shape,
|
||||
data
|
||||
)
|
||||
new_w_node = onnx.helper.make_node(
|
||||
"Constant",
|
||||
[],
|
||||
[name],
|
||||
name = name,
|
||||
value = tensor
|
||||
)
|
||||
return new_w_node
|
||||
|
||||
|
||||
def scaler_to_constant(name, data, data_type=None):
|
||||
"""Generate a constant node using the given infomation.
|
||||
|
||||
:name: the node name and the output value name\\
|
||||
:shape: the data shape\\
|
||||
:data: the data itself\\
|
||||
:returns: the generated onnx constant node
|
||||
"""
|
||||
if not data_type:
|
||||
if isinstance(data, int):
|
||||
data_type = onnx.helper.TensorProto.INT64
|
||||
elif isinstance(data, float):
|
||||
data_type = onnx.helper.TensorProto.FLOAT
|
||||
else:
|
||||
logger.error("Cannot create scaler constant with a list.")
|
||||
exit(1)
|
||||
tensor = onnx.helper.make_tensor(
|
||||
name,
|
||||
data_type,
|
||||
None,
|
||||
[data]
|
||||
)
|
||||
new_w_node = onnx.helper.make_node(
|
||||
"Constant",
|
||||
[],
|
||||
[name],
|
||||
name = name,
|
||||
value = tensor
|
||||
)
|
||||
return new_w_node
|
||||
|
||||
|
||||
def numpy_to_constant(name, np_array):
|
||||
return list_to_constant(name, np_array.shape, np_array.flatten().tolist())
|
||||
|
||||
def constant_to_list(node):
|
||||
"""Generate a list from the constant node
|
||||
|
||||
:node: the Constant node\\
|
||||
:returns: the shape of the constant node, the data of the constant node
|
||||
"""
|
||||
tensor = node.attribute[0].t
|
||||
# 1. check data type
|
||||
# 2. get data from raw or data
|
||||
# 3. get shape from dim
|
||||
if tensor.data_type == onnx.helper.TensorProto.INT32:
|
||||
if len(tensor.int32_data) != 0:
|
||||
data = list(tensor.int32_data)
|
||||
else:
|
||||
data = [i[0] for i in struct.iter_unpack('i', tensor.raw_data)]
|
||||
elif tensor.data_type == onnx.helper.TensorProto.INT64:
|
||||
if len(tensor.int64_data) != 0:
|
||||
data = list(tensor.int64_data)
|
||||
else:
|
||||
data = [i[0] for i in struct.iter_unpack('q', tensor.raw_data)]
|
||||
elif tensor.data_type == onnx.helper.TensorProto.INT8:
|
||||
if len(tensor.int32_data) != 0:
|
||||
data = list(tensor.int32_data)
|
||||
else:
|
||||
data = [i[0] for i in struct.iter_unpack('b', tensor.raw_data)]
|
||||
elif tensor.data_type == onnx.helper.TensorProto.FLOAT:
|
||||
if len(tensor.float_data) != 0:
|
||||
data = list(tensor.float_data)
|
||||
else:
|
||||
data = [i[0] for i in struct.iter_unpack('f', tensor.raw_data)]
|
||||
elif tensor.data_type == onnx.helper.TensorProto.DOUBLE:
|
||||
if len(tensor.double_data) != 0:
|
||||
data = list(tensor.double_data)
|
||||
else:
|
||||
data = [i[0] for i in struct.iter_unpack('d', tensor.raw_data)]
|
||||
else:
|
||||
print("Not supported data type {}".format(tensor.data_type))
|
||||
raise RuntimeError
|
||||
if len(tensor.dims) == 0:
|
||||
shape = len(data)
|
||||
else:
|
||||
shape = list(tensor.dims)
|
||||
return shape, data
|
||||
|
||||
def constant_to_numpy(node):
|
||||
"""Generate a numpy array from the constant node
|
||||
|
||||
:node: the Constant node\\
|
||||
:returns: the numpy array
|
||||
"""
|
||||
shape, data = constant_to_list(node)
|
||||
return np.array(data).reshape(shape)
|
||||
|
||||
def all_constant_input(node):
|
||||
"""Find the inputs of the given node. If the inputs of this node are all\\
|
||||
constant nodes, return True. Otherwise, return False.
|
||||
|
||||
:param node: the input node which has a Node structure\\
|
||||
:return: whether the node of this node are all constant
|
||||
"""
|
||||
if node.proto is None:
|
||||
return False
|
||||
isConstant = True
|
||||
for parent in node.parents:
|
||||
if parent.proto is None or parent.proto.op_type != 'Constant':
|
||||
isConstant = False
|
||||
break
|
||||
return isConstant
|
||||
|
||||
def get_padding(size, kernel_size, strides):
|
||||
""" Calculate the padding array for same padding in the Tensorflow fashion.\\
|
||||
See https://www.tensorflow.org/api_guides/python/nn#Convolution for more.
|
||||
"""
|
||||
if size[0] % strides[0] == 0:
|
||||
pad_h = max(kernel_size[0] - strides[0], 0)
|
||||
else:
|
||||
pad_h = max(kernel_size[0] - (size[0] % strides[0]), 0)
|
||||
if size[1] % strides[1] == 0:
|
||||
pad_w = max(kernel_size[1] - strides[1], 0)
|
||||
else:
|
||||
pad_w = max(kernel_size[1] - (size[1] % strides[1]), 0)
|
||||
return [pad_h//2, pad_w//2, pad_h-pad_h//2, pad_w-pad_w//2]
|
||||
|
||||
def get_shape_from_value_info(value):
|
||||
"""Get shape from a value info.
|
||||
|
||||
:param value: the value_info proto\\
|
||||
:return: list of the shape
|
||||
"""
|
||||
return [d.dim_value for d in value.type.tensor_type.shape.dim]
|
||||
|
||||
def find_size_shape_from_value(value):
|
||||
'''
|
||||
Find the size of data within the value_info object.
|
||||
:param value: value_info
|
||||
:return: int size and list shape of the data in the value_info
|
||||
'''
|
||||
if not value:
|
||||
return None, None
|
||||
if not value.type.tensor_type.shape.dim:
|
||||
return 0, []
|
||||
size = 1
|
||||
shape = []
|
||||
for i in range(len(value.type.tensor_type.shape.dim)):
|
||||
size *= max(1, value.type.tensor_type.shape.dim[i].dim_value)
|
||||
shape.append(max(1, value.type.tensor_type.shape.dim[i].dim_value))
|
||||
|
||||
return size, shape
|
||||
|
||||
|
||||
def get_attribute_by_name(node, attr_name):
|
||||
"""Get attribute proto with specific name in the given node proto.
|
||||
|
||||
:param node: the node proto.\\
|
||||
:param attr_name: a str for the name of the target.\\
|
||||
:return: if found, return the attribute_proto. Else, return None.
|
||||
"""
|
||||
for attr in node.attribute:
|
||||
if attr.name == attr_name:
|
||||
return attr
|
||||
return None
|
||||
|
||||
def get_list_attribute_by_name(node, attr_name: str, attr_type: str):
|
||||
"""Get list attribute with specific name in the given node proto.
|
||||
|
||||
:param node: the node proto.\\
|
||||
:param attr_name: a str for the name of the target.\\
|
||||
:param attr_type: a str which should be "float" or "int".\\
|
||||
:return: if found, return the list. Else, return None.
|
||||
"""
|
||||
attr_proto = get_attribute_by_name(node, attr_name)
|
||||
if attr_proto is None:
|
||||
return None
|
||||
if attr_type == "int":
|
||||
if len(attr_proto.ints) == 0:
|
||||
return None
|
||||
else:
|
||||
return list(attr_proto.ints)
|
||||
elif attr_type == "float":
|
||||
if len(attr_proto.ints) == 0:
|
||||
return None
|
||||
else:
|
||||
return list(attr_proto.floats)
|
||||
else:
|
||||
print("Warning: undefined type for list attribute extraction")
|
||||
return None
|
||||
|
||||
def get_var_attribute_by_name(node, attr_name: str, attr_type: str):
|
||||
"""Get variable attribute with specific name in the given node proto.
|
||||
|
||||
:param node: the node proto.\\
|
||||
:param attr_name: str for the name of the target.\\
|
||||
:param attr_type: str which should be "float", "int", "string" or "tensor".\\
|
||||
:return: if found, return the variable. Else, return None.
|
||||
"""
|
||||
attr_proto = get_attribute_by_name(node, attr_name)
|
||||
if attr_proto is None:
|
||||
return None
|
||||
if attr_type == "int":
|
||||
return attr_proto.i
|
||||
elif attr_type == "float":
|
||||
return attr_proto.f
|
||||
elif attr_type == "string":
|
||||
if type(attr_proto.s) == type(b'abc'):
|
||||
return attr_proto.s.decode("utf-8")
|
||||
else:
|
||||
return attr_proto.s
|
||||
elif attr_type == "tensor":
|
||||
return attr_proto.t
|
||||
else:
|
||||
print("Warning: undefined type for variable attribute extraction")
|
||||
return None
|
||||
|
||||
def flatten_with_depth(data, depth):
|
||||
output = []
|
||||
if type(data) not in [type(np.array([1])), type([1])]:
|
||||
return [[data, 0]]
|
||||
for item in data:
|
||||
if type(item) not in [type(np.array([1])), type([1])]:
|
||||
output.append([item, depth+1])
|
||||
else:
|
||||
output += flatten_with_depth(item, depth+1)
|
||||
return output
|
||||
|
||||
def flatten_to_list(data):
|
||||
flatten_depth = flatten_with_depth(data, 0)
|
||||
flat_data = [item[0] for item in flatten_depth]
|
||||
return flat_data
|
||||
|
||||
def get_shape(data):
|
||||
shape = []
|
||||
if type(data) not in [type(np.array([1])), type([1])]:
|
||||
return []
|
||||
sub_data = data[0]
|
||||
shape.append(len(data))
|
||||
while type(sub_data) in [type(np.array([1])), type([1])]:
|
||||
shape.append(len(sub_data))
|
||||
sub_data = sub_data[0]
|
||||
return shape
|
||||
|
||||
|
||||
def slice_data(data, starts, ends, axes):
|
||||
flat_data = [item[0] for item in flatten_with_depth(data, 0)]
|
||||
shape = get_shape(data)
|
||||
|
||||
starts_updated = []
|
||||
ends_updated = []
|
||||
for i in range(len(starts)):
|
||||
start_updated = min(starts[i], shape[i]-1) % shape[i]
|
||||
starts_updated.append(start_updated)
|
||||
for j in range(len(starts)):
|
||||
if ends[j] >= shape[j]:
|
||||
end_updated = shape[j]
|
||||
else:
|
||||
end_updated = min(ends[j], shape[j]) % shape[j]
|
||||
ends_updated.append(end_updated)
|
||||
|
||||
index_slices = []
|
||||
for i in range(len(shape)):
|
||||
if i not in axes:
|
||||
index_slices.append(list(range(shape[i])))
|
||||
else:
|
||||
axe_ind = axes.index(i)
|
||||
index_slices.append(list(range(starts_updated[axe_ind], ends_updated[axe_ind])))
|
||||
|
||||
indices = [1]
|
||||
for i in range(len(shape)-1, -1, -1):
|
||||
step = np.prod(shape[i+1:])
|
||||
temp_pos = indices
|
||||
new_indices = []
|
||||
for n in index_slices[i]:
|
||||
for pos in temp_pos:
|
||||
new_indices.append(int(n*step+pos))
|
||||
indices = new_indices
|
||||
|
||||
sliced_data = [flat_data[k-1] for k in indices]
|
||||
|
||||
# reshape to correct shape.
|
||||
new_shape = []
|
||||
for i in range(len(shape)):
|
||||
if i not in axes:
|
||||
new_shape.append(shape[i])
|
||||
else:
|
||||
axe_ind = axes.index(i)
|
||||
new_shape.append(ends_updated[axe_ind]-starts_updated[axe_ind])
|
||||
if any([dim < 1 for dim in new_shape]):
|
||||
raise RuntimeError('Invalid starts ends.')
|
||||
|
||||
sliced_data = np.reshape(sliced_data, new_shape)
|
||||
|
||||
return sliced_data
|
||||
|
||||
def concatenate(data_sets, axis):
|
||||
# check shapes
|
||||
shapes = []
|
||||
shapes_ = []
|
||||
for data_set in data_sets:
|
||||
shape = get_shape(data_set)
|
||||
shapes.append(list(shape))
|
||||
shape.pop(axis)
|
||||
shapes_.append(shape)
|
||||
if not all([s == shapes_[0] for s in shapes_]):
|
||||
raise RuntimeError('data sets shapes do not match')
|
||||
|
||||
new_dim = sum([s[axis] for s in shapes])
|
||||
new_shape = list(shapes[0])
|
||||
new_shape[axis] = new_dim
|
||||
|
||||
flat_data_sets = []
|
||||
for data_set in data_sets:
|
||||
flat_data_sets.append(flatten_to_list(data_set))
|
||||
|
||||
sub_block_size = 1
|
||||
for i in range(axis+1, len(shapes[0])):
|
||||
sub_block_size *= shapes[0][i]
|
||||
|
||||
split_num = 1
|
||||
for i in range(axis):
|
||||
split_num *= shapes[0][i]
|
||||
|
||||
total_flat_data = []
|
||||
for i in range(split_num):
|
||||
for j in range(len(shapes)):
|
||||
block_size = sub_block_size*shapes[j][axis]
|
||||
total_flat_data.extend(flat_data_sets[j][i*block_size:(i+1)*block_size])
|
||||
|
||||
new_data = np.reshape(total_flat_data, new_shape)
|
||||
|
||||
return new_data
|
||||
|
||||
|
||||
def broadcast_data_sets(data_set_1, data_set_2):
|
||||
shape1 = get_shape(data_set_1)
|
||||
shape2 = get_shape(data_set_2)
|
||||
|
||||
# compare shapes and get broadcasted shape
|
||||
list_a, list_b = (shape1, shape2) if len(shape1) > len(shape2) else (shape2, shape1)
|
||||
while len(list_a) > len(list_b):
|
||||
list_b.insert(0, 0)
|
||||
broadcasted_shape = []
|
||||
for i in range(len(list_a)):
|
||||
if list_b[i] == 0:
|
||||
broadcasted_shape.append(list_a[i])
|
||||
elif list_b[i] == 1:
|
||||
broadcasted_shape.append(list_a[i])
|
||||
elif list_a[i] == 1:
|
||||
broadcasted_shape.append(list_b[i])
|
||||
elif list_a[i] == list_b[i]:
|
||||
broadcasted_shape.append(list_a[i])
|
||||
else:
|
||||
raise RuntimeError('Can not broadcast two data sets')
|
||||
|
||||
# prepare data for broadcasting.
|
||||
shape1 = list(map(lambda x:x if x != 0 else 1, shape1))
|
||||
shape2 = list(map(lambda x:x if x != 0 else 1, shape2))
|
||||
data_1 = np.reshape(data_set_1, shape1)
|
||||
data_2 = np.reshape(data_set_2, shape2)
|
||||
|
||||
for i in range(len(shape1)):
|
||||
if shape1[i] != broadcasted_shape[i]:
|
||||
new_data_total = [list(data_1) for _ in range(broadcasted_shape[i])]
|
||||
data_1 = concatenate(new_data_total, axis=i)
|
||||
for i in range(len(shape2)):
|
||||
if shape2[i] != broadcasted_shape[i]:
|
||||
new_data_total = [list(data_2) for _ in range(broadcasted_shape[i])]
|
||||
data_2 = concatenate(new_data_total, axis=i)
|
||||
|
||||
return data_1, data_2
|
||||
|
||||
|
||||
def add(data_set_1, data_set_2):
|
||||
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2)
|
||||
|
||||
flat_data_1 = flatten_to_list(broadcasted_data_1)
|
||||
flat_data_2 = flatten_to_list(broadcasted_data_2)
|
||||
shape = get_shape(broadcasted_data_1)
|
||||
res = []
|
||||
for i in range(len(flat_data_1)):
|
||||
res.append(flat_data_1[i]+flat_data_2[i])
|
||||
|
||||
res = np.reshape(res, shape)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def reduceprod(data_set, axis, keepdims=1):
|
||||
flat_data = flatten_to_list(data_set)
|
||||
old_shape = get_shape(data_set)
|
||||
|
||||
temp_shape = old_shape
|
||||
temp_flat_data = flat_data
|
||||
for ax in axis:
|
||||
split_num = 1
|
||||
step = 1
|
||||
for i in range(ax):
|
||||
split_num *= temp_shape[i]
|
||||
for i in range(ax+1, len(temp_shape)):
|
||||
step *= temp_shape[i]
|
||||
|
||||
block_size = len(temp_flat_data)//split_num
|
||||
new_flat_data = []
|
||||
for j in range(split_num):
|
||||
block_data = temp_flat_data[j*block_size:(j+1)*block_size]
|
||||
reduced_block_data = []
|
||||
for k in range(step):
|
||||
val = block_data[k]
|
||||
for l in range(1, block_size//step):
|
||||
val *= block_data[k+l*step]
|
||||
reduced_block_data.append(val)
|
||||
new_flat_data.extend(reduced_block_data)
|
||||
temp_flat_data = new_flat_data
|
||||
temp_shape[ax] = 1
|
||||
|
||||
new_flat_data = temp_flat_data
|
||||
new_shape = temp_shape
|
||||
if not keepdims:
|
||||
axis = sorted(list(axis))
|
||||
for pos in axis[::-1]:
|
||||
new_shape.pop(pos)
|
||||
|
||||
return np.reshape(new_flat_data, new_shape)
|
||||
|
||||
|
||||
def transpose(data_set, permutation):
|
||||
# find series of local swaps
|
||||
data_set = list(data_set)
|
||||
perm = list(permutation)
|
||||
shape = get_shape(data_set)
|
||||
flat_data = flatten_to_list(data_set)
|
||||
assert set(perm) == set(range(len(shape))), 'invalid permutation'
|
||||
|
||||
new_shape = [shape[i] for i in perm]
|
||||
swaps = []
|
||||
bubbled = True
|
||||
while bubbled:
|
||||
bubbled = False
|
||||
for i in range(len(new_shape)-1):
|
||||
if perm[i] > perm[i+1]:
|
||||
swaps.append([i, i+1])
|
||||
p_1, p_2 = perm[i], perm[i+1]
|
||||
perm[i], perm[i+1] = p_2, p_1
|
||||
bubbled = True
|
||||
|
||||
# apply local swaps
|
||||
current_shape = list(shape)
|
||||
temp_flat_data = flat_data
|
||||
|
||||
for swap in swaps[::-1]:
|
||||
ind_1, ind_2 = swap[0], swap[1]
|
||||
dim_1 = current_shape[ind_1]
|
||||
dim_2 = current_shape[ind_2]
|
||||
split_num = 1
|
||||
block_size = 1
|
||||
|
||||
for i in range(ind_1):
|
||||
split_num *= current_shape[i]
|
||||
for i in range(ind_2+1, len(current_shape)):
|
||||
block_size *= current_shape[i]
|
||||
|
||||
data_blocks = np.reshape(temp_flat_data, [-1, block_size])
|
||||
flat_data_1 = []
|
||||
for k in range(split_num):
|
||||
block = []
|
||||
for m in range(dim_2):
|
||||
for n in range(dim_1):
|
||||
block_pos = k*dim_1*dim_2 + n*dim_2+m
|
||||
block.extend(data_blocks[block_pos])
|
||||
flat_data_1.extend(block)
|
||||
|
||||
temp_flat_data = flat_data_1
|
||||
current_shape[ind_1] = dim_2
|
||||
current_shape[ind_2] = dim_1
|
||||
|
||||
return np.reshape(temp_flat_data, current_shape)
|
||||
|
||||
def subtract(data_set_1, data_set_2):
|
||||
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2)
|
||||
|
||||
shape = get_shape(broadcasted_data_1)
|
||||
flat_data_1 = flatten_to_list(broadcasted_data_1)
|
||||
flat_data_2 = flatten_to_list(broadcasted_data_2)
|
||||
|
||||
substracted_data = [flat_data_1[i] - flat_data_2[i] for i in range(len(flat_data_1))]
|
||||
|
||||
new_data = np.reshape(substracted_data, shape)
|
||||
|
||||
return new_data
|
||||
|
||||
|
||||
78
tools/optimizer_scripts/tools/modhelper.py
Normal file
78
tools/optimizer_scripts/tools/modhelper.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""This module contains helper functions that do graph modifications.
|
||||
"""
|
||||
|
||||
import onnx
|
||||
from . import helper
|
||||
|
||||
|
||||
def replace_node_input(node, old_input, new_input):
|
||||
for i, input_name in enumerate(node.input):
|
||||
if input_name == old_input:
|
||||
node.input[i] = new_input
|
||||
|
||||
def delete_nodes(g, node_list):
|
||||
node_to_delete = []
|
||||
#Find target nodes
|
||||
for node in g.node:
|
||||
if node.name not in node_list:
|
||||
continue
|
||||
else:
|
||||
node_to_delete.append(node)
|
||||
if len(node_list) != len(node_to_delete):
|
||||
print("Some nodes do not exist in the graph. Skipping them.")
|
||||
for node in node_to_delete:
|
||||
# Check the node whether if it is valid to delete
|
||||
if len(node.input) == 0:
|
||||
print("Deleting an Constant node. Please make sure you also delete all its following nodes")
|
||||
elif len(node.input) > 1:
|
||||
print("Warning: Node {} has more than one input. This script cannot delete merge nodes.".format(node.name))
|
||||
# Connect the nodes around the target node.
|
||||
# Set the following node input as the previous node output.
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
if len(node.input) == 0:
|
||||
for following_node in following_nodes:
|
||||
following_node.input.remove(node.output[0])
|
||||
elif len(following_nodes) > 0 and len(node.input) == 1 and helper.find_input_by_name(g, node.input[0]) is not None:
|
||||
# The node input is an input
|
||||
new_input = helper.find_value_by_name(g, node.output[0])
|
||||
g.input.append(new_input)
|
||||
g.input.remove(helper.find_input_by_name(g, node.input[0]))
|
||||
g.value_info.remove(new_input)
|
||||
elif len(following_nodes) > 0:
|
||||
for following_node in following_nodes:
|
||||
replace_node_input(following_node, node.output[0], node.input[0])
|
||||
else:
|
||||
# If the node is the output, replace the output with the previous input.
|
||||
value = helper.find_value_by_name(g, node.input[0])
|
||||
output_values = []
|
||||
while len(g.output):
|
||||
output_values.append(g.output.pop())
|
||||
while output_values:
|
||||
output_value = output_values.pop()
|
||||
if output_value.name == node.output[0]:
|
||||
g.output.extend([value])
|
||||
else:
|
||||
g.output.extend([output_value])
|
||||
# Remove the node and value info.
|
||||
g.node.remove(node)
|
||||
|
||||
def delete_input(g, target_list):
|
||||
for name in target_list:
|
||||
input_value = helper.find_input_by_name(g, name)
|
||||
if input_value is None:
|
||||
print("Cannot find input {}".format(name))
|
||||
continue
|
||||
g.input.remove(input_value)
|
||||
|
||||
def delete_output(g, target_list):
|
||||
for name in target_list:
|
||||
output_value = helper.find_output_by_name(g, name)
|
||||
if output_value is None:
|
||||
print("Cannot find output {}".format(name))
|
||||
continue
|
||||
g.output.remove(output_value)
|
||||
|
||||
def delete_value_with_name_if_exists(g, name):
|
||||
value = helper.find_value_by_name(g, name)
|
||||
if value is not None:
|
||||
g.value_info.remove(value)
|
||||
1200
tools/optimizer_scripts/tools/other.py
Normal file
1200
tools/optimizer_scripts/tools/other.py
Normal file
File diff suppressed because it is too large
Load Diff
317
tools/optimizer_scripts/tools/removing_transpose.py
Normal file
317
tools/optimizer_scripts/tools/removing_transpose.py
Normal file
@ -0,0 +1,317 @@
|
||||
from . import helper
|
||||
from . import other
|
||||
from . import modhelper
|
||||
from . import fusing
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnx.utils
|
||||
|
||||
def eliminate_transposes(m):
|
||||
g = m.graph
|
||||
keep_eliminating = True
|
||||
while keep_eliminating:
|
||||
while swap_transpose_with_single_next_node(g):
|
||||
pass
|
||||
splitted = split_transpose_for_multiple_next_nodes(g)
|
||||
annihilated = annihilate_transposes(g)
|
||||
multiple_trans_swapped = swap_multiple_transposes_with_node(g)
|
||||
keep_eliminating = splitted or annihilated or multiple_trans_swapped
|
||||
|
||||
if keep_eliminating:
|
||||
m = other.polish_model(m)
|
||||
g = m.graph
|
||||
|
||||
return m
|
||||
|
||||
|
||||
def swap_transpose_with_single_next_node(g):
|
||||
swapped = False
|
||||
passable_nodes = set(['Relu', 'Neg', 'LeakyRelu', 'Sqrt', 'Reciprocal', 'Add', 'Mul', 'Tanh'])
|
||||
for node in g.node:
|
||||
trans_node = node
|
||||
# Check for transpose node
|
||||
if trans_node.op_type != 'Transpose':
|
||||
continue
|
||||
next_nodes = helper.find_nodes_by_input_name(g, trans_node.output[0])
|
||||
if len(next_nodes) != 1:
|
||||
continue
|
||||
next_node = next_nodes[0]
|
||||
# Check if the next node is the type can be swapped
|
||||
if next_node.op_type not in passable_nodes:
|
||||
continue
|
||||
|
||||
input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in next_node.input]
|
||||
|
||||
# Check if the node has nonconstant input other than the Transpose node itself
|
||||
nonconstant_input = False
|
||||
for input_node in input_nodes:
|
||||
if input_node == None:
|
||||
nonconstant_input = True
|
||||
break
|
||||
if input_node.name == trans_node.name:
|
||||
continue
|
||||
elif input_node.op_type == 'Constant':
|
||||
continue
|
||||
else:
|
||||
nonconstant_input = True
|
||||
break
|
||||
if nonconstant_input:
|
||||
continue
|
||||
|
||||
for input_node in input_nodes:
|
||||
if input_node.name == trans_node.name:
|
||||
# if the input is just the transpose node
|
||||
next_value_info = helper.find_value_by_name(g, next_node.output[0])
|
||||
mid_value_info = helper.find_value_by_name(g, trans_node.output[0])
|
||||
|
||||
output_nodes = helper.find_nodes_by_input_name(g, next_node.output[0])
|
||||
for out_node in output_nodes:
|
||||
modhelper.replace_node_input(out_node, next_node.output[0], trans_node.name)
|
||||
|
||||
next_node.input[0] = trans_node.input[0]
|
||||
next_node.output[0] = next_node.name
|
||||
trans_node.input[0] = next_node.name
|
||||
trans_node.output[0] = trans_node.name
|
||||
|
||||
if next_value_info:
|
||||
next_value_info.name = trans_node.name
|
||||
if mid_value_info:
|
||||
g.value_info.remove(mid_value_info)
|
||||
else:
|
||||
# if the input is a constant node
|
||||
old_tensor = input_node.attribute[0].t
|
||||
old_shape, data = helper.constant_to_list(input_node)
|
||||
# If the constant node is a scaler, no action is needed
|
||||
if type(old_shape) == int:
|
||||
old_shape = [old_shape]
|
||||
permutation = list(trans_node.attribute[0].ints)
|
||||
while len(old_shape) < len(permutation):
|
||||
old_shape.insert(0, 1)
|
||||
np_data = np.reshape(data, old_shape)
|
||||
reverse_perm = []
|
||||
for i in range(len(permutation)):
|
||||
reverse_perm.append(permutation.index(i))
|
||||
np_data = np.transpose(np_data, reverse_perm)
|
||||
new_shape = np_data.shape
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=old_tensor.name,
|
||||
data_type=old_tensor.data_type,
|
||||
dims=new_shape,
|
||||
vals=np_data.flatten().tolist()
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[input_node.output[0]],
|
||||
name=input_node.name,
|
||||
value=new_tensor
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
|
||||
g.value_info.remove(helper.find_value_by_name(g, input_node.output[0]))
|
||||
g.node.remove(input_node)
|
||||
|
||||
swapped = True
|
||||
|
||||
other.topological_sort(g)
|
||||
return swapped
|
||||
|
||||
|
||||
def swap_multiple_transposes_with_node(g):
|
||||
# here only consider same input transposes
|
||||
swapped = False
|
||||
passable_nodes = set(['Add', 'Mul'])
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
if node.op_type not in passable_nodes:
|
||||
continue
|
||||
input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in node.input]
|
||||
if any([input_node == None for input_node in input_nodes]):
|
||||
continue
|
||||
if any([input_node.op_type != 'Transpose' for input_node in input_nodes]):
|
||||
continue
|
||||
|
||||
permutation = list(input_nodes[0].attribute[0].ints)
|
||||
if any([list(input_node.attribute[0].ints) != permutation for input_node in input_nodes]):
|
||||
continue
|
||||
|
||||
for input_name in node.input:
|
||||
input_node = helper.find_node_by_output_name(g, input_name)
|
||||
modhelper.replace_node_input(node, input_name, input_node.input[0])
|
||||
|
||||
node_to_del.extend(input_nodes)
|
||||
for input_node in input_nodes:
|
||||
input_val_info = helper.find_value_by_name(g, input_node.output[0])
|
||||
if input_val_info is not None:
|
||||
g.value_info.remove(input_val_info)
|
||||
output_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
if output_val_info is not None:
|
||||
g.value_info.remove(output_val_info)
|
||||
|
||||
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
for i in range(len(output_nodes)):
|
||||
new_trans_node_name = node.name+'_trans_'+str(i)
|
||||
new_trans_node = onnx.helper.make_node(
|
||||
'Transpose',
|
||||
[node.output[0]],
|
||||
[new_trans_node_name],
|
||||
name=new_trans_node_name,
|
||||
perm=permutation
|
||||
)
|
||||
modhelper.replace_node_input(output_nodes[i], node.output[0], new_trans_node_name)
|
||||
|
||||
g.node.extend([new_trans_node])
|
||||
|
||||
swapped = True
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
other.topological_sort(g)
|
||||
return swapped
|
||||
|
||||
|
||||
def annihilate_transposes(g):
|
||||
node_to_del = []
|
||||
annihilated = False
|
||||
for node in g.node:
|
||||
if node.op_type != 'Transpose':
|
||||
continue
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
if not pre_node or pre_node.op_type != 'Transpose':
|
||||
continue
|
||||
nodes_from_top_transpose = helper.find_nodes_by_input_name(g, pre_node.output[0])
|
||||
if len(nodes_from_top_transpose) > 1:
|
||||
continue
|
||||
|
||||
perm_1 = list(pre_node.attribute[0].ints)
|
||||
perm_2 = list(node.attribute[0].ints)
|
||||
if perm_1 != perm_2:
|
||||
continue
|
||||
|
||||
out_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
for out_node in out_nodes:
|
||||
modhelper.replace_node_input(out_node, node.output[0], pre_node.input[0])
|
||||
|
||||
node_to_del.extend([node, pre_node])
|
||||
mid_value_info = helper.find_value_by_name(g, pre_node.output[0])
|
||||
out_value_info = helper.find_value_by_name(g, node.output[0])
|
||||
g.value_info.remove(mid_value_info)
|
||||
g.value_info.remove(out_value_info)
|
||||
|
||||
annihilated = True
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return annihilated
|
||||
|
||||
|
||||
def split_transpose_for_multiple_next_nodes(g):
|
||||
splitted = False
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Transpose':
|
||||
continue
|
||||
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
if len(output_nodes) < 2:
|
||||
continue
|
||||
for i in range(len(output_nodes)):
|
||||
output_node = output_nodes[i]
|
||||
new_trans_node_name = node.name + '_' + str(i)
|
||||
new_trans_node = onnx.helper.make_node(
|
||||
'Transpose',
|
||||
[node.input[0]],
|
||||
[new_trans_node_name],
|
||||
name=new_trans_node_name,
|
||||
perm=list(node.attribute[0].ints)
|
||||
)
|
||||
modhelper.replace_node_input(output_node, node.output[0], new_trans_node.output[0])
|
||||
g.node.extend([new_trans_node])
|
||||
|
||||
node_to_del.append(node)
|
||||
val_info = helper.find_value_by_name(g, node.output[0])
|
||||
g.value_info.remove(val_info)
|
||||
|
||||
splitted = True
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
other.topological_sort(g)
|
||||
return splitted
|
||||
|
||||
def remove_trivial_transpose(g):
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Transpose':
|
||||
continue
|
||||
permutation = list(node.attribute[0].ints)
|
||||
if permutation != list(range(len(permutation))):
|
||||
continue
|
||||
|
||||
next_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
if not next_nodes:
|
||||
input_val_info = helper.find_value_by_name(g, node.input[0])
|
||||
out_val_info = helper.find_output_by_name(g, node.output[0])
|
||||
if not input_val_info:
|
||||
input_val_info = helper.find_input_by_name(g, node.input[0])
|
||||
g.output.remove(out_val_info)
|
||||
g.output.extend([input_val_info])
|
||||
else:
|
||||
out_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
for next_node in next_nodes:
|
||||
modhelper.replace_node_input(next_node, node.output[0], node.input[0])
|
||||
g.value_info.remove(out_val_info)
|
||||
|
||||
node_to_del.append(node)
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
other.topological_sort(g)
|
||||
|
||||
def fuse_Transpose_into_Gemm_weight(g):
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
# Check pattern
|
||||
if node.op_type != 'Gemm':
|
||||
continue
|
||||
prev_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
if prev_node is None or prev_node.op_type != 'Flatten':
|
||||
continue
|
||||
transpose_node = helper.find_node_by_output_name(g, prev_node.input[0])
|
||||
if transpose_node.op_type != 'Transpose':
|
||||
continue
|
||||
# Check attribute
|
||||
perm = helper.get_list_attribute_by_name(transpose_node, 'perm', 'int')
|
||||
if perm != [0, 2, 3, 1]:
|
||||
continue
|
||||
transB = helper.get_var_attribute_by_name(node, 'transB', 'int')
|
||||
if transB is not None and transB == 1:
|
||||
continue
|
||||
# Get the original weight
|
||||
origin_weight = helper.find_node_by_output_name(g, node.input[1])
|
||||
origin_np = helper.constant_to_numpy(origin_weight)
|
||||
# Calculate a new weight
|
||||
shape = helper.get_shape_from_value_info(helper.find_value_by_name(g, prev_node.input[0]))
|
||||
shape.append(-1)
|
||||
new_np = np.reshape(origin_np, shape)
|
||||
new_np = np.transpose(new_np, [0, 3, 1, 2, 4])
|
||||
new_np = np.reshape(new_np, [-1, new_np.shape[-1]])
|
||||
new_weight = helper.numpy_to_constant(origin_weight.output[0], new_np)
|
||||
# Replace and eliminate
|
||||
prev_node.input[0] = transpose_node.input[0]
|
||||
node_to_del.append(transpose_node)
|
||||
node_to_del.append(origin_weight)
|
||||
g.value_info.remove(helper.find_value_by_name(g, transpose_node.output[0]))
|
||||
g.node.extend([new_weight])
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
other.topological_sort(g)
|
||||
1171
tools/optimizer_scripts/tools/replacing.py
Normal file
1171
tools/optimizer_scripts/tools/replacing.py
Normal file
File diff suppressed because it is too large
Load Diff
423
tools/optimizer_scripts/tools/special.py
Normal file
423
tools/optimizer_scripts/tools/special.py
Normal file
@ -0,0 +1,423 @@
|
||||
"""Special operations on model.
|
||||
"""
|
||||
import logging
|
||||
import onnx.helper
|
||||
import numpy as np
|
||||
from . import helper
|
||||
from . import other
|
||||
from . import modhelper
|
||||
|
||||
def change_first_conv_from_bgr_to_rgb(m):
|
||||
"""For input channel format BGR model, use this function to change the first
|
||||
conv weight to adapt the input into RGB.
|
||||
|
||||
:param m: the model proto
|
||||
"""
|
||||
# Check for first node.
|
||||
g = m.graph
|
||||
input_name = g.input[0].name
|
||||
first_nodes = helper.find_following_nodes_by_input_value_name(g, input_name)
|
||||
if len(first_nodes) > 1:
|
||||
return False
|
||||
first_node = first_nodes[0]
|
||||
# Now we have the first node. Check this first node.
|
||||
if first_node.op_type != 'Conv':
|
||||
return False
|
||||
weight_value = helper.find_value_by_name(g, first_node.input[1])
|
||||
weight_shape = helper.get_shape_from_value_info(weight_value)
|
||||
if weight_shape[1] != 3:
|
||||
return False
|
||||
# Do weight shuffle
|
||||
weight_node = helper.find_node_by_output_name(g, weight_value.name)
|
||||
weight_np = helper.constant_to_numpy(weight_node)
|
||||
b_channel = np.expand_dims(weight_np[:, 0, :, :], axis=1)
|
||||
g_channel = np.expand_dims(weight_np[:, 1, :, :], axis=1)
|
||||
r_channel = np.expand_dims(weight_np[:, 2, :, :], axis=1)
|
||||
new_np = np.concatenate((r_channel, g_channel, b_channel), axis=1)
|
||||
new_node = helper.numpy_to_constant(weight_value.name, new_np)
|
||||
# Replace the weight and topological sort
|
||||
g.node.remove(weight_node)
|
||||
g.node.extend([new_node])
|
||||
other.topological_sort(g)
|
||||
return True
|
||||
|
||||
def change_input_from_bgr_to_rgb(m):
|
||||
"""For input channel format BGR model, use this function to modify the model
|
||||
to accepct RGB image.If the first node is a non-group Conv. Modify weight to
|
||||
adapt the input into RGB. Otherwise create a new node.
|
||||
|
||||
:param m: the model proto
|
||||
"""
|
||||
g = m.graph
|
||||
if len(g.input) > 1:
|
||||
print("This model has multiple inputs. Cannot change to RGB input.")
|
||||
return
|
||||
input_shape = helper.get_shape_from_value_info(g.input[0])
|
||||
if len(input_shape) != 4 or input_shape[1] != 3:
|
||||
print("The input shape is invalid for bgr conversion.")
|
||||
return
|
||||
# Try change conv weight first
|
||||
if change_first_conv_from_bgr_to_rgb(m):
|
||||
return
|
||||
# Otherwise, create a special conv node and replace the input
|
||||
# Construct weight
|
||||
weight_np = np.zeros((3, 3, 3, 3)).astype('float32')
|
||||
weight_np[0, 2, 1, 1] = 1.0
|
||||
weight_np[1, 1, 1, 1] = 1.0
|
||||
weight_np[2, 0, 1, 1] = 1.0
|
||||
new_weight = helper.numpy_to_constant("bgr_shuffle_weight", weight_np)
|
||||
# Construct Conv
|
||||
new_conv = onnx.helper.make_node(
|
||||
'Conv',
|
||||
['rgb_input', "bgr_shuffle_weight"],
|
||||
[g.input[0].name],
|
||||
name='bgr_shuffle',
|
||||
dilations=[1, 1],
|
||||
kernel_shape=[3, 3],
|
||||
pads=[1, 1, 1, 1],
|
||||
strides=[1, 1]
|
||||
)
|
||||
# Connect the graph
|
||||
old_input_value = g.input.pop()
|
||||
new_input_value = onnx.helper.make_tensor_value_info(
|
||||
'rgb_input',
|
||||
old_input_value.type.tensor_type.elem_type,
|
||||
input_shape
|
||||
)
|
||||
g.input.extend([new_input_value])
|
||||
g.node.extend([new_weight, new_conv])
|
||||
# topological sort
|
||||
other.topological_sort(g)
|
||||
|
||||
def add_0_5_to_normalized_input(m):
|
||||
"""For normalized input between -0.5 ~ 0.5, add 0.5 to the input to keep it
|
||||
between 0 ~ 1.
|
||||
|
||||
:param m: the model proto
|
||||
"""
|
||||
g = m.graph
|
||||
if len(g.input) > 1:
|
||||
print("This model has multiple inputs. Cannot normalize input.")
|
||||
return
|
||||
input_shape = helper.get_shape_from_value_info(g.input[0])
|
||||
if len(input_shape) != 4:
|
||||
print("The input shape is not BCHW. Cannot normalize input.")
|
||||
return
|
||||
# Construct weight
|
||||
ch = input_shape[1]
|
||||
weight_np = np.zeros((ch, ch, 3, 3)).astype('float32')
|
||||
for i in range(ch):
|
||||
weight_np[i, i, 1, 1] = 1.0
|
||||
new_weight = helper.numpy_to_constant("input_norm_weight", weight_np)
|
||||
# Construct bias
|
||||
bias_np = np.array([0.5] * ch).astype('float32')
|
||||
new_bias = helper.numpy_to_constant("input_norm_bias", bias_np)
|
||||
# Construct Conv
|
||||
new_conv = onnx.helper.make_node(
|
||||
'Conv',
|
||||
['origin_input', "input_norm_weight", "input_norm_bias"],
|
||||
[g.input[0].name],
|
||||
name='input_norm',
|
||||
dilations=[1, 1],
|
||||
kernel_shape=[3, 3],
|
||||
pads=[1, 1, 1, 1],
|
||||
strides=[1, 1]
|
||||
)
|
||||
# Construct value_infos
|
||||
old_input_value = g.input.pop()
|
||||
weight_value = onnx.helper.make_tensor_value_info(
|
||||
'input_norm_weight',
|
||||
old_input_value.type.tensor_type.elem_type,
|
||||
[3, 3, 3, 3]
|
||||
)
|
||||
bias_value = onnx.helper.make_tensor_value_info(
|
||||
'input_norm_bias',
|
||||
old_input_value.type.tensor_type.elem_type,
|
||||
[3]
|
||||
)
|
||||
# Connect the graph
|
||||
new_input_value = onnx.helper.make_tensor_value_info(
|
||||
'origin_input',
|
||||
old_input_value.type.tensor_type.elem_type,
|
||||
input_shape
|
||||
)
|
||||
g.input.extend([new_input_value])
|
||||
g.node.extend([new_weight, new_bias, new_conv])
|
||||
g.value_info.extend([weight_value, bias_value, old_input_value])
|
||||
# topological sort
|
||||
other.topological_sort(g)
|
||||
|
||||
def add_rgb2yynn_node(m):
|
||||
"""Add a conv layer which can convert rgb to yynn input.
|
||||
"""
|
||||
g = m.graph
|
||||
if len(g.input) > 1:
|
||||
print("This model has multiple inputs. Cannot change to rgb input.")
|
||||
return
|
||||
input_shape = helper.get_shape_from_value_info(g.input[0])
|
||||
if len(input_shape) != 4:
|
||||
print("The input shape is not BCHW. Cannot normalize input.")
|
||||
return
|
||||
# Construct weight
|
||||
ch = input_shape[1]
|
||||
weight_np = np.zeros((3, 3, 4, 4)).astype('float32')
|
||||
weight_np[1, 1, :3, :2] = np.array([[[[0.299],
|
||||
[0.587],
|
||||
[0.114]]]])
|
||||
weight_np[1, 1, 3, 2:] = 1.
|
||||
weight_np = np.transpose(weight_np, (3, 2, 0, 1))
|
||||
new_weight = helper.numpy_to_constant("input_rgb2yynn_weight", weight_np)
|
||||
# Construct conv node
|
||||
new_conv = onnx.helper.make_node(
|
||||
'Conv',
|
||||
['new_input', "input_rgb2yynn_weight"],
|
||||
[g.input[0].name],
|
||||
name='input_rgba2yynn',
|
||||
dilations=[1, 1],
|
||||
kernel_shape=[3, 3],
|
||||
pads=[1, 1, 1, 1],
|
||||
strides=[1, 1]
|
||||
)
|
||||
# Construct value_infos
|
||||
old_input_value = g.input.pop()
|
||||
weight_value = onnx.helper.make_tensor_value_info(
|
||||
'input_rgb2yynn_weight',
|
||||
old_input_value.type.tensor_type.elem_type,
|
||||
[4, 4, 3, 3]
|
||||
)
|
||||
# Connect the graph
|
||||
new_input_value = onnx.helper.make_tensor_value_info(
|
||||
'new_input',
|
||||
old_input_value.type.tensor_type.elem_type,
|
||||
input_shape
|
||||
)
|
||||
g.input.extend([new_input_value])
|
||||
g.node.extend([new_weight, new_conv])
|
||||
g.value_info.extend([weight_value, old_input_value])
|
||||
# topological sort
|
||||
other.topological_sort(g)
|
||||
|
||||
def swap_MatMul_inputs(g, original_matmul_node):
|
||||
# Create Transpose nodes
|
||||
input_a_value = helper.find_value_by_name(g, original_matmul_node.input[0])
|
||||
input_a_shape = helper.get_shape_from_value_info(input_a_value)
|
||||
if len(input_a_shape) == 2:
|
||||
perm = [1, 0]
|
||||
else:
|
||||
perm = [0, 2, 1]
|
||||
new_input_b_node = onnx.helper.make_node(
|
||||
'Transpose',
|
||||
inputs = [input_a_value.name],
|
||||
outputs = [input_a_value.name + '_transposed'],
|
||||
name = f"{input_a_value.name}_transposed_for_{original_matmul_node.name}",
|
||||
perm = perm
|
||||
)
|
||||
input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1])
|
||||
input_b_shape = helper.get_shape_from_value_info(input_b_value)
|
||||
if len(input_b_shape) == 3:
|
||||
perm = [0, 2, 1]
|
||||
else:
|
||||
perm = [0, 1, 3, 2]
|
||||
new_input_a_node = onnx.helper.make_node(
|
||||
'Transpose',
|
||||
inputs = [input_b_value.name],
|
||||
outputs = [input_b_value.name + '_transposed'],
|
||||
name = f'{input_b_value.name}_transposed_for_{original_matmul_node.name}',
|
||||
perm = perm
|
||||
)
|
||||
# Create new MatMul node
|
||||
new_matmul_node = onnx.helper.make_node(
|
||||
'MatMul',
|
||||
inputs = [new_input_a_node.output[0], new_input_b_node.output[0]],
|
||||
outputs = [original_matmul_node.output[0] + '_transposed'],
|
||||
name = original_matmul_node.name + '_transposed'
|
||||
)
|
||||
# Create final Transpose node
|
||||
output_value = helper.find_value_by_name(g, original_matmul_node.output[0])
|
||||
output_shape = helper.get_shape_from_value_info(output_value)
|
||||
if len(output_shape) == 3:
|
||||
perm = [0, 2, 1]
|
||||
else:
|
||||
perm = [0, 1, 3, 2]
|
||||
new_final_transpose_node = onnx.helper.make_node(
|
||||
'Transpose',
|
||||
inputs = [new_matmul_node.output[0]],
|
||||
outputs = [original_matmul_node.output[0]],
|
||||
name = original_matmul_node.name + '_final_transpose',
|
||||
perm = perm
|
||||
)
|
||||
# Add new nodes
|
||||
g.node.extend([new_input_a_node, new_input_b_node, new_matmul_node, new_final_transpose_node])
|
||||
# Delete original nodes
|
||||
g.node.remove(original_matmul_node)
|
||||
|
||||
def split_MatMul_batch_then_concat(g, original_matmul_node):
|
||||
new_nodes = []
|
||||
final_concat_inputs = []
|
||||
# Get the batch count
|
||||
input_a_value = helper.find_value_by_name(g, original_matmul_node.input[0])
|
||||
input_a_shape = helper.get_shape_from_value_info(input_a_value)
|
||||
input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1])
|
||||
input_b_shape = helper.get_shape_from_value_info(input_b_value)
|
||||
if len(input_a_shape) == 3:
|
||||
batch_count = input_a_shape[0]
|
||||
else:
|
||||
batch_count = input_a_shape[1]
|
||||
for i in range(batch_count):
|
||||
# Create Split nodes for input A
|
||||
starts_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_starts", (1, ), [i])
|
||||
ends_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_ends", (1, ), [i+1])
|
||||
axes_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_axes", (1, ), [len(input_a_shape) - 3])
|
||||
new_sliced_a_node = onnx.helper.make_node(
|
||||
'Slice',
|
||||
inputs = [input_a_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]],
|
||||
outputs = [f"{input_a_value.name}_sliced_{i}"],
|
||||
name = f"{input_a_value.name}_sliced_{i}_for_{original_matmul_node.name}"
|
||||
)
|
||||
new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_a_node])
|
||||
# Create Split nodes for input B
|
||||
starts_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_starts", (1, ), [i])
|
||||
ends_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_ends", (1, ), [i+1])
|
||||
axes_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_axes", (1, ), [len(input_b_shape) - 3])
|
||||
new_sliced_b_node = onnx.helper.make_node(
|
||||
'Slice',
|
||||
inputs = [input_b_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]],
|
||||
outputs = [f"{input_b_value.name}_sliced_{i}"],
|
||||
name = f"{input_b_value.name}_sliced_{i}_for_{original_matmul_node.name}"
|
||||
)
|
||||
new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_b_node])
|
||||
# Create MatMul nodes
|
||||
new_matmul_node = onnx.helper.make_node(
|
||||
'MatMul',
|
||||
inputs = [new_sliced_a_node.output[0], new_sliced_b_node.output[0]],
|
||||
outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"],
|
||||
name = f"{original_matmul_node.name}_sliced_{i}"
|
||||
)
|
||||
new_nodes.append(new_matmul_node)
|
||||
final_concat_inputs.append(new_matmul_node.output[0])
|
||||
# Create Concat nodes
|
||||
output_value = helper.find_value_by_name(g, original_matmul_node.output[0])
|
||||
if output_value is None:
|
||||
output_value = helper.find_output_by_name(g, original_matmul_node.output[0])
|
||||
if output_value is None:
|
||||
helper.logger.error(f"Cannot find value_info for {original_matmul_node.output[0]}")
|
||||
output_shape = helper.get_shape_from_value_info(output_value)
|
||||
new_concat_node = onnx.helper.make_node(
|
||||
"Concat",
|
||||
inputs = final_concat_inputs,
|
||||
outputs = [original_matmul_node.output[0]],
|
||||
name = f"{original_matmul_node.name}_final_concat",
|
||||
axis = len(output_shape) - 3
|
||||
)
|
||||
new_nodes.append(new_concat_node)
|
||||
# Add new nodes
|
||||
g.node.extend(new_nodes)
|
||||
# Delete original nodes
|
||||
g.node.remove(original_matmul_node)
|
||||
|
||||
|
||||
def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
|
||||
new_nodes = []
|
||||
final_concat_inputs = []
|
||||
# Get the batch count
|
||||
input_b_node = helper.find_node_by_output_name(g, original_matmul_node.input[1])
|
||||
input_b_np = helper.constant_to_numpy(input_b_node)
|
||||
if len(input_b_np.shape) == 3:
|
||||
batch_count = input_b_np.shape[0]
|
||||
else:
|
||||
batch_count = input_b_np.shape[1]
|
||||
for i in range(batch_count):
|
||||
# Create new constant node
|
||||
if len(input_b_np.shape) == 3:
|
||||
new_np = input_b_np[i:i+1, ...]
|
||||
else:
|
||||
new_np = input_b_np[:, i:i+1, ...]
|
||||
new_weight = helper.numpy_to_constant(f"{input_b_node.name}_sliced_{i}", new_np)
|
||||
new_nodes.append(new_weight)
|
||||
# Create MatMul nodes
|
||||
new_matmul_node = onnx.helper.make_node(
|
||||
'MatMul',
|
||||
inputs = [original_matmul_node.input[0], new_weight.output[0]],
|
||||
outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"],
|
||||
name = f"{original_matmul_node.name}_sliced_{i}"
|
||||
)
|
||||
new_nodes.append(new_matmul_node)
|
||||
final_concat_inputs.append(new_matmul_node.output[0])
|
||||
# Create Concat nodes
|
||||
output_value = helper.find_value_by_name(g, original_matmul_node.output[0])
|
||||
output_shape = helper.get_shape_from_value_info(output_value)
|
||||
new_concat_node = onnx.helper.make_node(
|
||||
"Concat",
|
||||
inputs = final_concat_inputs,
|
||||
outputs = [original_matmul_node.output[0]],
|
||||
name = f"{original_matmul_node.name}_final_concat",
|
||||
axis = len(output_shape) - 3
|
||||
)
|
||||
new_nodes.append(new_concat_node)
|
||||
# Add new nodes
|
||||
g.node.extend(new_nodes)
|
||||
# Delete original value info
|
||||
input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1])
|
||||
if input_b_value is not None:
|
||||
g.value_info.remove(input_b_value)
|
||||
# Delete original nodes
|
||||
g.node.remove(original_matmul_node)
|
||||
g.node.remove(input_b_node)
|
||||
|
||||
|
||||
def special_MatMul_process(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'MatMul':
|
||||
continue
|
||||
input_a_name = node.input[0]
|
||||
input_a_value = helper.find_value_by_name(g, input_a_name)
|
||||
input_b_name = node.input[1]
|
||||
input_b_value = helper.find_value_by_name(g, input_b_name)
|
||||
if input_a_value is None or input_b_value is None:
|
||||
continue
|
||||
input_a_shape = helper.get_shape_from_value_info(input_a_value)
|
||||
input_b_shape = helper.get_shape_from_value_info(input_b_value)
|
||||
# Check shapes and choose the process
|
||||
# Normal case, Skip
|
||||
if len(input_b_shape) == 2:
|
||||
continue
|
||||
# Too many dimensions or too few dimensions. Not supported. Skip
|
||||
if len(input_a_shape) > 4 or len(input_b_shape) > 4:
|
||||
helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have too many dimensions.")
|
||||
continue
|
||||
if len(input_a_shape) < 2 or len(input_b_shape) < 2:
|
||||
helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have two few dimensions.")
|
||||
continue
|
||||
# For 4 dimension, check the first dimension (should be 1) and treated as 3 dimensions.
|
||||
extra_dim = None
|
||||
if len(input_a_shape) == 4:
|
||||
extra_dim = input_a_shape[0]
|
||||
input_a_shape = input_a_shape[1:]
|
||||
if len(input_b_shape) == 4:
|
||||
if input_b_shape[0] != extra_dim:
|
||||
helper.logger.warning(f"Cannot optimize MatMul {node.name}: input dimension batch sizes does not match ({extra_dim} vs {input_b_shape[0]}).")
|
||||
continue
|
||||
input_b_shape = input_b_shape[1:]
|
||||
# Check input B dimension
|
||||
# If B is 1 x W x V, it is the same as normal case.
|
||||
if input_b_shape[0] == 1:
|
||||
continue
|
||||
# If B is B x W x V, but B is a constant.
|
||||
input_b_node = helper.find_node_by_output_name(g, input_b_name)
|
||||
if input_b_node is not None and input_b_node.op_type == 'Constant':
|
||||
# Constant input
|
||||
helper.logger.debug(f"Optimizing MatMul node {node.name}: split constant input.")
|
||||
split_MatMul_Constant_input_then_concat(g, node)
|
||||
# If B is B x W x V and A is 1 x H x W, do the swap.
|
||||
elif len(input_a_shape) == 2 or (input_a_shape[0] == 1 and (extra_dim is None or extra_dim == 1)):
|
||||
helper.logger.debug(f"Optimizing MatMul node {node.name}: swap input.")
|
||||
swap_MatMul_inputs(g, node)
|
||||
# If B is B x W x V and A is B x H x W, do the split.
|
||||
elif input_b_shape[0] == input_a_shape[0]:
|
||||
helper.logger.debug(f"Optimizing MatMul node {node.name}: split input batch.")
|
||||
split_MatMul_batch_then_concat(g, node)
|
||||
# Other cases are not supported: If B is B x W x V but A is X x H x W.
|
||||
else:
|
||||
helper.logger.warning(f"Cannot optimize MatMul {node.name}: unknown reason. Might be shape mismatch.")
|
||||
continue
|
||||
other.topological_sort(g)
|
||||
@ -2,6 +2,7 @@
|
||||
# Original: tools/pytorch2onnx.py, modified by Kneron
|
||||
import argparse
|
||||
|
||||
import onnx
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import onnxruntime as rt
|
||||
@ -18,6 +19,10 @@ from mmseg.apis.inference import LoadImage
|
||||
from mmseg.datasets.pipelines import Compose
|
||||
from mmseg.models import build_segmentor
|
||||
|
||||
from optimizer_scripts.pytorch_exported_onnx_preprocess import (
|
||||
torch_exported_onnx_flow,
|
||||
)
|
||||
|
||||
torch.manual_seed(3)
|
||||
|
||||
|
||||
@ -117,10 +122,12 @@ def pytorch2onnx(model,
|
||||
dynamic_axes=None)
|
||||
print(f'Successfully exported ONNX model: {output_file}')
|
||||
model.forward = origin_forward
|
||||
# NOTE: optimizing onnx for kneron inference
|
||||
m = onnx.load(output_file)
|
||||
m = torch_exported_onnx_flow(m, disable_fuse_bn=False)
|
||||
onnx.save(m, output_file)
|
||||
|
||||
if verify:
|
||||
# check by onnx
|
||||
import onnx
|
||||
onnx_model = onnx.load(output_file)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user