feat: integrate kneron optimizer script to pytorch2onnx_kneron.py

This commit is contained in:
chingning.chen 2022-03-24 15:12:06 +08:00
parent b83243c5f4
commit 5b99260c9b
33 changed files with 8717 additions and 2 deletions

View File

@ -0,0 +1 @@
BasedOnStyle: Google

7
tools/optimizer_scripts/.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
__pycache__
.vscode
*.pyc
models.py
temp.py
.ssh/
docker/test_models/

View File

@ -0,0 +1,189 @@
# Converter Scripts
[![pipeline status](http://192.168.200.1:8088/jiyuan/converter_scripts/badges/master/pipeline.svg)](http://192.168.200.1:8088/jiyuan/converter_scripts/commits/master)
This project collects various optimization scripts and converter scritps for
Kneron toolchain. This collection does not include the Keras to ONNX converter
and the Caffe to ONNX converter. They are in seperate projects.
**The scripts not listed below are used as libraries and cannot be used
directly.**
## onnx2onnx.py
### 1.1. Description
General optimizations on ONNX model for Kneron toolchain. Though Kneron
toolchains are designed to take ONNX models as input, they have some
restrictions on the models (e.g. inferenced shapes for all value_info). Thus, we
have this tool to do some general optimization and conversion on ONNX models.
**Notice that this script should take an valid ONNX model as input.** It cannot
turn an invalid ONNX model into a valid one.
### 1.2. Basic Usage
```bash
python onnx2onnx.py input.onnx -o output.onnx
```
### 1.3. Optimizations Included
* Fusing BN into Conv.
* Fusing BN into Gemm.
* Fusing consecutive Gemm.
* Eliminating Identify layers and Dropout layers.
* Eliminating last shape changing nodes.
* Replacing initializers into Constant nodes.
* Replacing global AveragePool with GAP.
* Replacing Squeeze and Unsqueeze with Reshape.
* Replacing 1x1 depthwise with BN.
* Inferencing Upsample shapes.
* Transposing B in Gemm.
## pytorch2onnx.py
### 2.1. Description
Convert Pytorch models or Pytorch generated ONNX models into Kneron toolchain
compatible ONNX files. This script include most of the optimizations in
`onnx2onnx.py`. It also includes some optimizations for Pytorch model only.
### 2.2. Basic Usage
```bash
# Take Pytorch model name, input channel number, input height, input width
python pytorch2onnx.py input.pth output.onnx --input-size 3 224 224
# Or take Pytorch exported ONNX.
python pytorch2onnx.py input.onnx output.onnx
```
### 2.3. Optimizations Included
* Adding name to nodes.
* Unsqueeze nodes constant folding.
* Reshape nodes constant folding.
* Optimizations in `onnx2onnx.py`.
## editor.py
### 3.1. Description
This is an simple ONNX editor which achieves the following functions:
* Add nop BN or Conv nodes.
* Delete specific nodes or inputs.
* Cut the graph from certain node (Delete all the nodes following the node).
* Reshape inputs and outputs
### 3.2 Usage
```
usage: editor.py [-h] [-c CUT_NODE [CUT_NODE ...]]
[--cut-type CUT_TYPE [CUT_TYPE ...]]
[-d DELETE_NODE [DELETE_NODE ...]]
[--delete-input DELETE_INPUT [DELETE_INPUT ...]]
[-i INPUT_CHANGE [INPUT_CHANGE ...]]
[-o OUTPUT_CHANGE [OUTPUT_CHANGE ...]]
[--add-conv ADD_CONV [ADD_CONV ...]]
[--add-bn ADD_BN [ADD_BN ...]]
in_file out_file
Edit an ONNX model. The processing sequense is 'delete nodes/values' -> 'add
nodes' -> 'change shapes'. Cutting cannot be done with other operations
together
positional arguments:
in_file input ONNX FILE
out_file ouput ONNX FILE
optional arguments:
-h, --help show this help message and exit
-c CUT_NODE [CUT_NODE ...], --cut CUT_NODE [CUT_NODE ...]
remove nodes from the given nodes(inclusive)
--cut-type CUT_TYPE [CUT_TYPE ...]
remove nodes by type from the given nodes(inclusive)
-d DELETE_NODE [DELETE_NODE ...], --delete DELETE_NODE [DELETE_NODE ...]
delete nodes by names and only those nodes
--delete-input DELETE_INPUT [DELETE_INPUT ...]
delete inputs by names
-i INPUT_CHANGE [INPUT_CHANGE ...], --input INPUT_CHANGE [INPUT_CHANGE ...]
change input shape (e.g. -i 'input_0 1 3 224 224')
-o OUTPUT_CHANGE [OUTPUT_CHANGE ...], --output OUTPUT_CHANGE [OUTPUT_CHANGE ...]
change output shape (e.g. -o 'input_0 1 3 224 224')
--add-conv ADD_CONV [ADD_CONV ...]
add nop conv using specific input
--add-bn ADD_BN [ADD_BN ...]
add nop bn using specific input
```
### 3.3. Example
Here is an example of when and how to use the editor.py.
```bash
# In the `res` folder, there is a vdsr model from tensorflow.
# We need to convert this model firstly.
./tf2onnx.sh res/vdsr_41_20layer_1.pb res/tmp.onnx images:0 output:0
# This onnx file seems valid. But, it's channel last for the input and output.
# It is using Traspose to convert to channel first, affacting the performance.
# Thus, here we use the editor to delete these Transpose and reset the shapes.
python editor.py debug.onnx new.onnx -d Conv2D__6 Conv2D_19__84 -i 'images:0 1 3 41 41' -o 'output:0 1 3 41 41'
# Now, it has no Transpose and take channel first inputs directly.
```
## test_models_opt.py
### 4.1. Description
Compare all original and optimized onnx models under a specified directory.
Using different endings to locate original and optimized model paths. Apply
onnxruntime inference to the models, and compare the results from original
and optimized models. Calculate basic statistics and store to a csv file.
### 4.2. Usage
```bash
python DIR ending1 ending2 csv_out_file -p=Y/N
# csv_out_file is file path for the stats data.
# -p --plot is the plot option, if Y, stats plots will be generated.
```
### 4.3. Statistics
* max_rel_diff
* max_abs_diff
* mean_rel_diff
* mean_abs_diff
* std_rel_diff
* std_abs_diff
* acc_with_diff_precision
* percentile
### 4.4. Plots
* Max Relative Difference Histogram
* Max Absolute Difference Histogram
* Rel_diff Percentiles of Raw and Optimized Models
* Abs_diff Percentiles of Raw and Optimized Models
* Accuracies with Different Precisions
## tensorflow2onnx.py
### 5.1. Description
Convert and optimize tensorflow models. If input file is frozen tensorflow .pb model,
convert to onnx model and do the custmized optimization afterwards. If input model is already
onnx model, apply optimization and save optimized model.
### 5.2 Dependency
This scripts depends on the tensorflow-onnx project. Please [check and install it](https://github.com/onnx/tensorflow-onnx/tree/r1.5) before using this script. We currently support up to version 1.5.5. For other versions, you may need to try it our yourself.
### 5.3. Basic Usage
```bash
python tensorflow2onnx.py in_file out_file -t=True/False
# -t --test, is the option for test mode, if True, shape change after input will not be eliminated.
```
### 5.4. Model Save Paths
`in_file` is the input model path, `out_file` specifies output optimized model path.
If input file is `.pb` model, an unoptimized onnx model will be saved to the output directory as well.

View File

@ -0,0 +1,59 @@
import numpy as np
import onnx
import sys
from tools.other import topological_sort
from tools import helper
def fuse_bias_in_consecutive_1x1_conv(g):
for second in g.node:
# Find two conv
if second.op_type != 'Conv':
continue
first = helper.find_node_by_output_name(g, second.input[0])
if first is None or first.op_type != 'Conv':
continue
# Check if the first one has only one folloing node
if len(helper.find_following_nodes_by_input_value_name(g, first.output[0])) != 1:
continue
# If first node has no bias, continue
if len(first.input) == 2:
continue
# Check their kernel size
first_kernel_shape = helper.get_list_attribute_by_name(first, 'kernel_shape', 'int')
second_kernel_shape = helper.get_list_attribute_by_name(second, 'kernel_shape', 'int')
prod = first_kernel_shape[0] * first_kernel_shape[1] * second_kernel_shape[0] * second_kernel_shape[1]
if prod != 1:
continue
print('Found: ', first.name, ' ', second.name)
# Get bias of the nodes
first_bias_node = helper.find_node_by_output_name(g, first.input[2])
second_weight_node = helper.find_node_by_output_name(g, second.input[1])
second_bias_node = helper.find_node_by_output_name(g, second.input[2])
first_bias = helper.constant_to_numpy(first_bias_node)
second_weight = helper.constant_to_numpy(second_weight_node)
second_bias = helper.constant_to_numpy(second_bias_node)
# Calculate the weight for second node
first_bias = np.reshape(first_bias, (1, first_bias.size))
second_weight = np.reshape(second_weight, (second_weight.shape[0], second_weight.shape[1]))
second_weight = np.transpose(second_weight)
new_second_bias = second_bias + np.matmul(first_bias, second_weight)
new_second_bias = np.reshape(new_second_bias, (new_second_bias.size,))
# Generate new weight
new_first_bias = np.reshape(first_bias, (first_bias.size, ))
for i in range(new_first_bias.shape[0]):
new_first_bias[i] = 0.0
new_first_bias_node = helper.numpy_to_constant(first_bias_node.output[0], new_first_bias)
new_second_bias_node = helper.numpy_to_constant(second_bias_node.output[0], new_second_bias)
# Delete old weight and add new weights
g.node.remove(first_bias_node)
g.node.remove(second_bias_node)
g.node.extend([new_first_bias_node, new_second_bias_node])
topological_sort(g)
if __name__ == "__main__":
if len(sys.argv) != 3:
exit(1)
m = onnx.load(sys.argv[1])
fuse_bias_in_consecutive_1x1_conv(m.graph)
onnx.save(m, sys.argv[2])

View File

@ -0,0 +1,24 @@
FROM continuumio/miniconda3:latest
LABEL maintainer="jiyuan@kneron.us"
# Install python packages
RUN conda update -y conda && \
conda install -y python=3.6 && \
conda install -y -c intel caffe && \
conda install -y -c pytorch pytorch=1.3.1 torchvision=0.4.2 cpuonly && \
conda install -y -c conda-forge tensorflow=1.5.1 keras=2.2.4 && \
pip install onnx==1.4.1 onnxruntime==1.1.0 tf2onnx==1.5.4 && \
ln -s /opt/conda/lib/libgflags.so.2.2.2 /opt/conda/lib/libgflags.so.2
# Install git lfs packages
RUN apt-get update && apt-get install -y curl apt-utils && \
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \
apt-get install -y git-lfs
RUN conda clean -a -y && rm -rf /var/lib/apt/lists/*
# copy the test data
COPY ./test_models /test_models
# Clean the environment and finalize the process
WORKDIR /root

View File

@ -0,0 +1,118 @@
import onnx
import onnx.utils
try:
from onnx import optimizer
except ImportError:
import onnxoptimizer as optimizer
import argparse
import tools.modhelper as helper
import tools.other as other
import tools.replacing as replacing
# Main process
# Argument parser
parser = argparse.ArgumentParser(description="Edit an ONNX model.\nThe processing sequense is 'delete nodes/values' -> 'add nodes' -> 'change shapes'.\nCutting cannot be done with other operations together")
parser.add_argument('in_file', type=str, help='input ONNX FILE')
parser.add_argument('out_file', type=str, help="ouput ONNX FILE")
parser.add_argument('-c', '--cut', dest='cut_node', type=str, nargs='+', help="remove nodes from the given nodes(inclusive)")
parser.add_argument('--cut-type', dest='cut_type', type=str, nargs='+', help="remove nodes by type from the given nodes(inclusive)")
parser.add_argument('-d', '--delete', dest='delete_node', type=str, nargs='+', help="delete nodes by names and only those nodes")
parser.add_argument('--delete-input', dest='delete_input', type=str, nargs='+', help="delete inputs by names")
parser.add_argument('--delete-output', dest='delete_output', type=str, nargs='+', help="delete outputs by names")
parser.add_argument('-i', '--input', dest='input_change', type=str, nargs='+', help="change input shape (e.g. -i 'input_0 1 3 224 224')")
parser.add_argument('-o', '--output', dest='output_change', type=str, nargs='+', help="change output shape (e.g. -o 'input_0 1 3 224 224')")
parser.add_argument('--add-conv', dest='add_conv', type=str, nargs='+', help='add nop conv using specific input')
parser.add_argument('--add-bn', dest='add_bn', type=str, nargs='+', help='add nop bn using specific input')
parser.add_argument('--rename-output', dest='rename_output', type=str, nargs='+', help='Rename the specific output(e.g. --rename-output old_name new_name)')
parser.add_argument('--pixel-bias-value', dest='pixel_bias_value', type=str, nargs='+', help='(per channel) set pixel value bias bn layer at model front for normalization( e.g. --pixel_bias_value "[104.0, 117.0, 123.0]" )')
parser.add_argument('--pixel-scale-value', dest='pixel_scale_value', type=str, nargs='+', help='(per channel) set pixel value scale bn layer at model front for normalization( e.g. --pixel_scale_value "[0.0078125, 0.0078125, 0.0078125]" )')
args = parser.parse_args()
# Load model and polish
m = onnx.load(args.in_file)
m = other.polish_model(m)
g = m.graph
replacing.replace_initializer_with_Constant(g)
other.topological_sort(g)
# Remove nodes according to the given arguments.
if args.delete_node is not None:
helper.delete_nodes(g, args.delete_node)
if args.delete_input is not None:
helper.delete_input(g, args.delete_input)
if args.delete_output is not None:
helper.delete_output(g, args.delete_output)
# Add do-nothing Conv node
if args.add_conv is not None:
other.add_nop_conv_after(g, args.add_conv)
other.topological_sort(g)
# Add do-nothing BN node
if args.add_bn is not None:
other.add_nop_bn_after(g, args.add_bn)
other.topological_sort(g)
# Add bias scale BN node
if args.pixel_bias_value is not None or args.pixel_scale_value is not None:
if len(g.input) > 1:
raise ValueError(" '--pixel-bias-value' and '--pixel-scale-value' only support one input node model currently")
i_n = g.input[0]
pixel_bias_value = [0] * i_n.type.tensor_type.shape.dim[1].dim_value
pixel_scale_value = [1] * i_n.type.tensor_type.shape.dim[1].dim_value
if args.pixel_bias_value is not None and len(args.pixel_bias_value) == 1:
pixel_bias_value = [float(n) for n in args.pixel_bias_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')]
if args.pixel_scale_value is not None and len(args.pixel_scale_value) == 1:
pixel_scale_value = [float(n) for n in args.pixel_scale_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')]
if i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_bias_value) or i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_scale_value):
raise ValueError("--pixel-bias-value (" + str(pixel_bias_value) + ") and --pixel-scale-value (" + str(pixel_scale_value) + ") should be same as input dimension:" + str(i_n.type.tensor_type.shape.dim[1].dim_value) )
other.add_bias_scale_bn_after(g, i_n.name, pixel_bias_value, pixel_scale_value)
# Change input and output shapes as requested
if args.input_change is not None:
other.change_input_shape(g, args.input_change)
if args.output_change is not None:
other.change_output_shape(g, args.output_change)
# Cutting nodes according to the given arguments.
if args.cut_node is not None or args.cut_type is not None:
if args.cut_node is None:
other.remove_nodes(g, cut_types=args.cut_type)
elif args.cut_type is None:
other.remove_nodes(g, cut_nodes=args.cut_node)
else:
other.remove_nodes(g, cut_nodes=args.cut_node, cut_types=args.cut_type)
other.topological_sort(g)
# Rename nodes
if args.rename_output:
if len(args.rename_output) % 2 != 0:
print("Rename output should be paires of names.")
else:
for i in range(0, len(args.rename_output), 2):
other.rename_output_name(g, args.rename_output[i], args.rename_output[i + 1])
# Remove useless nodes
if args.delete_node or args.delete_input or args.input_change or args.output_change:
# If shape changed during the modification, redo shape inference.
while(len(g.value_info) > 0):
g.value_info.pop()
passes = ['extract_constant_to_initializer']
m = optimizer.optimize(m, passes)
g = m.graph
replacing.replace_initializer_with_Constant(g)
other.topological_sort(g)
# Polish and output
m = other.polish_model(m)
other.add_output_to_value_info(m.graph)
onnx.save(m, args.out_file)

View File

@ -0,0 +1,52 @@
import onnx
import sys
import json
from tools import special
if len(sys.argv) != 3:
print("python norm_on_scaled_onnx.py input.onnx input.json")
exit(1)
# Modify onnx
m = onnx.load(sys.argv[1])
special.add_0_5_to_normalized_input(m)
onnx.save(m, sys.argv[1][:-4] + 'norm.onnx')
# Change input node
origin_file = open(sys.argv[2], 'r')
origin_json = json.load(origin_file)
origin_json["input_node"]["output_datapath_radix"] = [8]
new_json_str = json.dumps(origin_json)
# Modify json
file = open(sys.argv[1][:-4] + 'norm.onnx' + '.json', 'w')
s = """{{
\"{0}\" :
{{
\"bias_bitwidth\" : 16,
\"{0}_bias\" : [15],
\"{0}_weight\" : [3,3,3],
\"conv_coarse_shift\" : [-4,-4,-4],
\"conv_fine_shift\" : [0,0,0],
\"conv_total_shift\" : [-4,-4,-4],
\"cpu_mode\" : false,
\"delta_input_bitwidth\" : [0],
\"delta_output_bitwidth\" : 8,
\"flag_radix_bias_eq_output\" : true,
\"input_scale\" : [[1.0,1.0,1.0]],
\"output_scale\" : [1.0, 1.0, 1.0],
\"psum_bitwidth\" : 16,
\"weight_bitwidth\" : 8,
\"input_datapath_bitwidth\" : [8],
\"input_datapath_radix\" : [8],
\"working_input_bitwidth\" : 8,
\"working_input_radix\" : [8],
\"working_output_bitwidth\" : 16,
\"working_output_radix\" : 15,
\"output_datapath_bitwidth\" : 8,
\"output_datapath_radix\" : 7
}},\n""".format('input_norm')
file.write(s + new_json_str[1:])
file.close()
origin_file.close()

View File

@ -0,0 +1,135 @@
# ref http://192.168.200.1:8088/jiyuan/converter_scripts.git
import sys
import onnx
import numpy as np
from onnx import numpy_helper
from tools import other, helper
"""
Change onnx model from version 1.3 to version 1.4.
Modify the BN node by removing the spatial attribute
Modify the Upsample node by removing the 'scales' attribute, and adding a constant node instead.
Model's ir_version and opset_import are updated.
"""
def remove_BN_spatial(g):
for node in g.node:
if node.op_type != 'BatchNormalization':
continue
for att in node.attribute:
if att.name == 'spatial':
node.attribute.remove(att)
def upsample_attribute_to_const(g):
for node in g.node:
if node.op_type != 'Upsample':
continue
scales_exist = False
for att in node.attribute:
if att.name == 'scales':
scales_exist = True
break
if not scales_exist:
continue
shape = [len(att.floats)]
node.attribute.remove(att)
new_node = helper.list_to_constant(node.name+'_input', shape, att.floats)
g.node.extend([new_node])
value_info = onnx.helper.make_tensor_value_info(node.name+'_input', onnx.TensorProto.FLOAT, shape)
node.input.extend([node.name+'_input'])
g.value_info.extend([value_info])
def relu6_to_clip(g):
for node in g.node:
if node.op_type != 'Relu':
continue
max_val = helper.get_var_attribute_by_name(node, 'max', 'float')
if max_val is None:
continue
new_node = onnx.helper.make_node(
"Clip",
node.input,
node.output,
name=node.name,
max=max_val,
min=0.0
)
g.node.remove(node)
g.node.extend([new_node])
def PRelu_weight_reshape(g):
# For PRelu with single dimension weight. Expand it to 1, x, 1, 1
for node in g.node:
if node.op_type != "PRelu":
continue
slope = helper.find_node_by_output_name(g, node.input[1])
if slope is not None:
# Constant node
if len(slope.attribute[0].t.dims) != 1:
continue
slope.attribute[0].t.dims.append(slope.attribute[0].t.dims[0])
slope.attribute[0].t.dims[0] = 1
slope.attribute[0].t.dims.append(1)
slope.attribute[0].t.dims.append(1)
else:
# Initializer
for i in g.initializer:
if i.name == node.input[1]:
slope = i
break
if len(slope.dims) != 1:
continue
slope.dims.append(slope.dims[0])
slope.dims[0] = 1
slope.dims.append(1)
slope.dims.append(1)
input_value = helper.find_input_by_name(g, node.input[1])
new_input = onnx.helper.make_tensor_value_info(
node.input[1],
input_value.type.tensor_type.elem_type,
(1, slope.dims[1], 1, 1))
g.input.remove(input_value)
g.input.append(new_input)
value_info = helper.find_value_by_name(g, node.input[1])
if value_info is not None:
g.value_info.remove(value_info)
def do_convert(m):
graph = m.graph
# Modify the nodes.
remove_BN_spatial(graph)
upsample_attribute_to_const(graph)
relu6_to_clip(graph)
PRelu_weight_reshape(graph)
other.topological_sort(graph)
# Change model properties.
m.ir_version = 4
m.opset_import[0].version = 9
return m
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage:{} file_in file_out".format(sys.argv[0]))
exit(1)
model = onnx.load(sys.argv[1])
graph = model.graph
# Modify the nodes.
remove_BN_spatial(graph)
upsample_attribute_to_const(graph)
relu6_to_clip(graph)
PRelu_weight_reshape(graph)
other.topological_sort(graph)
# Change model properties.
model.ir_version = 4
model.opset_import[0].version = 9
onnx.save(model, sys.argv[2])

View File

@ -0,0 +1,184 @@
# ref http://192.168.200.1:8088/jiyuan/converter_scripts.git
import sys
import onnx
import onnx.utils
import numpy as np
from onnx import numpy_helper
from tools import other, helper, replacing
"""
Change onnx model from version 1.4 to version 1.6.
"""
def replace_all_attribute_to_const_node_in_pad_node(g):
node_to_remove = []
node_to_extend = []
for node in g.node:
if node.op_type != 'Pad':
continue
pad_loc_node = None # must have
pad_mode = 'constant'
pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [0.0]) # need scalar
for att in node.attribute:
if att.name == 'mode':
pad_mode = helper.get_var_attribute_by_name(node, 'mode', 'string')
if att.name == 'pads':
pad_loc_node = helper.list_to_constant(node.name+'_pad_loc', [len(att.ints)], att.ints)
if att.name == 'value':
pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [att.f])
new_node = onnx.helper.make_node(
"Pad",
[node.input[0], pad_loc_node.name, pad_value_node.name],
[node.output[0]],
name=node.output[0],
mode=pad_mode,
)
node_to_remove.append(node)
node_to_extend.append(new_node)
node_to_extend.append(pad_loc_node)
node_to_extend.append(pad_value_node)
for node in node_to_remove:
g.node.remove(node)
for node in node_to_extend:
g.node.extend([node])
def upsampling_to_resize(g):
for node in g.node:
if node.op_type != 'Upsample':
continue
upsampling_mode = helper.get_var_attribute_by_name(node, 'mode', 'string')
scale_value_node = helper.find_node_by_output_name(g, node.input[1])
if scale_value_node.op_type != "Constant":
raise TypeError('seems there is a dynamic "scales" param in Upsampling node: ' + node.name + ' , you might need to do constant folding first')
roi_node = helper.list_to_constant(node.name+'_roi_value', [0], [])
new_node = onnx.helper.make_node(
"Resize",
[node.input[0], roi_node.name, scale_value_node.name],
[node.output[0]],
name=node.output[0],
mode=upsampling_mode,
coordinate_transformation_mode = 'asymmetric'
)
g.node.remove(node)
g.node.extend([new_node])
g.node.extend([roi_node])
def replace_all_attribute_to_const_node_in_slice_node(g):
for node in g.node:
if node.op_type != 'Slice':
continue
axes_const_node = None
ends_const_node = None
starts_const_node = None
steps_const_node = None
for att in node.attribute:
if att.name == 'axes':
axes_const_node = helper.list_to_constant(node.name+'_axes_value', [len(att.ints)], att.ints)
if att.name == 'ends':
ends_const_node = helper.list_to_constant(node.name+'_ends_value', [len(att.ints)], att.ints)
if att.name == 'starts':
starts_const_node = helper.list_to_constant(node.name+'_starts_value', [len(att.ints)], att.ints)
if att.name == 'steps':
steps_const_node = helper.list_to_constant(node.name+'_steps_value',[ len(att.ints)], att.ints)
## pop out from back
attr_len = len(node.attribute)
for i in range(attr_len):
node.attribute.remove(node.attribute[ attr_len -1 - i ])
## according the spec, we need to add node in specific order
if starts_const_node != None:
g.node.extend([starts_const_node])
node.input.extend([starts_const_node.name])
if ends_const_node != None:
g.node.extend([ends_const_node])
node.input.extend([ends_const_node.name])
if axes_const_node != None:
g.node.extend([axes_const_node])
node.input.extend([axes_const_node.name])
if steps_const_node != None:
g.node.extend([steps_const_node])
node.input.extend([steps_const_node.name])
def replace_min_max_attribute_to_const_node_in_clip_node(g):
for node in g.node:
if node.op_type != 'Clip':
continue
max_const_node = None
min_const_node = None
for att in node.attribute:
if att.name == 'max':
max_const_node = helper.list_to_constant(node.name+'_max_value', [], [att.f])
if att.name == 'min':
min_const_node = helper.list_to_constant(node.name+'_min_value', [], [att.f])
## pop out from back
node.attribute.remove(node.attribute[1])
node.attribute.remove(node.attribute[0])
## according the spec, we need to add node in specific order
g.node.extend([min_const_node])
g.node.extend([max_const_node])
node.input.extend([min_const_node.name])
node.input.extend([max_const_node.name])
def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto:
"""Update ir_version from 4 to 6 and update opset from 9 to 11.
Args:
model (onnx.ModelProto): input onnx model.
Returns:
onnx.ModelProto: updated onnx model.
"""
graph = model.graph
if model.opset_import[0].version == 11:
print("(Stop) the input model is already opset 11, no need to upgrade")
exit(1)
# deal with empty node name issue
other.add_name_to_node(graph)
# simplify the node param type from initializer to constant
replacing.replace_initializer_with_Constant(graph)
# Modify the nodes.
replace_min_max_attribute_to_const_node_in_clip_node(graph)
replace_all_attribute_to_const_node_in_slice_node(graph)
replace_all_attribute_to_const_node_in_pad_node(graph)
upsampling_to_resize(graph)
other.topological_sort(graph)
# Change model properties.
model.ir_version = 6
model.opset_import[0].version = 11
model = other.polish_model(model)
return model
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage:{} file_in file_out".format(sys.argv[0]))
exit(1)
model = onnx.load(sys.argv[1])
model = onnx1_4to1_6(model)
onnx.save(model, sys.argv[2])

View File

@ -0,0 +1,136 @@
import onnx
import onnx.utils
try:
from onnx import optimizer
except ImportError:
import onnxoptimizer as optimizer
import sys
import argparse
import logging
from tools import eliminating
from tools import fusing
from tools import replacing
from tools import other
from tools import special
from tools import combo
from tools.helper import logger
# from tools import temp
def onnx2onnx_flow(m: onnx.ModelProto,
disable_fuse_bn=False,
bn_on_skip=False,
bn_before_add=False,
bgr=False,
norm=False,
rgba2yynn=False,
eliminate_tail=False,
opt_matmul=False,
duplicate_shared_weights=True) -> onnx.ModelProto:
"""Optimize the onnx.
Args:
m (ModelProto): the input onnx ModelProto
disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False.
bn_on_skip (bool, optional): add BN operator on skip branches. Defaults to False.
bn_before_add (bool, optional): add BN before Add node on every branches. Defaults to False.
bgr (bool, optional): add an Conv layer to convert rgb input to bgr. Defaults to False.
norm (bool, optional): add an Conv layer to add 0.5 tp the input. Defaults to False.
rgba2yynn (bool, optional): add an Conv layer to convert rgb input to yynn . Defaults to False.
eliminate_tail (bool, optional): remove the trailing NPU unsupported nodes. Defaults to False.
opt_matmul(bool, optional): optimize the MatMul layers according to the NPU limit. Defaults to False.
duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True.
Returns:
ModelProto: the optimized onnx model object.
"""
# temp.weight_broadcast(m.graph)
m = combo.preprocess(m, disable_fuse_bn, duplicate_shared_weights)
# temp.fuse_bias_in_consecutive_1x1_conv(m.graph)
# Add BN on skip branch
if bn_on_skip:
other.add_bn_on_skip_branch(m.graph)
elif bn_before_add:
other.add_bn_before_add(m.graph)
other.add_bn_before_activation(m.graph)
# My optimization
m = combo.common_optimization(m)
# Special options
if bgr:
special.change_input_from_bgr_to_rgb(m)
if norm:
special.add_0_5_to_normalized_input(m)
if rgba2yynn:
special.add_rgb2yynn_node(m)
# Remove useless last node
if eliminate_tail:
eliminating.remove_useless_last_nodes(m.graph)
# Postprocessing
m = combo.postprocess(m)
# Put matmul after postprocess to avoid transpose moving downwards
if opt_matmul:
special.special_MatMul_process(m.graph)
m = other.polish_model(m)
return m
# Main process
if __name__ == "__main__":
# Argument parser
parser = argparse.ArgumentParser(description="Optimize an ONNX model for Kneron compiler")
parser.add_argument('in_file', help='input ONNX FILE')
parser.add_argument('-o', '--output', dest='out_file', type=str, help="ouput ONNX FILE")
parser.add_argument('--log', default='i', type=str, help="set log level")
parser.add_argument('--bgr', action='store_true', default=False, help="set if the model is trained in BGR mode")
parser.add_argument('--norm', action='store_true', default=False, help="set if you have the input -0.5~0.5")
parser.add_argument('--rgba2yynn', action='store_true', default=False, help="set if the model has yynn input but you want to take rgba images")
parser.add_argument('--add-bn-on-skip', dest='bn_on_skip', action='store_true', default=False,
help="set if you only want to add BN on skip branches")
parser.add_argument('--add-bn', dest='bn_before_add', action='store_true', default=False,
help="set if you want to add BN before Add")
parser.add_argument('-t', '--eliminate-tail-unsupported', dest='eliminate_tail', action='store_true', default=False,
help='whether remove the last unsupported node for hardware')
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
parser.add_argument('--opt-matmul', dest='opt_matmul', action='store_true', default=False,
help="set if you want to optimize the MatMul operations for the kneron hardware.")
parser.add_argument('--no-duplicate-shared-weights', dest='no_duplicate_shared_weights', action='store_true', default=False,
help='do not duplicate shared weights. Defaults to False.')
args = parser.parse_args()
if args.out_file is None:
outfile = args.in_file[:-5] + "_polished.onnx"
else:
outfile = args.out_file
if args.log == 'w':
logging.basicConfig(level=logging.WARN)
elif args.log == 'd':
logging.basicConfig(level=logging.DEBUG)
elif args.log == 'e':
logging.basicConfig(level=logging.ERROR)
else:
logging.basicConfig(level=logging.INFO)
# onnx Polish model includes:
# -- nop
# -- eliminate_identity
# -- eliminate_nop_transpose
# -- eliminate_nop_pad
# -- eliminate_unused_initializer
# -- fuse_consecutive_squeezes
# -- fuse_consecutive_transposes
# -- fuse_add_bias_into_conv
# -- fuse_transpose_into_gemm
# Basic model organize
m = onnx.load(args.in_file)
m = onnx2onnx_flow(m, args.disable_fuse_bn, args.bn_on_skip, args.bn_before_add, args.bgr, args.norm, args.rgba2yynn, args.eliminate_tail, args.opt_matmul, not args.no_duplicate_shared_weights)
onnx.save(m, outfile)

View File

@ -0,0 +1,134 @@
import onnxruntime
import onnx
import argparse
import numpy as np
from tools import helper
onnx2np_dtype = {0: 'float', 1: 'float32', 2: 'uint8', 3: 'int8', 4: 'uint16', 5: 'int16', 6: 'int32', 7: 'int64', 8: 'str', 9: 'bool', 10: 'float16', 11: 'double', 12: 'uint32', 13: 'uint64', 14: 'complex64', 15: 'complex128', 16: 'float'}
def onnx_model_results(path_a, path_b, total_times=10):
""" using onnxruntime to inference two onnx models' ouputs
:onnx model paths: two model paths
:total_times: inference times, default to be 10
:returns: inference results of two models
"""
# load model a and model b to runtime
session_a = onnxruntime.InferenceSession(path_a, None)
session_b = onnxruntime.InferenceSession(path_b, None)
outputs_a = session_a.get_outputs()
outputs_b = session_b.get_outputs()
# check outputs
assert len(outputs_a) == len(outputs_b), 'Two models have different output numbers.'
for i in range(len(outputs_a)):
out_shape_a, out_shape_b = outputs_a[i].shape, outputs_b[i].shape
out_shape_a = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_a))
out_shape_b = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_b))
assert out_shape_a == out_shape_b, 'Output {} has unmatched shapes'.format(i)
# load onnx graph_a and graph_b, to find the initializer and inputs
# then compare to remove the items in the inputs which will be initialized
model_a, model_b = onnx.load(path_a), onnx.load(path_b)
graph_a, graph_b = model_a.graph, model_b.graph
inputs_a, inputs_b = graph_a.input, graph_b.input
init_a, init_b = graph_a.initializer, graph_b.initializer
# remove initializer from raw inputs
input_names_a, input_names_b = set([ele.name for ele in inputs_a]), set([ele.name for ele in inputs_b])
init_names_a, init_names_b = set([ele.name for ele in init_a]), set([ele.name for ele in init_b])
real_inputs_names_a, real_inputs_names_b = input_names_a - init_names_a, input_names_b - init_names_b
# prepare and figure out matching of real inputs a and real inputs b
# try to keep original orders of each inputs
real_inputs_a, real_inputs_b = [], []
for item in inputs_a:
if item.name in real_inputs_names_a:
real_inputs_a.append(item)
for item in inputs_b:
if item.name in real_inputs_names_b:
real_inputs_b.append(item)
# suppose there's only one real single input tensor for each model
# find the real single inputs for model_a and model_b
real_single_input_a = None
real_single_input_b = None
size_a, size_b = 0, 0
shape_a, shape_b = [], []
for item_a in real_inputs_a:
size, shape = helper.find_size_shape_from_value(item_a)
if size:
assert real_single_input_a is None, 'Multiple inputs of first model, single input expected.'
real_single_input_a = item_a
size_a, shape_a = size, shape
for item_b in real_inputs_b:
size, shape = helper.find_size_shape_from_value(item_b)
if size:
assert real_single_input_b is None, 'Multiple inputs of second model, single input expected.'
real_single_input_b = item_b
size_b, shape_b = size, shape
assert size_a == size_b, 'Sizes of two models do not match.'
# construct inputs tensors
input_data_type_a = real_single_input_a.type.tensor_type.elem_type
input_data_type_b = real_single_input_b.type.tensor_type.elem_type
input_data_type_a = onnx2np_dtype[input_data_type_a]
input_data_type_b = onnx2np_dtype[input_data_type_b]
# run inference
times = 0
results_a = [[] for i in range(len(outputs_a))]
results_b = [[] for i in range(len(outputs_b))]
while times < total_times:
# initialize inputs by random data, default to be uniform
data = np.random.random(size_a)
input_a = np.reshape(data, shape_a).astype(input_data_type_a)
input_b = np.reshape(data, shape_b).astype(input_data_type_b)
input_dict_a = {}
input_dict_b = {}
for item_a in real_inputs_a:
item_type_a = onnx2np_dtype[item_a.type.tensor_type.elem_type]
input_dict_a[item_a.name] = np.array([]).astype(item_type_a) \
if item_a.name != real_single_input_a.name else input_a
for item_b in real_inputs_b:
item_type_b = onnx2np_dtype[item_b.type.tensor_type.elem_type]
input_dict_b[item_b.name] = np.array([]).astype(item_type_b) \
if item_b.name != real_single_input_b.name else input_b
ra = session_a.run([], input_dict_a)
rb = session_b.run([], input_dict_b)
for i in range(len(outputs_a)):
results_a[i].append(ra[i])
results_b[i].append(rb[i])
times += 1
return results_a, results_b
if __name__ == '__main__':
# Argument parser.
parser = argparse.ArgumentParser(description="Compare two ONNX models to check if they have the same output.")
parser.add_argument('in_file_a', help='input ONNX file a')
parser.add_argument('in_file_b', help='input ONNX file b')
args = parser.parse_args()
results_a, results_b = onnx_model_results(args.in_file_a, args.in_file_b, total_times=10)
ra_flat = helper.flatten_with_depth(results_a, 0)
rb_flat = helper.flatten_with_depth(results_b, 0)
shape_a = [item[1] for item in ra_flat]
shape_b = [item[1] for item in rb_flat]
assert shape_a == shape_b, 'two results data shape doesn\'t match'
ra_raw = [item[0] for item in ra_flat]
rb_raw = [item[0] for item in rb_flat]
try:
np.testing.assert_almost_equal(ra_raw, rb_raw, 4)
print('Two models have the same behaviour.')
except Exception as mismatch:
print(mismatch)
exit(1)

View File

@ -0,0 +1,221 @@
import onnx
import argparse
import glob
import csv
import numpy as np
import matplotlib.pyplot as plt
from tools import helper
import onnx_vs_onnx as onnx_tester
def compare_results(results_a, results_b):
""" compare onnx model inference results
calculate basic statistical values
results: results from inference multiple times
returns: list of basic statistical values
"""
# input results data can be of nonuniform shape
# get flatten data to compare
ra_flat = helper.flatten_with_depth(results_a, 0)
rb_flat = helper.flatten_with_depth(results_b, 0)
shape_a = [item[1] for item in ra_flat]
shape_b = [item[1] for item in rb_flat]
assert shape_a == shape_b, 'two results data shape doesn\'t match'
ra_raw = [item[0] for item in ra_flat]
rb_raw = [item[0] for item in rb_flat]
# the statistical values
max_rel_diff = 0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } )
max_abs_diff = 0 # defined to be max( { abs(ra-rb) } )
mean_rel_diff = 0
mean_abs_diff = 0
std_rel_diff = 0
std_abs_diff = 0
acc_with_diff_precision = []
rel_diff = []
abs_diff_percentiles = [] # rel_diff percentiles
rel_diff_percentiles = [] # abs_diff precentiles
raw_diff = [ra_raw[i]-rb_raw[i] for i in range(len(ra_raw))]
abs_diff = [abs(num) for num in raw_diff]
for i in range(len(ra_raw)):
divider = max([abs(ra_raw[i]), abs(rb_raw[i])])
val = abs_diff[i]/divider if divider != 0 else 0
rel_diff.append(val)
max_rel_diff = max(rel_diff)
max_abs_diff = max(abs_diff)
mean_rel_diff = np.average(rel_diff)
mean_abs_diff = np.average(abs_diff)
std_rel_diff = np.std(rel_diff)
std_abs_diff = np.std(abs_diff)
# calculate accuracy with different precison
for digit in range(8):
correct = 0
for i in range(len(ra_raw)):
if format(ra_raw[i], '.'+str(digit)+'f')\
== format(rb_raw[i], '.'+str(digit)+'f'):
correct += 1
acc_with_diff_precision.append([digit, float(format(correct/len(ra_raw), '.3f'))])
# analyze rel_diff distribution
rel_diff.sort()
abs_diff.sort()
for i in range(20):
rel_diff_percentiles.append(['{}%'.format(i*5), rel_diff[int((i/20)*len(rel_diff))]])
abs_diff_percentiles.append(['{}%'.format(i*5), abs_diff[int((i/20)*len(abs_diff))]])
results = [
['max_rel_diff', max_rel_diff],
['max_abs_diff', max_abs_diff],
['mean_rel_diff', mean_rel_diff],
['mean_abs_diff', mean_abs_diff],
['std_rel_diff', std_rel_diff],
['std_abs_diff', std_abs_diff],
['acc_with_diff_precision', acc_with_diff_precision],
['rel_diff_percentiles', rel_diff_percentiles],
['abs_diff_percentiles', abs_diff_percentiles]
]
return results
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='test model optimization results')
parser.add_argument('dir', type=str, help='the directory that stores onnx models')
parser.add_argument('ending1', type=str, help='model file name ending(eg, .onnx)')
parser.add_argument('ending2', type=str, help='opt model file name ending(eg. _opt.onnx)')
parser.add_argument('out_file', type=str, help='output csv file name')
parser.add_argument('-p', '--plot', default='N', help='get plots (Y/N)')
parser.add_argument('-i', '--iter_times', default=10, type=int, help='inference times')
args = parser.parse_args()
old_models_paths = glob.glob(args.dir+'*'+args.ending1)
new_models_paths = glob.glob(args.dir+'*'+args.ending2)
stats_table = [[
'Model',
'max_rel_diff',
'max_abs_diff',
'mean_rel_diff',
'mean_abs_diff',
'std_rel_diff',
'std_abs_diff',
'acc_with_diff_precision',
'rel_diff_percentiles',
'abs_diff_percentiles'
]]
for new_model_path in new_models_paths:
old_model_path = new_model_path[:-len(args.ending2)] + args.ending1
if old_model_path not in old_models_paths:
continue
# run inference
results_a, results_b = onnx_tester.onnx_model_results(old_model_path, new_model_path, total_times=args.iter_times)
# compare inference results
comparision = compare_results(results_a, results_b)
new_line = [old_model_path.split('/')[-1]]
for item in comparision:
new_line.append(item[1])
stats_table.append(new_line)
# try to read existing file
old_stats_table = []
try:
old_file = open(args.out_file, 'r')
reader = csv.reader(old_file)
old_header = reader.__next__()
for row in reader:
old_stats_table.append(row)
old_file.close()
except:
pass
# compare and merge possible old stat data file with new stat data file
header = stats_table[0]
stats_table = stats_table[1:]
new_model_names = set([item[0] for item in stats_table])
for row in old_stats_table:
if row[0] not in new_model_names:
stats_table.append(row)
stats_table.insert(0, header)
# write a new stat data file, overwrite old file
new_file = open(args.out_file, 'w', newline='')
writer = csv.writer(new_file)
for row in stats_table:
writer.writerow(row)
new_file.close()
# make some plots
if args.plot == 'Y':
if len(stats_table) < 2:
exit(0)
sample_table = stats_table[1:] if len(stats_table) < 6 else stats_table[1:6]
max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]]
plt.hist(max_rel_diffs, bins=15)
plt.title('Max Relavtive Difference Histogram')
plt.xlabel('Max Relative Difference')
plt.ylabel('Counts')
plt.savefig('max_rel_diff_hist.png')
plt.close()
max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]]
plt.hist(max_abs_diffs, bins=15)
plt.title('Max Absolute Difference Histogram')
plt.xlabel('Max Absolute Difference')
plt.ylabel('Counts')
plt.savefig('max_abs_diff_hist.png')
plt.close()
for line in sample_table:
model_name = line[0]
percentiles = line[-2]
x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))]
y = [ele[1] for ele in percentiles]
plt.plot(x, y, label=model_name)
plt.title('Rel_diff Percentiles of Raw and Optimized Models')
plt.xlabel('percentage')
plt.ylabel('relative difference')
plt.legend()
plt.savefig('rel_diff_percentiles.png')
plt.close()
for line in sample_table:
model_name = line[0]
percentiles = line[-1]
x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))]
y = [ele[1] for ele in percentiles]
plt.plot(x, y, label=model_name)
plt.title('Abs_diff Percentiles of Raw and Optimized Models')
plt.xlabel('percentage')
plt.ylabel('absolute difference')
plt.legend()
plt.savefig('abs_diff_percentiles.png')
plt.close()
for line in sample_table:
model_name = line[0]
accuracies = line[-3]
x = [acc[0] for acc in accuracies]
y = [acc[1] for acc in accuracies]
plt.plot(x, y, label=model_name)
plt.title('Accuracies with Different Precisions')
plt.xlabel('Decimals')
plt.ylabel('Precision')
plt.legend()
plt.savefig('precisions.png')
plt.close()

View File

@ -0,0 +1,81 @@
import onnx
import onnx.utils
try:
from onnx import optimizer
except ImportError:
import onnxoptimizer as optimizer
import sys
import numpy as np
import struct
import logging
import argparse
from tools import eliminating
from tools import fusing
from tools import replacing
from tools import other
from tools import combo
from tools import special
from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow
# Debug use
# logging.basicConfig(level=logging.DEBUG)
######################################
# Generate a prototype onnx #
######################################
parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler")
parser.add_argument('in_file', help='input ONNX or PTH FILE')
parser.add_argument('out_file', help="ouput ONNX FILE")
parser.add_argument('--input-size', dest='input_size', nargs=3,
help='if you using pth, please use this argument to set up the input size of the model. It should be in \'CH H W\' format, e.g. \'--input-size 3 256 512\'.')
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
args = parser.parse_args()
if len(args.in_file) <= 4:
# When the filename is too short.
logging.error("Invalid input file: {}".format(args.in_file))
exit(1)
elif args.in_file[-4:] == '.pth':
# Pytorch pth case
logging.warning("Converting from pth to onnx is not recommended.")
onnx_in = args.out_file
# Import pytorch libraries
from torch.autograd import Variable
import torch
import torch.onnx
# import torchvision
# Standard ImageNet input - 3 channels, 224x224.
# Values don't matter as we care about network structure.
# But they can also be real inputs.
if args.input_size is None:
logging.error("\'--input-size\' is required for the pth input file.")
exit(1)
dummy_input = Variable(torch.randn(1, int(args.input_size[0]), int(args.input_size[1]), int(args.input_size[2])))
# Obtain your model, it can be also constructed in your script explicitly.
model = torch.load(sys.argv[1], map_location='cpu')
# model = torchvision.models.resnet34(pretrained=True)
# Invoke export.
# torch.save(model, "resnet34.pth")
torch.onnx.export(model, dummy_input, args.out_file, opset_version=11)
elif args.in_file[-4:] == 'onnx':
onnx_in = args.in_file
else:
# When the file is neither an onnx or a pytorch pth.
logging.error("Invalid input file: {}".format(args.in_file))
exit(1)
onnx_out = args.out_file
######################################
# Optimize onnx #
######################################
m = onnx.load(onnx_in)
m = torch_exported_onnx_flow(m, args.disable_fuse_bn)
onnx.save(m, onnx_out)

View File

@ -0,0 +1,80 @@
import onnx
import onnx.utils
try:
from onnx import optimizer
except ImportError:
import onnxoptimizer as optimizer
import sys
import numpy as np
import struct
import logging
import argparse
from .tools import eliminating
from .tools import fusing
from .tools import replacing
from .tools import other
from .tools import combo
from .tools import special
# Define general pytorch exported onnx optimize process
def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx.ModelProto:
"""Optimize the Pytorch exported onnx.
Args:
m (ModelProto): the input onnx model
disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False.
Returns:
ModelProto: the optimized onnx model
"""
m = combo.preprocess(m, disable_fuse_bn)
m = combo.pytorch_constant_folding(m)
m = combo.common_optimization(m)
m = combo.postprocess(m)
return m
# Main Process
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler")
parser.add_argument('in_file', help='input ONNX')
parser.add_argument('out_file', help="ouput ONNX FILE")
parser.add_argument('--log', default='i', type=str, help="set log level")
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
args = parser.parse_args()
if args.log == 'w':
logging.basicConfig(level=logging.WARN)
elif args.log == 'd':
logging.basicConfig(level=logging.DEBUG)
elif args.log == 'e':
logging.basicConfig(level=logging.ERROR)
else:
logging.basicConfig(level=logging.INFO)
if len(args.in_file) <= 4:
# When the filename is too short.
logging.error("Invalid input file: {}".format(args.in_file))
exit(1)
elif args.in_file[-4:] == 'onnx':
onnx_in = args.in_file
else:
# When the file is not an onnx file.
logging.error("Invalid input file: {}".format(args.in_file))
exit(1)
onnx_out = args.out_file
######################################
# Optimize onnx #
######################################
m = onnx.load(onnx_in)
m = torch_exported_onnx_flow(m, args.disable_fuse_bn)
onnx.save(m, onnx_out)

View File

@ -0,0 +1,27 @@
{
"LAYERNAME" :
{
"bias_bitwidth" : 16,
"LAYERNAME_bias" : [15],
"LAYERNAME_weight" : [3,3,3],
"conv_coarse_shift" : [-4,-4,-4],
"conv_fine_shift" : [0,0,0],
"conv_total_shift" : [-4,-4,-4],
"cpu_mode" : false,
"delta_input_bitwidth" : [0],
"delta_output_bitwidth" : 8,
"flag_radix_bias_eq_output" : true,
"input_scale" : [[1.0,1.0,1.0]],
"output_scale" : [1.0, 1.0, 1.0],
"psum_bitwidth" : 16,
"weight_bitwidth" : 8,
"input_datapath_bitwidth" : [8],
"input_datapath_radix" : [7],
"working_input_bitwidth" : 8,
"working_input_radix" : [7],
"working_output_bitwidth" : 16,
"working_output_radix" : 15,
"output_datapath_bitwidth" : 8,
"output_datapath_radix" : 7
}
}

View File

@ -0,0 +1,9 @@
#!/bin/bash
python onnx_tester.py /test_models/mobilenet_v2_224.onnx /test_models/mobilenet_v2_224.cut.onnx
if [ $? -eq 0 ]; then
echo "Those two model results should be different!"
exit 1
fi
exit 0

Binary file not shown.

View File

@ -0,0 +1,147 @@
import tensorflow as tf
import tf2onnx
import argparse
import logging
import sys
import onnx
import onnx.utils
from tensorflow.python.platform import gfile
from tools import combo, eliminating, replacing, other
def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
"""Convert frozen graph pb file into onnx
Args:
pb_path (str): input pb file path
test_mode (bool, optional): test mode. Defaults to False.
Raises:
Exception: invalid input file
Returns:
onnx.ModelProto: converted onnx
"""
TF2ONNX_VERSION = int(tf2onnx.version.version.replace('.', ''))
if 160 <= TF2ONNX_VERSION:
from tf2onnx import tf_loader
else:
from tf2onnx import loader as tf_loader
if pb_path[-3:] == '.pb':
model_name = pb_path.split('/')[-1][:-3]
# always reset tensorflow session at begin
tf.reset_default_graph()
with tf.Session() as sess:
with gfile.FastGFile(pb_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
if 160 <= int(tf2onnx.version.version.replace('.', '')):
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions = tf2onnx.tf_utils.tflist_to_onnx(
sess.graph,
{})
else:
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tf2onnx.tfonnx.tflist_to_onnx(
sess.graph.get_operations(),
{})
for n in onnx_nodes:
if len(n.output) == 0:
onnx_nodes.remove(n)
# find inputs and outputs of graph
nodes_inputs = set()
nodes_outputs = set()
for n in onnx_nodes:
if n.op_type == 'Placeholder':
continue
for input in n.input:
nodes_inputs.add(input)
for output in n.output:
nodes_outputs.add(output)
graph_input_names = set()
for input_name in nodes_inputs:
if input_name not in nodes_outputs:
graph_input_names.add(input_name)
graph_output_names = set()
for n in onnx_nodes:
if n.input and n.input[0] not in nodes_outputs:
continue
if len(n.output) == 0:
n.output.append(n.name + ':0')
graph_output_names.add(n.output[0])
else:
output_name = n.output[0]
if (output_name not in nodes_inputs) and (0 < len(n.input)):
graph_output_names.add(output_name)
logging.info('Model Inputs: %s', str(list(graph_input_names)))
logging.info('Model Outputs: %s', str(list(graph_output_names)))
graph_def, inputs, outputs = tf_loader.from_graphdef(model_path=pb_path,
input_names=list(graph_input_names),
output_names=list(graph_output_names))
with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(graph_def, name='')
if 160 <= TF2ONNX_VERSION:
with tf_loader.tf_session(graph=tf_graph):
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
input_names=inputs,
output_names=outputs,
opset=11)
else:
with tf.Session(graph=tf_graph):
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
input_names=inputs,
output_names=outputs,
opset=11)
# Optimize with tf2onnx.optimizer
onnx_graph = tf2onnx.optimizer.optimize_graph(onnx_graph)
model_proto = onnx_graph.make_model(model_name)
# Make tf2onnx output compatible with the spec. of other.polish_model
replacing.replace_initializer_with_Constant(model_proto.graph)
model_proto = other.polish_model(model_proto)
else:
raise Exception('expect .pb file as input, but got "' + str(pb_path) + '"')
# rename
m = model_proto
m = combo.preprocess(m)
m = combo.common_optimization(m)
m = combo.tensorflow_optimization(m)
m = combo.postprocess(m)
if not test_mode:
g = m.graph
eliminating.eliminate_shape_changing_after_input(g)
m = other.polish_model(m)
return m
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert tensorflow pb file to onnx file and optimized onnx file. Or just optimize tensorflow onnx file.')
parser.add_argument('in_file', help='input file')
parser.add_argument('out_file', help='output optimized model file')
parser.add_argument('-t', '--test_mode', default=False, help='test mode will not eliminate shape changes after input')
args = parser.parse_args()
logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.INFO)
m = tf2onnx_flow(args.in_file, args.test_mode)
onnx.save(m, args.out_file)
logging.info('Save Optimized ONNX: %s', args.out_file)

View File

@ -0,0 +1,68 @@
import argparse
import numpy as np
import tensorflow as tf
import onnx
import onnxruntime
from tools import helper
def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
# Setup onnx session and get meta data
onnx_session = onnxruntime.InferenceSession(onnx_file, None)
onnx_outputs = onnx_session.get_outputs()
assert len(onnx_outputs) == 1, "The onnx model has more than one output"
onnx_model = onnx.load(onnx_file)
onnx_graph = onnx_model.graph
onnx_inputs = onnx_graph.input
assert len(onnx_inputs) == 1, "The onnx model has more than one input"
_, onnx_input_shape = helper.find_size_shape_from_value(onnx_inputs[0])
# Setup TFLite sessio and get meta data
tflite_session = tf.lite.Interpreter(model_path=tflite_file)
tflite_session.allocate_tensors()
tflite_inputs = tflite_session.get_input_details()
tflite_outputs = tflite_session.get_output_details()
tflite_input_shape = tflite_inputs[0]['shape']
# Compare input shape
assert(len(onnx_input_shape) == len(tflite_input_shape)), "TFLite and ONNX shape unmatch."
assert(onnx_input_shape == [tflite_input_shape[0], tflite_input_shape[3], tflite_input_shape[1], tflite_input_shape[2]]), "TFLite and ONNX shape unmatch."
# Generate random number and run
tflite_results = []
onnx_results = []
for _ in range(total_times):
# Generate input
tflite_input_data = np.array(np.random.random_sample(tflite_input_shape), dtype=np.float32)
onnx_input_data = np.transpose(tflite_input_data, [0, 3, 1, 2])
# Run tflite
tflite_session.set_tensor(tflite_inputs[0]['index'], tflite_input_data)
tflite_session.invoke()
tflite_results.append(tflite_session.get_tensor(tflite_outputs[0]['index']))
# Run onnx
onnx_input_dict = {onnx_inputs[0].name: onnx_input_data}
onnx_results.append(onnx_session.run([], onnx_input_dict)[0])
return tflite_results, onnx_results
if __name__ == '__main__':
# Argument parser.
parser = argparse.ArgumentParser(description="Compare a TFLite model and an ONNX model to check if they have the same output.")
parser.add_argument('tflite_file', help='input tflite file')
parser.add_argument('onnx_file', help='input ONNX file')
args = parser.parse_args()
results_a, results_b = compare_tflite_and_onnx(args.tflite_file, args.onnx_file, total_times=10)
ra_flat = helper.flatten_with_depth(results_a, 0)
rb_flat = helper.flatten_with_depth(results_b, 0)
shape_a = [item[1] for item in ra_flat]
shape_b = [item[1] for item in rb_flat]
assert shape_a == shape_b, 'two results data shape doesn\'t match'
ra_raw = [item[0] for item in ra_flat]
rb_raw = [item[0] for item in rb_flat]
try:
np.testing.assert_almost_equal(ra_raw, rb_raw, 8)
print('Two models have the same behaviour.')
except Exception as mismatch:
print(mismatch)
exit(1)

View File

@ -0,0 +1,258 @@
"""Combo functions that are usually called together.
"""
import logging
import onnx.utils
try:
from onnx import optimizer
except ImportError:
import onnxoptimizer as optimizer
from . import helper
from . import other
from . import replacing
from . import eliminating
from . import fusing
from . import constant_folding
from . import removing_transpose
from . import modhelper
from .common_pattern import torch_pattern_match, tf_pattern_match
from .helper import logger
def preprocess(model_proto, disable_fuse_bn=False, duplicate_shared_weights=True):
"""The most common used functions before other processing.
Args:
model_proto: the original model input
duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True.
Return:
the new model after preprocessing
It includes:
- inference shapes
- optimize model by ONNX library
- give names to the nodes
- replace initializer with Constant node
- replace -1 batch size with 1
- eliminate dropout and identity
- eliminate no children inputs
- topological sort
The optimizations provided by ONNX:
- eliminate_identity
- eliminate_nop_dropout
- eliminate_nop_transpose
- eliminate_nop_pad
- eliminate_unused_initializer
- eliminate_deadend
- fuse_consecutive_squeezes
- fuse_consecutive_transposes
- fuse_add_bias_into_conv
- fuse_transpose_into_gemm
- fuse_matmul_add_bias_into_gemm
- fuse_bn_into_conv
- fuse_pad_into_conv
"""
logger.info("Preprocessing the model...")
helper.setup_current_opset_version(model_proto)
eliminating.eliminate_empty_value_infos(model_proto.graph)
other.add_name_to_node(model_proto.graph)
other.rename_all_node_name(model_proto.graph)
replacing.replace_initializer_with_Constant(model_proto.graph)
other.topological_sort(model_proto.graph)
m = other.polish_model(model_proto)
passes = ['extract_constant_to_initializer',
'eliminate_nop_dropout',
'eliminate_deadend',
'fuse_matmul_add_bias_into_gemm',
'fuse_pad_into_conv']
if not disable_fuse_bn:
passes.append('fuse_bn_into_conv')
m = optimizer.optimize(m, passes)
g = m.graph
# Add name again since onnx optimizer higher than 1.7 may remove node names.
other.add_name_to_node(g)
if duplicate_shared_weights:
replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=True)
other.duplicate_param_shared_constant(g)
else:
replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=False)
other.topological_sort(g)
m = other.polish_model(m)
g = m.graph
eliminating.eliminate_consecutive_Cast(m.graph)
eliminating.eliminate_Cast_after_input(m.graph)
eliminating.eliminate_nop_pads(g)
eliminating.eliminate_nop_cast(g)
eliminating.eliminate_Identify_and_Dropout(g)
eliminating.eliminate_trivial_maxpool(g)
eliminating.eliminate_no_children_input(g)
other.format_value_info_shape(g)
other.topological_sort(g)
m = other.inference_shapes(m)
g = m.graph
replacing.replace_split_with_slices(g)
other.topological_sort(g)
return m
def common_optimization(m):
"""Common optimizations can be used in most cases.
:param m: the original model input\\
:return: the new model after preprocessing
It includes:
- transpose B in Gemm
- fuse BN into Gemm
- fuse consecutive Gemm
- replace AveragePool with GAP
- replace Squeeze/Unsqueeze with Reshape
- replace Reshape with Flatten
"""
logger.info("Doing nodes fusion and replacement... ")
m = other.polish_model(m)
g = m.graph
other.transpose_B_in_Gemm(g)
fusing.fuse_BN_into_Gemm(g)
fusing.fuse_BN_with_Reshape_into_Gemm(g)
fusing.fuse_Gemm_into_Gemm(g)
fusing.fuse_consecutive_reducemean(g)
fusing.fuse_slice_nodes_into_conv(g)
fusing.fuse_relu_min_into_clip(g)
other.duplicate_shared_Flatten(g)
replacing.replace_average_pool_with_GAP(g)
m = other.polish_model(m)
g = m.graph
replacing.replace_Squeeze_with_Reshape(g)
replacing.replace_Unsqueeze_with_Reshape(g)
replacing.replace_Reshape_with_Flatten(g)
replacing.replace_ReduceMean_with_GlobalAveragePool(g)
replacing.replace_Sum_with_Adds(g)
replacing.replace_constant_input_concat_with_pad(g)
other.topological_sort(g)
return m
def pytorch_constant_folding(m):
"""Constant folding needed by Pytorch exported models. It should be done
before using onnx optimizers since the dynamic shape structure may affect
the optimizations.
:param m: the original model input\\
:return: the new model after preprocessing
"""
logger.info("Working on constant folding.")
replacing.replace_shape_with_constant(m.graph)
replacing.replace_ConstantOfShape_with_constant(m.graph)
# constant_folding
m = other.inference_shapes(m)
while constant_folding.constant_folding(m.graph):
logging.debug("After constant folding jobs.")
other.topological_sort(m.graph)
while len(m.graph.value_info) != 0:
m.graph.value_info.pop()
m = other.inference_shapes(m)
replacing.replace_shape_with_constant(m.graph)
other.topological_sort(m.graph)
m = torch_pattern_match(m)
m = optimizer.optimize(m, ['eliminate_deadend'])
return m
def tensorflow_optimization(m):
"""Optimizations for tf models can be used in most cases.
:param m: the original model input\\
:return: the new model after preprocessing
It includes:
- eliminate shape change after input
- eliminate Reshape cast
- eliminate Squeeze before Reshape
- fuse Transpose into Constant
- replace Shape with Constant
"""
fusing.fuse_Transpose_into_Constant(m.graph)
fusing.fuse_MatMul_and_Add_into_Gemm(m.graph)
other.topological_sort(m.graph)
m = other.polish_model(m)
# constant folding
replacing.replace_shape_with_constant(m.graph)
# constant_folding
m = other.inference_shapes(m)
while constant_folding.constant_folding(m.graph):
logging.debug("After constant folding jobs.")
other.topological_sort(m.graph)
while len(m.graph.value_info) != 0:
m.graph.value_info.pop()
m = other.inference_shapes(m)
replacing.replace_shape_with_constant(m.graph)
other.topological_sort(m.graph)
m = tf_pattern_match(m)
m = optimizer.optimize(m, ['eliminate_deadend'])
eliminating.eliminate_consecutive_reshape(m.graph)
eliminating.eliminate_Squeeze_before_Reshape(m.graph)
other.topological_sort(m.graph)
return m
def postprocess(m):
"""Inference the shape and prepare for export.
:param m: the original model input\\
:return: the new model after preprocessing
"""
logger.info("Postprocessing the model...")
while len(m.graph.value_info) > 0:
m.graph.value_info.pop()
m = other.polish_model(m)
eliminating.eliminate_single_input_Concat(m.graph)
eliminating.eliminate_nop_Maxpool_and_AveragePool(m.graph)
eliminating.eliminate_trivial_elementwise_calculation(m.graph)
m = other.polish_model(m)
replacing.replace_depthwise_1x1_with_bn(m.graph)
m = other.polish_model(m)
# removing transpose
m = removing_transpose.eliminate_transposes(m)
m = other.polish_model(m)
removing_transpose.remove_trivial_transpose(m.graph)
removing_transpose.fuse_Transpose_into_Gemm_weight(m.graph)
# fuse some nodes
fusing.fuse_mul_and_add_into_bn(m.graph)
m = other.polish_model(m)
fusing.fuse_mul_and_add_into_gemm(m.graph)
m = other.polish_model(m)
fusing.fuse_conv_and_add_into_conv(m.graph)
m = other.polish_model(m)
replacing.replace_mul_to_bn(m.graph)
replacing.replace_div_to_bn(m.graph)
replacing.replace_add_to_bn(m.graph)
replacing.replace_sub_to_bn(m.graph)
replacing.replace_sub_with_bn_and_add(m.graph)
m = other.polish_model(m)
other.add_output_to_value_info(m.graph)
m = optimizer.optimize(m, ['eliminate_deadend'])
m.producer_name = 'kneron_formatter'
return m

View File

@ -0,0 +1,157 @@
from collections import defaultdict
import numpy as np
import onnx.helper
import onnx.utils
from . import modhelper
from . import helper
from . import other
def torch_pattern_match(m):
# Create a map from optype to the nodes.
optype2node = defaultdict(list)
for node in m.graph.node:
optype2node[node.op_type].append(node)
for matmul_node in optype2node['MatMul']:
pattern_matmul_mul_add(m.graph, matmul_node)
for resize_node in optype2node['Resize']:
# torch nn.UpsamplingBilinear2d will be given us 4 input: "X, roi, scales, sizes"
if len(resize_node.input) != 4:
continue
make_UpsamplingBilinear2d_value_info(m.graph, resize_node.name)
m = onnx.shape_inference.infer_shapes(m)
polish_RESIZE_input_param_node(m.graph, resize_node.name)
m = other.polish_model(m)
return m
def tf_pattern_match(m):
# Create a map from optype to the nodes.
optype2node = defaultdict(list)
for node in m.graph.node:
optype2node[node.op_type].append(node)
for matmul_node in optype2node['MatMul']:
pattern_matmul_mul_add(m.graph, matmul_node)
for resize_node in optype2node['Resize']:
# In tensorflow2onnx, ReizeXXX will be given us 4 input: "X, roi, scales, sizes"
# and node output name will be given the "node name + :0"
if len(resize_node.input) != 4:
continue
make_UpsamplingBilinear2d_value_info(m.graph, resize_node.name)
m = onnx.shape_inference.infer_shapes(m)
polish_RESIZE_input_param_node(m.graph, resize_node.name)
m = other.polish_model(m)
return m
def pattern_matmul_mul_add(g, matmul_node):
# Check node match - Mul node
next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0])
if len(next_nodes) != 1:
return
if next_nodes[0].op_type != 'Mul':
return
mul_node = next_nodes[0]
# Check node match - Add node
next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0])
if len(next_nodes) != 1:
return
if next_nodes[0].op_type != 'Add':
return
add_node = next_nodes[0]
# Check Mul weight
mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1])
if mul_weight_node.op_type != 'Constant':
return
weight_size, mul_weight = helper.constant_to_list(mul_weight_node)
for i in mul_weight:
if i != 1:
return
channel = weight_size[0]
# Check Add weight
add_weight_node = helper.find_node_by_output_name(g, add_node.input[1])
if add_weight_node.op_type != 'Constant':
return
# Check MatMul weight to see if it need weight broadcast
matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1])
matmul_weight = helper.constant_to_numpy(matmul_weight_node)
if matmul_weight.shape[1] == 1:
# Weight broadcast
new_matmul_weight = np.tile(matmul_weight, channel)
new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight)
g.node.remove(matmul_weight_node)
g.node.extend([new_matmul_weight_node])
value = helper.find_value_by_name(g, matmul_weight_node.output[0])
if value is not None:
g.value_info.remove(value)
# Remove Mul node
g.node.remove(mul_weight_node)
value = helper.find_value_by_name(g, mul_weight_node.output[0])
if value is not None:
g.value_info.remove(value)
g.node.remove(mul_node)
value = helper.find_value_by_name(g, mul_node.output[0])
if value is not None:
g.value_info.remove(value)
# Fuse Matmul and Add
gemm_node = onnx.helper.make_node(
'Gemm',
[matmul_node.input[0], matmul_node.input[1], add_node.input[1]],
[add_node.output[0]],
name = matmul_node.name,
alpha = 1.0,
beta = 1.0,
transA = 0,
transB = 0
)
g.node.extend([gemm_node])
# Clean up
g.node.remove(matmul_node)
g.node.remove(add_node)
value = helper.find_value_by_name(g, matmul_node.output[0])
if value is not None:
g.value_info.remove(value)
other.topological_sort(g)
def make_UpsamplingBilinear2d_value_info(g, resize_node_name):
resize_node = helper.find_node_by_node_name(g, resize_node_name)
shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3])
shape_data = helper.constant_to_numpy(shape_data_node).astype(int)
l_shape_data = list(shape_data)
if l_shape_data[0] == 0:
l_shape_data[0] = 1 + l_shape_data[0]
shape_data = np.array(l_shape_data)
new_output_value_info = onnx.helper.make_tensor_value_info(
resize_node.output[0],
onnx.helper.TensorProto.FLOAT,
shape_data.tolist()
)
g.value_info.extend([new_output_value_info])
def polish_RESIZE_input_param_node(g, resize_node_name):
resize_node = helper.find_node_by_node_name(g, resize_node_name)
shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3])
shape_data = helper.constant_to_numpy(shape_data_node).astype(int)
# handle 0 batch size which is invalid
if shape_data[0] == 0:
shape_data[0] = 1
pre_node_output_value_info = helper.find_value_by_name(g, resize_node.input[0])
ori_shape = np.array([pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value,
pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value,
pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value,
pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value])
resize_node.input.remove(resize_node.input[3])
resize_scales = np.array(shape_data/ori_shape).astype(float)
resize_scale_node = helper.list_to_constant('resize_scales_node_' + resize_node.name, resize_scales.shape, resize_scales, data_type=onnx.helper.TensorProto.FLOAT)
resize_node.input[2] = resize_scale_node.name
g.node.extend([resize_scale_node])
other.topological_sort(g)

View File

@ -0,0 +1,995 @@
import onnx.utils
import onnx
import numpy as np
import logging
import traceback
from . import helper
from .general_graph import Graph, Node
from .other import topological_sort
from .replacing import replace_shape_with_constant
from .helper import logger
def are_all_inputs_Constant_with_one_child(g, node):
for input_name in node.input:
input_node = helper.find_node_by_output_name(g, input_name)
if input_node is None or input_node.op_type != 'Constant':
return False
relative_outputs = helper.find_nodes_by_input_name(g, input_name)
if len(relative_outputs) > 1:
return False
return True
def constant_folding(g):
""" Do constant folding until nothing more can be done.
:param g: The onnx GraphProto\\
:return: If any node is folded, return True. Otherwise, return False.
"""
keep_folding = True # Keep the while loop
folded = False # Return value
try:
# Before constant folding, duplicate the constant nodes.
duplicate_constant_node(g)
while keep_folding:
keep_folding = False
for node in g.node:
# Check if the node is foldable
if node.op_type not in constant_folding_nodes.keys():
continue
# Check if the parents of the node are all single follower constant node.
if not are_all_inputs_Constant_with_one_child(g, node):
continue
# Constant folding for the specific node
if constant_folding_nodes[node.op_type](g, node):
logging.debug("Constant nodes and %s %s are folded.",
node.op_type, node.name)
folded = True
keep_folding = True
else:
logging.debug(
"Constant nodes and %s %s are skipped.", node.op_type, node.name)
except Exception as e:
logger.error("An exception is raised while constant folding.")
logger.error(traceback.format_exc())
return folded
def duplicate_constant_node(g):
""" Duplicate the constant node if its following nodes contain constant folding
nodes. Create and link the new constant nodes to the constant folding nodes.
"""
for node in g.node:
# Find a valid constant node
if node.op_type != 'Constant':
continue
output_val_info = helper.find_value_by_name(g, node.output[0])
if output_val_info is None:
print("Cannot inference the shape of Const node output: " +
node.output[0])
exit(1)
data_shape = helper.get_shape_from_value_info(output_val_info)
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
# For constant that has only one following node, no need to duplicate
if len(output_nodes) < 2:
continue
# Check if its following nodes are foldable
foldable_output_nodes = list(filter(lambda n: n.op_type in
constant_folding_nodes.keys(), output_nodes))
if not foldable_output_nodes:
continue
# Duplicate the node needed by foldable nodes
for i in range(len(foldable_output_nodes)):
logging.debug("Found constant %s and %s %s are availble for folding. Duplicate constant.",
node.name, foldable_output_nodes[i].op_type, foldable_output_nodes[i].name)
output_name = node.output[0] + '_dup_' + str(i)
new_constant_node = onnx.helper.make_node(
'Constant',
[],
[output_name],
name=output_name,
value=node.attribute[0].t
)
new_val_info = onnx.helper.make_tensor_value_info(
output_name,
node.attribute[0].t.data_type,
data_shape
)
input_ind = list(foldable_output_nodes[i].input).index(
node.output[0])
foldable_output_nodes[i].input[input_ind] = output_name
g.node.extend([new_constant_node])
g.value_info.extend([new_val_info])
# If all following nodes are foldable node, delete the original node.
if len(foldable_output_nodes) == len(output_nodes):
g.node.remove(node)
g.value_info.remove(output_val_info)
topological_sort(g)
return
def slice_constant_folding(g, node):
op_version = helper.get_current_opset_version()
# only support opset 9 & 11
if op_version == 11:
return slice_constant_folding_Opset_11(g, node)
elif op_version == 9:
return slice_constant_folding_Opset_9(g, node)
def slice_constant_folding_Opset_11(g, node):
""" Fold constant and slice nodes to a single constant node.
"""
pre_node = helper.find_node_by_output_name(g, node.input[0])
pre_shape, data_list = helper.constant_to_list(pre_node)
starts_node = helper.find_node_by_output_name(g, node.input[1])
_, starts = helper.constant_to_list(starts_node)
ends_node = helper.find_node_by_output_name(g, node.input[2])
_, ends = helper.constant_to_list(ends_node)
axes_node = None if len(node.input) <= 3 else helper.find_node_by_output_name(g, node.input[3])
if not axes_node:
axes = list(range(len(helper.get_shape(data_list))))
else:
_, axes = helper.constant_to_list(axes_node)
steps_node = None if len(node.input) <= 4 else helper.find_node_by_output_name(g, node.input[4])
if not steps_node:
steps = [1]*len(helper.get_shape(data_list))
else:
_, steps = helper.constant_to_list(steps_node)
data_list = list(map(int, data_list))
starts = list(map(int, starts))
ends = list(map(int, ends))
axes = list(map(int, axes))
steps = list(map(int, steps))
data_list = np.reshape(data_list, pre_shape)
new_data = None
for idx, _ in enumerate(axes):
new_data = np.apply_along_axis( lambda x: x[starts[idx] : ends[idx] : steps[idx]], idx, data_list )
new_node = helper.list_to_constant(node.output[0], helper.get_shape(
new_data), helper.flatten_to_list(new_data))
g.node.extend([new_node])
value_info = helper.find_value_by_name(g, pre_node.output[0])
if value_info is not None:
g.value_info.remove(value_info)
g.node.remove(node)
g.node.remove(pre_node)
return True
def slice_constant_folding_Opset_9(g, node):
""" Fold constant and slice nodes to a single constant node.
"""
pre_node = helper.find_node_by_output_name(g, node.input[0])
pre_shape, data_list = helper.constant_to_list(pre_node)
data_list = np.reshape(data_list, pre_shape)
axes = helper.get_attribute_by_name(node, 'axes')
ends = list(helper.get_attribute_by_name(node, 'ends').ints)
starts = list(helper.get_attribute_by_name(node, 'starts').ints)
if not axes:
axes = list(range(len(helper.get_shape(data_list))))
else:
axes = list(axes.ints)
new_data = helper.slice_data(data_list, starts, ends, axes)
new_node = helper.list_to_constant(node.output[0], helper.get_shape(
new_data), helper.flatten_to_list(new_data))
g.node.extend([new_node])
value_info = helper.find_value_by_name(g, pre_node.output[0])
if value_info is not None:
g.value_info.remove(value_info)
g.node.remove(node)
g.node.remove(pre_node)
return True
def cast_constant_folding(g, node):
""" Fold constant and cast node to a single constant node.
"""
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node)
data_type = node.attribute[0].i
if data_type in (6, 7):
data = list(map(int, data))
elif data_type == onnx.helper.TensorProto.FLOAT:
data = list(map(float, data))
else:
raise RuntimeError('data type not supported')
if shape == 1:
tensor = onnx.helper.make_tensor(
name=pre_node.attribute[0].name,
data_type=data_type,
dims=[],
vals=data
)
else:
tensor = onnx.helper.make_tensor(
name=pre_node.attribute[0].name,
data_type=data_type,
dims=shape,
vals=helper.flatten_to_list(data)
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=tensor
)
g.node.extend([new_node])
value_info = helper.find_value_by_name(g, pre_node.output[0])
if value_info is not None:
g.value_info.remove(value_info)
value_info = helper.find_value_by_name(g, node.output[0])
if value_info is not None:
g.value_info.remove(value_info)
g.node.remove(pre_node)
g.node.remove(node)
return True
def reduceprod_constant_folding(g, node):
""" Fold constant and reduceprod nodes to a single constant node.
"""
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data_set = helper.constant_to_list(pre_node)
tensor = pre_node.attribute[0].t
data_set = np.reshape(data_set, shape)
for att in node.attribute:
if att.name == 'axes':
axes = list(att.ints)
else:
keepdims = int(att.i)
new_data = np.prod(data_set, axis=tuple(axes), keepdims=keepdims == 1)
new_shape = helper.get_shape(new_data)
new_flat_data = helper.flatten_to_list(new_data)
new_tensor = onnx.helper.make_tensor(
name=node.output[0],
data_type=tensor.data_type,
dims=new_shape,
vals=new_flat_data
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
)
g.node.extend([new_node])
value_info = None
for item in g.value_info:
if item.name == pre_node.output[0]:
value_info = item
if value_info is not None:
g.value_info.remove(value_info)
g.node.remove(pre_node)
g.node.remove(node)
return True
def reshape_constant_input_folding(g, node):
""" Fold constant and reshape nodes to a single constant node.
"""
pre_data_node = helper.find_node_by_output_name(g, node.input[0])
pre_shape_node = helper.find_node_by_output_name(g, node.input[1])
data = helper.constant_to_numpy(pre_data_node)
_, shape = helper.constant_to_list(pre_shape_node)
new_data = np.reshape(data, shape)
new_tensor = onnx.helper.make_tensor(
name=node.output[0],
data_type=pre_data_node.attribute[0].t.data_type,
dims=new_data.shape,
vals=helper.flatten_to_list(new_data)
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
)
g.node.extend([new_node])
data_val_info = helper.find_value_by_name(g, pre_data_node.output[0])
shape_val_info = helper.find_value_by_name(g, pre_shape_node.output[0])
g.value_info.remove(data_val_info)
g.value_info.remove(shape_val_info)
g.node.remove(node)
g.node.remove(pre_data_node)
g.node.remove(pre_shape_node)
return True
def concat_constant_folding(g, node):
""" Fold constant and concat nodes to a single constant node.
"""
node_to_del = []
valid_inputs = True
for input_name in node.input:
input_node = helper.find_node_by_output_name(g, input_name)
input_node_output = helper.find_nodes_by_input_name(g, input_name)
if len(input_node_output) > 1:
valid_inputs = False
break
if input_node.op_type != 'Constant':
valid_inputs = False
break
if not valid_inputs:
return False
input_data = []
input_shapes = []
for input_name in node.input:
input_node = helper.find_node_by_output_name(g, input_name)
s, d = helper.constant_to_list(input_node)
d = np.reshape(d, s)
input_data.append(d)
input_shapes.append(s)
node_to_del.append(input_node)
concat_data = np.concatenate(input_data, axis=node.attribute[0].i)
node_data_type = input_node.attribute[0].t.data_type
if concat_data.dtype in [np.int32, np.int64]:
node_data_type = onnx.helper.TensorProto.INT64
elif concat_data.dtype in [np.float32, np.float64]:
node_data_type = onnx.helper.TensorProto.FLOAT
new_node = helper.list_to_constant(
node.output[0],
helper.get_shape(concat_data),
helper.flatten_to_list(concat_data),
data_type=node_data_type
)
g.node.extend([new_node])
node_to_del.append(node)
for input_name in node.input:
val_info = helper.find_value_by_name(g, input_name)
if val_info:
g.value_info.remove(val_info)
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return True
def transpose_constant_folding(g, node):
"""Fold constant and transpose nodes to a single constant node.
"""
node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node)
np_data = np.reshape(data, shape)
permutation = list(node.attribute[0].ints)
new_data = np.transpose(np_data, permutation)
new_shape = new_data.shape
new_node = helper.list_to_constant(
node.output[0],
new_shape,
new_data.flatten().tolist(),
data_type=pre_node.attribute[0].t.data_type
)
g.node.extend([new_node])
node_to_del.extend([node, pre_node])
pre_val_info = helper.find_value_by_name(g, node.input[0])
g.value_info.remove(pre_val_info)
next_val_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(next_val_info)
new_val_info = onnx.helper.make_tensor_value_info(
node.output[0],
pre_node.attribute[0].t.data_type,
new_shape
)
g.value_info.extend([new_val_info])
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
folded = True
return folded
def unsqueeze_constant_folding(g, node):
"""Fold constant and unsqueeze nodes to a single constant node.
"""
node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node)
if type(shape) == int:
np_data = data[0]
else:
np_data = np.reshape(data, shape)
axes = list(node.attribute[0].ints)
axes.sort()
for dim in axes:
np_data = np.expand_dims(np_data, axis=dim)
new_shape = np_data.shape
new_node = helper.list_to_constant(
node.output[0],
new_shape,
np_data.flatten().tolist(),
data_type=pre_node.attribute[0].t.data_type
)
g.node.extend([new_node])
node_to_del.extend([node, pre_node])
pre_val_info = helper.find_value_by_name(g, node.input[0])
next_val_info = helper.find_value_by_name(g, node.output[0])
if pre_val_info is not None:
g.value_info.remove(pre_val_info)
else:
print(node.name)
if next_val_info is not None:
g.value_info.remove(next_val_info)
new_val_info = onnx.helper.make_tensor_value_info(
node.output[0],
pre_node.attribute[0].t.data_type,
new_shape
)
g.value_info.extend([new_val_info])
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return True
def gather_constant_folding(g, node):
"""Fold constant and gather nodes to a single constant node.
"""
node_to_del = []
pre_data_node = helper.find_node_by_output_name(g, node.input[0])
pre_indices_node = helper.find_node_by_output_name(g, node.input[1])
shape, data = helper.constant_to_list(pre_data_node)
indice_shape, indices = helper.constant_to_list(pre_indices_node)
if type(indice_shape) == int:
indices = indices[0]
np_data = np.reshape(data, shape)
if len(node.attribute) < 1:
axis = 0
else:
axis = node.attribute[0].i
new_data = np.take(np_data, indices, axis=axis)
new_shape = new_data.shape
new_node = helper.list_to_constant(
node.output[0],
new_shape,
new_data.flatten().tolist(),
data_type=pre_data_node.attribute[0].t.data_type
)
node_to_del.extend([node, pre_data_node, pre_indices_node])
g.node.extend([new_node])
val_info_1 = helper.find_value_by_name(g, node.input[0])
val_info_2 = helper.find_value_by_name(g, node.input[1])
val_info_3 = helper.find_value_by_name(g, node.output[0])
new_val_info = onnx.helper.make_tensor_value_info(
new_node.output[0],
pre_data_node.attribute[0].t.data_type,
new_shape
)
if val_info_1 is not None:
g.value_info.remove(val_info_1)
if val_info_2 is not None:
g.value_info.remove(val_info_2)
if val_info_3 is not None:
g.value_info.remove(val_info_3)
g.value_info.extend([new_val_info])
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return True
def add_constant_folding(g, node):
"""Fold constant and add nodes to a single constant node.
"""
node_to_del = []
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
if not pre_node_1 or not pre_node_2:
return False
shape1, data1 = helper.constant_to_list(pre_node_1)
shape2, data2 = helper.constant_to_list(pre_node_2)
np_data1 = np.reshape(data1, shape1)
np_data2 = np.reshape(data2, shape2)
try:
new_data = np.add(np_data1, np_data2)
except:
raise RuntimeError('can\'t broadcast and add two data sets')
new_node = helper.list_to_constant(
node.output[0],
new_data.shape,
new_data.flatten().tolist(),
data_type=pre_node_1.attribute[0].t.data_type
)
g.node.extend([new_node])
node_to_del.extend([node, pre_node_1, pre_node_2])
g.value_info.remove(helper.find_value_by_name(g, pre_node_1.output[0]))
g.value_info.remove(helper.find_value_by_name(g, pre_node_2.output[0]))
folded = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return folded
def sqrt_constant_folding(g, node):
""" Fold constant and sqrt nodes to a single node.
"""
node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node)
np_data = np.sqrt(np.reshape(data, shape))
output_val_info = helper.find_value_by_name(g, node.output[0])
input_val_info = helper.find_value_by_name(g, node.input[0])
data_type = output_val_info.type.tensor_type.elem_type
new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data',
data_type=data_type,
dims=shape,
vals=np_data.flatten().tolist()
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
)
g.value_info.remove(input_val_info)
node_to_del.extend([pre_node, node])
g.node.extend([new_node])
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return True
def reciprocal_constant_folding(g, node):
""" Fold constant and reciprocal nodes to a single constant node.
"""
node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node)
data = list(map(lambda x: x if abs(x) > 1.e-8 else 1.e-8, data))
np_data = np.reshape(data, shape)
np_data = np.reciprocal(np_data)
input_val_info = helper.find_value_by_name(g, node.input[0])
output_val_info = helper.find_value_by_name(g, node.output[0])
data_type = output_val_info.type.tensor_type.elem_type
new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data',
data_type=data_type,
dims=shape,
vals=np_data.flatten().tolist()
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
)
node_to_del.extend([node, pre_node])
g.node.extend([new_node])
g.value_info.remove(input_val_info)
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return True
def mul_constant_folding(g, node):
""" Fold constant and mul nodes to a single constant node.
"""
node_to_del = []
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
pre_value_info1 = helper.find_value_by_name(g, node.input[0])
pre_value_info2 = helper.find_value_by_name(g, node.input[1])
if pre_value_info1 is None or pre_value_info2 is None:
return False
shape1, data1 = helper.constant_to_list(pre_node_1)
shape2, data2 = helper.constant_to_list(pre_node_2)
np_data1 = np.reshape(data1, shape1)
np_data2 = np.reshape(data2, shape2)
try:
new_data = np.multiply(np_data1, np_data2)
except:
raise RuntimeError('can not broadcast and multiply two data sets')
# Special shape for single element.
if shape1 == 1 and shape2 == 1:
new_shape = []
else:
new_shape = new_data.shape
new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data',
data_type=pre_node_1.attribute[0].t.data_type,
dims=new_shape,
vals=new_data.flatten().tolist()
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
)
node_to_del.extend([node, pre_node_1, pre_node_2])
g.node.extend([new_node])
g.value_info.remove(pre_value_info1)
g.value_info.remove(pre_value_info2)
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return True
def div_constant_folding(g, node):
""" Fold constant and mul nodes to a single constant node.
"""
node_to_del = []
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
pre_value_info1 = helper.find_value_by_name(g, node.input[0])
pre_value_info2 = helper.find_value_by_name(g, node.input[1])
if pre_value_info1 is None or pre_value_info2 is None:
return False
shape1, data1 = helper.constant_to_list(pre_node_1)
shape2, data2 = helper.constant_to_list(pre_node_2)
np_data1 = np.reshape(data1, shape1)
np_data2 = np.reshape(data2, shape2)
try:
new_data = np.divide(np_data1, np_data2)
except:
raise RuntimeError('can not broadcast and multiply two data sets')
# Special shape for single element.
if shape1 == 1 and shape2 == 1:
new_shape = []
else:
new_shape = new_data.shape
# Check data type if it is int
if pre_node_1.attribute[0].t.data_type == 7:
new_data = new_data.astype('int64')
new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data',
data_type=pre_node_1.attribute[0].t.data_type,
dims=new_shape,
vals=new_data.flatten().tolist()
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
)
node_to_del.extend([node, pre_node_1, pre_node_2])
g.node.extend([new_node])
g.value_info.remove(pre_value_info1)
g.value_info.remove(pre_value_info2)
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return True
def sub_constant_folding(g, node):
""" Fold constant and sub nodes to a single node.
"""
node_to_del = []
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
pre_val_info_1 = helper.find_value_by_name(g, node.input[0])
pre_val_info_2 = helper.find_value_by_name(g, node.input[1])
shape1, data1 = helper.constant_to_list(pre_node_1)
shape2, data2 = helper.constant_to_list(pre_node_2)
new_data = np.subtract(data1, data2)
# Special shape for single element.
if shape1 == 1 and shape2 == 1:
new_shape = []
else:
new_shape = new_data.shape
new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data',
data_type=pre_node_1.attribute[0].t.data_type,
dims=new_shape,
vals=helper.flatten_to_list(new_data)
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
)
g.node.extend([new_node])
node_to_del.extend([node, pre_node_1, pre_node_2])
g.value_info.remove(pre_val_info_1)
g.value_info.remove(pre_val_info_2)
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return True
def neg_constant_folding(g, node):
node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data_list = helper.constant_to_list(pre_node)
new_data_list = [-num for num in data_list]
new_tensor = onnx.helper.make_tensor(
name=pre_node.name+'_neg_tensor',
data_type=pre_node.attribute[0].t.data_type,
dims=shape,
vals=new_data_list
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
)
g.node.extend([new_node])
node_to_del.extend([pre_node, node])
g.value_info.remove(helper.find_value_by_name(g, node.input[0]))
while node_to_del:
g.node.remove(node_to_del.pop())
return True
def floor_constant_folding(g, node):
node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node)
new_data = np.floor(data).flatten().tolist()
if shape == 1:
new_shape = []
else:
new_shape = shape
new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data',
data_type=pre_node.attribute[0].t.data_type,
dims=new_shape,
vals=helper.flatten_to_list(new_data)
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
)
g.node.extend([new_node])
node_to_del.extend([pre_node, node])
old_value = helper.find_value_by_name(g, node.input[0])
if old_value is not None:
g.value_info.remove(old_value)
while node_to_del:
g.node.remove(node_to_del.pop())
return True
def bn_constant_folding(g, node):
""" Fold constant and mul nodes to a single constant node.
"""
# Prepare data
node_to_del = []
input_node = helper.find_node_by_output_name(g, node.input[0])
scale_node = helper.find_node_by_output_name(g, node.input[1])
bias_node = helper.find_node_by_output_name(g, node.input[2])
mean_node = helper.find_node_by_output_name(g, node.input[3])
var_node = helper.find_node_by_output_name(g, node.input[4])
input_value_info = []
for i in range(5):
input_value_info.append(helper.find_value_by_name(g, node.input[i]))
if input_value_info[0] is None:
return False
input_data = helper.constant_to_numpy(input_node)
scale_data = helper.constant_to_numpy(scale_node)
bias_data = helper.constant_to_numpy(bias_node)
mean_data = helper.constant_to_numpy(mean_node)
var_data = helper.constant_to_numpy(var_node)
epsilon = helper.get_var_attribute_by_name(node, 'epsilon', 'float')
if epsilon is None:
epsilon = 0.00001
# Calculate new node
new_data = scale_data * (input_data - mean_data) / np.sqrt(var_data + epsilon) + bias_data
new_node = helper.numpy_to_constant(node.output[0], new_data)
# Reconnect the graph
node_to_del.extend([node, input_node, scale_node, bias_node, mean_node, var_node])
g.node.extend([new_node])
for value in input_value_info:
if value is not None:
g.value_info.remove(value)
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return True
def DequantizeLinear_constant_folding(g, node):
""" Fold constant and mul nodes to a single constant node.
"""
# Prepare data
node_to_del = []
x_node = helper.find_node_by_output_name(g, node.input[0])
x_scale_node = helper.find_node_by_output_name(g, node.input[1])
if len(node.input) > 2:
x_zero_point_node = helper.find_node_by_output_name(g, node.input[2])
else:
x_zero_point_node = None
input_value_info = []
for i in range(len(node.input)):
input_value_info.append(helper.find_value_by_name(g, node.input[i]))
if input_value_info[0] is None:
return False
x_data = helper.constant_to_numpy(x_node)
x_scale_data = helper.constant_to_numpy(x_scale_node)
if x_zero_point_node is not None:
x_zero_point_data = helper.constant_to_numpy(x_zero_point_node)
else:
x_zero_point_data = np.array([0.0])
# Calculate new node
new_data = (x_data.astype(np.float32) - x_zero_point_data.astype(np.float32)) * x_scale_data
new_node = helper.numpy_to_constant(node.output[0], new_data)
# Reconnect the graph
node_to_del.extend([node, x_node, x_scale_node])
if x_zero_point_node is not None:
node_to_del.append(x_zero_point_node)
g.node.extend([new_node])
for value in input_value_info:
if value is not None:
g.value_info.remove(value)
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return True
# Available constant folding names to function map.
constant_folding_nodes = {
'Add': add_constant_folding,
'BatchNormalization': bn_constant_folding,
'Cast': cast_constant_folding,
'Concat': concat_constant_folding,
'DequantizeLinear': DequantizeLinear_constant_folding,
'Div': div_constant_folding,
'Floor': floor_constant_folding,
'Gather': gather_constant_folding,
'Mul': mul_constant_folding,
'Reciprocal': reciprocal_constant_folding,
'ReduceProd': reduceprod_constant_folding,
'Reshape': reshape_constant_input_folding,
'Slice': slice_constant_folding,
'Sqrt': sqrt_constant_folding,
'Transpose': transpose_constant_folding,
'Unsqueeze': unsqueeze_constant_folding,
'Sub': sub_constant_folding,
'Neg': neg_constant_folding
}

View File

@ -0,0 +1,669 @@
import collections
import struct
import onnx
import numpy as np
from . import other
from . import helper
from . import modhelper
from .general_graph import Graph
def eliminate_Identify_and_Dropout(g):
"""
Eliminate Identify layers
:param g: the onnx graph
"""
node_to_remove = []
for node in g.node:
if node.op_type != 'Identity' and node.op_type != 'Dropout':
continue
# If this node is the last node, leave it to `eliminate_useless_last node`
if helper.find_output_by_name(g, node.output[0]) is not None:
continue
# Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
# Delete value info
value_between = helper.find_value_by_name(g, node.output[0])
try:
g.value_info.remove(value_between)
except:
print("No value info to delete while eliminating identity layers.")
# Node is waiting for elimination
node_to_remove.append(node)
for node in node_to_remove:
g.node.remove(node)
# Remove last useless nodes
def remove_useless_last_nodes(g):
"""Remove useless nodes from the tail of the graph
"""
USELESS = ["Reshape", "Identity", "Transpose", "Flatten", "Dropout", "Mystery", "Constant", "Squeeze", "Unsqueeze", 'Softmax']
graph = Graph(g)
todo = collections.deque()
for node in graph.output_nodes:
if len(node.children) == 0:
todo.append(node)
node_to_remove = []
while todo:
# BFS find nodes to remove
cur_node = todo.popleft()
if cur_node.proto is None:
continue
if cur_node.proto.op_type not in USELESS:
continue
# Find the output
cur_node_output = helper.find_output_by_name(g, cur_node.proto.output[0])
for cur_input in cur_node.parents:
cur_input.children.remove(cur_node)
if len(cur_input.children) == 0:
todo.append(cur_input)
if cur_node_output is not None:
cur_input_output = helper.find_value_by_name(g, cur_input.proto.output[0])
cur_input_output_in_output = helper.find_output_by_name(g, cur_input.proto.output[0])
if cur_input_output is not None and cur_input_output_in_output is None:
g.output.extend([cur_input_output])
node_to_remove.append(cur_node.proto)
try:
g.value_info.remove(helper.find_value_by_name(g, cur_node.proto.output[0]))
except ValueError:
pass
if cur_node_output is not None:
g.output.remove(cur_node_output)
cur_node.proto = None
cur_node.parents.clear()
for node in node_to_remove:
g.node.remove(node)
######################################
# TF only optimization passes #
######################################
def eliminate_shape_changing_after_input(g):
"""
Eliminate the Reshape node after input and reshape the input
:param g: the onnx graph
"""
node_to_remove = []
REMOVE_LIST = ["Reshape", "Transpose", "Flatten", "Dropout", "Squeeze", "Unsqueeze"]
for node in g.node:
# Find an input and the shape node
if node.op_type not in REMOVE_LIST:
continue
old_input = helper.find_input_by_name(g, node.input[0])
if old_input is None:
continue
# If the input is used by multiple nodes, skip.
counter = 0
for tnode in g.node:
if old_input.name in tnode.input:
counter += 1
if counter > 1:
continue
# Remove Weight if any.
output_val_info = helper.find_value_by_name(g, node.output[0])
if node.op_type == 'Reshape':
shape_node = helper.find_node_by_output_name(g, node.input[1])
if shape_node.op_type != 'Constant':
continue
# manuelly set the input shape
shape_info = helper.find_value_by_name(g, shape_node.output[0])
old_size, old_shape = helper.find_size_shape_from_value(shape_info)
_, new_shape = helper.constant_to_list(shape_node)
for i in range(len(new_shape)):
if new_shape[i] == -1:
dim = int(old_size//np.prod(new_shape)*(-1))
new_shape[i] = dim
new_input = onnx.helper.make_tensor_value_info(
output_val_info.name,
output_val_info.type.tensor_type.elem_type,
new_shape
)
node_to_remove.append(node)
shape_outputs = helper.find_nodes_by_input_name(g, shape_node.output[0])
if len(shape_outputs) == 1:
node_to_remove.append(shape_node)
g.value_info.remove(helper.find_value_by_name(g, shape_node.output[0]))
g.input.remove(old_input)
g.input.extend([new_input])
g.value_info.remove(output_val_info)
elif node.op_type == 'Transpose':
permutation = list(node.attribute[0].ints)
pre_shape = helper.get_shape_from_value_info(old_input)
new_shape = [pre_shape[i] for i in permutation]
new_input = onnx.helper.make_tensor_value_info(
output_val_info.name,
output_val_info.type.tensor_type.elem_type,
new_shape
)
node_to_remove.append(node)
g.input.remove(old_input)
g.input.extend([new_input])
g.value_info.remove(output_val_info)
elif node.op_type == 'Flatten':
axis = node.attribute[0].int
pre_shape = helper.get_shape_from_value_info(old_input)
dim_1, dim_2 = 1, 1
if axis == 0:
dim_1 = 1
dim_2 = np.prod(pre_shape)
else:
dim_1 = np.prod(pre_shape[:axis]).astype(int)
dim_2 = np.prod(pre_shape[axis:]).astype(int)
new_shape = [dim_1, dim_2]
new_input = onnx.helper.make_tensor_value_info(
output_val_info.name,
output_val_info.type.tensor_type.elem_type,
new_shape
)
node_to_remove.append(node)
g.input.remove(old_input)
g.input.extend([new_input])
g.value_info.remove(output_val_info)
elif node.op_type == 'Dropout':
g.input.remove(old_input)
g.input.extend([output_val_info])
g.value_info.remove(output_val_info)
node_to_remove.append(node)
elif node.op_type == 'Squeeze':
axis = list(node.attribute[0].ints)
pre_shape = helper.get_shape_from_value_info(old_input)
for pos in sorted(axis)[::-1]:
if pre_shape[pos] != 1:
raise RuntimeError('invalid axis for squeeze')
else:
pre_shape.pop(pos)
new_shape = pre_shape
new_input = onnx.helper.make_tensor_value_info(
output_val_info.name,
output_val_info.type.tensor_type.elem_type,
new_shape
)
node_to_remove.append(node)
g.input.remove(old_input)
g.input.extend([new_input])
g.value_info.remove(output_val_info)
elif node.op_type == 'Unsqueeze':
axis = list(node.attribute[0].ints)
pre_shape = helper.get_shape_from_value_info(old_input)
new_shape = pre_shape
for pos in axis:
new_shape.insert(pos, 1)
new_input = onnx.helper.make_tensor_value_info(
output_val_info.name,
output_val_info.type.tensor_type.elem_type,
new_shape
)
node_to_remove.append(node)
g.input.remove(old_input)
g.input.extend([new_input])
g.value_info.remove(output_val_info)
else:
pass
for node in node_to_remove:
g.node.remove(node)
other.topological_sort(g)
def eliminate_Reshape_Cast(g):
"""Eliminate the cast layer for shape of Reshape layer
:param g: the onnx graph
"""
#Find all reshape layers
node_to_remove = []
for node in g.node:
if node.op_type != 'Reshape':
continue
prev_node = helper.find_node_by_output_name(g, node.input[1])
if prev_node.op_type != 'Cast':
continue
# Now we find the cast weight pattern. Cast the weight, delete the cast.
reshape_node = node
cast_node = prev_node
weight_node = helper.find_node_by_output_name(g, cast_node.input[0])
if weight_node is None:
raise RuntimeError("Unexpected None before Cast-Reshape.")
weight_node.attribute[0].t.data_type = 7
if weight_node.attribute[0].t.raw_data:
raw_data = weight_node.attribute[0].t.raw_data
int_data = [i[0] for i in struct.iter_unpack('i', raw_data)]
raw_data = struct.pack('q' * len(int_data), *int_data)
elif len(weight_node.attribute[0].t.int64_data) > 0\
or len(weight_node.attribute[0].t.int32_data) > 0:
# It's already int. Do nothing
pass
else:
raise NotImplementedError()
# Change Value info
origin_weight_out = helper.find_value_by_name(g, weight_node.output[0])
weight_node.output.pop()
weight_node.output.extend([reshape_node.input[1]])
# Delete
g.value_info.remove(origin_weight_out)
g.node.remove(cast_node)
def eliminate_Cast_after_input(g):
"""Eliminate the cast layer right after the input
:param g: the onnx graph
"""
node_to_remove = []
for node in g.node:
if node.op_type != 'Cast':
continue
old_input = helper.find_input_by_name(g, node.input[0])
if old_input is None:
continue
next_val_info = helper.find_value_by_name(g, node.output[0])
shape = helper.get_shape_from_value_info(next_val_info)
new_val_info = onnx.helper.make_tensor_value_info(
next_val_info.name,
node.attribute[0].i,
shape
)
# Delete old value_info
g.input.remove(old_input)
g.value_info.remove(next_val_info)
# Append nodes to node_to_remove
node_to_remove.append(node)
# Add new input
g.input.extend([new_val_info])
for node in node_to_remove:
g.node.remove(node)
def eliminate_consecutive_Cast(g):
"""If two cast is next to each other, remove the first cast
:param g: the onnx graph
"""
node_to_remove = []
for node in g.node:
if node.op_type != 'Cast':
continue
first_node = helper.find_node_by_output_name(g, node.input[0])
if first_node is None or first_node.op_type != 'Cast':
continue
# Here we have two consecutive Cast Node
# Reset the input of the later node
node.input[0] = first_node.input[0]
# Remove the first node and its output value info
node_to_remove.append(first_node)
first_output = helper.find_value_by_name(g, first_node.output[0])
g.value_info.remove(first_output)
for node in node_to_remove:
g.node.remove(node)
def eliminate_Squeeze_before_Reshape(g):
"""If Squeeze and Reshape is next to each other, remove the first node
:param g: the onnx graph
"""
node_to_remove = []
for node in g.node:
if node.op_type != 'Reshape':
continue
first_node = helper.find_node_by_output_name(g, node.input[0])
if not first_node:
continue
if first_node.op_type != 'Squeeze':
continue
# Here we have two consecutive Cast Node
# Reset the input of the later node
node.input[0] = first_node.input[0]
# Remove the first node and its output value info
node_to_remove.append(first_node)
first_output = helper.find_value_by_name(g, first_node.output[0])
g.value_info.remove(first_output)
for node in node_to_remove:
g.node.remove(node)
def eliminate_no_children_input(g):
"""Eliminate inputs with no children at all.
"""
# Create a set of input names
input_names = set([i.name for i in g.input])
# If a name is used in any node, remove this name from the set.
for n in g.node:
for i in n.input:
input_names.discard(i)
# Remove the inputs with the left names.
for i in input_names:
info = helper.find_input_by_name(g, i)
g.input.remove(info)
def eliminate_consecutive_reshape(g):
"""Replace consecutive reshape nodes by a single node.
"""
node_to_del = []
for node in g.node:
if node.op_type != 'Reshape':
continue
pre_data_node = helper.find_node_by_output_name(g, node.input[0])
pre_shape_node = helper.find_node_by_output_name(g, node.input[1])
if not pre_data_node or not pre_shape_node:
continue
if pre_shape_node.op_type != 'Constant':
continue
if pre_data_node.op_type != 'Reshape':
continue
pre_pre_shape_node = helper.find_node_by_output_name(g, pre_data_node.input[1])
if pre_pre_shape_node.op_type != 'Constant':
continue
new_reshape_node = onnx.helper.make_node(
'Reshape',
[pre_data_node.input[0], node.input[1]],
[node.output[0]],
name = node.output[0]
)
g.node.extend([new_reshape_node])
node_to_del.append(node)
node_to_del.append(pre_data_node)
node_to_del.append(pre_pre_shape_node)
val_info_to_del1 = helper.find_value_by_name(g, node.input[0])
val_info_to_del2 = helper.find_value_by_name(g, pre_data_node.input[1])
g.value_info.remove(val_info_to_del1)
g.value_info.remove(val_info_to_del2)
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
def eliminate_single_input_Concat(g):
"""
Eliminate single input Concat layers
:param g: the onnx graph
"""
node_to_remove = []
for node in g.node:
if node.op_type != 'Concat':
continue
# If this node has more than 1 input, continue.
if len(node.input) > 1:
continue
# If this node is the output node, set its previous node as output nodes.
if helper.find_output_by_name(g, node.output[0]) is not None:
todel_output = helper.find_output_by_name(g, node.output[0])
the_input_value = helper.find_value_by_name(g, node.input[0])
g.output.remove(todel_output)
g.output.extend([the_input_value])
node_to_remove.append(node)
continue
# Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
# Delete value info
value_between = helper.find_value_by_name(g, node.output[0])
try:
g.value_info.remove(value_between)
except:
print("No value info to delete while eliminating identity layers.")
# Node is waiting for elimination
node_to_remove.append(node)
for node in node_to_remove:
g.node.remove(node)
def eliminate_nop_Maxpool_and_AveragePool(g):
"""
Eliminate do nothing MaxPool and AveragePool layers.
Those layers have valid padding, 1x1 kernel and [1,1] strides.
:param g: the onnx graph
"""
node_to_remove = []
for node in g.node:
if node.op_type != 'MaxPool' and node.op_type != 'AveragePool':
continue
# If this node is actually working, continue.
kernel = helper.get_list_attribute_by_name(node, "kernel_shape", "int")
pads = helper.get_list_attribute_by_name(node, "pads", "int")
strides = helper.get_list_attribute_by_name(node, "strides", "int")
if kernel != [1, 1] or pads != [0, 0, 0, 0] or strides != [1, 1]:
continue
# If this node is the output node, set its previous node as output nodes.
if helper.find_output_by_name(g, node.output[0]) is not None:
todel_output = helper.find_output_by_name(g, node.output[0])
the_input_value = helper.find_value_by_name(g, node.input[0])
g.output.remove(todel_output)
g.output.extend([the_input_value])
node_to_remove.append(node)
continue
# Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
# Delete value info
value_between = helper.find_value_by_name(g, node.output[0])
try:
g.value_info.remove(value_between)
except:
print("No value info to delete while eliminating identity layers.")
# Node is waiting for elimination
node_to_remove.append(node)
for node in node_to_remove:
g.node.remove(node)
def eliminate_trivial_maxpool(g):
node_to_del = []
for node in g.node:
if node.op_type != 'MaxPool':
continue
pads = None
strides = None
dilation = None
kernel_shape = None
for att in node.attribute:
if att.name == 'pads':
pads = list(att.ints)
elif att.name == 'strides':
strides = list(att.ints)
elif att.name == 'kernel_shape':
kernel_shape = list(att.ints)
elif att.name == 'dilation':
dilation = list(att.ints)
else:
pass
if pads and any([pad != 0 for pad in pads]):
continue
if strides and any([stride != 1 for stride in strides]):
continue
if dilation and any([dila != 1 for dila in dilation]):
continue
if any([dim != 1 for dim in kernel_shape]):
continue
node_to_del.append(node)
next_nodes = helper.find_nodes_by_input_name(g, node.output[0])
if next_nodes[0] == None:
output_value = helper.find_output_by_name(g, node.output[0])
if not output_value:
continue
else:
pre_val_info = helper.find_value_by_name(g, node.input[0])
g.output.extend([pre_val_info])
g.output.remove(output_value)
for next_node in next_nodes:
modhelper.replace_node_input(next_node, node.output[0], node.input[0])
next_val_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(next_val_info)
while node_to_del:
g.node.remove(node_to_del.pop())
other.topological_sort(g)
def eliminate_empty_value_infos(g):
to_remove = []
for value_info in g.value_info:
if len(value_info.type.tensor_type.shape.dim) == 0:
to_remove.append(value_info)
for value_info in to_remove:
g.value_info.remove(value_info)
def eliminate_nop_pads(g):
node_to_remove = []
for node in g.node:
if node.op_type != 'Pad':
continue
# Check if the Pad is empty or not
pads_node = helper.find_node_by_output_name(g, node.input[1])
pads_np = helper.constant_to_numpy(pads_node)
all_zero = True
for value in pads_np:
if value != 0:
all_zero = False
if not all_zero:
continue
# Check if it has the constant_value_node
constant_value_node = None
if len(node.input) > 2:
constant_value_node = helper.find_node_by_output_name(g, node.input[2])
# If this node is the output node, set its previous node as output nodes.
if helper.find_output_by_name(g, node.output[0]) is not None:
todel_output = helper.find_output_by_name(g, node.output[0])
g.output.remove(todel_output)
if helper.find_output_by_name(g, node.input[0]) is None:
the_input_value = helper.find_value_by_name(g, node.input[0])
if the_input_value is not None:
g.output.extend([the_input_value])
# Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
# Delete value info
value_between = helper.find_value_by_name(g, node.output[0])
try:
g.value_info.remove(value_between)
except:
helper.logger.info("No value info to delete while eliminating identity layers.")
# Node is waiting for elimination
node_to_remove.append(node)
for node in node_to_remove:
g.node.remove(node)
def eliminate_trivial_elementwise_calculation(g):
"""Eliminate Add, Sub, Mul, Sub nodes which do nothing.
"""
node_to_remove = []
for node in g.node:
weight_node = None
if node.op_type == 'Add' or node.op_type == 'Sub':
# For add and sub, check if the weights are 0s.
weight_node = helper.find_node_by_output_name(g, node.input[1])
if weight_node is None or weight_node.op_type != 'Constant':
continue
weight_np = helper.constant_to_numpy(weight_node)
if np.any(weight_np):
continue
elif node.op_type == 'Mul' or node.op_type == 'Div':
# For Mul and Div, check if the weights are 1s.
weight_node = helper.find_node_by_output_name(g, node.input[1])
if weight_node is None or weight_node.op_type != 'Constant':
continue
weight_np = helper.constant_to_numpy(weight_node)
weight_np = weight_np - 1
if np.any(weight_np):
continue
else:
# For other nodes, just skip
continue
# Remove the node
node_to_remove.append(node)
output_value_info = helper.find_value_by_name(g, node.output[0])
if output_value_info is not None:
g.value_info.remove(output_value_info)
# Replace next node input if any.
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
todel_output = helper.find_output_by_name(g, node.output[0])
if todel_output is not None:
g.output.remove(todel_output)
previous_output = helper.find_output_by_name(g, node.input[0])
if previous_output is None:
the_input_value = helper.find_value_by_name(g, node.input[0])
g.output.extend([the_input_value])
# Delete the constant node if it is not used by other nodes
constant_following_nodes = helper.find_following_nodes_by_input_value_name(g, weight_node.output[0])
if len(constant_following_nodes) == 1:
node_to_remove.append(weight_node)
output_value_info = helper.find_value_by_name(g, weight_node.output[0])
if output_value_info is not None:
g.value_info.remove(output_value_info)
for node in node_to_remove:
g.node.remove(node)
def eliminate_nop_cast(g):
"""Eliminate do nothing Cast nodes.
"""
node_to_remove = []
for node in g.node:
if node.op_type != 'Cast':
continue
# Get input value_info
input_value = helper.find_value_by_name(g, node.input[0])
if input_value is None:
helper.logger.debug(f"Cannot find the input value_info for Cast node {node.name}. Skip elimination check.")
continue
# Get output value_info
output_value = helper.find_value_by_name(g, node.output[0])
if output_value is None:
output_value = helper.find_output_by_name(g, node.output[0])
if output_value is None:
helper.logger.debug(f"Cannot find the output value_info for Cast node {node.name}. Skip elimination check.")
continue
# Compare the type.
if input_value.type.tensor_type.elem_type != output_value.type.tensor_type.elem_type:
continue
# If this node is the output node, set its previous node as output nodes.
if helper.find_output_by_name(g, node.output[0]) is not None:
todel_output = helper.find_output_by_name(g, node.output[0])
g.output.remove(todel_output)
if helper.find_output_by_name(g, node.input[0]) is None:
the_input_value = helper.find_value_by_name(g, node.input[0])
if the_input_value is not None:
g.output.extend([the_input_value])
# Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
# Delete value info
value_between = helper.find_value_by_name(g, node.output[0])
if value_between is not None:
g.value_info.remove(value_between)
# Node is waiting for elimination
node_to_remove.append(node)
for node in node_to_remove:
g.node.remove(node)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,83 @@
from collections import deque
class Node:
"""A Node which maps a node proto. It has pointers to its parents and
children.
"""
def __init__(self, onnx_node):
"""Initialize a node. This initialization only set up the mapping to
node proto. The pointers should be set up by outside.
"""
self.name = None
self.parents = []
self.children = []
self.proto = None
self.output_value = None
if onnx_node is not None:
self.name = onnx_node.name
self.proto = onnx_node
class Graph:
"""A graph which is constructed from the onnx proto.
"""
def __init__(self, onnx_graph):
"""Construct the graph from onnx.
"""
self.input_nodes = []
self.output_nodes = []
self.name2node = {}
self.output2node = {}
self.proto = onnx_graph
# Add input nodes
for value in onnx_graph.input:
input_node = Node(None)
input_node.name = "Input_" + value.name
input_node.output_value = value
self.name2node[input_node.name] = input_node
self.output2node[value.name] = input_node
self.input_nodes.append(input_node)
output_value_names = [value.name for value in onnx_graph.output]
# Add regular nodes
for onnx_node in onnx_graph.node:
node = Node(onnx_node)
self.name2node[node.name] = node
self.output2node[onnx_node.output[0]] = node
for value_name in onnx_node.input:
node.parents.append(self.output2node[value_name])
self.output2node[value_name].children.append(node)
if onnx_node.output[0] in output_value_names:
self.output_nodes.append(node)
# Add value infos
for value in onnx_graph.value_info:
node = self.output2node[value.name]
node.output_value = value
def get_sorted_node_list(self):
"""Return a node list in topological order.
"""
visited = set()
todo = deque()
result = []
for node in self.input_nodes:
todo.append(node)
visited.add(node)
for onnx_node in self.proto.node:
if onnx_node.op_type == "Constant":
node = self.name2node[onnx_node.name]
todo.append(node)
visited.add(node)
while todo:
node = todo.popleft()
result.append(node)
for child in node.children:
if child in visited:
continue
ready = True
for child_parent in child.parents:
if child_parent in visited:
continue
ready = False
break
if ready:
todo.append(child)
visited.add(child)
return result

View File

@ -0,0 +1,621 @@
"""This module contains helper functions that do not modify the graph.
"""
import onnx
import onnx.helper
import struct
import numpy as np
import logging
__ONNX_VERSION__ = -1
logger = logging.getLogger("optimizer_scripts")
def setup_current_opset_version(m):
global __ONNX_VERSION__
__ONNX_VERSION__ = m.opset_import[0].version
if __ONNX_VERSION__ not in [11]:
raise RuntimeError('Only support opset 11, but got ' + str(__ONNX_VERSION__))
def get_current_opset_version():
if __ONNX_VERSION__ == -1:
raise RuntimeError('do setup_current_opset_version first please')
return __ONNX_VERSION__
def find_nodes_by_input_name(g, name):
nodes = []
for node in g.node:
if name in node.input:
nodes.append(node)
return nodes
def find_node_by_output_name(g, name):
"""
Find a node in the graph by its output name
:param g: the onnx graph\\
:param name: the target node output name\\
:returns: the node find by name
"""
for i in g.node:
if name in i.output:
return i
return None
def find_node_by_node_name(g, name):
"""
Find a node in the graph by its output name
:param g: the onnx graph\\
:param name: the target node output name\\
:returns: the node find by name
"""
for i in g.node:
if i.name == name:
return i
return None
def find_following_nodes_by_input_value_name(g, name):
""" Find the following nodes of a specific value.
:param g: the onnx graph. \\
:param name: the value name. \\
:return: a list of following nodes.
"""
return find_nodes_by_input_name(g, name)
def find_value_by_name(g, name):
"""
Find a value_info in the graph by name
:param g: the onnx graph\\
:param name: the target value_info name\\
:returns: the value_info find by name
"""
for i in g.value_info:
if i.name == name:
return i
return None
def find_output_by_name(g, name):
"""
Find a value_info in the graph by name
:param g: the onnx graph\\
:param name: the target value_info name\\
:returns: the value_info find by name
"""
for i in g.output:
if i.name == name:
return i
return None
def find_input_by_name(g, name):
"""
Find a input in the graph by name
:param g: the onnx graph\\
:param name: the target input name\\
:returns: the input find by name
"""
for i in g.input:
if i.name == name:
return i
return None
def list_to_constant(name, shape, data, data_type=None):
"""Generate a constant node using the given infomation.
:name: the node name and the output value name\\
:shape: the data shape\\
:data: the data itself\\
:returns: the generated onnx constant node
"""
if not data_type:
if isinstance(data, int):
data_type = onnx.helper.TensorProto.INT64
elif isinstance(data, float):
data_type = onnx.helper.TensorProto.FLOAT
elif len(data) > 0 and isinstance(data[0], int):
data_type = onnx.helper.TensorProto.INT64
else:
data_type = onnx.helper.TensorProto.FLOAT
tensor = onnx.helper.make_tensor(
name,
data_type,
shape,
data
)
new_w_node = onnx.helper.make_node(
"Constant",
[],
[name],
name = name,
value = tensor
)
return new_w_node
def scaler_to_constant(name, data, data_type=None):
"""Generate a constant node using the given infomation.
:name: the node name and the output value name\\
:shape: the data shape\\
:data: the data itself\\
:returns: the generated onnx constant node
"""
if not data_type:
if isinstance(data, int):
data_type = onnx.helper.TensorProto.INT64
elif isinstance(data, float):
data_type = onnx.helper.TensorProto.FLOAT
else:
logger.error("Cannot create scaler constant with a list.")
exit(1)
tensor = onnx.helper.make_tensor(
name,
data_type,
None,
[data]
)
new_w_node = onnx.helper.make_node(
"Constant",
[],
[name],
name = name,
value = tensor
)
return new_w_node
def numpy_to_constant(name, np_array):
return list_to_constant(name, np_array.shape, np_array.flatten().tolist())
def constant_to_list(node):
"""Generate a list from the constant node
:node: the Constant node\\
:returns: the shape of the constant node, the data of the constant node
"""
tensor = node.attribute[0].t
# 1. check data type
# 2. get data from raw or data
# 3. get shape from dim
if tensor.data_type == onnx.helper.TensorProto.INT32:
if len(tensor.int32_data) != 0:
data = list(tensor.int32_data)
else:
data = [i[0] for i in struct.iter_unpack('i', tensor.raw_data)]
elif tensor.data_type == onnx.helper.TensorProto.INT64:
if len(tensor.int64_data) != 0:
data = list(tensor.int64_data)
else:
data = [i[0] for i in struct.iter_unpack('q', tensor.raw_data)]
elif tensor.data_type == onnx.helper.TensorProto.INT8:
if len(tensor.int32_data) != 0:
data = list(tensor.int32_data)
else:
data = [i[0] for i in struct.iter_unpack('b', tensor.raw_data)]
elif tensor.data_type == onnx.helper.TensorProto.FLOAT:
if len(tensor.float_data) != 0:
data = list(tensor.float_data)
else:
data = [i[0] for i in struct.iter_unpack('f', tensor.raw_data)]
elif tensor.data_type == onnx.helper.TensorProto.DOUBLE:
if len(tensor.double_data) != 0:
data = list(tensor.double_data)
else:
data = [i[0] for i in struct.iter_unpack('d', tensor.raw_data)]
else:
print("Not supported data type {}".format(tensor.data_type))
raise RuntimeError
if len(tensor.dims) == 0:
shape = len(data)
else:
shape = list(tensor.dims)
return shape, data
def constant_to_numpy(node):
"""Generate a numpy array from the constant node
:node: the Constant node\\
:returns: the numpy array
"""
shape, data = constant_to_list(node)
return np.array(data).reshape(shape)
def all_constant_input(node):
"""Find the inputs of the given node. If the inputs of this node are all\\
constant nodes, return True. Otherwise, return False.
:param node: the input node which has a Node structure\\
:return: whether the node of this node are all constant
"""
if node.proto is None:
return False
isConstant = True
for parent in node.parents:
if parent.proto is None or parent.proto.op_type != 'Constant':
isConstant = False
break
return isConstant
def get_padding(size, kernel_size, strides):
""" Calculate the padding array for same padding in the Tensorflow fashion.\\
See https://www.tensorflow.org/api_guides/python/nn#Convolution for more.
"""
if size[0] % strides[0] == 0:
pad_h = max(kernel_size[0] - strides[0], 0)
else:
pad_h = max(kernel_size[0] - (size[0] % strides[0]), 0)
if size[1] % strides[1] == 0:
pad_w = max(kernel_size[1] - strides[1], 0)
else:
pad_w = max(kernel_size[1] - (size[1] % strides[1]), 0)
return [pad_h//2, pad_w//2, pad_h-pad_h//2, pad_w-pad_w//2]
def get_shape_from_value_info(value):
"""Get shape from a value info.
:param value: the value_info proto\\
:return: list of the shape
"""
return [d.dim_value for d in value.type.tensor_type.shape.dim]
def find_size_shape_from_value(value):
'''
Find the size of data within the value_info object.
:param value: value_info
:return: int size and list shape of the data in the value_info
'''
if not value:
return None, None
if not value.type.tensor_type.shape.dim:
return 0, []
size = 1
shape = []
for i in range(len(value.type.tensor_type.shape.dim)):
size *= max(1, value.type.tensor_type.shape.dim[i].dim_value)
shape.append(max(1, value.type.tensor_type.shape.dim[i].dim_value))
return size, shape
def get_attribute_by_name(node, attr_name):
"""Get attribute proto with specific name in the given node proto.
:param node: the node proto.\\
:param attr_name: a str for the name of the target.\\
:return: if found, return the attribute_proto. Else, return None.
"""
for attr in node.attribute:
if attr.name == attr_name:
return attr
return None
def get_list_attribute_by_name(node, attr_name: str, attr_type: str):
"""Get list attribute with specific name in the given node proto.
:param node: the node proto.\\
:param attr_name: a str for the name of the target.\\
:param attr_type: a str which should be "float" or "int".\\
:return: if found, return the list. Else, return None.
"""
attr_proto = get_attribute_by_name(node, attr_name)
if attr_proto is None:
return None
if attr_type == "int":
if len(attr_proto.ints) == 0:
return None
else:
return list(attr_proto.ints)
elif attr_type == "float":
if len(attr_proto.ints) == 0:
return None
else:
return list(attr_proto.floats)
else:
print("Warning: undefined type for list attribute extraction")
return None
def get_var_attribute_by_name(node, attr_name: str, attr_type: str):
"""Get variable attribute with specific name in the given node proto.
:param node: the node proto.\\
:param attr_name: str for the name of the target.\\
:param attr_type: str which should be "float", "int", "string" or "tensor".\\
:return: if found, return the variable. Else, return None.
"""
attr_proto = get_attribute_by_name(node, attr_name)
if attr_proto is None:
return None
if attr_type == "int":
return attr_proto.i
elif attr_type == "float":
return attr_proto.f
elif attr_type == "string":
if type(attr_proto.s) == type(b'abc'):
return attr_proto.s.decode("utf-8")
else:
return attr_proto.s
elif attr_type == "tensor":
return attr_proto.t
else:
print("Warning: undefined type for variable attribute extraction")
return None
def flatten_with_depth(data, depth):
output = []
if type(data) not in [type(np.array([1])), type([1])]:
return [[data, 0]]
for item in data:
if type(item) not in [type(np.array([1])), type([1])]:
output.append([item, depth+1])
else:
output += flatten_with_depth(item, depth+1)
return output
def flatten_to_list(data):
flatten_depth = flatten_with_depth(data, 0)
flat_data = [item[0] for item in flatten_depth]
return flat_data
def get_shape(data):
shape = []
if type(data) not in [type(np.array([1])), type([1])]:
return []
sub_data = data[0]
shape.append(len(data))
while type(sub_data) in [type(np.array([1])), type([1])]:
shape.append(len(sub_data))
sub_data = sub_data[0]
return shape
def slice_data(data, starts, ends, axes):
flat_data = [item[0] for item in flatten_with_depth(data, 0)]
shape = get_shape(data)
starts_updated = []
ends_updated = []
for i in range(len(starts)):
start_updated = min(starts[i], shape[i]-1) % shape[i]
starts_updated.append(start_updated)
for j in range(len(starts)):
if ends[j] >= shape[j]:
end_updated = shape[j]
else:
end_updated = min(ends[j], shape[j]) % shape[j]
ends_updated.append(end_updated)
index_slices = []
for i in range(len(shape)):
if i not in axes:
index_slices.append(list(range(shape[i])))
else:
axe_ind = axes.index(i)
index_slices.append(list(range(starts_updated[axe_ind], ends_updated[axe_ind])))
indices = [1]
for i in range(len(shape)-1, -1, -1):
step = np.prod(shape[i+1:])
temp_pos = indices
new_indices = []
for n in index_slices[i]:
for pos in temp_pos:
new_indices.append(int(n*step+pos))
indices = new_indices
sliced_data = [flat_data[k-1] for k in indices]
# reshape to correct shape.
new_shape = []
for i in range(len(shape)):
if i not in axes:
new_shape.append(shape[i])
else:
axe_ind = axes.index(i)
new_shape.append(ends_updated[axe_ind]-starts_updated[axe_ind])
if any([dim < 1 for dim in new_shape]):
raise RuntimeError('Invalid starts ends.')
sliced_data = np.reshape(sliced_data, new_shape)
return sliced_data
def concatenate(data_sets, axis):
# check shapes
shapes = []
shapes_ = []
for data_set in data_sets:
shape = get_shape(data_set)
shapes.append(list(shape))
shape.pop(axis)
shapes_.append(shape)
if not all([s == shapes_[0] for s in shapes_]):
raise RuntimeError('data sets shapes do not match')
new_dim = sum([s[axis] for s in shapes])
new_shape = list(shapes[0])
new_shape[axis] = new_dim
flat_data_sets = []
for data_set in data_sets:
flat_data_sets.append(flatten_to_list(data_set))
sub_block_size = 1
for i in range(axis+1, len(shapes[0])):
sub_block_size *= shapes[0][i]
split_num = 1
for i in range(axis):
split_num *= shapes[0][i]
total_flat_data = []
for i in range(split_num):
for j in range(len(shapes)):
block_size = sub_block_size*shapes[j][axis]
total_flat_data.extend(flat_data_sets[j][i*block_size:(i+1)*block_size])
new_data = np.reshape(total_flat_data, new_shape)
return new_data
def broadcast_data_sets(data_set_1, data_set_2):
shape1 = get_shape(data_set_1)
shape2 = get_shape(data_set_2)
# compare shapes and get broadcasted shape
list_a, list_b = (shape1, shape2) if len(shape1) > len(shape2) else (shape2, shape1)
while len(list_a) > len(list_b):
list_b.insert(0, 0)
broadcasted_shape = []
for i in range(len(list_a)):
if list_b[i] == 0:
broadcasted_shape.append(list_a[i])
elif list_b[i] == 1:
broadcasted_shape.append(list_a[i])
elif list_a[i] == 1:
broadcasted_shape.append(list_b[i])
elif list_a[i] == list_b[i]:
broadcasted_shape.append(list_a[i])
else:
raise RuntimeError('Can not broadcast two data sets')
# prepare data for broadcasting.
shape1 = list(map(lambda x:x if x != 0 else 1, shape1))
shape2 = list(map(lambda x:x if x != 0 else 1, shape2))
data_1 = np.reshape(data_set_1, shape1)
data_2 = np.reshape(data_set_2, shape2)
for i in range(len(shape1)):
if shape1[i] != broadcasted_shape[i]:
new_data_total = [list(data_1) for _ in range(broadcasted_shape[i])]
data_1 = concatenate(new_data_total, axis=i)
for i in range(len(shape2)):
if shape2[i] != broadcasted_shape[i]:
new_data_total = [list(data_2) for _ in range(broadcasted_shape[i])]
data_2 = concatenate(new_data_total, axis=i)
return data_1, data_2
def add(data_set_1, data_set_2):
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2)
flat_data_1 = flatten_to_list(broadcasted_data_1)
flat_data_2 = flatten_to_list(broadcasted_data_2)
shape = get_shape(broadcasted_data_1)
res = []
for i in range(len(flat_data_1)):
res.append(flat_data_1[i]+flat_data_2[i])
res = np.reshape(res, shape)
return res
def reduceprod(data_set, axis, keepdims=1):
flat_data = flatten_to_list(data_set)
old_shape = get_shape(data_set)
temp_shape = old_shape
temp_flat_data = flat_data
for ax in axis:
split_num = 1
step = 1
for i in range(ax):
split_num *= temp_shape[i]
for i in range(ax+1, len(temp_shape)):
step *= temp_shape[i]
block_size = len(temp_flat_data)//split_num
new_flat_data = []
for j in range(split_num):
block_data = temp_flat_data[j*block_size:(j+1)*block_size]
reduced_block_data = []
for k in range(step):
val = block_data[k]
for l in range(1, block_size//step):
val *= block_data[k+l*step]
reduced_block_data.append(val)
new_flat_data.extend(reduced_block_data)
temp_flat_data = new_flat_data
temp_shape[ax] = 1
new_flat_data = temp_flat_data
new_shape = temp_shape
if not keepdims:
axis = sorted(list(axis))
for pos in axis[::-1]:
new_shape.pop(pos)
return np.reshape(new_flat_data, new_shape)
def transpose(data_set, permutation):
# find series of local swaps
data_set = list(data_set)
perm = list(permutation)
shape = get_shape(data_set)
flat_data = flatten_to_list(data_set)
assert set(perm) == set(range(len(shape))), 'invalid permutation'
new_shape = [shape[i] for i in perm]
swaps = []
bubbled = True
while bubbled:
bubbled = False
for i in range(len(new_shape)-1):
if perm[i] > perm[i+1]:
swaps.append([i, i+1])
p_1, p_2 = perm[i], perm[i+1]
perm[i], perm[i+1] = p_2, p_1
bubbled = True
# apply local swaps
current_shape = list(shape)
temp_flat_data = flat_data
for swap in swaps[::-1]:
ind_1, ind_2 = swap[0], swap[1]
dim_1 = current_shape[ind_1]
dim_2 = current_shape[ind_2]
split_num = 1
block_size = 1
for i in range(ind_1):
split_num *= current_shape[i]
for i in range(ind_2+1, len(current_shape)):
block_size *= current_shape[i]
data_blocks = np.reshape(temp_flat_data, [-1, block_size])
flat_data_1 = []
for k in range(split_num):
block = []
for m in range(dim_2):
for n in range(dim_1):
block_pos = k*dim_1*dim_2 + n*dim_2+m
block.extend(data_blocks[block_pos])
flat_data_1.extend(block)
temp_flat_data = flat_data_1
current_shape[ind_1] = dim_2
current_shape[ind_2] = dim_1
return np.reshape(temp_flat_data, current_shape)
def subtract(data_set_1, data_set_2):
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2)
shape = get_shape(broadcasted_data_1)
flat_data_1 = flatten_to_list(broadcasted_data_1)
flat_data_2 = flatten_to_list(broadcasted_data_2)
substracted_data = [flat_data_1[i] - flat_data_2[i] for i in range(len(flat_data_1))]
new_data = np.reshape(substracted_data, shape)
return new_data

View File

@ -0,0 +1,78 @@
"""This module contains helper functions that do graph modifications.
"""
import onnx
from . import helper
def replace_node_input(node, old_input, new_input):
for i, input_name in enumerate(node.input):
if input_name == old_input:
node.input[i] = new_input
def delete_nodes(g, node_list):
node_to_delete = []
#Find target nodes
for node in g.node:
if node.name not in node_list:
continue
else:
node_to_delete.append(node)
if len(node_list) != len(node_to_delete):
print("Some nodes do not exist in the graph. Skipping them.")
for node in node_to_delete:
# Check the node whether if it is valid to delete
if len(node.input) == 0:
print("Deleting an Constant node. Please make sure you also delete all its following nodes")
elif len(node.input) > 1:
print("Warning: Node {} has more than one input. This script cannot delete merge nodes.".format(node.name))
# Connect the nodes around the target node.
# Set the following node input as the previous node output.
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
if len(node.input) == 0:
for following_node in following_nodes:
following_node.input.remove(node.output[0])
elif len(following_nodes) > 0 and len(node.input) == 1 and helper.find_input_by_name(g, node.input[0]) is not None:
# The node input is an input
new_input = helper.find_value_by_name(g, node.output[0])
g.input.append(new_input)
g.input.remove(helper.find_input_by_name(g, node.input[0]))
g.value_info.remove(new_input)
elif len(following_nodes) > 0:
for following_node in following_nodes:
replace_node_input(following_node, node.output[0], node.input[0])
else:
# If the node is the output, replace the output with the previous input.
value = helper.find_value_by_name(g, node.input[0])
output_values = []
while len(g.output):
output_values.append(g.output.pop())
while output_values:
output_value = output_values.pop()
if output_value.name == node.output[0]:
g.output.extend([value])
else:
g.output.extend([output_value])
# Remove the node and value info.
g.node.remove(node)
def delete_input(g, target_list):
for name in target_list:
input_value = helper.find_input_by_name(g, name)
if input_value is None:
print("Cannot find input {}".format(name))
continue
g.input.remove(input_value)
def delete_output(g, target_list):
for name in target_list:
output_value = helper.find_output_by_name(g, name)
if output_value is None:
print("Cannot find output {}".format(name))
continue
g.output.remove(output_value)
def delete_value_with_name_if_exists(g, name):
value = helper.find_value_by_name(g, name)
if value is not None:
g.value_info.remove(value)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,317 @@
from . import helper
from . import other
from . import modhelper
from . import fusing
import numpy as np
import onnx
import onnx.utils
def eliminate_transposes(m):
g = m.graph
keep_eliminating = True
while keep_eliminating:
while swap_transpose_with_single_next_node(g):
pass
splitted = split_transpose_for_multiple_next_nodes(g)
annihilated = annihilate_transposes(g)
multiple_trans_swapped = swap_multiple_transposes_with_node(g)
keep_eliminating = splitted or annihilated or multiple_trans_swapped
if keep_eliminating:
m = other.polish_model(m)
g = m.graph
return m
def swap_transpose_with_single_next_node(g):
swapped = False
passable_nodes = set(['Relu', 'Neg', 'LeakyRelu', 'Sqrt', 'Reciprocal', 'Add', 'Mul', 'Tanh'])
for node in g.node:
trans_node = node
# Check for transpose node
if trans_node.op_type != 'Transpose':
continue
next_nodes = helper.find_nodes_by_input_name(g, trans_node.output[0])
if len(next_nodes) != 1:
continue
next_node = next_nodes[0]
# Check if the next node is the type can be swapped
if next_node.op_type not in passable_nodes:
continue
input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in next_node.input]
# Check if the node has nonconstant input other than the Transpose node itself
nonconstant_input = False
for input_node in input_nodes:
if input_node == None:
nonconstant_input = True
break
if input_node.name == trans_node.name:
continue
elif input_node.op_type == 'Constant':
continue
else:
nonconstant_input = True
break
if nonconstant_input:
continue
for input_node in input_nodes:
if input_node.name == trans_node.name:
# if the input is just the transpose node
next_value_info = helper.find_value_by_name(g, next_node.output[0])
mid_value_info = helper.find_value_by_name(g, trans_node.output[0])
output_nodes = helper.find_nodes_by_input_name(g, next_node.output[0])
for out_node in output_nodes:
modhelper.replace_node_input(out_node, next_node.output[0], trans_node.name)
next_node.input[0] = trans_node.input[0]
next_node.output[0] = next_node.name
trans_node.input[0] = next_node.name
trans_node.output[0] = trans_node.name
if next_value_info:
next_value_info.name = trans_node.name
if mid_value_info:
g.value_info.remove(mid_value_info)
else:
# if the input is a constant node
old_tensor = input_node.attribute[0].t
old_shape, data = helper.constant_to_list(input_node)
# If the constant node is a scaler, no action is needed
if type(old_shape) == int:
old_shape = [old_shape]
permutation = list(trans_node.attribute[0].ints)
while len(old_shape) < len(permutation):
old_shape.insert(0, 1)
np_data = np.reshape(data, old_shape)
reverse_perm = []
for i in range(len(permutation)):
reverse_perm.append(permutation.index(i))
np_data = np.transpose(np_data, reverse_perm)
new_shape = np_data.shape
new_tensor = onnx.helper.make_tensor(
name=old_tensor.name,
data_type=old_tensor.data_type,
dims=new_shape,
vals=np_data.flatten().tolist()
)
new_node = onnx.helper.make_node(
'Constant',
[],
[input_node.output[0]],
name=input_node.name,
value=new_tensor
)
g.node.extend([new_node])
g.value_info.remove(helper.find_value_by_name(g, input_node.output[0]))
g.node.remove(input_node)
swapped = True
other.topological_sort(g)
return swapped
def swap_multiple_transposes_with_node(g):
# here only consider same input transposes
swapped = False
passable_nodes = set(['Add', 'Mul'])
node_to_del = []
for node in g.node:
if node.op_type not in passable_nodes:
continue
input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in node.input]
if any([input_node == None for input_node in input_nodes]):
continue
if any([input_node.op_type != 'Transpose' for input_node in input_nodes]):
continue
permutation = list(input_nodes[0].attribute[0].ints)
if any([list(input_node.attribute[0].ints) != permutation for input_node in input_nodes]):
continue
for input_name in node.input:
input_node = helper.find_node_by_output_name(g, input_name)
modhelper.replace_node_input(node, input_name, input_node.input[0])
node_to_del.extend(input_nodes)
for input_node in input_nodes:
input_val_info = helper.find_value_by_name(g, input_node.output[0])
if input_val_info is not None:
g.value_info.remove(input_val_info)
output_val_info = helper.find_value_by_name(g, node.output[0])
if output_val_info is not None:
g.value_info.remove(output_val_info)
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
for i in range(len(output_nodes)):
new_trans_node_name = node.name+'_trans_'+str(i)
new_trans_node = onnx.helper.make_node(
'Transpose',
[node.output[0]],
[new_trans_node_name],
name=new_trans_node_name,
perm=permutation
)
modhelper.replace_node_input(output_nodes[i], node.output[0], new_trans_node_name)
g.node.extend([new_trans_node])
swapped = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)
return swapped
def annihilate_transposes(g):
node_to_del = []
annihilated = False
for node in g.node:
if node.op_type != 'Transpose':
continue
pre_node = helper.find_node_by_output_name(g, node.input[0])
if not pre_node or pre_node.op_type != 'Transpose':
continue
nodes_from_top_transpose = helper.find_nodes_by_input_name(g, pre_node.output[0])
if len(nodes_from_top_transpose) > 1:
continue
perm_1 = list(pre_node.attribute[0].ints)
perm_2 = list(node.attribute[0].ints)
if perm_1 != perm_2:
continue
out_nodes = helper.find_nodes_by_input_name(g, node.output[0])
for out_node in out_nodes:
modhelper.replace_node_input(out_node, node.output[0], pre_node.input[0])
node_to_del.extend([node, pre_node])
mid_value_info = helper.find_value_by_name(g, pre_node.output[0])
out_value_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(mid_value_info)
g.value_info.remove(out_value_info)
annihilated = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return annihilated
def split_transpose_for_multiple_next_nodes(g):
splitted = False
node_to_del = []
for node in g.node:
if node.op_type != 'Transpose':
continue
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
if len(output_nodes) < 2:
continue
for i in range(len(output_nodes)):
output_node = output_nodes[i]
new_trans_node_name = node.name + '_' + str(i)
new_trans_node = onnx.helper.make_node(
'Transpose',
[node.input[0]],
[new_trans_node_name],
name=new_trans_node_name,
perm=list(node.attribute[0].ints)
)
modhelper.replace_node_input(output_node, node.output[0], new_trans_node.output[0])
g.node.extend([new_trans_node])
node_to_del.append(node)
val_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(val_info)
splitted = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)
return splitted
def remove_trivial_transpose(g):
node_to_del = []
for node in g.node:
if node.op_type != 'Transpose':
continue
permutation = list(node.attribute[0].ints)
if permutation != list(range(len(permutation))):
continue
next_nodes = helper.find_nodes_by_input_name(g, node.output[0])
if not next_nodes:
input_val_info = helper.find_value_by_name(g, node.input[0])
out_val_info = helper.find_output_by_name(g, node.output[0])
if not input_val_info:
input_val_info = helper.find_input_by_name(g, node.input[0])
g.output.remove(out_val_info)
g.output.extend([input_val_info])
else:
out_val_info = helper.find_value_by_name(g, node.output[0])
for next_node in next_nodes:
modhelper.replace_node_input(next_node, node.output[0], node.input[0])
g.value_info.remove(out_val_info)
node_to_del.append(node)
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)
def fuse_Transpose_into_Gemm_weight(g):
node_to_del = []
for node in g.node:
# Check pattern
if node.op_type != 'Gemm':
continue
prev_node = helper.find_node_by_output_name(g, node.input[0])
if prev_node is None or prev_node.op_type != 'Flatten':
continue
transpose_node = helper.find_node_by_output_name(g, prev_node.input[0])
if transpose_node.op_type != 'Transpose':
continue
# Check attribute
perm = helper.get_list_attribute_by_name(transpose_node, 'perm', 'int')
if perm != [0, 2, 3, 1]:
continue
transB = helper.get_var_attribute_by_name(node, 'transB', 'int')
if transB is not None and transB == 1:
continue
# Get the original weight
origin_weight = helper.find_node_by_output_name(g, node.input[1])
origin_np = helper.constant_to_numpy(origin_weight)
# Calculate a new weight
shape = helper.get_shape_from_value_info(helper.find_value_by_name(g, prev_node.input[0]))
shape.append(-1)
new_np = np.reshape(origin_np, shape)
new_np = np.transpose(new_np, [0, 3, 1, 2, 4])
new_np = np.reshape(new_np, [-1, new_np.shape[-1]])
new_weight = helper.numpy_to_constant(origin_weight.output[0], new_np)
# Replace and eliminate
prev_node.input[0] = transpose_node.input[0]
node_to_del.append(transpose_node)
node_to_del.append(origin_weight)
g.value_info.remove(helper.find_value_by_name(g, transpose_node.output[0]))
g.node.extend([new_weight])
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,423 @@
"""Special operations on model.
"""
import logging
import onnx.helper
import numpy as np
from . import helper
from . import other
from . import modhelper
def change_first_conv_from_bgr_to_rgb(m):
"""For input channel format BGR model, use this function to change the first
conv weight to adapt the input into RGB.
:param m: the model proto
"""
# Check for first node.
g = m.graph
input_name = g.input[0].name
first_nodes = helper.find_following_nodes_by_input_value_name(g, input_name)
if len(first_nodes) > 1:
return False
first_node = first_nodes[0]
# Now we have the first node. Check this first node.
if first_node.op_type != 'Conv':
return False
weight_value = helper.find_value_by_name(g, first_node.input[1])
weight_shape = helper.get_shape_from_value_info(weight_value)
if weight_shape[1] != 3:
return False
# Do weight shuffle
weight_node = helper.find_node_by_output_name(g, weight_value.name)
weight_np = helper.constant_to_numpy(weight_node)
b_channel = np.expand_dims(weight_np[:, 0, :, :], axis=1)
g_channel = np.expand_dims(weight_np[:, 1, :, :], axis=1)
r_channel = np.expand_dims(weight_np[:, 2, :, :], axis=1)
new_np = np.concatenate((r_channel, g_channel, b_channel), axis=1)
new_node = helper.numpy_to_constant(weight_value.name, new_np)
# Replace the weight and topological sort
g.node.remove(weight_node)
g.node.extend([new_node])
other.topological_sort(g)
return True
def change_input_from_bgr_to_rgb(m):
"""For input channel format BGR model, use this function to modify the model
to accepct RGB image.If the first node is a non-group Conv. Modify weight to
adapt the input into RGB. Otherwise create a new node.
:param m: the model proto
"""
g = m.graph
if len(g.input) > 1:
print("This model has multiple inputs. Cannot change to RGB input.")
return
input_shape = helper.get_shape_from_value_info(g.input[0])
if len(input_shape) != 4 or input_shape[1] != 3:
print("The input shape is invalid for bgr conversion.")
return
# Try change conv weight first
if change_first_conv_from_bgr_to_rgb(m):
return
# Otherwise, create a special conv node and replace the input
# Construct weight
weight_np = np.zeros((3, 3, 3, 3)).astype('float32')
weight_np[0, 2, 1, 1] = 1.0
weight_np[1, 1, 1, 1] = 1.0
weight_np[2, 0, 1, 1] = 1.0
new_weight = helper.numpy_to_constant("bgr_shuffle_weight", weight_np)
# Construct Conv
new_conv = onnx.helper.make_node(
'Conv',
['rgb_input', "bgr_shuffle_weight"],
[g.input[0].name],
name='bgr_shuffle',
dilations=[1, 1],
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1]
)
# Connect the graph
old_input_value = g.input.pop()
new_input_value = onnx.helper.make_tensor_value_info(
'rgb_input',
old_input_value.type.tensor_type.elem_type,
input_shape
)
g.input.extend([new_input_value])
g.node.extend([new_weight, new_conv])
# topological sort
other.topological_sort(g)
def add_0_5_to_normalized_input(m):
"""For normalized input between -0.5 ~ 0.5, add 0.5 to the input to keep it
between 0 ~ 1.
:param m: the model proto
"""
g = m.graph
if len(g.input) > 1:
print("This model has multiple inputs. Cannot normalize input.")
return
input_shape = helper.get_shape_from_value_info(g.input[0])
if len(input_shape) != 4:
print("The input shape is not BCHW. Cannot normalize input.")
return
# Construct weight
ch = input_shape[1]
weight_np = np.zeros((ch, ch, 3, 3)).astype('float32')
for i in range(ch):
weight_np[i, i, 1, 1] = 1.0
new_weight = helper.numpy_to_constant("input_norm_weight", weight_np)
# Construct bias
bias_np = np.array([0.5] * ch).astype('float32')
new_bias = helper.numpy_to_constant("input_norm_bias", bias_np)
# Construct Conv
new_conv = onnx.helper.make_node(
'Conv',
['origin_input', "input_norm_weight", "input_norm_bias"],
[g.input[0].name],
name='input_norm',
dilations=[1, 1],
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1]
)
# Construct value_infos
old_input_value = g.input.pop()
weight_value = onnx.helper.make_tensor_value_info(
'input_norm_weight',
old_input_value.type.tensor_type.elem_type,
[3, 3, 3, 3]
)
bias_value = onnx.helper.make_tensor_value_info(
'input_norm_bias',
old_input_value.type.tensor_type.elem_type,
[3]
)
# Connect the graph
new_input_value = onnx.helper.make_tensor_value_info(
'origin_input',
old_input_value.type.tensor_type.elem_type,
input_shape
)
g.input.extend([new_input_value])
g.node.extend([new_weight, new_bias, new_conv])
g.value_info.extend([weight_value, bias_value, old_input_value])
# topological sort
other.topological_sort(g)
def add_rgb2yynn_node(m):
"""Add a conv layer which can convert rgb to yynn input.
"""
g = m.graph
if len(g.input) > 1:
print("This model has multiple inputs. Cannot change to rgb input.")
return
input_shape = helper.get_shape_from_value_info(g.input[0])
if len(input_shape) != 4:
print("The input shape is not BCHW. Cannot normalize input.")
return
# Construct weight
ch = input_shape[1]
weight_np = np.zeros((3, 3, 4, 4)).astype('float32')
weight_np[1, 1, :3, :2] = np.array([[[[0.299],
[0.587],
[0.114]]]])
weight_np[1, 1, 3, 2:] = 1.
weight_np = np.transpose(weight_np, (3, 2, 0, 1))
new_weight = helper.numpy_to_constant("input_rgb2yynn_weight", weight_np)
# Construct conv node
new_conv = onnx.helper.make_node(
'Conv',
['new_input', "input_rgb2yynn_weight"],
[g.input[0].name],
name='input_rgba2yynn',
dilations=[1, 1],
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1]
)
# Construct value_infos
old_input_value = g.input.pop()
weight_value = onnx.helper.make_tensor_value_info(
'input_rgb2yynn_weight',
old_input_value.type.tensor_type.elem_type,
[4, 4, 3, 3]
)
# Connect the graph
new_input_value = onnx.helper.make_tensor_value_info(
'new_input',
old_input_value.type.tensor_type.elem_type,
input_shape
)
g.input.extend([new_input_value])
g.node.extend([new_weight, new_conv])
g.value_info.extend([weight_value, old_input_value])
# topological sort
other.topological_sort(g)
def swap_MatMul_inputs(g, original_matmul_node):
# Create Transpose nodes
input_a_value = helper.find_value_by_name(g, original_matmul_node.input[0])
input_a_shape = helper.get_shape_from_value_info(input_a_value)
if len(input_a_shape) == 2:
perm = [1, 0]
else:
perm = [0, 2, 1]
new_input_b_node = onnx.helper.make_node(
'Transpose',
inputs = [input_a_value.name],
outputs = [input_a_value.name + '_transposed'],
name = f"{input_a_value.name}_transposed_for_{original_matmul_node.name}",
perm = perm
)
input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1])
input_b_shape = helper.get_shape_from_value_info(input_b_value)
if len(input_b_shape) == 3:
perm = [0, 2, 1]
else:
perm = [0, 1, 3, 2]
new_input_a_node = onnx.helper.make_node(
'Transpose',
inputs = [input_b_value.name],
outputs = [input_b_value.name + '_transposed'],
name = f'{input_b_value.name}_transposed_for_{original_matmul_node.name}',
perm = perm
)
# Create new MatMul node
new_matmul_node = onnx.helper.make_node(
'MatMul',
inputs = [new_input_a_node.output[0], new_input_b_node.output[0]],
outputs = [original_matmul_node.output[0] + '_transposed'],
name = original_matmul_node.name + '_transposed'
)
# Create final Transpose node
output_value = helper.find_value_by_name(g, original_matmul_node.output[0])
output_shape = helper.get_shape_from_value_info(output_value)
if len(output_shape) == 3:
perm = [0, 2, 1]
else:
perm = [0, 1, 3, 2]
new_final_transpose_node = onnx.helper.make_node(
'Transpose',
inputs = [new_matmul_node.output[0]],
outputs = [original_matmul_node.output[0]],
name = original_matmul_node.name + '_final_transpose',
perm = perm
)
# Add new nodes
g.node.extend([new_input_a_node, new_input_b_node, new_matmul_node, new_final_transpose_node])
# Delete original nodes
g.node.remove(original_matmul_node)
def split_MatMul_batch_then_concat(g, original_matmul_node):
new_nodes = []
final_concat_inputs = []
# Get the batch count
input_a_value = helper.find_value_by_name(g, original_matmul_node.input[0])
input_a_shape = helper.get_shape_from_value_info(input_a_value)
input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1])
input_b_shape = helper.get_shape_from_value_info(input_b_value)
if len(input_a_shape) == 3:
batch_count = input_a_shape[0]
else:
batch_count = input_a_shape[1]
for i in range(batch_count):
# Create Split nodes for input A
starts_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_starts", (1, ), [i])
ends_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_ends", (1, ), [i+1])
axes_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_axes", (1, ), [len(input_a_shape) - 3])
new_sliced_a_node = onnx.helper.make_node(
'Slice',
inputs = [input_a_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]],
outputs = [f"{input_a_value.name}_sliced_{i}"],
name = f"{input_a_value.name}_sliced_{i}_for_{original_matmul_node.name}"
)
new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_a_node])
# Create Split nodes for input B
starts_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_starts", (1, ), [i])
ends_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_ends", (1, ), [i+1])
axes_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_axes", (1, ), [len(input_b_shape) - 3])
new_sliced_b_node = onnx.helper.make_node(
'Slice',
inputs = [input_b_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]],
outputs = [f"{input_b_value.name}_sliced_{i}"],
name = f"{input_b_value.name}_sliced_{i}_for_{original_matmul_node.name}"
)
new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_b_node])
# Create MatMul nodes
new_matmul_node = onnx.helper.make_node(
'MatMul',
inputs = [new_sliced_a_node.output[0], new_sliced_b_node.output[0]],
outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"],
name = f"{original_matmul_node.name}_sliced_{i}"
)
new_nodes.append(new_matmul_node)
final_concat_inputs.append(new_matmul_node.output[0])
# Create Concat nodes
output_value = helper.find_value_by_name(g, original_matmul_node.output[0])
if output_value is None:
output_value = helper.find_output_by_name(g, original_matmul_node.output[0])
if output_value is None:
helper.logger.error(f"Cannot find value_info for {original_matmul_node.output[0]}")
output_shape = helper.get_shape_from_value_info(output_value)
new_concat_node = onnx.helper.make_node(
"Concat",
inputs = final_concat_inputs,
outputs = [original_matmul_node.output[0]],
name = f"{original_matmul_node.name}_final_concat",
axis = len(output_shape) - 3
)
new_nodes.append(new_concat_node)
# Add new nodes
g.node.extend(new_nodes)
# Delete original nodes
g.node.remove(original_matmul_node)
def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
new_nodes = []
final_concat_inputs = []
# Get the batch count
input_b_node = helper.find_node_by_output_name(g, original_matmul_node.input[1])
input_b_np = helper.constant_to_numpy(input_b_node)
if len(input_b_np.shape) == 3:
batch_count = input_b_np.shape[0]
else:
batch_count = input_b_np.shape[1]
for i in range(batch_count):
# Create new constant node
if len(input_b_np.shape) == 3:
new_np = input_b_np[i:i+1, ...]
else:
new_np = input_b_np[:, i:i+1, ...]
new_weight = helper.numpy_to_constant(f"{input_b_node.name}_sliced_{i}", new_np)
new_nodes.append(new_weight)
# Create MatMul nodes
new_matmul_node = onnx.helper.make_node(
'MatMul',
inputs = [original_matmul_node.input[0], new_weight.output[0]],
outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"],
name = f"{original_matmul_node.name}_sliced_{i}"
)
new_nodes.append(new_matmul_node)
final_concat_inputs.append(new_matmul_node.output[0])
# Create Concat nodes
output_value = helper.find_value_by_name(g, original_matmul_node.output[0])
output_shape = helper.get_shape_from_value_info(output_value)
new_concat_node = onnx.helper.make_node(
"Concat",
inputs = final_concat_inputs,
outputs = [original_matmul_node.output[0]],
name = f"{original_matmul_node.name}_final_concat",
axis = len(output_shape) - 3
)
new_nodes.append(new_concat_node)
# Add new nodes
g.node.extend(new_nodes)
# Delete original value info
input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1])
if input_b_value is not None:
g.value_info.remove(input_b_value)
# Delete original nodes
g.node.remove(original_matmul_node)
g.node.remove(input_b_node)
def special_MatMul_process(g):
for node in g.node:
if node.op_type != 'MatMul':
continue
input_a_name = node.input[0]
input_a_value = helper.find_value_by_name(g, input_a_name)
input_b_name = node.input[1]
input_b_value = helper.find_value_by_name(g, input_b_name)
if input_a_value is None or input_b_value is None:
continue
input_a_shape = helper.get_shape_from_value_info(input_a_value)
input_b_shape = helper.get_shape_from_value_info(input_b_value)
# Check shapes and choose the process
# Normal case, Skip
if len(input_b_shape) == 2:
continue
# Too many dimensions or too few dimensions. Not supported. Skip
if len(input_a_shape) > 4 or len(input_b_shape) > 4:
helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have too many dimensions.")
continue
if len(input_a_shape) < 2 or len(input_b_shape) < 2:
helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have two few dimensions.")
continue
# For 4 dimension, check the first dimension (should be 1) and treated as 3 dimensions.
extra_dim = None
if len(input_a_shape) == 4:
extra_dim = input_a_shape[0]
input_a_shape = input_a_shape[1:]
if len(input_b_shape) == 4:
if input_b_shape[0] != extra_dim:
helper.logger.warning(f"Cannot optimize MatMul {node.name}: input dimension batch sizes does not match ({extra_dim} vs {input_b_shape[0]}).")
continue
input_b_shape = input_b_shape[1:]
# Check input B dimension
# If B is 1 x W x V, it is the same as normal case.
if input_b_shape[0] == 1:
continue
# If B is B x W x V, but B is a constant.
input_b_node = helper.find_node_by_output_name(g, input_b_name)
if input_b_node is not None and input_b_node.op_type == 'Constant':
# Constant input
helper.logger.debug(f"Optimizing MatMul node {node.name}: split constant input.")
split_MatMul_Constant_input_then_concat(g, node)
# If B is B x W x V and A is 1 x H x W, do the swap.
elif len(input_a_shape) == 2 or (input_a_shape[0] == 1 and (extra_dim is None or extra_dim == 1)):
helper.logger.debug(f"Optimizing MatMul node {node.name}: swap input.")
swap_MatMul_inputs(g, node)
# If B is B x W x V and A is B x H x W, do the split.
elif input_b_shape[0] == input_a_shape[0]:
helper.logger.debug(f"Optimizing MatMul node {node.name}: split input batch.")
split_MatMul_batch_then_concat(g, node)
# Other cases are not supported: If B is B x W x V but A is X x H x W.
else:
helper.logger.warning(f"Cannot optimize MatMul {node.name}: unknown reason. Might be shape mismatch.")
continue
other.topological_sort(g)

View File

@ -2,6 +2,7 @@
# Original: tools/pytorch2onnx.py, modified by Kneron
import argparse
import onnx
import mmcv
import numpy as np
import onnxruntime as rt
@ -18,6 +19,10 @@ from mmseg.apis.inference import LoadImage
from mmseg.datasets.pipelines import Compose
from mmseg.models import build_segmentor
from optimizer_scripts.pytorch_exported_onnx_preprocess import (
torch_exported_onnx_flow,
)
torch.manual_seed(3)
@ -117,10 +122,12 @@ def pytorch2onnx(model,
dynamic_axes=None)
print(f'Successfully exported ONNX model: {output_file}')
model.forward = origin_forward
# NOTE: optimizing onnx for kneron inference
m = onnx.load(output_file)
m = torch_exported_onnx_flow(m, disable_fuse_bn=False)
onnx.save(m, output_file)
if verify:
# check by onnx
import onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)