style: fix format so pep8 is satisfied

This commit is contained in:
chingning.chen 2022-04-12 14:26:54 +08:00
parent a783220efa
commit 0136a5b2bd
27 changed files with 3329 additions and 2104 deletions

View File

@ -316,7 +316,9 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
session_options.register_custom_ops_library(ort_custom_op_path)
providers = ['CPUExecutionProvider']
provider_options = [{}]
is_cuda_available = ort.get_device() == 'GPU' and torch.cuda.is_available()
is_cuda_available = (
ort.get_device() == 'GPU' and torch.cuda.is_available()
)
if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider')
device_id = device_id or 0
@ -334,7 +336,7 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
self.output_name_list = [sess_outputs[0].name]
self.cfg = cfg # TODO: necessary?
self.test_cfg = cfg.model.test_cfg
self.test_mode = self.test_cfg.mode # NOTE: should be 'whole' or 'slide'
self.test_mode = self.test_cfg.mode # NOTE: either 'whole' or 'slide'
self.is_cuda_available = is_cuda_available
self.count_mat = None
try:

View File

@ -171,7 +171,8 @@ if __name__ == '__main__':
setup(
name='mmsegmentation',
version=get_version(),
description='Open MMLab Semantic Segmentation Toolbox and Benchmark (Kneron Edition)',
description='Open MMLab Semantic Segmentation Toolbox '
'and Benchmark (Kneron Edition)',
long_description=readme(),
long_description_content_type='text/markdown',
author='MMSegmentation Contributors and Kneron',

View File

@ -163,9 +163,9 @@ def main():
efficient_test = eval_kwargs.get('efficient_test', False)
if efficient_test:
warnings.warn(
'``efficient_test=True`` does not have effect in tools/test_kneron.py, '
'the evaluation and format results are CPU memory efficient by '
'default')
'"efficient_test=True" does not have effect in '
'tools/test_kneron.py, the evaluation and format '
'results are CPU memory efficient by default')
eval_on_format_results = (
args.eval is not None and 'cityscapes' in args.eval)

View File

@ -5,55 +5,81 @@ import sys
from tools.other import topological_sort
from tools import helper
def fuse_bias_in_consecutive_1x1_conv(g):
for second in g.node:
# Find two conv
if second.op_type != 'Conv':
if second.op_type != "Conv":
continue
first = helper.find_node_by_output_name(g, second.input[0])
if first is None or first.op_type != 'Conv':
if first is None or first.op_type != "Conv":
continue
# Check if the first one has only one folloing node
if len(helper.find_following_nodes_by_input_value_name(g, first.output[0])) != 1:
if (
len(
helper.find_following_nodes_by_input_value_name(
g, first.output[0]
)
)
!= 1
):
continue
# If first node has no bias, continue
if len(first.input) == 2:
continue
# Check their kernel size
first_kernel_shape = helper.get_list_attribute_by_name(first, 'kernel_shape', 'int')
second_kernel_shape = helper.get_list_attribute_by_name(second, 'kernel_shape', 'int')
prod = first_kernel_shape[0] * first_kernel_shape[1] * second_kernel_shape[0] * second_kernel_shape[1]
first_kernel_shape = helper.get_list_attribute_by_name(
first, "kernel_shape", "int"
)
second_kernel_shape = helper.get_list_attribute_by_name(
second, "kernel_shape", "int"
)
prod = (
first_kernel_shape[0]
* first_kernel_shape[1]
* second_kernel_shape[0]
* second_kernel_shape[1]
)
if prod != 1:
continue
print('Found: ', first.name, ' ', second.name)
print("Found: ", first.name, " ", second.name)
# Get bias of the nodes
first_bias_node = helper.find_node_by_output_name(g, first.input[2])
second_weight_node = helper.find_node_by_output_name(g, second.input[1])
second_weight_node = helper.find_node_by_output_name(
g, second.input[1]
)
second_bias_node = helper.find_node_by_output_name(g, second.input[2])
first_bias = helper.constant_to_numpy(first_bias_node)
second_weight = helper.constant_to_numpy(second_weight_node)
second_bias = helper.constant_to_numpy(second_bias_node)
# Calculate the weight for second node
first_bias = np.reshape(first_bias, (1, first_bias.size))
second_weight = np.reshape(second_weight, (second_weight.shape[0], second_weight.shape[1]))
second_weight = np.reshape(
second_weight, (second_weight.shape[0], second_weight.shape[1])
)
second_weight = np.transpose(second_weight)
new_second_bias = second_bias + np.matmul(first_bias, second_weight)
new_second_bias = np.reshape(new_second_bias, (new_second_bias.size,))
# Generate new weight
new_first_bias = np.reshape(first_bias, (first_bias.size, ))
new_first_bias = np.reshape(first_bias, (first_bias.size,))
for i in range(new_first_bias.shape[0]):
new_first_bias[i] = 0.0
new_first_bias_node = helper.numpy_to_constant(first_bias_node.output[0], new_first_bias)
new_second_bias_node = helper.numpy_to_constant(second_bias_node.output[0], new_second_bias)
new_first_bias_node = helper.numpy_to_constant(
first_bias_node.output[0], new_first_bias
)
new_second_bias_node = helper.numpy_to_constant(
second_bias_node.output[0], new_second_bias
)
# Delete old weight and add new weights
g.node.remove(first_bias_node)
g.node.remove(second_bias_node)
g.node.extend([new_first_bias_node, new_second_bias_node])
topological_sort(g)
if __name__ == "__main__":
if len(sys.argv) != 3:
exit(1)
m = onnx.load(sys.argv[1])
fuse_bias_in_consecutive_1x1_conv(m.graph)
onnx.save(m, sys.argv[2])
onnx.save(m, sys.argv[2])

View File

@ -1,5 +1,6 @@
import onnx
import onnx.utils
try:
from onnx import optimizer
except ImportError:
@ -9,23 +10,107 @@ import argparse
import tools.modhelper as helper
import tools.other as other
import tools.replacing as replacing
# Main process
# Argument parser
parser = argparse.ArgumentParser(description="Edit an ONNX model.\nThe processing sequense is 'delete nodes/values' -> 'add nodes' -> 'change shapes'.\nCutting cannot be done with other operations together")
parser.add_argument('in_file', type=str, help='input ONNX FILE')
parser.add_argument('out_file', type=str, help="ouput ONNX FILE")
parser.add_argument('-c', '--cut', dest='cut_node', type=str, nargs='+', help="remove nodes from the given nodes(inclusive)")
parser.add_argument('--cut-type', dest='cut_type', type=str, nargs='+', help="remove nodes by type from the given nodes(inclusive)")
parser.add_argument('-d', '--delete', dest='delete_node', type=str, nargs='+', help="delete nodes by names and only those nodes")
parser.add_argument('--delete-input', dest='delete_input', type=str, nargs='+', help="delete inputs by names")
parser.add_argument('--delete-output', dest='delete_output', type=str, nargs='+', help="delete outputs by names")
parser.add_argument('-i', '--input', dest='input_change', type=str, nargs='+', help="change input shape (e.g. -i 'input_0 1 3 224 224')")
parser.add_argument('-o', '--output', dest='output_change', type=str, nargs='+', help="change output shape (e.g. -o 'input_0 1 3 224 224')")
parser.add_argument('--add-conv', dest='add_conv', type=str, nargs='+', help='add nop conv using specific input')
parser.add_argument('--add-bn', dest='add_bn', type=str, nargs='+', help='add nop bn using specific input')
parser.add_argument('--rename-output', dest='rename_output', type=str, nargs='+', help='Rename the specific output(e.g. --rename-output old_name new_name)')
parser.add_argument('--pixel-bias-value', dest='pixel_bias_value', type=str, nargs='+', help='(per channel) set pixel value bias bn layer at model front for normalization( e.g. --pixel_bias_value "[104.0, 117.0, 123.0]" )')
parser.add_argument('--pixel-scale-value', dest='pixel_scale_value', type=str, nargs='+', help='(per channel) set pixel value scale bn layer at model front for normalization( e.g. --pixel_scale_value "[0.0078125, 0.0078125, 0.0078125]" )')
parser = argparse.ArgumentParser(
description="Edit an ONNX model.\nThe processing sequense is 'delete "
"nodes/values' -> 'add nodes' -> 'change shapes'.\nCutting "
"cannot be done with other operations together"
)
parser.add_argument("in_file", type=str, help="input ONNX FILE")
parser.add_argument("out_file", type=str, help="ouput ONNX FILE")
parser.add_argument(
"-c",
"--cut",
dest="cut_node",
type=str,
nargs="+",
help="remove nodes from the given nodes(inclusive)",
)
parser.add_argument(
"--cut-type",
dest="cut_type",
type=str,
nargs="+",
help="remove nodes by type from the given nodes(inclusive)",
)
parser.add_argument(
"-d",
"--delete",
dest="delete_node",
type=str,
nargs="+",
help="delete nodes by names and only those nodes",
)
parser.add_argument(
"--delete-input",
dest="delete_input",
type=str,
nargs="+",
help="delete inputs by names",
)
parser.add_argument(
"--delete-output",
dest="delete_output",
type=str,
nargs="+",
help="delete outputs by names",
)
parser.add_argument(
"-i",
"--input",
dest="input_change",
type=str,
nargs="+",
help="change input shape (e.g. -i 'input_0 1 3 224 224')",
)
parser.add_argument(
"-o",
"--output",
dest="output_change",
type=str,
nargs="+",
help="change output shape (e.g. -o 'input_0 1 3 224 224')",
)
parser.add_argument(
"--add-conv",
dest="add_conv",
type=str,
nargs="+",
help="add nop conv using specific input",
)
parser.add_argument(
"--add-bn",
dest="add_bn",
type=str,
nargs="+",
help="add nop bn using specific input",
)
parser.add_argument(
"--rename-output",
dest="rename_output",
type=str,
nargs="+",
help="Rename the specific output(e.g. --rename-output old_name new_name)",
)
parser.add_argument(
"--pixel-bias-value",
dest="pixel_bias_value",
type=str,
nargs="+",
help='(per channel) set pixel value bias bn layer at model front for '
'normalization( e.g. --pixel_bias_value "[104.0, 117.0, 123.0]" )',
)
parser.add_argument(
"--pixel-scale-value",
dest="pixel_scale_value",
type=str,
nargs="+",
help='(per channel) set pixel value scale bn layer at model front for '
'normalization( e.g. --pixel_scale_value '
'"[0.0078125, 0.0078125, 0.0078125]" )',
)
args = parser.parse_args()
@ -60,23 +145,48 @@ if args.add_bn is not None:
if args.pixel_bias_value is not None or args.pixel_scale_value is not None:
if len(g.input) > 1:
raise ValueError(" '--pixel-bias-value' and '--pixel-scale-value' only support one input node model currently")
raise ValueError(
" '--pixel-bias-value' and '--pixel-scale-value' "
"only support one input node model currently"
)
i_n = g.input[0]
pixel_bias_value = [0] * i_n.type.tensor_type.shape.dim[1].dim_value
pixel_scale_value = [1] * i_n.type.tensor_type.shape.dim[1].dim_value
if args.pixel_bias_value is not None and len(args.pixel_bias_value) == 1:
pixel_bias_value = [float(n) for n in args.pixel_bias_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')]
pixel_bias_value = [
float(n)
for n in args.pixel_bias_value[0]
.replace("[", "")
.replace("]", "")
.split(",")
]
if args.pixel_scale_value is not None and len(args.pixel_scale_value) == 1:
pixel_scale_value = [float(n) for n in args.pixel_scale_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')]
pixel_scale_value = [
float(n)
for n in args.pixel_scale_value[0]
.replace("[", "")
.replace("]", "")
.split(",")
]
if i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_bias_value) or i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_scale_value):
raise ValueError("--pixel-bias-value (" + str(pixel_bias_value) + ") and --pixel-scale-value (" + str(pixel_scale_value) + ") should be same as input dimension:" + str(i_n.type.tensor_type.shape.dim[1].dim_value) )
other.add_bias_scale_bn_after(g, i_n.name, pixel_bias_value, pixel_scale_value)
if i_n.type.tensor_type.shape.dim[1].dim_value != len(
pixel_bias_value
) or i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_scale_value):
raise ValueError(
"--pixel-bias-value ("
+ str(pixel_bias_value)
+ ") and --pixel-scale-value ("
+ str(pixel_scale_value)
+ ") should be same as input dimension:"
+ str(i_n.type.tensor_type.shape.dim[1].dim_value)
)
other.add_bias_scale_bn_after(
g, i_n.name, pixel_bias_value, pixel_scale_value
)
# Change input and output shapes as requested
if args.input_change is not None:
@ -100,14 +210,21 @@ if args.rename_output:
print("Rename output should be paires of names.")
else:
for i in range(0, len(args.rename_output), 2):
other.rename_output_name(g, args.rename_output[i], args.rename_output[i + 1])
other.rename_output_name(
g, args.rename_output[i], args.rename_output[i + 1]
)
# Remove useless nodes
if args.delete_node or args.delete_input or args.input_change or args.output_change:
if (
args.delete_node
or args.delete_input
or args.input_change
or args.output_change
):
# If shape changed during the modification, redo shape inference.
while(len(g.value_info) > 0):
while len(g.value_info) > 0:
g.value_info.pop()
passes = ['extract_constant_to_initializer']
passes = ["extract_constant_to_initializer"]
m = optimizer.optimize(m, passes)
g = m.graph
replacing.replace_initializer_with_Constant(g)
@ -115,4 +232,4 @@ other.topological_sort(g)
# Polish and output
m = other.polish_model(m)
other.add_output_to_value_info(m.graph)
onnx.save(m, args.out_file)
onnx.save(m, args.out_file)

View File

@ -11,42 +11,44 @@ if len(sys.argv) != 3:
# Modify onnx
m = onnx.load(sys.argv[1])
special.add_0_5_to_normalized_input(m)
onnx.save(m, sys.argv[1][:-4] + 'norm.onnx')
onnx.save(m, sys.argv[1][:-4] + "norm.onnx")
# Change input node
origin_file = open(sys.argv[2], 'r')
origin_file = open(sys.argv[2], "r")
origin_json = json.load(origin_file)
origin_json["input_node"]["output_datapath_radix"] = [8]
new_json_str = json.dumps(origin_json)
# Modify json
file = open(sys.argv[1][:-4] + 'norm.onnx' + '.json', 'w')
file = open(sys.argv[1][:-4] + "norm.onnx" + ".json", "w")
s = """{{
\"{0}\" :
{{
\"bias_bitwidth\" : 16,
\"{0}_bias\" : [15],
\"{0}_weight\" : [3,3,3],
\"conv_coarse_shift\" : [-4,-4,-4],
\"conv_fine_shift\" : [0,0,0],
\"conv_total_shift\" : [-4,-4,-4],
\"cpu_mode\" : false,
\"delta_input_bitwidth\" : [0],
\"delta_output_bitwidth\" : 8,
\"flag_radix_bias_eq_output\" : true,
\"input_scale\" : [[1.0,1.0,1.0]],
\"output_scale\" : [1.0, 1.0, 1.0],
\"psum_bitwidth\" : 16,
\"weight_bitwidth\" : 8,
\"input_datapath_bitwidth\" : [8],
\"input_datapath_radix\" : [8],
\"working_input_bitwidth\" : 8,
\"working_input_radix\" : [8],
\"working_output_bitwidth\" : 16,
\"working_output_radix\" : 15,
\"output_datapath_bitwidth\" : 8,
\"output_datapath_radix\" : 7
}},\n""".format('input_norm')
\"{0}\" :
{{
\"bias_bitwidth\" : 16,
\"{0}_bias\" : [15],
\"{0}_weight\" : [3,3,3],
\"conv_coarse_shift\" : [-4,-4,-4],
\"conv_fine_shift\" : [0,0,0],
\"conv_total_shift\" : [-4,-4,-4],
\"cpu_mode\" : false,
\"delta_input_bitwidth\" : [0],
\"delta_output_bitwidth\" : 8,
\"flag_radix_bias_eq_output\" : true,
\"input_scale\" : [[1.0,1.0,1.0]],
\"output_scale\" : [1.0, 1.0, 1.0],
\"psum_bitwidth\" : 16,
\"weight_bitwidth\" : 8,
\"input_datapath_bitwidth\" : [8],
\"input_datapath_radix\" : [8],
\"working_input_bitwidth\" : 8,
\"working_input_radix\" : [8],
\"working_output_bitwidth\" : 16,
\"working_output_radix\" : 15,
\"output_datapath_bitwidth\" : 8,
\"output_datapath_radix\" : 7
}},\n""".format(
"input_norm"
)
file.write(s + new_json_str[1:])
file.close()
origin_file.close()

View File

@ -2,33 +2,33 @@
import sys
import onnx
import numpy as np
from onnx import numpy_helper
from tools import other, helper
"""
Change onnx model from version 1.3 to version 1.4.
Modify the BN node by removing the spatial attribute
Modify the Upsample node by removing the 'scales' attribute, and adding a constant node instead.
Model's ir_version and opset_import are updated.
Change onnx model from version 1.3 to version 1.4.
- Modify the BN node by removing the spatial attribute
- Modify the Upsample node by removing the 'scales' attribute,
and adding a constant node instead.
- Model's ir_version and opset_import are updated.
"""
def remove_BN_spatial(g):
for node in g.node:
if node.op_type != 'BatchNormalization':
if node.op_type != "BatchNormalization":
continue
for att in node.attribute:
if att.name == 'spatial':
if att.name == "spatial":
node.attribute.remove(att)
def upsample_attribute_to_const(g):
for node in g.node:
if node.op_type != 'Upsample':
if node.op_type != "Upsample":
continue
scales_exist = False
for att in node.attribute:
if att.name == 'scales':
if att.name == "scales":
scales_exist = True
break
if not scales_exist:
@ -36,18 +36,23 @@ def upsample_attribute_to_const(g):
shape = [len(att.floats)]
node.attribute.remove(att)
new_node = helper.list_to_constant(node.name+'_input', shape, att.floats)
new_node = helper.list_to_constant(
node.name + "_input", shape, att.floats
)
g.node.extend([new_node])
value_info = onnx.helper.make_tensor_value_info(node.name+'_input', onnx.TensorProto.FLOAT, shape)
node.input.extend([node.name+'_input'])
value_info = onnx.helper.make_tensor_value_info(
node.name + "_input", onnx.TensorProto.FLOAT, shape
)
node.input.extend([node.name + "_input"])
g.value_info.extend([value_info])
def relu6_to_clip(g):
for node in g.node:
if node.op_type != 'Relu':
if node.op_type != "Relu":
continue
max_val = helper.get_var_attribute_by_name(node, 'max', 'float')
max_val = helper.get_var_attribute_by_name(node, "max", "float")
if max_val is None:
continue
new_node = onnx.helper.make_node(
@ -56,11 +61,12 @@ def relu6_to_clip(g):
node.output,
name=node.name,
max=max_val,
min=0.0
min=0.0,
)
g.node.remove(node)
g.node.extend([new_node])
def PRelu_weight_reshape(g):
# For PRelu with single dimension weight. Expand it to 1, x, 1, 1
for node in g.node:
@ -91,16 +97,18 @@ def PRelu_weight_reshape(g):
new_input = onnx.helper.make_tensor_value_info(
node.input[1],
input_value.type.tensor_type.elem_type,
(1, slope.dims[1], 1, 1))
(1, slope.dims[1], 1, 1),
)
g.input.remove(input_value)
g.input.append(new_input)
value_info = helper.find_value_by_name(g, node.input[1])
if value_info is not None:
g.value_info.remove(value_info)
def do_convert(m):
graph = m.graph
# Modify the nodes.
remove_BN_spatial(graph)
upsample_attribute_to_const(graph)
@ -113,6 +121,7 @@ def do_convert(m):
m.opset_import[0].version = 9
return m
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage:{} file_in file_out".format(sys.argv[0]))

View File

@ -3,31 +3,38 @@
import sys
import onnx
import onnx.utils
import numpy as np
from onnx import numpy_helper
from tools import other, helper, replacing
"""
Change onnx model from version 1.4 to version 1.6.
"""
def replace_all_attribute_to_const_node_in_pad_node(g):
node_to_remove = []
node_to_extend = []
for node in g.node:
if node.op_type != 'Pad':
if node.op_type != "Pad":
continue
pad_loc_node = None # must have
pad_mode = 'constant'
pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [0.0]) # need scalar
pad_loc_node = None # must have
pad_mode = "constant"
pad_value_node = helper.list_to_constant(
node.name + "_pad_value", [], [0.0]
) # need scalar
for att in node.attribute:
if att.name == 'mode':
pad_mode = helper.get_var_attribute_by_name(node, 'mode', 'string')
if att.name == 'pads':
pad_loc_node = helper.list_to_constant(node.name+'_pad_loc', [len(att.ints)], att.ints)
if att.name == 'value':
pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [att.f])
if att.name == "mode":
pad_mode = helper.get_var_attribute_by_name(
node, "mode", "string"
)
if att.name == "pads":
pad_loc_node = helper.list_to_constant(
node.name + "_pad_loc", [len(att.ints)], att.ints
)
if att.name == "value":
pad_value_node = helper.list_to_constant(
node.name + "_pad_value", [], [att.f]
)
new_node = onnx.helper.make_node(
"Pad",
@ -40,24 +47,30 @@ def replace_all_attribute_to_const_node_in_pad_node(g):
node_to_extend.append(new_node)
node_to_extend.append(pad_loc_node)
node_to_extend.append(pad_value_node)
for node in node_to_remove:
for node in node_to_remove:
g.node.remove(node)
for node in node_to_extend:
for node in node_to_extend:
g.node.extend([node])
def upsampling_to_resize(g):
for node in g.node:
if node.op_type != 'Upsample':
if node.op_type != "Upsample":
continue
upsampling_mode = helper.get_var_attribute_by_name(node, 'mode', 'string')
upsampling_mode = helper.get_var_attribute_by_name(
node, "mode", "string"
)
scale_value_node = helper.find_node_by_output_name(g, node.input[1])
if scale_value_node.op_type != "Constant":
raise TypeError('seems there is a dynamic "scales" param in Upsampling node: ' + node.name + ' , you might need to do constant folding first')
raise TypeError(
'seems there is a dynamic "scales" param in Upsampling node: '
+ node.name
+ " , you might need to do constant folding first"
)
roi_node = helper.list_to_constant(node.name+'_roi_value', [0], [])
roi_node = helper.list_to_constant(node.name + "_roi_value", [0], [])
new_node = onnx.helper.make_node(
"Resize",
@ -65,7 +78,7 @@ def upsampling_to_resize(g):
[node.output[0]],
name=node.output[0],
mode=upsampling_mode,
coordinate_transformation_mode = 'asymmetric'
coordinate_transformation_mode="asymmetric",
)
g.node.remove(node)
@ -75,7 +88,7 @@ def upsampling_to_resize(g):
def replace_all_attribute_to_const_node_in_slice_node(g):
for node in g.node:
if node.op_type != 'Slice':
if node.op_type != "Slice":
continue
axes_const_node = None
@ -83,62 +96,75 @@ def replace_all_attribute_to_const_node_in_slice_node(g):
starts_const_node = None
steps_const_node = None
for att in node.attribute:
if att.name == 'axes':
axes_const_node = helper.list_to_constant(node.name+'_axes_value', [len(att.ints)], att.ints)
if att.name == 'ends':
ends_const_node = helper.list_to_constant(node.name+'_ends_value', [len(att.ints)], att.ints)
if att.name == "axes":
axes_const_node = helper.list_to_constant(
node.name + "_axes_value", [len(att.ints)], att.ints
)
if att.name == 'starts':
starts_const_node = helper.list_to_constant(node.name+'_starts_value', [len(att.ints)], att.ints)
if att.name == "ends":
ends_const_node = helper.list_to_constant(
node.name + "_ends_value", [len(att.ints)], att.ints
)
if att.name == 'steps':
steps_const_node = helper.list_to_constant(node.name+'_steps_value',[ len(att.ints)], att.ints)
if att.name == "starts":
starts_const_node = helper.list_to_constant(
node.name + "_starts_value", [len(att.ints)], att.ints
)
## pop out from back
if att.name == "steps":
steps_const_node = helper.list_to_constant(
node.name + "_steps_value", [len(att.ints)], att.ints
)
# pop out from back
attr_len = len(node.attribute)
for i in range(attr_len):
node.attribute.remove(node.attribute[ attr_len -1 - i ])
node.attribute.remove(node.attribute[attr_len - 1 - i])
## according the spec, we need to add node in specific order
if starts_const_node != None:
# according the spec, we need to add node in specific order
if starts_const_node is not None:
g.node.extend([starts_const_node])
node.input.extend([starts_const_node.name])
if ends_const_node != None:
if ends_const_node is not None:
g.node.extend([ends_const_node])
node.input.extend([ends_const_node.name])
if axes_const_node != None:
node.input.extend([ends_const_node.name])
if axes_const_node is not None:
g.node.extend([axes_const_node])
node.input.extend([axes_const_node.name])
if steps_const_node != None:
if steps_const_node is not None:
g.node.extend([steps_const_node])
node.input.extend([steps_const_node.name])
def replace_min_max_attribute_to_const_node_in_clip_node(g):
for node in g.node:
if node.op_type != 'Clip':
if node.op_type != "Clip":
continue
max_const_node = None
min_const_node = None
for att in node.attribute:
if att.name == 'max':
max_const_node = helper.list_to_constant(node.name+'_max_value', [], [att.f])
if att.name == 'min':
min_const_node = helper.list_to_constant(node.name+'_min_value', [], [att.f])
if att.name == "max":
max_const_node = helper.list_to_constant(
node.name + "_max_value", [], [att.f]
)
## pop out from back
node.attribute.remove(node.attribute[1])
node.attribute.remove(node.attribute[0])
## according the spec, we need to add node in specific order
if att.name == "min":
min_const_node = helper.list_to_constant(
node.name + "_min_value", [], [att.f]
)
# pop out from back
node.attribute.remove(node.attribute[1])
node.attribute.remove(node.attribute[0])
# according the spec, we need to add node in specific order
g.node.extend([min_const_node])
g.node.extend([max_const_node])
node.input.extend([min_const_node.name])
node.input.extend([max_const_node.name])
def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto:
"""Update ir_version from 4 to 6 and update opset from 9 to 11.
@ -173,6 +199,7 @@ def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto:
model = other.polish_model(model)
return model
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage:{} file_in file_out".format(sys.argv[0]))

View File

@ -1,45 +1,51 @@
import onnx
import onnx.utils
try:
from onnx import optimizer
except ImportError:
import onnxoptimizer as optimizer
import sys
import argparse
import logging
from tools import eliminating
from tools import fusing
from tools import replacing
from tools import other
from tools import special
from tools import combo
from tools.helper import logger
# from tools import temp
def onnx2onnx_flow(m: onnx.ModelProto,
disable_fuse_bn=False,
bn_on_skip=False,
bn_before_add=False,
bgr=False,
norm=False,
rgba2yynn=False,
eliminate_tail=False,
opt_matmul=False,
duplicate_shared_weights=True) -> onnx.ModelProto:
def onnx2onnx_flow(
m: onnx.ModelProto,
disable_fuse_bn=False,
bn_on_skip=False,
bn_before_add=False,
bgr=False,
norm=False,
rgba2yynn=False,
eliminate_tail=False,
opt_matmul=False,
duplicate_shared_weights=True,
) -> onnx.ModelProto:
"""Optimize the onnx.
Args:
m (ModelProto): the input onnx ModelProto
disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False.
bn_on_skip (bool, optional): add BN operator on skip branches. Defaults to False.
bn_before_add (bool, optional): add BN before Add node on every branches. Defaults to False.
bgr (bool, optional): add an Conv layer to convert rgb input to bgr. Defaults to False.
norm (bool, optional): add an Conv layer to add 0.5 tp the input. Defaults to False.
rgba2yynn (bool, optional): add an Conv layer to convert rgb input to yynn . Defaults to False.
eliminate_tail (bool, optional): remove the trailing NPU unsupported nodes. Defaults to False.
opt_matmul(bool, optional): optimize the MatMul layers according to the NPU limit. Defaults to False.
duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True.
disable_fuse_bn (bool, optional): do not fuse BN into Conv.
Defaults to False.
bn_on_skip (bool, optional): add BN operator on skip branches.
Defaults to False.
bn_before_add (bool, optional): add BN before Add node on every branch.
Defaults to False.
bgr (bool, optional): add an Conv layer to convert rgb input to bgr.
Defaults to False.
norm (bool, optional): add an Conv layer to add 0.5 tp the input.
Defaults to False.
rgba2yynn (bool, optional): add an Conv layer to convert rgb to yynn.
Defaults to False.
eliminate_tail (bool, optional): remove trailing NPU unsupported nodes.
Defaults to False.
opt_matmul(bool, optional): optimize MatMul layers due to NPU limit.
Defaults to False.
duplicate_shared_weights(bool, optional): duplicate shared weights.
Defaults to True.
Returns:
ModelProto: the optimized onnx model object.
@ -79,28 +85,83 @@ def onnx2onnx_flow(m: onnx.ModelProto,
return m
# Main process
if __name__ == "__main__":
# Argument parser
parser = argparse.ArgumentParser(description="Optimize an ONNX model for Kneron compiler")
parser.add_argument('in_file', help='input ONNX FILE')
parser.add_argument('-o', '--output', dest='out_file', type=str, help="ouput ONNX FILE")
parser.add_argument('--log', default='i', type=str, help="set log level")
parser.add_argument('--bgr', action='store_true', default=False, help="set if the model is trained in BGR mode")
parser.add_argument('--norm', action='store_true', default=False, help="set if you have the input -0.5~0.5")
parser.add_argument('--rgba2yynn', action='store_true', default=False, help="set if the model has yynn input but you want to take rgba images")
parser.add_argument('--add-bn-on-skip', dest='bn_on_skip', action='store_true', default=False,
help="set if you only want to add BN on skip branches")
parser.add_argument('--add-bn', dest='bn_before_add', action='store_true', default=False,
help="set if you want to add BN before Add")
parser.add_argument('-t', '--eliminate-tail-unsupported', dest='eliminate_tail', action='store_true', default=False,
help='whether remove the last unsupported node for hardware')
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
parser.add_argument('--opt-matmul', dest='opt_matmul', action='store_true', default=False,
help="set if you want to optimize the MatMul operations for the kneron hardware.")
parser.add_argument('--no-duplicate-shared-weights', dest='no_duplicate_shared_weights', action='store_true', default=False,
help='do not duplicate shared weights. Defaults to False.')
parser = argparse.ArgumentParser(
description="Optimize an ONNX model for Kneron compiler"
)
parser.add_argument("in_file", help="input ONNX FILE")
parser.add_argument(
"-o", "--output", dest="out_file", type=str, help="ouput ONNX FILE"
)
parser.add_argument("--log", default="i", type=str, help="set log level")
parser.add_argument(
"--bgr",
action="store_true",
default=False,
help="set if the model is trained in BGR mode",
)
parser.add_argument(
"--norm",
action="store_true",
default=False,
help="set if you have the input -0.5~0.5",
)
parser.add_argument(
"--rgba2yynn",
action="store_true",
default=False,
help="set if the model has yynn input but you want "
"to take rgba images",
)
parser.add_argument(
"--add-bn-on-skip",
dest="bn_on_skip",
action="store_true",
default=False,
help="set if you only want to add BN on skip branches",
)
parser.add_argument(
"--add-bn",
dest="bn_before_add",
action="store_true",
default=False,
help="set if you want to add BN before Add",
)
parser.add_argument(
"-t",
"--eliminate-tail-unsupported",
dest="eliminate_tail",
action="store_true",
default=False,
help="whether remove the last unsupported node for hardware",
)
parser.add_argument(
"--no-bn-fusion",
dest="disable_fuse_bn",
action="store_true",
default=False,
help="set if you have met errors which related to inferenced "
"shape mismatch. This option will prevent fusing "
"BatchNormalization into Conv.",
)
parser.add_argument(
"--opt-matmul",
dest="opt_matmul",
action="store_true",
default=False,
help="set if you want to optimize MatMul operations "
"for kneron hardware.",
)
parser.add_argument(
"--no-duplicate-shared-weights",
dest="no_duplicate_shared_weights",
action="store_true",
default=False,
help="do not duplicate shared weights. Defaults to False.",
)
args = parser.parse_args()
if args.out_file is None:
@ -108,11 +169,11 @@ if __name__ == "__main__":
else:
outfile = args.out_file
if args.log == 'w':
if args.log == "w":
logging.basicConfig(level=logging.WARN)
elif args.log == 'd':
elif args.log == "d":
logging.basicConfig(level=logging.DEBUG)
elif args.log == 'e':
elif args.log == "e":
logging.basicConfig(level=logging.ERROR)
else:
logging.basicConfig(level=logging.INFO)
@ -131,6 +192,17 @@ if __name__ == "__main__":
# Basic model organize
m = onnx.load(args.in_file)
m = onnx2onnx_flow(m, args.disable_fuse_bn, args.bn_on_skip, args.bn_before_add, args.bgr, args.norm, args.rgba2yynn, args.eliminate_tail, args.opt_matmul, not args.no_duplicate_shared_weights)
m = onnx2onnx_flow(
m,
args.disable_fuse_bn,
args.bn_on_skip,
args.bn_before_add,
args.bgr,
args.norm,
args.rgba2yynn,
args.eliminate_tail,
args.opt_matmul,
not args.no_duplicate_shared_weights,
)
onnx.save(m, outfile)

View File

@ -5,12 +5,30 @@ import numpy as np
from tools import helper
onnx2np_dtype = {0: 'float', 1: 'float32', 2: 'uint8', 3: 'int8', 4: 'uint16', 5: 'int16', 6: 'int32', 7: 'int64', 8: 'str', 9: 'bool', 10: 'float16', 11: 'double', 12: 'uint32', 13: 'uint64', 14: 'complex64', 15: 'complex128', 16: 'float'}
onnx2np_dtype = {
0: "float",
1: "float32",
2: "uint8",
3: "int8",
4: "uint16",
5: "int16",
6: "int32",
7: "int64",
8: "str",
9: "bool",
10: "float16",
11: "double",
12: "uint32",
13: "uint64",
14: "complex64",
15: "complex128",
16: "float",
}
def onnx_model_results(path_a, path_b, total_times=10):
""" using onnxruntime to inference two onnx models' ouputs
"""using onnxruntime to inference two onnx models' ouputs
:onnx model paths: two model paths
:total_times: inference times, default to be 10
:returns: inference results of two models
@ -22,13 +40,20 @@ def onnx_model_results(path_a, path_b, total_times=10):
outputs_b = session_b.get_outputs()
# check outputs
assert len(outputs_a) == len(outputs_b), 'Two models have different output numbers.'
assert len(outputs_a) == len(
outputs_b
), "Two models have different output numbers."
for i in range(len(outputs_a)):
out_shape_a, out_shape_b = outputs_a[i].shape, outputs_b[i].shape
out_shape_a = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_a))
out_shape_b = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_b))
assert out_shape_a == out_shape_b, 'Output {} has unmatched shapes'.format(i)
out_shape_a = list(
map(lambda x: x if isinstance(x, int) else 1, out_shape_a)
)
out_shape_b = list(
map(lambda x: x if isinstance(x, int) else 1, out_shape_b)
)
assert (
out_shape_a == out_shape_b
), "Output {} has unmatched shapes".format(i)
# load onnx graph_a and graph_b, to find the initializer and inputs
# then compare to remove the items in the inputs which will be initialized
@ -38,9 +63,16 @@ def onnx_model_results(path_a, path_b, total_times=10):
init_a, init_b = graph_a.initializer, graph_b.initializer
# remove initializer from raw inputs
input_names_a, input_names_b = set([ele.name for ele in inputs_a]), set([ele.name for ele in inputs_b])
init_names_a, init_names_b = set([ele.name for ele in init_a]), set([ele.name for ele in init_b])
real_inputs_names_a, real_inputs_names_b = input_names_a - init_names_a, input_names_b - init_names_b
input_names_a, input_names_b = set([ele.name for ele in inputs_a]), set(
[ele.name for ele in inputs_b]
)
init_names_a, init_names_b = set([ele.name for ele in init_a]), set(
[ele.name for ele in init_b]
)
real_inputs_names_a, real_inputs_names_b = (
input_names_a - init_names_a,
input_names_b - init_names_b,
)
# prepare and figure out matching of real inputs a and real inputs b
# try to keep original orders of each inputs
@ -61,17 +93,20 @@ def onnx_model_results(path_a, path_b, total_times=10):
for item_a in real_inputs_a:
size, shape = helper.find_size_shape_from_value(item_a)
if size:
assert real_single_input_a is None, 'Multiple inputs of first model, single input expected.'
assert (
real_single_input_a is None
), "Multiple inputs of first model, single input expected."
real_single_input_a = item_a
size_a, shape_a = size, shape
for item_b in real_inputs_b:
size, shape = helper.find_size_shape_from_value(item_b)
if size:
assert real_single_input_b is None, 'Multiple inputs of second model, single input expected.'
assert (
real_single_input_b is None
), "Multiple inputs of second model, single input expected."
real_single_input_b = item_b
size_b, shape_b = size, shape
assert size_a == size_b, 'Sizes of two models do not match.'
assert size_a == size_b, "Sizes of two models do not match."
# construct inputs tensors
input_data_type_a = real_single_input_a.type.tensor_type.elem_type
@ -84,7 +119,7 @@ def onnx_model_results(path_a, path_b, total_times=10):
results_a = [[] for i in range(len(outputs_a))]
results_b = [[] for i in range(len(outputs_b))]
while times < total_times:
# initialize inputs by random data, default to be uniform
# initialize inputs by random data, default to be uniform
data = np.random.random(size_a)
input_a = np.reshape(data, shape_a).astype(input_data_type_a)
input_b = np.reshape(data, shape_b).astype(input_data_type_b)
@ -93,12 +128,18 @@ def onnx_model_results(path_a, path_b, total_times=10):
input_dict_b = {}
for item_a in real_inputs_a:
item_type_a = onnx2np_dtype[item_a.type.tensor_type.elem_type]
input_dict_a[item_a.name] = np.array([]).astype(item_type_a) \
if item_a.name != real_single_input_a.name else input_a
input_dict_a[item_a.name] = (
np.array([]).astype(item_type_a)
if item_a.name != real_single_input_a.name
else input_a
)
for item_b in real_inputs_b:
item_type_b = onnx2np_dtype[item_b.type.tensor_type.elem_type]
input_dict_b[item_b.name] = np.array([]).astype(item_type_b) \
if item_b.name != real_single_input_b.name else input_b
input_dict_b[item_b.name] = (
np.array([]).astype(item_type_b)
if item_b.name != real_single_input_b.name
else input_b
)
ra = session_a.run([], input_dict_a)
rb = session_b.run([], input_dict_b)
@ -109,26 +150,32 @@ def onnx_model_results(path_a, path_b, total_times=10):
return results_a, results_b
if __name__ == '__main__':
if __name__ == "__main__":
# Argument parser.
parser = argparse.ArgumentParser(description="Compare two ONNX models to check if they have the same output.")
parser.add_argument('in_file_a', help='input ONNX file a')
parser.add_argument('in_file_b', help='input ONNX file b')
parser = argparse.ArgumentParser(
description="Compare two ONNX models to check if "
"they have the same output."
)
parser.add_argument("in_file_a", help="input ONNX file a")
parser.add_argument("in_file_b", help="input ONNX file b")
args = parser.parse_args()
results_a, results_b = onnx_model_results(args.in_file_a, args.in_file_b, total_times=10)
results_a, results_b = onnx_model_results(
args.in_file_a, args.in_file_b, total_times=10
)
ra_flat = helper.flatten_with_depth(results_a, 0)
rb_flat = helper.flatten_with_depth(results_b, 0)
shape_a = [item[1] for item in ra_flat]
shape_b = [item[1] for item in rb_flat]
assert shape_a == shape_b, 'two results data shape doesn\'t match'
assert shape_a == shape_b, "two results data shape doesn't match"
ra_raw = [item[0] for item in ra_flat]
rb_raw = [item[0] for item in rb_flat]
try:
np.testing.assert_almost_equal(ra_raw, rb_raw, 4)
print('Two models have the same behaviour.')
print("Two models have the same behaviour.")
except Exception as mismatch:
print(mismatch)
exit(1)

View File

@ -1,4 +1,3 @@
import onnx
import argparse
import glob
import csv
@ -8,214 +7,242 @@ import matplotlib.pyplot as plt
from tools import helper
import onnx_vs_onnx as onnx_tester
def compare_results(results_a, results_b):
""" compare onnx model inference results
calculate basic statistical values
results: results from inference multiple times
returns: list of basic statistical values
"""
# input results data can be of nonuniform shape
# get flatten data to compare
ra_flat = helper.flatten_with_depth(results_a, 0)
rb_flat = helper.flatten_with_depth(results_b, 0)
shape_a = [item[1] for item in ra_flat]
shape_b = [item[1] for item in rb_flat]
assert shape_a == shape_b, 'two results data shape doesn\'t match'
ra_raw = [item[0] for item in ra_flat]
rb_raw = [item[0] for item in rb_flat]
"""compare onnx model inference results
calculate basic statistical values
results: results from inference multiple times
returns: list of basic statistical values
"""
# input results data can be of nonuniform shape
# get flatten data to compare
ra_flat = helper.flatten_with_depth(results_a, 0)
rb_flat = helper.flatten_with_depth(results_b, 0)
shape_a = [item[1] for item in ra_flat]
shape_b = [item[1] for item in rb_flat]
assert shape_a == shape_b, "two results data shape doesn't match"
ra_raw = [item[0] for item in ra_flat]
rb_raw = [item[0] for item in rb_flat]
# the statistical values
max_rel_diff = 0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } )
max_abs_diff = 0 # defined to be max( { abs(ra-rb) } )
mean_rel_diff = 0
mean_abs_diff = 0
std_rel_diff = 0
std_abs_diff = 0
acc_with_diff_precision = []
rel_diff = []
abs_diff_percentiles = [] # rel_diff percentiles
rel_diff_percentiles = [] # abs_diff precentiles
# the statistical values
max_rel_diff = (
0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } )
)
max_abs_diff = 0 # defined to be max( { abs(ra-rb) } )
mean_rel_diff = 0
mean_abs_diff = 0
std_rel_diff = 0
std_abs_diff = 0
acc_with_diff_precision = []
rel_diff = []
abs_diff_percentiles = [] # rel_diff percentiles
rel_diff_percentiles = [] # abs_diff precentiles
raw_diff = [ra_raw[i]-rb_raw[i] for i in range(len(ra_raw))]
abs_diff = [abs(num) for num in raw_diff]
for i in range(len(ra_raw)):
divider = max([abs(ra_raw[i]), abs(rb_raw[i])])
val = abs_diff[i]/divider if divider != 0 else 0
rel_diff.append(val)
max_rel_diff = max(rel_diff)
max_abs_diff = max(abs_diff)
mean_rel_diff = np.average(rel_diff)
mean_abs_diff = np.average(abs_diff)
std_rel_diff = np.std(rel_diff)
std_abs_diff = np.std(abs_diff)
# calculate accuracy with different precison
for digit in range(8):
correct = 0
raw_diff = [ra_raw[i] - rb_raw[i] for i in range(len(ra_raw))]
abs_diff = [abs(num) for num in raw_diff]
for i in range(len(ra_raw)):
if format(ra_raw[i], '.'+str(digit)+'f')\
== format(rb_raw[i], '.'+str(digit)+'f'):
correct += 1
acc_with_diff_precision.append([digit, float(format(correct/len(ra_raw), '.3f'))])
divider = max([abs(ra_raw[i]), abs(rb_raw[i])])
val = abs_diff[i] / divider if divider != 0 else 0
rel_diff.append(val)
# analyze rel_diff distribution
rel_diff.sort()
abs_diff.sort()
for i in range(20):
rel_diff_percentiles.append(['{}%'.format(i*5), rel_diff[int((i/20)*len(rel_diff))]])
abs_diff_percentiles.append(['{}%'.format(i*5), abs_diff[int((i/20)*len(abs_diff))]])
max_rel_diff = max(rel_diff)
max_abs_diff = max(abs_diff)
mean_rel_diff = np.average(rel_diff)
mean_abs_diff = np.average(abs_diff)
std_rel_diff = np.std(rel_diff)
std_abs_diff = np.std(abs_diff)
results = [
['max_rel_diff', max_rel_diff],
['max_abs_diff', max_abs_diff],
['mean_rel_diff', mean_rel_diff],
['mean_abs_diff', mean_abs_diff],
['std_rel_diff', std_rel_diff],
['std_abs_diff', std_abs_diff],
['acc_with_diff_precision', acc_with_diff_precision],
['rel_diff_percentiles', rel_diff_percentiles],
['abs_diff_percentiles', abs_diff_percentiles]
]
return results
# calculate accuracy with different precison
for digit in range(8):
correct = 0
for i in range(len(ra_raw)):
if format(ra_raw[i], "." + str(digit) + "f") == format(
rb_raw[i], "." + str(digit) + "f"
):
correct += 1
acc_with_diff_precision.append(
[digit, float(format(correct / len(ra_raw), ".3f"))]
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='test model optimization results')
parser.add_argument('dir', type=str, help='the directory that stores onnx models')
parser.add_argument('ending1', type=str, help='model file name ending(eg, .onnx)')
parser.add_argument('ending2', type=str, help='opt model file name ending(eg. _opt.onnx)')
parser.add_argument('out_file', type=str, help='output csv file name')
parser.add_argument('-p', '--plot', default='N', help='get plots (Y/N)')
parser.add_argument('-i', '--iter_times', default=10, type=int, help='inference times')
# analyze rel_diff distribution
rel_diff.sort()
abs_diff.sort()
for i in range(20):
rel_diff_percentiles.append(
["{}%".format(i * 5), rel_diff[int((i / 20) * len(rel_diff))]]
)
abs_diff_percentiles.append(
["{}%".format(i * 5), abs_diff[int((i / 20) * len(abs_diff))]]
)
args = parser.parse_args()
results = [
["max_rel_diff", max_rel_diff],
["max_abs_diff", max_abs_diff],
["mean_rel_diff", mean_rel_diff],
["mean_abs_diff", mean_abs_diff],
["std_rel_diff", std_rel_diff],
["std_abs_diff", std_abs_diff],
["acc_with_diff_precision", acc_with_diff_precision],
["rel_diff_percentiles", rel_diff_percentiles],
["abs_diff_percentiles", abs_diff_percentiles],
]
old_models_paths = glob.glob(args.dir+'*'+args.ending1)
new_models_paths = glob.glob(args.dir+'*'+args.ending2)
stats_table = [[
'Model',
'max_rel_diff',
'max_abs_diff',
'mean_rel_diff',
'mean_abs_diff',
'std_rel_diff',
'std_abs_diff',
'acc_with_diff_precision',
'rel_diff_percentiles',
'abs_diff_percentiles'
]]
for new_model_path in new_models_paths:
old_model_path = new_model_path[:-len(args.ending2)] + args.ending1
if old_model_path not in old_models_paths:
continue
# run inference
results_a, results_b = onnx_tester.onnx_model_results(old_model_path, new_model_path, total_times=args.iter_times)
# compare inference results
comparision = compare_results(results_a, results_b)
new_line = [old_model_path.split('/')[-1]]
for item in comparision:
new_line.append(item[1])
stats_table.append(new_line)
# try to read existing file
old_stats_table = []
try:
old_file = open(args.out_file, 'r')
reader = csv.reader(old_file)
old_header = reader.__next__()
for row in reader:
old_stats_table.append(row)
old_file.close()
except:
pass
# compare and merge possible old stat data file with new stat data file
header = stats_table[0]
stats_table = stats_table[1:]
new_model_names = set([item[0] for item in stats_table])
for row in old_stats_table:
if row[0] not in new_model_names:
stats_table.append(row)
stats_table.insert(0, header)
# write a new stat data file, overwrite old file
new_file = open(args.out_file, 'w', newline='')
writer = csv.writer(new_file)
for row in stats_table:
writer.writerow(row)
new_file.close()
# make some plots
if args.plot == 'Y':
if len(stats_table) < 2:
exit(0)
sample_table = stats_table[1:] if len(stats_table) < 6 else stats_table[1:6]
max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]]
plt.hist(max_rel_diffs, bins=15)
plt.title('Max Relavtive Difference Histogram')
plt.xlabel('Max Relative Difference')
plt.ylabel('Counts')
plt.savefig('max_rel_diff_hist.png')
plt.close()
max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]]
plt.hist(max_abs_diffs, bins=15)
plt.title('Max Absolute Difference Histogram')
plt.xlabel('Max Absolute Difference')
plt.ylabel('Counts')
plt.savefig('max_abs_diff_hist.png')
plt.close()
for line in sample_table:
model_name = line[0]
percentiles = line[-2]
x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))]
y = [ele[1] for ele in percentiles]
plt.plot(x, y, label=model_name)
plt.title('Rel_diff Percentiles of Raw and Optimized Models')
plt.xlabel('percentage')
plt.ylabel('relative difference')
plt.legend()
plt.savefig('rel_diff_percentiles.png')
plt.close()
for line in sample_table:
model_name = line[0]
percentiles = line[-1]
x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))]
y = [ele[1] for ele in percentiles]
plt.plot(x, y, label=model_name)
plt.title('Abs_diff Percentiles of Raw and Optimized Models')
plt.xlabel('percentage')
plt.ylabel('absolute difference')
plt.legend()
plt.savefig('abs_diff_percentiles.png')
plt.close()
for line in sample_table:
model_name = line[0]
accuracies = line[-3]
x = [acc[0] for acc in accuracies]
y = [acc[1] for acc in accuracies]
plt.plot(x, y, label=model_name)
plt.title('Accuracies with Different Precisions')
plt.xlabel('Decimals')
plt.ylabel('Precision')
plt.legend()
plt.savefig('precisions.png')
plt.close()
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="test model optimization results"
)
parser.add_argument(
"dir", type=str, help="the directory that stores onnx models"
)
parser.add_argument(
"ending1", type=str, help="model file name ending(eg, .onnx)"
)
parser.add_argument(
"ending2", type=str, help="opt model file name ending(eg. _opt.onnx)"
)
parser.add_argument("out_file", type=str, help="output csv file name")
parser.add_argument("-p", "--plot", default="N", help="get plots (Y/N)")
parser.add_argument(
"-i", "--iter_times", default=10, type=int, help="inference times"
)
args = parser.parse_args()
old_models_paths = glob.glob(args.dir + "*" + args.ending1)
new_models_paths = glob.glob(args.dir + "*" + args.ending2)
stats_table = [
[
"Model",
"max_rel_diff",
"max_abs_diff",
"mean_rel_diff",
"mean_abs_diff",
"std_rel_diff",
"std_abs_diff",
"acc_with_diff_precision",
"rel_diff_percentiles",
"abs_diff_percentiles",
]
]
for new_model_path in new_models_paths:
old_model_path = new_model_path[: -len(args.ending2)] + args.ending1
if old_model_path not in old_models_paths:
continue
# run inference
results_a, results_b = onnx_tester.onnx_model_results(
old_model_path, new_model_path, total_times=args.iter_times
)
# compare inference results
comparision = compare_results(results_a, results_b)
new_line = [old_model_path.split("/")[-1]]
for item in comparision:
new_line.append(item[1])
stats_table.append(new_line)
# try to read existing file
old_stats_table = []
try:
old_file = open(args.out_file, "r")
reader = csv.reader(old_file)
old_header = reader.__next__()
for row in reader:
old_stats_table.append(row)
old_file.close()
except Exception:
pass
# compare and merge possible old stat data file with new stat data file
header = stats_table[0]
stats_table = stats_table[1:]
new_model_names = set([item[0] for item in stats_table])
for row in old_stats_table:
if row[0] not in new_model_names:
stats_table.append(row)
stats_table.insert(0, header)
# write a new stat data file, overwrite old file
new_file = open(args.out_file, "w", newline="")
writer = csv.writer(new_file)
for row in stats_table:
writer.writerow(row)
new_file.close()
# make some plots
if args.plot == "Y":
if len(stats_table) < 2:
exit(0)
sample_table = (
stats_table[1:] if len(stats_table) < 6 else stats_table[1:6]
)
max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]]
plt.hist(max_rel_diffs, bins=15)
plt.title("Max Relavtive Difference Histogram")
plt.xlabel("Max Relative Difference")
plt.ylabel("Counts")
plt.savefig("max_rel_diff_hist.png")
plt.close()
max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]]
plt.hist(max_abs_diffs, bins=15)
plt.title("Max Absolute Difference Histogram")
plt.xlabel("Max Absolute Difference")
plt.ylabel("Counts")
plt.savefig("max_abs_diff_hist.png")
plt.close()
for line in sample_table:
model_name = line[0]
percentiles = line[-2]
x = [
round(i * (1 / len(percentiles)), 2)
for i in range(len(percentiles))
]
y = [ele[1] for ele in percentiles]
plt.plot(x, y, label=model_name)
plt.title("Rel_diff Percentiles of Raw and Optimized Models")
plt.xlabel("percentage")
plt.ylabel("relative difference")
plt.legend()
plt.savefig("rel_diff_percentiles.png")
plt.close()
for line in sample_table:
model_name = line[0]
percentiles = line[-1]
x = [
round(i * (1 / len(percentiles)), 2)
for i in range(len(percentiles))
]
y = [ele[1] for ele in percentiles]
plt.plot(x, y, label=model_name)
plt.title("Abs_diff Percentiles of Raw and Optimized Models")
plt.xlabel("percentage")
plt.ylabel("absolute difference")
plt.legend()
plt.savefig("abs_diff_percentiles.png")
plt.close()
for line in sample_table:
model_name = line[0]
accuracies = line[-3]
x = [acc[0] for acc in accuracies]
y = [acc[1] for acc in accuracies]
plt.plot(x, y, label=model_name)
plt.title("Accuracies with Different Precisions")
plt.xlabel("Decimals")
plt.ylabel("Precision")
plt.legend()
plt.savefig("precisions.png")
plt.close()

View File

@ -1,21 +1,10 @@
import onnx
import onnx.utils
try:
from onnx import optimizer
except ImportError:
import onnxoptimizer as optimizer
import sys
import numpy as np
import struct
import logging
import argparse
from tools import eliminating
from tools import fusing
from tools import replacing
from tools import other
from tools import combo
from tools import special
from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow
# Debug use
@ -25,13 +14,28 @@ from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow
# Generate a prototype onnx #
######################################
parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler")
parser.add_argument('in_file', help='input ONNX or PTH FILE')
parser.add_argument('out_file', help="ouput ONNX FILE")
parser.add_argument('--input-size', dest='input_size', nargs=3,
help='if you using pth, please use this argument to set up the input size of the model. It should be in \'CH H W\' format, e.g. \'--input-size 3 256 512\'.')
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
parser = argparse.ArgumentParser(
description="Optimize a Pytorch generated model for Kneron compiler"
)
parser.add_argument("in_file", help="input ONNX or PTH FILE")
parser.add_argument("out_file", help="ouput ONNX FILE")
parser.add_argument(
"--input-size",
dest="input_size",
nargs=3,
help="if you using pth, please use this argument to set up the input "
"size of the model. It should be in 'CH H W' format, "
"e.g. '--input-size 3 256 512'.",
)
parser.add_argument(
"--no-bn-fusion",
dest="disable_fuse_bn",
action="store_true",
default=False,
help="set if you have met errors which related to inferenced shape "
"mismatch. This option will prevent fusing BatchNormalization "
"into Conv.",
)
args = parser.parse_args()
@ -39,7 +43,7 @@ if len(args.in_file) <= 4:
# When the filename is too short.
logging.error("Invalid input file: {}".format(args.in_file))
exit(1)
elif args.in_file[-4:] == '.pth':
elif args.in_file[-4:] == ".pth":
# Pytorch pth case
logging.warning("Converting from pth to onnx is not recommended.")
onnx_in = args.out_file
@ -47,21 +51,29 @@ elif args.in_file[-4:] == '.pth':
from torch.autograd import Variable
import torch
import torch.onnx
# import torchvision
# Standard ImageNet input - 3 channels, 224x224.
# Values don't matter as we care about network structure.
# But they can also be real inputs.
if args.input_size is None:
logging.error("\'--input-size\' is required for the pth input file.")
logging.error("'--input-size' is required for the pth input file.")
exit(1)
dummy_input = Variable(torch.randn(1, int(args.input_size[0]), int(args.input_size[1]), int(args.input_size[2])))
dummy_input = Variable(
torch.randn(
1,
int(args.input_size[0]),
int(args.input_size[1]),
int(args.input_size[2]),
)
)
# Obtain your model, it can be also constructed in your script explicitly.
model = torch.load(sys.argv[1], map_location='cpu')
model = torch.load(sys.argv[1], map_location="cpu")
# model = torchvision.models.resnet34(pretrained=True)
# Invoke export.
# torch.save(model, "resnet34.pth")
torch.onnx.export(model, dummy_input, args.out_file, opset_version=11)
elif args.in_file[-4:] == 'onnx':
elif args.in_file[-4:] == "onnx":
onnx_in = args.in_file
else:
# When the file is neither an onnx or a pytorch pth.

View File

@ -1,29 +1,22 @@
import onnx
import onnx.utils
try:
from onnx import optimizer
except ImportError:
import onnxoptimizer as optimizer
import sys
import numpy as np
import struct
import logging
import argparse
from .tools import eliminating
from .tools import fusing
from .tools import replacing
from .tools import other
from .tools import combo
from .tools import special
# Define general pytorch exported onnx optimize process
def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx.ModelProto:
def torch_exported_onnx_flow(
m: onnx.ModelProto, disable_fuse_bn=False
) -> onnx.ModelProto:
"""Optimize the Pytorch exported onnx.
Args:
m (ModelProto): the input onnx model
disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False.
disable_fuse_bn (bool, optional): do not fuse BN into Conv.
Defaults to False.
Returns:
ModelProto: the optimized onnx model
@ -38,20 +31,29 @@ def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx.
# Main Process
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler")
parser.add_argument('in_file', help='input ONNX')
parser.add_argument('out_file', help="ouput ONNX FILE")
parser.add_argument('--log', default='i', type=str, help="set log level")
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
parser = argparse.ArgumentParser(
description="Optimize a Pytorch generated model for Kneron compiler"
)
parser.add_argument("in_file", help="input ONNX")
parser.add_argument("out_file", help="ouput ONNX FILE")
parser.add_argument("--log", default="i", type=str, help="set log level")
parser.add_argument(
"--no-bn-fusion",
dest="disable_fuse_bn",
action="store_true",
default=False,
help="set if you have met errors which related to inferenced shape "
"mismatch. This option will prevent fusing BatchNormalization "
"into Conv.",
)
args = parser.parse_args()
if args.log == 'w':
if args.log == "w":
logging.basicConfig(level=logging.WARN)
elif args.log == 'd':
elif args.log == "d":
logging.basicConfig(level=logging.DEBUG)
elif args.log == 'e':
elif args.log == "e":
logging.basicConfig(level=logging.ERROR)
else:
logging.basicConfig(level=logging.INFO)
@ -60,7 +62,7 @@ if __name__ == "__main__":
# When the filename is too short.
logging.error("Invalid input file: {}".format(args.in_file))
exit(1)
elif args.in_file[-4:] == 'onnx':
elif args.in_file[-4:] == "onnx":
onnx_in = args.in_file
else:
# When the file is not an onnx file.

View File

@ -8,7 +8,8 @@ import onnx.utils
from tensorflow.python.platform import gfile
from tools import combo, eliminating, replacing, other
def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
def tf2onnx_flow(pb_path: str, test_mode=False) -> onnx.ModelProto:
"""Convert frozen graph pb file into onnx
Args:
@ -21,34 +22,45 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
Returns:
onnx.ModelProto: converted onnx
"""
TF2ONNX_VERSION = int(tf2onnx.version.version.replace('.', ''))
TF2ONNX_VERSION = int(tf2onnx.version.version.replace(".", ""))
if 160 <= TF2ONNX_VERSION:
from tf2onnx import tf_loader
else:
from tf2onnx import loader as tf_loader
if pb_path[-3:] == '.pb':
model_name = pb_path.split('/')[-1][:-3]
# always reset tensorflow session at begin
if pb_path[-3:] == ".pb":
model_name = pb_path.split("/")[-1][:-3]
# always reset tensorflow session at begin
tf.reset_default_graph()
with tf.Session() as sess:
with gfile.FastGFile(pb_path, 'rb') as f:
with gfile.FastGFile(pb_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
tf.import_graph_def(graph_def, name="")
if 160 <= int(tf2onnx.version.version.replace('.', '')):
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions = tf2onnx.tf_utils.tflist_to_onnx(
sess.graph,
{})
if 160 <= int(tf2onnx.version.version.replace(".", "")):
(
onnx_nodes,
op_cnt,
attr_cnt,
output_shapes,
dtypes,
functions,
) = tf2onnx.tf_utils.tflist_to_onnx(sess.graph, {})
else:
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tf2onnx.tfonnx.tflist_to_onnx(
sess.graph.get_operations(),
{})
(
onnx_nodes,
op_cnt,
attr_cnt,
output_shapes,
dtypes,
) = tf2onnx.tfonnx.tflist_to_onnx(
sess.graph.get_operations(), {}
)
for n in onnx_nodes:
if len(n.output) == 0:
@ -59,12 +71,12 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
nodes_outputs = set()
for n in onnx_nodes:
if n.op_type == 'Placeholder':
if n.op_type == "Placeholder":
continue
for input in n.input:
nodes_inputs.add(input)
for output in n.output:
nodes_outputs.add(output)
nodes_outputs.add(output)
graph_input_names = set()
for input_name in nodes_inputs:
@ -76,35 +88,43 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
if n.input and n.input[0] not in nodes_outputs:
continue
if len(n.output) == 0:
n.output.append(n.name + ':0')
n.output.append(n.name + ":0")
graph_output_names.add(n.output[0])
else:
output_name = n.output[0]
if (output_name not in nodes_inputs) and (0 < len(n.input)):
if (output_name not in nodes_inputs) and (
0 < len(n.input)
):
graph_output_names.add(output_name)
logging.info('Model Inputs: %s', str(list(graph_input_names)))
logging.info('Model Outputs: %s', str(list(graph_output_names)))
logging.info("Model Inputs: %s", str(list(graph_input_names)))
logging.info("Model Outputs: %s", str(list(graph_output_names)))
graph_def, inputs, outputs = tf_loader.from_graphdef(model_path=pb_path,
input_names=list(graph_input_names),
output_names=list(graph_output_names))
graph_def, inputs, outputs = tf_loader.from_graphdef(
model_path=pb_path,
input_names=list(graph_input_names),
output_names=list(graph_output_names),
)
with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(graph_def, name='')
tf.import_graph_def(graph_def, name="")
if 160 <= TF2ONNX_VERSION:
with tf_loader.tf_session(graph=tf_graph):
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
input_names=inputs,
output_names=outputs,
opset=11)
onnx_graph = tf2onnx.tfonnx.process_tf_graph(
tf_graph=tf_graph,
input_names=inputs,
output_names=outputs,
opset=11,
)
else:
with tf.Session(graph=tf_graph):
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
input_names=inputs,
output_names=outputs,
opset=11)
onnx_graph = tf2onnx.tfonnx.process_tf_graph(
tf_graph=tf_graph,
input_names=inputs,
output_names=outputs,
opset=11,
)
# Optimize with tf2onnx.optimizer
onnx_graph = tf2onnx.optimizer.optimize_graph(onnx_graph)
@ -115,7 +135,9 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
model_proto = other.polish_model(model_proto)
else:
raise Exception('expect .pb file as input, but got "' + str(pb_path) + '"')
raise Exception(
'expect .pb file as input, but got "' + str(pb_path) + '"'
)
# rename
m = model_proto
@ -133,15 +155,26 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
return m
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert tensorflow pb file to onnx file and optimized onnx file. Or just optimize tensorflow onnx file.')
parser.add_argument('in_file', help='input file')
parser.add_argument('out_file', help='output optimized model file')
parser.add_argument('-t', '--test_mode', default=False, help='test mode will not eliminate shape changes after input')
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert tensorflow pb file to onnx file and optimized "
"onnx file. Or just optimize tensorflow onnx file."
)
parser.add_argument("in_file", help="input file")
parser.add_argument("out_file", help="output optimized model file")
parser.add_argument(
"-t",
"--test_mode",
default=False,
help="test mode will not eliminate shape changes after input",
)
args = parser.parse_args()
logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.INFO)
logging.basicConfig(
stream=sys.stdout,
format="[%(asctime)s] %(levelname)s: %(message)s",
level=logging.INFO,
)
m = tf2onnx_flow(args.in_file, args.test_mode)
onnx.save(m, args.out_file)
logging.info('Save Optimized ONNX: %s', args.out_file)
logging.info("Save Optimized ONNX: %s", args.out_file)

View File

@ -6,6 +6,7 @@ import onnxruntime
from tools import helper
def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
# Setup onnx session and get meta data
onnx_session = onnxruntime.InferenceSession(onnx_file, None)
@ -21,21 +22,32 @@ def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
tflite_session.allocate_tensors()
tflite_inputs = tflite_session.get_input_details()
tflite_outputs = tflite_session.get_output_details()
tflite_input_shape = tflite_inputs[0]['shape']
tflite_input_shape = tflite_inputs[0]["shape"]
# Compare input shape
assert(len(onnx_input_shape) == len(tflite_input_shape)), "TFLite and ONNX shape unmatch."
assert(onnx_input_shape == [tflite_input_shape[0], tflite_input_shape[3], tflite_input_shape[1], tflite_input_shape[2]]), "TFLite and ONNX shape unmatch."
assert len(onnx_input_shape) == len(
tflite_input_shape
), "TFLite and ONNX shape unmatch."
assert onnx_input_shape == [
tflite_input_shape[0],
tflite_input_shape[3],
tflite_input_shape[1],
tflite_input_shape[2],
], "TFLite and ONNX shape unmatch."
# Generate random number and run
tflite_results = []
onnx_results = []
for _ in range(total_times):
# Generate input
tflite_input_data = np.array(np.random.random_sample(tflite_input_shape), dtype=np.float32)
tflite_input_data = np.array(
np.random.random_sample(tflite_input_shape), dtype=np.float32
)
onnx_input_data = np.transpose(tflite_input_data, [0, 3, 1, 2])
# Run tflite
tflite_session.set_tensor(tflite_inputs[0]['index'], tflite_input_data)
tflite_session.set_tensor(tflite_inputs[0]["index"], tflite_input_data)
tflite_session.invoke()
tflite_results.append(tflite_session.get_tensor(tflite_outputs[0]['index']))
tflite_results.append(
tflite_session.get_tensor(tflite_outputs[0]["index"])
)
# Run onnx
onnx_input_dict = {onnx_inputs[0].name: onnx_input_data}
onnx_results.append(onnx_session.run([], onnx_input_dict)[0])
@ -43,26 +55,31 @@ def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
return tflite_results, onnx_results
if __name__ == '__main__':
if __name__ == "__main__":
# Argument parser.
parser = argparse.ArgumentParser(description="Compare a TFLite model and an ONNX model to check if they have the same output.")
parser.add_argument('tflite_file', help='input tflite file')
parser.add_argument('onnx_file', help='input ONNX file')
parser = argparse.ArgumentParser(
description="Compare a TFLite model and an ONNX model to check "
"if they have the same output."
)
parser.add_argument("tflite_file", help="input tflite file")
parser.add_argument("onnx_file", help="input ONNX file")
args = parser.parse_args()
results_a, results_b = compare_tflite_and_onnx(args.tflite_file, args.onnx_file, total_times=10)
results_a, results_b = compare_tflite_and_onnx(
args.tflite_file, args.onnx_file, total_times=10
)
ra_flat = helper.flatten_with_depth(results_a, 0)
rb_flat = helper.flatten_with_depth(results_b, 0)
shape_a = [item[1] for item in ra_flat]
shape_b = [item[1] for item in rb_flat]
assert shape_a == shape_b, 'two results data shape doesn\'t match'
assert shape_a == shape_b, "two results data shape doesn't match"
ra_raw = [item[0] for item in ra_flat]
rb_raw = [item[0] for item in rb_flat]
try:
np.testing.assert_almost_equal(ra_raw, rb_raw, 8)
print('Two models have the same behaviour.')
print("Two models have the same behaviour.")
except Exception as mismatch:
print(mismatch)
exit(1)
exit(1)

View File

@ -2,7 +2,7 @@
"""
import logging
import onnx.utils
try:
from onnx import optimizer
except ImportError:
@ -15,16 +15,19 @@ from . import eliminating
from . import fusing
from . import constant_folding
from . import removing_transpose
from . import modhelper
from .common_pattern import torch_pattern_match, tf_pattern_match
from .helper import logger
def preprocess(model_proto, disable_fuse_bn=False, duplicate_shared_weights=True):
def preprocess(
model_proto, disable_fuse_bn=False, duplicate_shared_weights=True
):
"""The most common used functions before other processing.
Args:
model_proto: the original model input
duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True.
duplicate_shared_weights(bool, optional): duplicate shared weights.
Defaults to True.
Return:
the new model after preprocessing
@ -65,22 +68,28 @@ def preprocess(model_proto, disable_fuse_bn=False, duplicate_shared_weights=True
replacing.replace_initializer_with_Constant(model_proto.graph)
other.topological_sort(model_proto.graph)
m = other.polish_model(model_proto)
passes = ['extract_constant_to_initializer',
'eliminate_nop_dropout',
'eliminate_deadend',
'fuse_matmul_add_bias_into_gemm',
'fuse_pad_into_conv']
passes = [
"extract_constant_to_initializer",
"eliminate_nop_dropout",
"eliminate_deadend",
"fuse_matmul_add_bias_into_gemm",
"fuse_pad_into_conv",
]
if not disable_fuse_bn:
passes.append('fuse_bn_into_conv')
passes.append("fuse_bn_into_conv")
m = optimizer.optimize(m, passes)
g = m.graph
# Add name again since onnx optimizer higher than 1.7 may remove node names.
# Add name again since onnx optimizer higher than 1.7 may remove node names
other.add_name_to_node(g)
if duplicate_shared_weights:
replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=True)
replacing.replace_initializer_with_Constant(
g, duplicate_shared_weights=True
)
other.duplicate_param_shared_constant(g)
else:
replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=False)
replacing.replace_initializer_with_Constant(
g, duplicate_shared_weights=False
)
other.topological_sort(g)
m = other.polish_model(m)
g = m.graph
@ -161,12 +170,12 @@ def pytorch_constant_folding(m):
other.topological_sort(m.graph)
while len(m.graph.value_info) != 0:
m.graph.value_info.pop()
m = other.inference_shapes(m)
replacing.replace_shape_with_constant(m.graph)
other.topological_sort(m.graph)
m = torch_pattern_match(m)
m = optimizer.optimize(m, ['eliminate_deadend'])
m = optimizer.optimize(m, ["eliminate_deadend"])
return m
@ -206,7 +215,7 @@ def tensorflow_optimization(m):
replacing.replace_shape_with_constant(m.graph)
other.topological_sort(m.graph)
m = tf_pattern_match(m)
m = optimizer.optimize(m, ['eliminate_deadend'])
m = optimizer.optimize(m, ["eliminate_deadend"])
eliminating.eliminate_consecutive_reshape(m.graph)
eliminating.eliminate_Squeeze_before_Reshape(m.graph)
@ -253,6 +262,6 @@ def postprocess(m):
m = other.polish_model(m)
other.add_output_to_value_info(m.graph)
m = optimizer.optimize(m, ['eliminate_deadend'])
m.producer_name = 'kneron_formatter'
m = optimizer.optimize(m, ["eliminate_deadend"])
m.producer_name = "kneron_formatter"
return m

View File

@ -3,19 +3,20 @@ import numpy as np
import onnx.helper
import onnx.utils
from . import modhelper
from . import helper
from . import other
def torch_pattern_match(m):
# Create a map from optype to the nodes.
optype2node = defaultdict(list)
for node in m.graph.node:
optype2node[node.op_type].append(node)
for matmul_node in optype2node['MatMul']:
for matmul_node in optype2node["MatMul"]:
pattern_matmul_mul_add(m.graph, matmul_node)
for resize_node in optype2node['Resize']:
# torch nn.UpsamplingBilinear2d will be given us 4 input: "X, roi, scales, sizes"
for resize_node in optype2node["Resize"]:
# torch nn.UpsamplingBilinear2d will be given us 4 input:
# "X, roi, scales, sizes"
if len(resize_node.input) != 4:
continue
make_UpsamplingBilinear2d_value_info(m.graph, resize_node.name)
@ -24,15 +25,17 @@ def torch_pattern_match(m):
m = other.polish_model(m)
return m
def tf_pattern_match(m):
# Create a map from optype to the nodes.
optype2node = defaultdict(list)
for node in m.graph.node:
optype2node[node.op_type].append(node)
for matmul_node in optype2node['MatMul']:
for matmul_node in optype2node["MatMul"]:
pattern_matmul_mul_add(m.graph, matmul_node)
for resize_node in optype2node['Resize']:
# In tensorflow2onnx, ReizeXXX will be given us 4 input: "X, roi, scales, sizes"
for resize_node in optype2node["Resize"]:
# In tensorflow2onnx, ReizeXXX will be given us 4 input:
# "X, roi, scales, sizes"
# and node output name will be given the "node name + :0"
if len(resize_node.input) != 4:
continue
@ -42,24 +45,25 @@ def tf_pattern_match(m):
m = other.polish_model(m)
return m
def pattern_matmul_mul_add(g, matmul_node):
# Check node match - Mul node
next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0])
if len(next_nodes) != 1:
return
if next_nodes[0].op_type != 'Mul':
if next_nodes[0].op_type != "Mul":
return
mul_node = next_nodes[0]
# Check node match - Add node
next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0])
if len(next_nodes) != 1:
return
if next_nodes[0].op_type != 'Add':
if next_nodes[0].op_type != "Add":
return
add_node = next_nodes[0]
# Check Mul weight
mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1])
if mul_weight_node.op_type != 'Constant':
if mul_weight_node.op_type != "Constant":
return
weight_size, mul_weight = helper.constant_to_list(mul_weight_node)
for i in mul_weight:
@ -68,15 +72,19 @@ def pattern_matmul_mul_add(g, matmul_node):
channel = weight_size[0]
# Check Add weight
add_weight_node = helper.find_node_by_output_name(g, add_node.input[1])
if add_weight_node.op_type != 'Constant':
if add_weight_node.op_type != "Constant":
return
# Check MatMul weight to see if it need weight broadcast
matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1])
matmul_weight_node = helper.find_node_by_output_name(
g, matmul_node.input[1]
)
matmul_weight = helper.constant_to_numpy(matmul_weight_node)
if matmul_weight.shape[1] == 1:
# Weight broadcast
new_matmul_weight = np.tile(matmul_weight, channel)
new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight)
new_matmul_weight_node = helper.numpy_to_constant(
matmul_weight_node.name, new_matmul_weight
)
g.node.remove(matmul_weight_node)
g.node.extend([new_matmul_weight_node])
value = helper.find_value_by_name(g, matmul_weight_node.output[0])
@ -93,14 +101,14 @@ def pattern_matmul_mul_add(g, matmul_node):
g.value_info.remove(value)
# Fuse Matmul and Add
gemm_node = onnx.helper.make_node(
'Gemm',
"Gemm",
[matmul_node.input[0], matmul_node.input[1], add_node.input[1]],
[add_node.output[0]],
name = matmul_node.name,
alpha = 1.0,
beta = 1.0,
transA = 0,
transB = 0
name=matmul_node.name,
alpha=1.0,
beta=1.0,
transA=0,
transB=0,
)
g.node.extend([gemm_node])
# Clean up
@ -111,6 +119,7 @@ def pattern_matmul_mul_add(g, matmul_node):
g.value_info.remove(value)
other.topological_sort(g)
def make_UpsamplingBilinear2d_value_info(g, resize_node_name):
resize_node = helper.find_node_by_node_name(g, resize_node_name)
@ -124,34 +133,45 @@ def make_UpsamplingBilinear2d_value_info(g, resize_node_name):
new_output_value_info = onnx.helper.make_tensor_value_info(
resize_node.output[0],
onnx.helper.TensorProto.FLOAT,
shape_data.tolist()
shape_data.tolist(),
)
g.value_info.extend([new_output_value_info])
def polish_RESIZE_input_param_node(g, resize_node_name):
resize_node = helper.find_node_by_node_name(g, resize_node_name)
shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3])
shape_data = helper.constant_to_numpy(shape_data_node).astype(int)
# handle 0 batch size which is invalid
# handle 0 batch size which is invalid
if shape_data[0] == 0:
shape_data[0] = 1
pre_node_output_value_info = helper.find_value_by_name(g, resize_node.input[0])
ori_shape = np.array([pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value,
pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value,
pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value,
pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value])
resize_node.input.remove(resize_node.input[3])
pre_node_output_value_info = helper.find_value_by_name(
g, resize_node.input[0]
)
ori_shape = np.array(
[
pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value,
pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value,
pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value,
pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value,
]
)
resize_scales = np.array(shape_data/ori_shape).astype(float)
resize_scale_node = helper.list_to_constant('resize_scales_node_' + resize_node.name, resize_scales.shape, resize_scales, data_type=onnx.helper.TensorProto.FLOAT)
resize_node.input.remove(resize_node.input[3])
resize_scales = np.array(shape_data / ori_shape).astype(float)
resize_scale_node = helper.list_to_constant(
"resize_scales_node_" + resize_node.name,
resize_scales.shape,
resize_scales,
data_type=onnx.helper.TensorProto.FLOAT,
)
resize_node.input[2] = resize_scale_node.name
g.node.extend([resize_scale_node])
other.topological_sort(g)

View File

@ -5,15 +5,14 @@ import logging
import traceback
from . import helper
from .general_graph import Graph, Node
from .other import topological_sort
from .replacing import replace_shape_with_constant
from .helper import logger
def are_all_inputs_Constant_with_one_child(g, node):
for input_name in node.input:
input_node = helper.find_node_by_output_name(g, input_name)
if input_node is None or input_node.op_type != 'Constant':
if input_node is None or input_node.op_type != "Constant":
return False
relative_outputs = helper.find_nodes_by_input_name(g, input_name)
if len(relative_outputs) > 1:
@ -28,7 +27,7 @@ def constant_folding(g):
:return: If any node is folded, return True. Otherwise, return False.
"""
keep_folding = True # Keep the while loop
folded = False # Return value
folded = False # Return value
try:
# Before constant folding, duplicate the constant nodes.
duplicate_constant_node(g)
@ -38,37 +37,47 @@ def constant_folding(g):
# Check if the node is foldable
if node.op_type not in constant_folding_nodes.keys():
continue
# Check if the parents of the node are all single follower constant node.
# Check if parents of the node are all
# single follower constant node.
if not are_all_inputs_Constant_with_one_child(g, node):
continue
# Constant folding for the specific node
if constant_folding_nodes[node.op_type](g, node):
logging.debug("Constant nodes and %s %s are folded.",
node.op_type, node.name)
logging.debug(
"Constant nodes and %s %s are folded.",
node.op_type,
node.name,
)
folded = True
keep_folding = True
else:
logging.debug(
"Constant nodes and %s %s are skipped.", node.op_type, node.name)
except Exception as e:
"Constant nodes and %s %s are skipped.",
node.op_type,
node.name,
)
except Exception:
logger.error("An exception is raised while constant folding.")
logger.error(traceback.format_exc())
return folded
def duplicate_constant_node(g):
""" Duplicate the constant node if its following nodes contain constant folding
nodes. Create and link the new constant nodes to the constant folding nodes.
"""
Duplicate the constant node if its following nodes contain
constant folding nodes. Create and link the new constant nodes
to the constant folding nodes.
"""
for node in g.node:
# Find a valid constant node
if node.op_type != 'Constant':
if node.op_type != "Constant":
continue
output_val_info = helper.find_value_by_name(g, node.output[0])
if output_val_info is None:
print("Cannot inference the shape of Const node output: " +
node.output[0])
print(
"Cannot inference the shape of Const node output: "
+ node.output[0]
)
exit(1)
data_shape = helper.get_shape_from_value_info(output_val_info)
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
@ -78,30 +87,37 @@ def duplicate_constant_node(g):
continue
# Check if its following nodes are foldable
foldable_output_nodes = list(filter(lambda n: n.op_type in
constant_folding_nodes.keys(), output_nodes))
foldable_output_nodes = list(
filter(
lambda n: n.op_type in constant_folding_nodes.keys(),
output_nodes,
)
)
if not foldable_output_nodes:
continue
# Duplicate the node needed by foldable nodes
for i in range(len(foldable_output_nodes)):
logging.debug("Found constant %s and %s %s are availble for folding. Duplicate constant.",
node.name, foldable_output_nodes[i].op_type, foldable_output_nodes[i].name)
output_name = node.output[0] + '_dup_' + str(i)
logging.debug(
f"Found constant {node.name} and "
f"{foldable_output_nodes[i].op_type} "
f"{foldable_output_nodes[i].name} are availble for folding. "
"Duplicate constant.",
)
output_name = node.output[0] + "_dup_" + str(i)
new_constant_node = onnx.helper.make_node(
'Constant',
"Constant",
[],
[output_name],
name=output_name,
value=node.attribute[0].t
value=node.attribute[0].t,
)
new_val_info = onnx.helper.make_tensor_value_info(
output_name,
node.attribute[0].t.data_type,
data_shape
output_name, node.attribute[0].t.data_type, data_shape
)
input_ind = list(foldable_output_nodes[i].input).index(
node.output[0])
node.output[0]
)
foldable_output_nodes[i].input[input_ind] = output_name
g.node.extend([new_constant_node])
@ -116,6 +132,7 @@ def duplicate_constant_node(g):
return
def slice_constant_folding(g, node):
op_version = helper.get_current_opset_version()
# only support opset 9 & 11
@ -124,9 +141,9 @@ def slice_constant_folding(g, node):
elif op_version == 9:
return slice_constant_folding_Opset_9(g, node)
def slice_constant_folding_Opset_11(g, node):
""" Fold constant and slice nodes to a single constant node.
"""
"""Fold constant and slice nodes to a single constant node."""
pre_node = helper.find_node_by_output_name(g, node.input[0])
pre_shape, data_list = helper.constant_to_list(pre_node)
@ -136,20 +153,26 @@ def slice_constant_folding_Opset_11(g, node):
ends_node = helper.find_node_by_output_name(g, node.input[2])
_, ends = helper.constant_to_list(ends_node)
axes_node = None if len(node.input) <= 3 else helper.find_node_by_output_name(g, node.input[3])
axes_node = (
None
if len(node.input) <= 3
else helper.find_node_by_output_name(g, node.input[3])
)
if not axes_node:
axes = list(range(len(helper.get_shape(data_list))))
else:
_, axes = helper.constant_to_list(axes_node)
steps_node = None if len(node.input) <= 4 else helper.find_node_by_output_name(g, node.input[4])
steps_node = (
None
if len(node.input) <= 4
else helper.find_node_by_output_name(g, node.input[4])
)
if not steps_node:
steps = [1]*len(helper.get_shape(data_list))
steps = [1] * len(helper.get_shape(data_list))
else:
_, steps = helper.constant_to_list(steps_node)
data_list = list(map(int, data_list))
starts = list(map(int, starts))
ends = list(map(int, ends))
@ -160,10 +183,15 @@ def slice_constant_folding_Opset_11(g, node):
new_data = None
for idx, _ in enumerate(axes):
new_data = np.apply_along_axis( lambda x: x[starts[idx] : ends[idx] : steps[idx]], idx, data_list )
new_data = np.apply_along_axis(
lambda x: x[starts[idx]:ends[idx]:steps[idx]], idx, data_list
)
new_node = helper.list_to_constant(node.output[0], helper.get_shape(
new_data), helper.flatten_to_list(new_data))
new_node = helper.list_to_constant(
node.output[0],
helper.get_shape(new_data),
helper.flatten_to_list(new_data),
)
g.node.extend([new_node])
value_info = helper.find_value_by_name(g, pre_node.output[0])
if value_info is not None:
@ -173,16 +201,16 @@ def slice_constant_folding_Opset_11(g, node):
return True
def slice_constant_folding_Opset_9(g, node):
""" Fold constant and slice nodes to a single constant node.
"""
"""Fold constant and slice nodes to a single constant node."""
pre_node = helper.find_node_by_output_name(g, node.input[0])
pre_shape, data_list = helper.constant_to_list(pre_node)
data_list = np.reshape(data_list, pre_shape)
axes = helper.get_attribute_by_name(node, 'axes')
ends = list(helper.get_attribute_by_name(node, 'ends').ints)
starts = list(helper.get_attribute_by_name(node, 'starts').ints)
axes = helper.get_attribute_by_name(node, "axes")
ends = list(helper.get_attribute_by_name(node, "ends").ints)
starts = list(helper.get_attribute_by_name(node, "starts").ints)
if not axes:
axes = list(range(len(helper.get_shape(data_list))))
@ -190,8 +218,11 @@ def slice_constant_folding_Opset_9(g, node):
axes = list(axes.ints)
new_data = helper.slice_data(data_list, starts, ends, axes)
new_node = helper.list_to_constant(node.output[0], helper.get_shape(
new_data), helper.flatten_to_list(new_data))
new_node = helper.list_to_constant(
node.output[0],
helper.get_shape(new_data),
helper.flatten_to_list(new_data),
)
g.node.extend([new_node])
value_info = helper.find_value_by_name(g, pre_node.output[0])
if value_info is not None:
@ -201,9 +232,9 @@ def slice_constant_folding_Opset_9(g, node):
return True
def cast_constant_folding(g, node):
""" Fold constant and cast node to a single constant node.
"""
"""Fold constant and cast node to a single constant node."""
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node)
data_type = node.attribute[0].i
@ -212,28 +243,24 @@ def cast_constant_folding(g, node):
elif data_type == onnx.helper.TensorProto.FLOAT:
data = list(map(float, data))
else:
raise RuntimeError('data type not supported')
raise RuntimeError("data type not supported")
if shape == 1:
tensor = onnx.helper.make_tensor(
name=pre_node.attribute[0].name,
data_type=data_type,
dims=[],
vals=data
vals=data,
)
else:
tensor = onnx.helper.make_tensor(
name=pre_node.attribute[0].name,
data_type=data_type,
dims=shape,
vals=helper.flatten_to_list(data)
vals=helper.flatten_to_list(data),
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=tensor
"Constant", [], [node.output[0]], name=node.output[0], value=tensor
)
g.node.extend([new_node])
@ -250,15 +277,14 @@ def cast_constant_folding(g, node):
def reduceprod_constant_folding(g, node):
""" Fold constant and reduceprod nodes to a single constant node.
"""
"""Fold constant and reduceprod nodes to a single constant node."""
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data_set = helper.constant_to_list(pre_node)
tensor = pre_node.attribute[0].t
data_set = np.reshape(data_set, shape)
for att in node.attribute:
if att.name == 'axes':
if att.name == "axes":
axes = list(att.ints)
else:
keepdims = int(att.i)
@ -270,14 +296,10 @@ def reduceprod_constant_folding(g, node):
name=node.output[0],
data_type=tensor.data_type,
dims=new_shape,
vals=new_flat_data
vals=new_flat_data,
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
)
g.node.extend([new_node])
@ -294,8 +316,7 @@ def reduceprod_constant_folding(g, node):
def reshape_constant_input_folding(g, node):
""" Fold constant and reshape nodes to a single constant node.
"""
"""Fold constant and reshape nodes to a single constant node."""
pre_data_node = helper.find_node_by_output_name(g, node.input[0])
pre_shape_node = helper.find_node_by_output_name(g, node.input[1])
@ -307,14 +328,10 @@ def reshape_constant_input_folding(g, node):
name=node.output[0],
data_type=pre_data_node.attribute[0].t.data_type,
dims=new_data.shape,
vals=helper.flatten_to_list(new_data)
vals=helper.flatten_to_list(new_data),
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
)
g.node.extend([new_node])
@ -332,8 +349,7 @@ def reshape_constant_input_folding(g, node):
def concat_constant_folding(g, node):
""" Fold constant and concat nodes to a single constant node.
"""
"""Fold constant and concat nodes to a single constant node."""
node_to_del = []
valid_inputs = True
for input_name in node.input:
@ -342,7 +358,7 @@ def concat_constant_folding(g, node):
if len(input_node_output) > 1:
valid_inputs = False
break
if input_node.op_type != 'Constant':
if input_node.op_type != "Constant":
valid_inputs = False
break
@ -370,7 +386,7 @@ def concat_constant_folding(g, node):
node.output[0],
helper.get_shape(concat_data),
helper.flatten_to_list(concat_data),
data_type=node_data_type
data_type=node_data_type,
)
g.node.extend([new_node])
node_to_del.append(node)
@ -388,8 +404,7 @@ def concat_constant_folding(g, node):
def transpose_constant_folding(g, node):
"""Fold constant and transpose nodes to a single constant node.
"""
"""Fold constant and transpose nodes to a single constant node."""
node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node)
@ -402,7 +417,7 @@ def transpose_constant_folding(g, node):
node.output[0],
new_shape,
new_data.flatten().tolist(),
data_type=pre_node.attribute[0].t.data_type
data_type=pre_node.attribute[0].t.data_type,
)
g.node.extend([new_node])
@ -415,9 +430,7 @@ def transpose_constant_folding(g, node):
g.value_info.remove(next_val_info)
new_val_info = onnx.helper.make_tensor_value_info(
node.output[0],
pre_node.attribute[0].t.data_type,
new_shape
node.output[0], pre_node.attribute[0].t.data_type, new_shape
)
g.value_info.extend([new_val_info])
@ -430,8 +443,7 @@ def transpose_constant_folding(g, node):
def unsqueeze_constant_folding(g, node):
"""Fold constant and unsqueeze nodes to a single constant node.
"""
"""Fold constant and unsqueeze nodes to a single constant node."""
node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node)
@ -449,7 +461,7 @@ def unsqueeze_constant_folding(g, node):
node.output[0],
new_shape,
np_data.flatten().tolist(),
data_type=pre_node.attribute[0].t.data_type
data_type=pre_node.attribute[0].t.data_type,
)
g.node.extend([new_node])
node_to_del.extend([node, pre_node])
@ -464,9 +476,7 @@ def unsqueeze_constant_folding(g, node):
g.value_info.remove(next_val_info)
new_val_info = onnx.helper.make_tensor_value_info(
node.output[0],
pre_node.attribute[0].t.data_type,
new_shape
node.output[0], pre_node.attribute[0].t.data_type, new_shape
)
g.value_info.extend([new_val_info])
@ -478,8 +488,7 @@ def unsqueeze_constant_folding(g, node):
def gather_constant_folding(g, node):
"""Fold constant and gather nodes to a single constant node.
"""
"""Fold constant and gather nodes to a single constant node."""
node_to_del = []
pre_data_node = helper.find_node_by_output_name(g, node.input[0])
@ -502,7 +511,7 @@ def gather_constant_folding(g, node):
node.output[0],
new_shape,
new_data.flatten().tolist(),
data_type=pre_data_node.attribute[0].t.data_type
data_type=pre_data_node.attribute[0].t.data_type,
)
node_to_del.extend([node, pre_data_node, pre_indices_node])
@ -512,9 +521,7 @@ def gather_constant_folding(g, node):
val_info_2 = helper.find_value_by_name(g, node.input[1])
val_info_3 = helper.find_value_by_name(g, node.output[0])
new_val_info = onnx.helper.make_tensor_value_info(
new_node.output[0],
pre_data_node.attribute[0].t.data_type,
new_shape
new_node.output[0], pre_data_node.attribute[0].t.data_type, new_shape
)
if val_info_1 is not None:
@ -533,8 +540,7 @@ def gather_constant_folding(g, node):
def add_constant_folding(g, node):
"""Fold constant and add nodes to a single constant node.
"""
"""Fold constant and add nodes to a single constant node."""
node_to_del = []
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
@ -547,14 +553,14 @@ def add_constant_folding(g, node):
np_data2 = np.reshape(data2, shape2)
try:
new_data = np.add(np_data1, np_data2)
except:
raise RuntimeError('can\'t broadcast and add two data sets')
except Exception:
raise RuntimeError("can't broadcast and add two data sets")
new_node = helper.list_to_constant(
node.output[0],
new_data.shape,
new_data.flatten().tolist(),
data_type=pre_node_1.attribute[0].t.data_type
data_type=pre_node_1.attribute[0].t.data_type,
)
g.node.extend([new_node])
@ -571,8 +577,7 @@ def add_constant_folding(g, node):
def sqrt_constant_folding(g, node):
""" Fold constant and sqrt nodes to a single node.
"""
"""Fold constant and sqrt nodes to a single node."""
node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node)
@ -582,17 +587,13 @@ def sqrt_constant_folding(g, node):
data_type = output_val_info.type.tensor_type.elem_type
new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data',
name=node.output[0] + "_data",
data_type=data_type,
dims=shape,
vals=np_data.flatten().tolist()
vals=np_data.flatten().tolist(),
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
)
g.value_info.remove(input_val_info)
@ -607,13 +608,12 @@ def sqrt_constant_folding(g, node):
def reciprocal_constant_folding(g, node):
""" Fold constant and reciprocal nodes to a single constant node.
"""
"""Fold constant and reciprocal nodes to a single constant node."""
node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node)
data = list(map(lambda x: x if abs(x) > 1.e-8 else 1.e-8, data))
data = list(map(lambda x: x if abs(x) > 1.0e-8 else 1.0e-8, data))
np_data = np.reshape(data, shape)
np_data = np.reciprocal(np_data)
@ -622,17 +622,13 @@ def reciprocal_constant_folding(g, node):
data_type = output_val_info.type.tensor_type.elem_type
new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data',
name=node.output[0] + "_data",
data_type=data_type,
dims=shape,
vals=np_data.flatten().tolist()
vals=np_data.flatten().tolist(),
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
)
node_to_del.extend([node, pre_node])
@ -648,8 +644,7 @@ def reciprocal_constant_folding(g, node):
def mul_constant_folding(g, node):
""" Fold constant and mul nodes to a single constant node.
"""
"""Fold constant and mul nodes to a single constant node."""
node_to_del = []
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
@ -666,8 +661,8 @@ def mul_constant_folding(g, node):
try:
new_data = np.multiply(np_data1, np_data2)
except:
raise RuntimeError('can not broadcast and multiply two data sets')
except Exception:
raise RuntimeError("can not broadcast and multiply two data sets")
# Special shape for single element.
if shape1 == 1 and shape2 == 1:
@ -676,17 +671,13 @@ def mul_constant_folding(g, node):
new_shape = new_data.shape
new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data',
name=node.output[0] + "_data",
data_type=pre_node_1.attribute[0].t.data_type,
dims=new_shape,
vals=new_data.flatten().tolist()
vals=new_data.flatten().tolist(),
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
)
node_to_del.extend([node, pre_node_1, pre_node_2])
@ -703,8 +694,7 @@ def mul_constant_folding(g, node):
def div_constant_folding(g, node):
""" Fold constant and mul nodes to a single constant node.
"""
"""Fold constant and mul nodes to a single constant node."""
node_to_del = []
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
@ -721,8 +711,8 @@ def div_constant_folding(g, node):
try:
new_data = np.divide(np_data1, np_data2)
except:
raise RuntimeError('can not broadcast and multiply two data sets')
except Exception:
raise RuntimeError("can not broadcast and multiply two data sets")
# Special shape for single element.
if shape1 == 1 and shape2 == 1:
@ -732,20 +722,16 @@ def div_constant_folding(g, node):
# Check data type if it is int
if pre_node_1.attribute[0].t.data_type == 7:
new_data = new_data.astype('int64')
new_data = new_data.astype("int64")
new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data',
name=node.output[0] + "_data",
data_type=pre_node_1.attribute[0].t.data_type,
dims=new_shape,
vals=new_data.flatten().tolist()
vals=new_data.flatten().tolist(),
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
)
node_to_del.extend([node, pre_node_1, pre_node_2])
@ -762,8 +748,7 @@ def div_constant_folding(g, node):
def sub_constant_folding(g, node):
""" Fold constant and sub nodes to a single node.
"""
"""Fold constant and sub nodes to a single node."""
node_to_del = []
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
@ -781,17 +766,13 @@ def sub_constant_folding(g, node):
new_shape = new_data.shape
new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data',
name=node.output[0] + "_data",
data_type=pre_node_1.attribute[0].t.data_type,
dims=new_shape,
vals=helper.flatten_to_list(new_data)
vals=helper.flatten_to_list(new_data),
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
)
g.node.extend([new_node])
@ -815,17 +796,13 @@ def neg_constant_folding(g, node):
new_data_list = [-num for num in data_list]
new_tensor = onnx.helper.make_tensor(
name=pre_node.name+'_neg_tensor',
name=pre_node.name + "_neg_tensor",
data_type=pre_node.attribute[0].t.data_type,
dims=shape,
vals=new_data_list
vals=new_data_list,
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
)
g.node.extend([new_node])
@ -851,17 +828,13 @@ def floor_constant_folding(g, node):
new_shape = shape
new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data',
name=node.output[0] + "_data",
data_type=pre_node.attribute[0].t.data_type,
dims=new_shape,
vals=helper.flatten_to_list(new_data)
vals=helper.flatten_to_list(new_data),
)
new_node = onnx.helper.make_node(
'Constant',
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
)
g.node.extend([new_node])
@ -877,8 +850,7 @@ def floor_constant_folding(g, node):
def bn_constant_folding(g, node):
""" Fold constant and mul nodes to a single constant node.
"""
"""Fold constant and mul nodes to a single constant node."""
# Prepare data
node_to_del = []
input_node = helper.find_node_by_output_name(g, node.input[0])
@ -900,17 +872,22 @@ def bn_constant_folding(g, node):
mean_data = helper.constant_to_numpy(mean_node)
var_data = helper.constant_to_numpy(var_node)
epsilon = helper.get_var_attribute_by_name(node, 'epsilon', 'float')
epsilon = helper.get_var_attribute_by_name(node, "epsilon", "float")
if epsilon is None:
epsilon = 0.00001
# Calculate new node
new_data = scale_data * (input_data - mean_data) / np.sqrt(var_data + epsilon) + bias_data
new_data = (
scale_data * (input_data - mean_data) / np.sqrt(var_data + epsilon)
+ bias_data
)
new_node = helper.numpy_to_constant(node.output[0], new_data)
# Reconnect the graph
node_to_del.extend([node, input_node, scale_node, bias_node, mean_node, var_node])
node_to_del.extend(
[node, input_node, scale_node, bias_node, mean_node, var_node]
)
g.node.extend([new_node])
for value in input_value_info:
@ -925,8 +902,7 @@ def bn_constant_folding(g, node):
def DequantizeLinear_constant_folding(g, node):
""" Fold constant and mul nodes to a single constant node.
"""
"""Fold constant and mul nodes to a single constant node."""
# Prepare data
node_to_del = []
x_node = helper.find_node_by_output_name(g, node.input[0])
@ -951,7 +927,9 @@ def DequantizeLinear_constant_folding(g, node):
x_zero_point_data = np.array([0.0])
# Calculate new node
new_data = (x_data.astype(np.float32) - x_zero_point_data.astype(np.float32)) * x_scale_data
new_data = (
x_data.astype(np.float32) - x_zero_point_data.astype(np.float32)
) * x_scale_data
new_node = helper.numpy_to_constant(node.output[0], new_data)
@ -974,22 +952,22 @@ def DequantizeLinear_constant_folding(g, node):
# Available constant folding names to function map.
constant_folding_nodes = {
'Add': add_constant_folding,
'BatchNormalization': bn_constant_folding,
'Cast': cast_constant_folding,
'Concat': concat_constant_folding,
'DequantizeLinear': DequantizeLinear_constant_folding,
'Div': div_constant_folding,
'Floor': floor_constant_folding,
'Gather': gather_constant_folding,
'Mul': mul_constant_folding,
'Reciprocal': reciprocal_constant_folding,
'ReduceProd': reduceprod_constant_folding,
'Reshape': reshape_constant_input_folding,
'Slice': slice_constant_folding,
'Sqrt': sqrt_constant_folding,
'Transpose': transpose_constant_folding,
'Unsqueeze': unsqueeze_constant_folding,
'Sub': sub_constant_folding,
'Neg': neg_constant_folding
"Add": add_constant_folding,
"BatchNormalization": bn_constant_folding,
"Cast": cast_constant_folding,
"Concat": concat_constant_folding,
"DequantizeLinear": DequantizeLinear_constant_folding,
"Div": div_constant_folding,
"Floor": floor_constant_folding,
"Gather": gather_constant_folding,
"Mul": mul_constant_folding,
"Reciprocal": reciprocal_constant_folding,
"ReduceProd": reduceprod_constant_folding,
"Reshape": reshape_constant_input_folding,
"Slice": slice_constant_folding,
"Sqrt": sqrt_constant_folding,
"Transpose": transpose_constant_folding,
"Unsqueeze": unsqueeze_constant_folding,
"Sub": sub_constant_folding,
"Neg": neg_constant_folding,
}

View File

@ -7,6 +7,7 @@ from . import helper
from . import modhelper
from .general_graph import Graph
def eliminate_Identify_and_Dropout(g):
"""
Eliminate Identify layers
@ -15,31 +16,46 @@ def eliminate_Identify_and_Dropout(g):
"""
node_to_remove = []
for node in g.node:
if node.op_type != 'Identity' and node.op_type != 'Dropout':
if node.op_type != "Identity" and node.op_type != "Dropout":
continue
# If this node is the last node, leave it to `eliminate_useless_last node`
# If this node is the last, leave it to `eliminate_useless_last node`
if helper.find_output_by_name(g, node.output[0]) is not None:
continue
# Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
modhelper.replace_node_input(
following_node, node.output[0], node.input[0]
)
# Delete value info
value_between = helper.find_value_by_name(g, node.output[0])
try:
g.value_info.remove(value_between)
except:
except Exception:
print("No value info to delete while eliminating identity layers.")
# Node is waiting for elimination
node_to_remove.append(node)
for node in node_to_remove:
g.node.remove(node)
# Remove last useless nodes
def remove_useless_last_nodes(g):
"""Remove useless nodes from the tail of the graph
"""
USELESS = ["Reshape", "Identity", "Transpose", "Flatten", "Dropout", "Mystery", "Constant", "Squeeze", "Unsqueeze", 'Softmax']
"""Remove useless nodes from the tail of the graph"""
USELESS = [
"Reshape",
"Identity",
"Transpose",
"Flatten",
"Dropout",
"Mystery",
"Constant",
"Squeeze",
"Unsqueeze",
"Softmax",
]
graph = Graph(g)
todo = collections.deque()
for node in graph.output_nodes:
@ -54,19 +70,30 @@ def remove_useless_last_nodes(g):
if cur_node.proto.op_type not in USELESS:
continue
# Find the output
cur_node_output = helper.find_output_by_name(g, cur_node.proto.output[0])
cur_node_output = helper.find_output_by_name(
g, cur_node.proto.output[0]
)
for cur_input in cur_node.parents:
cur_input.children.remove(cur_node)
if len(cur_input.children) == 0:
todo.append(cur_input)
if cur_node_output is not None:
cur_input_output = helper.find_value_by_name(g, cur_input.proto.output[0])
cur_input_output_in_output = helper.find_output_by_name(g, cur_input.proto.output[0])
if cur_input_output is not None and cur_input_output_in_output is None:
cur_input_output = helper.find_value_by_name(
g, cur_input.proto.output[0]
)
cur_input_output_in_output = helper.find_output_by_name(
g, cur_input.proto.output[0]
)
if (
cur_input_output is not None
and cur_input_output_in_output is None
):
g.output.extend([cur_input_output])
node_to_remove.append(cur_node.proto)
try:
g.value_info.remove(helper.find_value_by_name(g, cur_node.proto.output[0]))
g.value_info.remove(
helper.find_value_by_name(g, cur_node.proto.output[0])
)
except ValueError:
pass
if cur_node_output is not None:
@ -76,10 +103,12 @@ def remove_useless_last_nodes(g):
for node in node_to_remove:
g.node.remove(node)
######################################
# TF only optimization passes #
######################################
def eliminate_shape_changing_after_input(g):
"""
Eliminate the Reshape node after input and reshape the input
@ -87,7 +116,14 @@ def eliminate_shape_changing_after_input(g):
:param g: the onnx graph
"""
node_to_remove = []
REMOVE_LIST = ["Reshape", "Transpose", "Flatten", "Dropout", "Squeeze", "Unsqueeze"]
REMOVE_LIST = [
"Reshape",
"Transpose",
"Flatten",
"Dropout",
"Squeeze",
"Unsqueeze",
]
for node in g.node:
# Find an input and the shape node
if node.op_type not in REMOVE_LIST:
@ -105,9 +141,9 @@ def eliminate_shape_changing_after_input(g):
# Remove Weight if any.
output_val_info = helper.find_value_by_name(g, node.output[0])
if node.op_type == 'Reshape':
if node.op_type == "Reshape":
shape_node = helper.find_node_by_output_name(g, node.input[1])
if shape_node.op_type != 'Constant':
if shape_node.op_type != "Constant":
continue
# manuelly set the input shape
@ -117,25 +153,29 @@ def eliminate_shape_changing_after_input(g):
_, new_shape = helper.constant_to_list(shape_node)
for i in range(len(new_shape)):
if new_shape[i] == -1:
dim = int(old_size//np.prod(new_shape)*(-1))
dim = int(old_size // np.prod(new_shape) * (-1))
new_shape[i] = dim
new_input = onnx.helper.make_tensor_value_info(
output_val_info.name,
output_val_info.type.tensor_type.elem_type,
new_shape
new_shape,
)
node_to_remove.append(node)
shape_outputs = helper.find_nodes_by_input_name(g, shape_node.output[0])
shape_outputs = helper.find_nodes_by_input_name(
g, shape_node.output[0]
)
if len(shape_outputs) == 1:
node_to_remove.append(shape_node)
g.value_info.remove(helper.find_value_by_name(g, shape_node.output[0]))
g.value_info.remove(
helper.find_value_by_name(g, shape_node.output[0])
)
g.input.remove(old_input)
g.input.extend([new_input])
g.value_info.remove(output_val_info)
elif node.op_type == 'Transpose':
elif node.op_type == "Transpose":
permutation = list(node.attribute[0].ints)
pre_shape = helper.get_shape_from_value_info(old_input)
new_shape = [pre_shape[i] for i in permutation]
@ -143,7 +183,7 @@ def eliminate_shape_changing_after_input(g):
new_input = onnx.helper.make_tensor_value_info(
output_val_info.name,
output_val_info.type.tensor_type.elem_type,
new_shape
new_shape,
)
node_to_remove.append(node)
@ -151,7 +191,7 @@ def eliminate_shape_changing_after_input(g):
g.input.remove(old_input)
g.input.extend([new_input])
g.value_info.remove(output_val_info)
elif node.op_type == 'Flatten':
elif node.op_type == "Flatten":
axis = node.attribute[0].int
pre_shape = helper.get_shape_from_value_info(old_input)
dim_1, dim_2 = 1, 1
@ -166,7 +206,7 @@ def eliminate_shape_changing_after_input(g):
new_input = onnx.helper.make_tensor_value_info(
output_val_info.name,
output_val_info.type.tensor_type.elem_type,
new_shape
new_shape,
)
node_to_remove.append(node)
@ -174,18 +214,18 @@ def eliminate_shape_changing_after_input(g):
g.input.remove(old_input)
g.input.extend([new_input])
g.value_info.remove(output_val_info)
elif node.op_type == 'Dropout':
elif node.op_type == "Dropout":
g.input.remove(old_input)
g.input.extend([output_val_info])
g.value_info.remove(output_val_info)
node_to_remove.append(node)
elif node.op_type == 'Squeeze':
elif node.op_type == "Squeeze":
axis = list(node.attribute[0].ints)
pre_shape = helper.get_shape_from_value_info(old_input)
for pos in sorted(axis)[::-1]:
if pre_shape[pos] != 1:
raise RuntimeError('invalid axis for squeeze')
raise RuntimeError("invalid axis for squeeze")
else:
pre_shape.pop(pos)
new_shape = pre_shape
@ -193,7 +233,7 @@ def eliminate_shape_changing_after_input(g):
new_input = onnx.helper.make_tensor_value_info(
output_val_info.name,
output_val_info.type.tensor_type.elem_type,
new_shape
new_shape,
)
node_to_remove.append(node)
@ -201,7 +241,7 @@ def eliminate_shape_changing_after_input(g):
g.input.remove(old_input)
g.input.extend([new_input])
g.value_info.remove(output_val_info)
elif node.op_type == 'Unsqueeze':
elif node.op_type == "Unsqueeze":
axis = list(node.attribute[0].ints)
pre_shape = helper.get_shape_from_value_info(old_input)
new_shape = pre_shape
@ -210,7 +250,7 @@ def eliminate_shape_changing_after_input(g):
new_input = onnx.helper.make_tensor_value_info(
output_val_info.name,
output_val_info.type.tensor_type.elem_type,
new_shape
new_shape,
)
node_to_remove.append(node)
@ -222,7 +262,7 @@ def eliminate_shape_changing_after_input(g):
for node in node_to_remove:
g.node.remove(node)
other.topological_sort(g)
@ -231,15 +271,13 @@ def eliminate_Reshape_Cast(g):
:param g: the onnx graph
"""
#Find all reshape layers
node_to_remove = []
# Find all reshape layers
for node in g.node:
if node.op_type != 'Reshape':
if node.op_type != "Reshape":
continue
prev_node = helper.find_node_by_output_name(g, node.input[1])
if prev_node.op_type != 'Cast':
if prev_node.op_type != "Cast":
continue
# Now we find the cast weight pattern. Cast the weight, delete the cast.
reshape_node = node
cast_node = prev_node
weight_node = helper.find_node_by_output_name(g, cast_node.input[0])
@ -248,10 +286,12 @@ def eliminate_Reshape_Cast(g):
weight_node.attribute[0].t.data_type = 7
if weight_node.attribute[0].t.raw_data:
raw_data = weight_node.attribute[0].t.raw_data
int_data = [i[0] for i in struct.iter_unpack('i', raw_data)]
raw_data = struct.pack('q' * len(int_data), *int_data)
elif len(weight_node.attribute[0].t.int64_data) > 0\
or len(weight_node.attribute[0].t.int32_data) > 0:
int_data = [i[0] for i in struct.iter_unpack("i", raw_data)]
raw_data = struct.pack("q" * len(int_data), *int_data)
elif (
len(weight_node.attribute[0].t.int64_data) > 0
or len(weight_node.attribute[0].t.int32_data) > 0
):
# It's already int. Do nothing
pass
else:
@ -264,6 +304,7 @@ def eliminate_Reshape_Cast(g):
g.value_info.remove(origin_weight_out)
g.node.remove(cast_node)
def eliminate_Cast_after_input(g):
"""Eliminate the cast layer right after the input
@ -271,7 +312,7 @@ def eliminate_Cast_after_input(g):
"""
node_to_remove = []
for node in g.node:
if node.op_type != 'Cast':
if node.op_type != "Cast":
continue
old_input = helper.find_input_by_name(g, node.input[0])
if old_input is None:
@ -279,9 +320,7 @@ def eliminate_Cast_after_input(g):
next_val_info = helper.find_value_by_name(g, node.output[0])
shape = helper.get_shape_from_value_info(next_val_info)
new_val_info = onnx.helper.make_tensor_value_info(
next_val_info.name,
node.attribute[0].i,
shape
next_val_info.name, node.attribute[0].i, shape
)
# Delete old value_info
g.input.remove(old_input)
@ -293,6 +332,7 @@ def eliminate_Cast_after_input(g):
for node in node_to_remove:
g.node.remove(node)
def eliminate_consecutive_Cast(g):
"""If two cast is next to each other, remove the first cast
@ -300,10 +340,10 @@ def eliminate_consecutive_Cast(g):
"""
node_to_remove = []
for node in g.node:
if node.op_type != 'Cast':
if node.op_type != "Cast":
continue
first_node = helper.find_node_by_output_name(g, node.input[0])
if first_node is None or first_node.op_type != 'Cast':
if first_node is None or first_node.op_type != "Cast":
continue
# Here we have two consecutive Cast Node
# Reset the input of the later node
@ -315,6 +355,7 @@ def eliminate_consecutive_Cast(g):
for node in node_to_remove:
g.node.remove(node)
def eliminate_Squeeze_before_Reshape(g):
"""If Squeeze and Reshape is next to each other, remove the first node
@ -322,12 +363,12 @@ def eliminate_Squeeze_before_Reshape(g):
"""
node_to_remove = []
for node in g.node:
if node.op_type != 'Reshape':
if node.op_type != "Reshape":
continue
first_node = helper.find_node_by_output_name(g, node.input[0])
if not first_node:
continue
if first_node.op_type != 'Squeeze':
if first_node.op_type != "Squeeze":
continue
# Here we have two consecutive Cast Node
# Reset the input of the later node
@ -339,9 +380,9 @@ def eliminate_Squeeze_before_Reshape(g):
for node in node_to_remove:
g.node.remove(node)
def eliminate_no_children_input(g):
"""Eliminate inputs with no children at all.
"""
"""Eliminate inputs with no children at all."""
# Create a set of input names
input_names = set([i.name for i in g.input])
# If a name is used in any node, remove this name from the set.
@ -353,31 +394,33 @@ def eliminate_no_children_input(g):
info = helper.find_input_by_name(g, i)
g.input.remove(info)
def eliminate_consecutive_reshape(g):
"""Replace consecutive reshape nodes by a single node.
"""
"""Replace consecutive reshape nodes by a single node."""
node_to_del = []
for node in g.node:
if node.op_type != 'Reshape':
if node.op_type != "Reshape":
continue
pre_data_node = helper.find_node_by_output_name(g, node.input[0])
pre_shape_node = helper.find_node_by_output_name(g, node.input[1])
if not pre_data_node or not pre_shape_node:
continue
if pre_shape_node.op_type != 'Constant':
if pre_shape_node.op_type != "Constant":
continue
if pre_data_node.op_type != 'Reshape':
if pre_data_node.op_type != "Reshape":
continue
pre_pre_shape_node = helper.find_node_by_output_name(g, pre_data_node.input[1])
if pre_pre_shape_node.op_type != 'Constant':
pre_pre_shape_node = helper.find_node_by_output_name(
g, pre_data_node.input[1]
)
if pre_pre_shape_node.op_type != "Constant":
continue
new_reshape_node = onnx.helper.make_node(
'Reshape',
"Reshape",
[pre_data_node.input[0], node.input[1]],
[node.output[0]],
name = node.output[0]
name=node.output[0],
)
g.node.extend([new_reshape_node])
@ -394,6 +437,7 @@ def eliminate_consecutive_reshape(g):
node = node_to_del.pop()
g.node.remove(node)
def eliminate_single_input_Concat(g):
"""
Eliminate single input Concat layers
@ -402,12 +446,12 @@ def eliminate_single_input_Concat(g):
"""
node_to_remove = []
for node in g.node:
if node.op_type != 'Concat':
if node.op_type != "Concat":
continue
# If this node has more than 1 input, continue.
if len(node.input) > 1:
continue
# If this node is the output node, set its previous node as output nodes.
# If this node is output node, set its previous node as output nodes.
if helper.find_output_by_name(g, node.output[0]) is not None:
todel_output = helper.find_output_by_name(g, node.output[0])
the_input_value = helper.find_value_by_name(g, node.input[0])
@ -416,20 +460,25 @@ def eliminate_single_input_Concat(g):
node_to_remove.append(node)
continue
# Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
modhelper.replace_node_input(
following_node, node.output[0], node.input[0]
)
# Delete value info
value_between = helper.find_value_by_name(g, node.output[0])
try:
g.value_info.remove(value_between)
except:
except Exception:
print("No value info to delete while eliminating identity layers.")
# Node is waiting for elimination
node_to_remove.append(node)
for node in node_to_remove:
g.node.remove(node)
def eliminate_nop_Maxpool_and_AveragePool(g):
"""
Eliminate do nothing MaxPool and AveragePool layers.
@ -439,7 +488,7 @@ def eliminate_nop_Maxpool_and_AveragePool(g):
"""
node_to_remove = []
for node in g.node:
if node.op_type != 'MaxPool' and node.op_type != 'AveragePool':
if node.op_type != "MaxPool" and node.op_type != "AveragePool":
continue
# If this node is actually working, continue.
kernel = helper.get_list_attribute_by_name(node, "kernel_shape", "int")
@ -447,7 +496,7 @@ def eliminate_nop_Maxpool_and_AveragePool(g):
strides = helper.get_list_attribute_by_name(node, "strides", "int")
if kernel != [1, 1] or pads != [0, 0, 0, 0] or strides != [1, 1]:
continue
# If this node is the output node, set its previous node as output nodes.
# If this node is the output, set its previous node as output nodes.
if helper.find_output_by_name(g, node.output[0]) is not None:
todel_output = helper.find_output_by_name(g, node.output[0])
the_input_value = helper.find_value_by_name(g, node.input[0])
@ -456,14 +505,18 @@ def eliminate_nop_Maxpool_and_AveragePool(g):
node_to_remove.append(node)
continue
# Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
modhelper.replace_node_input(
following_node, node.output[0], node.input[0]
)
# Delete value info
value_between = helper.find_value_by_name(g, node.output[0])
try:
g.value_info.remove(value_between)
except:
except Exception:
print("No value info to delete while eliminating identity layers.")
# Node is waiting for elimination
node_to_remove.append(node)
@ -474,20 +527,20 @@ def eliminate_nop_Maxpool_and_AveragePool(g):
def eliminate_trivial_maxpool(g):
node_to_del = []
for node in g.node:
if node.op_type != 'MaxPool':
if node.op_type != "MaxPool":
continue
pads = None
strides = None
dilation = None
kernel_shape = None
for att in node.attribute:
if att.name == 'pads':
if att.name == "pads":
pads = list(att.ints)
elif att.name == 'strides':
elif att.name == "strides":
strides = list(att.ints)
elif att.name == 'kernel_shape':
elif att.name == "kernel_shape":
kernel_shape = list(att.ints)
elif att.name == 'dilation':
elif att.name == "dilation":
dilation = list(att.ints)
else:
pass
@ -504,7 +557,7 @@ def eliminate_trivial_maxpool(g):
next_nodes = helper.find_nodes_by_input_name(g, node.output[0])
if next_nodes[0] == None:
if next_nodes[0] is None:
output_value = helper.find_output_by_name(g, node.output[0])
if not output_value:
continue
@ -512,18 +565,21 @@ def eliminate_trivial_maxpool(g):
pre_val_info = helper.find_value_by_name(g, node.input[0])
g.output.extend([pre_val_info])
g.output.remove(output_value)
for next_node in next_nodes:
modhelper.replace_node_input(next_node, node.output[0], node.input[0])
modhelper.replace_node_input(
next_node, node.output[0], node.input[0]
)
next_val_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(next_val_info)
while node_to_del:
g.node.remove(node_to_del.pop())
other.topological_sort(g)
def eliminate_empty_value_infos(g):
to_remove = []
for value_info in g.value_info:
@ -532,10 +588,11 @@ def eliminate_empty_value_infos(g):
for value_info in to_remove:
g.value_info.remove(value_info)
def eliminate_nop_pads(g):
node_to_remove = []
for node in g.node:
if node.op_type != 'Pad':
if node.op_type != "Pad":
continue
# Check if the Pad is empty or not
pads_node = helper.find_node_by_output_name(g, node.input[1])
@ -546,11 +603,7 @@ def eliminate_nop_pads(g):
all_zero = False
if not all_zero:
continue
# Check if it has the constant_value_node
constant_value_node = None
if len(node.input) > 2:
constant_value_node = helper.find_node_by_output_name(g, node.input[2])
# If this node is the output node, set its previous node as output nodes.
# If this node is the output, set its previous node as output nodes.
if helper.find_output_by_name(g, node.output[0]) is not None:
todel_output = helper.find_output_by_name(g, node.output[0])
g.output.remove(todel_output)
@ -559,38 +612,44 @@ def eliminate_nop_pads(g):
if the_input_value is not None:
g.output.extend([the_input_value])
# Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
modhelper.replace_node_input(
following_node, node.output[0], node.input[0]
)
# Delete value info
value_between = helper.find_value_by_name(g, node.output[0])
try:
g.value_info.remove(value_between)
except:
helper.logger.info("No value info to delete while eliminating identity layers.")
except Exception:
helper.logger.info(
"No value info to delete while eliminating identity layers."
)
# Node is waiting for elimination
node_to_remove.append(node)
for node in node_to_remove:
g.node.remove(node)
def eliminate_trivial_elementwise_calculation(g):
"""Eliminate Add, Sub, Mul, Sub nodes which do nothing.
"""
"""Eliminate Add, Sub, Mul, Sub nodes which do nothing."""
node_to_remove = []
for node in g.node:
weight_node = None
if node.op_type == 'Add' or node.op_type == 'Sub':
if node.op_type == "Add" or node.op_type == "Sub":
# For add and sub, check if the weights are 0s.
weight_node = helper.find_node_by_output_name(g, node.input[1])
if weight_node is None or weight_node.op_type != 'Constant':
if weight_node is None or weight_node.op_type != "Constant":
continue
weight_np = helper.constant_to_numpy(weight_node)
if np.any(weight_np):
continue
elif node.op_type == 'Mul' or node.op_type == 'Div':
elif node.op_type == "Mul" or node.op_type == "Div":
# For Mul and Div, check if the weights are 1s.
weight_node = helper.find_node_by_output_name(g, node.input[1])
if weight_node is None or weight_node.op_type != 'Constant':
if weight_node is None or weight_node.op_type != "Constant":
continue
weight_np = helper.constant_to_numpy(weight_node)
weight_np = weight_np - 1
@ -605,9 +664,13 @@ def eliminate_trivial_elementwise_calculation(g):
if output_value_info is not None:
g.value_info.remove(output_value_info)
# Replace next node input if any.
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
modhelper.replace_node_input(
following_node, node.output[0], node.input[0]
)
todel_output = helper.find_output_by_name(g, node.output[0])
if todel_output is not None:
g.output.remove(todel_output)
@ -616,38 +679,53 @@ def eliminate_trivial_elementwise_calculation(g):
the_input_value = helper.find_value_by_name(g, node.input[0])
g.output.extend([the_input_value])
# Delete the constant node if it is not used by other nodes
constant_following_nodes = helper.find_following_nodes_by_input_value_name(g, weight_node.output[0])
constant_following_nodes = (
helper.find_following_nodes_by_input_value_name(
g, weight_node.output[0]
)
)
if len(constant_following_nodes) == 1:
node_to_remove.append(weight_node)
output_value_info = helper.find_value_by_name(g, weight_node.output[0])
output_value_info = helper.find_value_by_name(
g, weight_node.output[0]
)
if output_value_info is not None:
g.value_info.remove(output_value_info)
for node in node_to_remove:
g.node.remove(node)
def eliminate_nop_cast(g):
"""Eliminate do nothing Cast nodes.
"""
"""Eliminate do nothing Cast nodes."""
node_to_remove = []
for node in g.node:
if node.op_type != 'Cast':
if node.op_type != "Cast":
continue
# Get input value_info
input_value = helper.find_value_by_name(g, node.input[0])
if input_value is None:
helper.logger.debug(f"Cannot find the input value_info for Cast node {node.name}. Skip elimination check.")
helper.logger.debug(
f"Cannot find the input value_info for Cast node {node.name}. "
"Skip elimination check."
)
continue
# Get output value_info
output_value = helper.find_value_by_name(g, node.output[0])
if output_value is None:
output_value = helper.find_output_by_name(g, node.output[0])
if output_value is None:
helper.logger.debug(f"Cannot find the output value_info for Cast node {node.name}. Skip elimination check.")
helper.logger.debug(
f"Cannot find the output value_info for Cast node {node.name}."
" Skip elimination check."
)
continue
# Compare the type.
if input_value.type.tensor_type.elem_type != output_value.type.tensor_type.elem_type:
if (
input_value.type.tensor_type.elem_type
!= output_value.type.tensor_type.elem_type
):
continue
# If this node is the output node, set its previous node as output nodes.
# If this node is the output, set its previous node as output nodes.
if helper.find_output_by_name(g, node.output[0]) is not None:
todel_output = helper.find_output_by_name(g, node.output[0])
g.output.remove(todel_output)
@ -656,9 +734,13 @@ def eliminate_nop_cast(g):
if the_input_value is not None:
g.output.extend([the_input_value])
# Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
modhelper.replace_node_input(
following_node, node.output[0], node.input[0]
)
# Delete value info
value_between = helper.find_value_by_name(g, node.output[0])
if value_between is not None:

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,11 @@
from collections import deque
class Node:
"""A Node which maps a node proto. It has pointers to its parents and
children.
"""
def __init__(self, onnx_node):
"""Initialize a node. This initialization only set up the mapping to
node proto. The pointers should be set up by outside.
@ -17,12 +19,12 @@ class Node:
self.name = onnx_node.name
self.proto = onnx_node
class Graph:
"""A graph which is constructed from the onnx proto.
"""
"""A graph which is constructed from the onnx proto."""
def __init__(self, onnx_graph):
"""Construct the graph from onnx.
"""
"""Construct the graph from onnx."""
self.input_nodes = []
self.output_nodes = []
self.name2node = {}
@ -51,9 +53,9 @@ class Graph:
for value in onnx_graph.value_info:
node = self.output2node[value.name]
node.output_value = value
def get_sorted_node_list(self):
"""Return a node list in topological order.
"""
"""Return a node list in topological order."""
visited = set()
todo = deque()
result = []

View File

@ -6,21 +6,26 @@ import struct
import numpy as np
import logging
__ONNX_VERSION__ = -1
__ONNX_VERSION__ = -1
logger = logging.getLogger("optimizer_scripts")
def setup_current_opset_version(m):
global __ONNX_VERSION__
__ONNX_VERSION__ = m.opset_import[0].version
if __ONNX_VERSION__ not in [11]:
raise RuntimeError('Only support opset 11, but got ' + str(__ONNX_VERSION__))
raise RuntimeError(
"Only support opset 11, but got " + str(__ONNX_VERSION__)
)
def get_current_opset_version():
if __ONNX_VERSION__ == -1:
raise RuntimeError('do setup_current_opset_version first please')
raise RuntimeError("do setup_current_opset_version first please")
return __ONNX_VERSION__
def find_nodes_by_input_name(g, name):
nodes = []
for node in g.node:
@ -28,6 +33,7 @@ def find_nodes_by_input_name(g, name):
nodes.append(node)
return nodes
def find_node_by_output_name(g, name):
"""
Find a node in the graph by its output name
@ -41,6 +47,7 @@ def find_node_by_output_name(g, name):
return i
return None
def find_node_by_node_name(g, name):
"""
Find a node in the graph by its output name
@ -54,6 +61,7 @@ def find_node_by_node_name(g, name):
return i
return None
def find_following_nodes_by_input_value_name(g, name):
""" Find the following nodes of a specific value.
@ -63,6 +71,7 @@ def find_following_nodes_by_input_value_name(g, name):
"""
return find_nodes_by_input_name(g, name)
def find_value_by_name(g, name):
"""
Find a value_info in the graph by name
@ -76,6 +85,7 @@ def find_value_by_name(g, name):
return i
return None
def find_output_by_name(g, name):
"""
Find a value_info in the graph by name
@ -89,6 +99,7 @@ def find_output_by_name(g, name):
return i
return None
def find_input_by_name(g, name):
"""
Find a input in the graph by name
@ -102,6 +113,7 @@ def find_input_by_name(g, name):
return i
return None
def list_to_constant(name, shape, data, data_type=None):
"""Generate a constant node using the given infomation.
@ -119,18 +131,9 @@ def list_to_constant(name, shape, data, data_type=None):
data_type = onnx.helper.TensorProto.INT64
else:
data_type = onnx.helper.TensorProto.FLOAT
tensor = onnx.helper.make_tensor(
name,
data_type,
shape,
data
)
tensor = onnx.helper.make_tensor(name, data_type, shape, data)
new_w_node = onnx.helper.make_node(
"Constant",
[],
[name],
name = name,
value = tensor
"Constant", [], [name], name=name, value=tensor
)
return new_w_node
@ -151,18 +154,9 @@ def scaler_to_constant(name, data, data_type=None):
else:
logger.error("Cannot create scaler constant with a list.")
exit(1)
tensor = onnx.helper.make_tensor(
name,
data_type,
None,
[data]
)
tensor = onnx.helper.make_tensor(name, data_type, None, [data])
new_w_node = onnx.helper.make_node(
"Constant",
[],
[name],
name = name,
value = tensor
"Constant", [], [name], name=name, value=tensor
)
return new_w_node
@ -170,6 +164,7 @@ def scaler_to_constant(name, data, data_type=None):
def numpy_to_constant(name, np_array):
return list_to_constant(name, np_array.shape, np_array.flatten().tolist())
def constant_to_list(node):
"""Generate a list from the constant node
@ -184,27 +179,27 @@ def constant_to_list(node):
if len(tensor.int32_data) != 0:
data = list(tensor.int32_data)
else:
data = [i[0] for i in struct.iter_unpack('i', tensor.raw_data)]
data = [i[0] for i in struct.iter_unpack("i", tensor.raw_data)]
elif tensor.data_type == onnx.helper.TensorProto.INT64:
if len(tensor.int64_data) != 0:
data = list(tensor.int64_data)
else:
data = [i[0] for i in struct.iter_unpack('q', tensor.raw_data)]
data = [i[0] for i in struct.iter_unpack("q", tensor.raw_data)]
elif tensor.data_type == onnx.helper.TensorProto.INT8:
if len(tensor.int32_data) != 0:
data = list(tensor.int32_data)
else:
data = [i[0] for i in struct.iter_unpack('b', tensor.raw_data)]
data = [i[0] for i in struct.iter_unpack("b", tensor.raw_data)]
elif tensor.data_type == onnx.helper.TensorProto.FLOAT:
if len(tensor.float_data) != 0:
data = list(tensor.float_data)
else:
data = [i[0] for i in struct.iter_unpack('f', tensor.raw_data)]
data = [i[0] for i in struct.iter_unpack("f", tensor.raw_data)]
elif tensor.data_type == onnx.helper.TensorProto.DOUBLE:
if len(tensor.double_data) != 0:
data = list(tensor.double_data)
else:
data = [i[0] for i in struct.iter_unpack('d', tensor.raw_data)]
data = [i[0] for i in struct.iter_unpack("d", tensor.raw_data)]
else:
print("Not supported data type {}".format(tensor.data_type))
raise RuntimeError
@ -214,6 +209,7 @@ def constant_to_list(node):
shape = list(tensor.dims)
return shape, data
def constant_to_numpy(node):
"""Generate a numpy array from the constant node
@ -223,6 +219,7 @@ def constant_to_numpy(node):
shape, data = constant_to_list(node)
return np.array(data).reshape(shape)
def all_constant_input(node):
"""Find the inputs of the given node. If the inputs of this node are all\\
constant nodes, return True. Otherwise, return False.
@ -234,24 +231,26 @@ def all_constant_input(node):
return False
isConstant = True
for parent in node.parents:
if parent.proto is None or parent.proto.op_type != 'Constant':
if parent.proto is None or parent.proto.op_type != "Constant":
isConstant = False
break
return isConstant
def get_padding(size, kernel_size, strides):
""" Calculate the padding array for same padding in the Tensorflow fashion.\\
""" Calculate the padding array for same padding in the Tensorflow fashion.\\
See https://www.tensorflow.org/api_guides/python/nn#Convolution for more.
"""
if size[0] % strides[0] == 0:
pad_h = max(kernel_size[0] - strides[0], 0)
else:
pad_h = max(kernel_size[0] - (size[0] % strides[0]), 0)
if size[1] % strides[1] == 0:
pad_w = max(kernel_size[1] - strides[1], 0)
else:
pad_w = max(kernel_size[1] - (size[1] % strides[1]), 0)
return [pad_h//2, pad_w//2, pad_h-pad_h//2, pad_w-pad_w//2]
if size[0] % strides[0] == 0:
pad_h = max(kernel_size[0] - strides[0], 0)
else:
pad_h = max(kernel_size[0] - (size[0] % strides[0]), 0)
if size[1] % strides[1] == 0:
pad_w = max(kernel_size[1] - strides[1], 0)
else:
pad_w = max(kernel_size[1] - (size[1] % strides[1]), 0)
return [pad_h // 2, pad_w // 2, pad_h - pad_h // 2, pad_w - pad_w // 2]
def get_shape_from_value_info(value):
"""Get shape from a value info.
@ -261,12 +260,13 @@ def get_shape_from_value_info(value):
"""
return [d.dim_value for d in value.type.tensor_type.shape.dim]
def find_size_shape_from_value(value):
'''
"""
Find the size of data within the value_info object.
:param value: value_info
:return: int size and list shape of the data in the value_info
'''
"""
if not value:
return None, None
if not value.type.tensor_type.shape.dim:
@ -292,6 +292,7 @@ def get_attribute_by_name(node, attr_name):
return attr
return None
def get_list_attribute_by_name(node, attr_name: str, attr_type: str):
"""Get list attribute with specific name in the given node proto.
@ -317,12 +318,13 @@ def get_list_attribute_by_name(node, attr_name: str, attr_type: str):
print("Warning: undefined type for list attribute extraction")
return None
def get_var_attribute_by_name(node, attr_name: str, attr_type: str):
"""Get variable attribute with specific name in the given node proto.
:param node: the node proto.\\
:param attr_name: str for the name of the target.\\
:param attr_type: str which should be "float", "int", "string" or "tensor".\\
:param node: the node proto.
:param attr_name: str for the name of the target.
:param attr_type: str which should be "float", "int", "string" or "tensor".
:return: if found, return the variable. Else, return None.
"""
attr_proto = get_attribute_by_name(node, attr_name)
@ -333,7 +335,7 @@ def get_var_attribute_by_name(node, attr_name: str, attr_type: str):
elif attr_type == "float":
return attr_proto.f
elif attr_type == "string":
if type(attr_proto.s) == type(b'abc'):
if isinstance(attr_proto.s, bytes):
return attr_proto.s.decode("utf-8")
else:
return attr_proto.s
@ -343,22 +345,25 @@ def get_var_attribute_by_name(node, attr_name: str, attr_type: str):
print("Warning: undefined type for variable attribute extraction")
return None
def flatten_with_depth(data, depth):
output = []
if type(data) not in [type(np.array([1])), type([1])]:
return [[data, 0]]
for item in data:
if type(item) not in [type(np.array([1])), type([1])]:
output.append([item, depth+1])
output.append([item, depth + 1])
else:
output += flatten_with_depth(item, depth+1)
output += flatten_with_depth(item, depth + 1)
return output
def flatten_to_list(data):
flatten_depth = flatten_with_depth(data, 0)
flat_data = [item[0] for item in flatten_depth]
return flat_data
def get_shape(data):
shape = []
if type(data) not in [type(np.array([1])), type([1])]:
@ -378,7 +383,7 @@ def slice_data(data, starts, ends, axes):
starts_updated = []
ends_updated = []
for i in range(len(starts)):
start_updated = min(starts[i], shape[i]-1) % shape[i]
start_updated = min(starts[i], shape[i] - 1) % shape[i]
starts_updated.append(start_updated)
for j in range(len(starts)):
if ends[j] >= shape[j]:
@ -393,19 +398,21 @@ def slice_data(data, starts, ends, axes):
index_slices.append(list(range(shape[i])))
else:
axe_ind = axes.index(i)
index_slices.append(list(range(starts_updated[axe_ind], ends_updated[axe_ind])))
index_slices.append(
list(range(starts_updated[axe_ind], ends_updated[axe_ind]))
)
indices = [1]
for i in range(len(shape)-1, -1, -1):
step = np.prod(shape[i+1:])
for i in range(len(shape) - 1, -1, -1):
step = np.prod(shape[i + 1:])
temp_pos = indices
new_indices = []
for n in index_slices[i]:
for pos in temp_pos:
new_indices.append(int(n*step+pos))
new_indices.append(int(n * step + pos))
indices = new_indices
sliced_data = [flat_data[k-1] for k in indices]
sliced_data = [flat_data[k - 1] for k in indices]
# reshape to correct shape.
new_shape = []
@ -414,48 +421,51 @@ def slice_data(data, starts, ends, axes):
new_shape.append(shape[i])
else:
axe_ind = axes.index(i)
new_shape.append(ends_updated[axe_ind]-starts_updated[axe_ind])
new_shape.append(ends_updated[axe_ind] - starts_updated[axe_ind])
if any([dim < 1 for dim in new_shape]):
raise RuntimeError('Invalid starts ends.')
raise RuntimeError("Invalid starts ends.")
sliced_data = np.reshape(sliced_data, new_shape)
return sliced_data
def concatenate(data_sets, axis):
# check shapes
shapes = []
shapes_ = []
for data_set in data_sets:
shape = get_shape(data_set)
shapes.append(list(shape))
shape.pop(axis)
shapes_.append(shape)
shape = get_shape(data_set)
shapes.append(list(shape))
shape.pop(axis)
shapes_.append(shape)
if not all([s == shapes_[0] for s in shapes_]):
raise RuntimeError('data sets shapes do not match')
raise RuntimeError("data sets shapes do not match")
new_dim = sum([s[axis] for s in shapes])
new_shape = list(shapes[0])
new_shape[axis] = new_dim
flat_data_sets = []
for data_set in data_sets:
flat_data_sets.append(flatten_to_list(data_set))
flat_data_sets.append(flatten_to_list(data_set))
sub_block_size = 1
for i in range(axis+1, len(shapes[0])):
sub_block_size *= shapes[0][i]
for i in range(axis + 1, len(shapes[0])):
sub_block_size *= shapes[0][i]
split_num = 1
for i in range(axis):
split_num *= shapes[0][i]
split_num *= shapes[0][i]
total_flat_data = []
for i in range(split_num):
for j in range(len(shapes)):
block_size = sub_block_size*shapes[j][axis]
total_flat_data.extend(flat_data_sets[j][i*block_size:(i+1)*block_size])
for j in range(len(shapes)):
block_size = sub_block_size * shapes[j][axis]
total_flat_data.extend(
flat_data_sets[j][i * block_size:(i + 1) * block_size]
)
new_data = np.reshape(total_flat_data, new_shape)
return new_data
@ -464,158 +474,169 @@ def concatenate(data_sets, axis):
def broadcast_data_sets(data_set_1, data_set_2):
shape1 = get_shape(data_set_1)
shape2 = get_shape(data_set_2)
# compare shapes and get broadcasted shape
list_a, list_b = (shape1, shape2) if len(shape1) > len(shape2) else (shape2, shape1)
list_a, list_b = (
(shape1, shape2) if len(shape1) > len(shape2) else (shape2, shape1)
)
while len(list_a) > len(list_b):
list_b.insert(0, 0)
broadcasted_shape = []
for i in range(len(list_a)):
if list_b[i] == 0:
broadcasted_shape.append(list_a[i])
elif list_b[i] == 1:
broadcasted_shape.append(list_a[i])
elif list_a[i] == 1:
broadcasted_shape.append(list_b[i])
elif list_a[i] == list_b[i]:
broadcasted_shape.append(list_a[i])
else:
raise RuntimeError('Can not broadcast two data sets')
if list_b[i] == 0:
broadcasted_shape.append(list_a[i])
elif list_b[i] == 1:
broadcasted_shape.append(list_a[i])
elif list_a[i] == 1:
broadcasted_shape.append(list_b[i])
elif list_a[i] == list_b[i]:
broadcasted_shape.append(list_a[i])
else:
raise RuntimeError("Can not broadcast two data sets")
# prepare data for broadcasting.
shape1 = list(map(lambda x:x if x != 0 else 1, shape1))
shape2 = list(map(lambda x:x if x != 0 else 1, shape2))
shape1 = list(map(lambda x: x if x != 0 else 1, shape1))
shape2 = list(map(lambda x: x if x != 0 else 1, shape2))
data_1 = np.reshape(data_set_1, shape1)
data_2 = np.reshape(data_set_2, shape2)
for i in range(len(shape1)):
if shape1[i] != broadcasted_shape[i]:
new_data_total = [list(data_1) for _ in range(broadcasted_shape[i])]
data_1 = concatenate(new_data_total, axis=i)
if shape1[i] != broadcasted_shape[i]:
new_data_total = [
list(data_1) for _ in range(broadcasted_shape[i])
]
data_1 = concatenate(new_data_total, axis=i)
for i in range(len(shape2)):
if shape2[i] != broadcasted_shape[i]:
new_data_total = [list(data_2) for _ in range(broadcasted_shape[i])]
data_2 = concatenate(new_data_total, axis=i)
if shape2[i] != broadcasted_shape[i]:
new_data_total = [
list(data_2) for _ in range(broadcasted_shape[i])
]
data_2 = concatenate(new_data_total, axis=i)
return data_1, data_2
def add(data_set_1, data_set_2):
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2)
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(
data_set_1, data_set_2
)
flat_data_1 = flatten_to_list(broadcasted_data_1)
flat_data_2 = flatten_to_list(broadcasted_data_2)
shape = get_shape(broadcasted_data_1)
res = []
for i in range(len(flat_data_1)):
res.append(flat_data_1[i]+flat_data_2[i])
res = np.reshape(res, shape)
flat_data_1 = flatten_to_list(broadcasted_data_1)
flat_data_2 = flatten_to_list(broadcasted_data_2)
shape = get_shape(broadcasted_data_1)
res = []
for i in range(len(flat_data_1)):
res.append(flat_data_1[i] + flat_data_2[i])
return res
res = np.reshape(res, shape)
return res
def reduceprod(data_set, axis, keepdims=1):
flat_data = flatten_to_list(data_set)
old_shape = get_shape(data_set)
flat_data = flatten_to_list(data_set)
old_shape = get_shape(data_set)
temp_shape = old_shape
temp_flat_data = flat_data
for ax in axis:
split_num = 1
step = 1
for i in range(ax):
split_num *= temp_shape[i]
for i in range(ax+1, len(temp_shape)):
step *= temp_shape[i]
block_size = len(temp_flat_data)//split_num
new_flat_data = []
for j in range(split_num):
block_data = temp_flat_data[j*block_size:(j+1)*block_size]
reduced_block_data = []
for k in range(step):
val = block_data[k]
for l in range(1, block_size//step):
val *= block_data[k+l*step]
reduced_block_data.append(val)
new_flat_data.extend(reduced_block_data)
temp_flat_data = new_flat_data
temp_shape[ax] = 1
new_flat_data = temp_flat_data
new_shape = temp_shape
if not keepdims:
axis = sorted(list(axis))
for pos in axis[::-1]:
new_shape.pop(pos)
return np.reshape(new_flat_data, new_shape)
temp_shape = old_shape
temp_flat_data = flat_data
for ax in axis:
split_num = 1
step = 1
for i in range(ax):
split_num *= temp_shape[i]
for i in range(ax + 1, len(temp_shape)):
step *= temp_shape[i]
block_size = len(temp_flat_data) // split_num
new_flat_data = []
for j in range(split_num):
block_data = temp_flat_data[j * block_size:(j + 1) * block_size]
reduced_block_data = []
for k in range(step):
val = block_data[k]
for li in range(1, block_size // step):
val *= block_data[k + li * step]
reduced_block_data.append(val)
new_flat_data.extend(reduced_block_data)
temp_flat_data = new_flat_data
temp_shape[ax] = 1
new_flat_data = temp_flat_data
new_shape = temp_shape
if not keepdims:
axis = sorted(list(axis))
for pos in axis[::-1]:
new_shape.pop(pos)
return np.reshape(new_flat_data, new_shape)
def transpose(data_set, permutation):
# find series of local swaps
data_set = list(data_set)
perm = list(permutation)
shape = get_shape(data_set)
flat_data = flatten_to_list(data_set)
assert set(perm) == set(range(len(shape))), 'invalid permutation'
# find series of local swaps
data_set = list(data_set)
perm = list(permutation)
shape = get_shape(data_set)
flat_data = flatten_to_list(data_set)
assert set(perm) == set(range(len(shape))), "invalid permutation"
new_shape = [shape[i] for i in perm]
swaps = []
bubbled = True
while bubbled:
bubbled = False
for i in range(len(new_shape)-1):
if perm[i] > perm[i+1]:
swaps.append([i, i+1])
p_1, p_2 = perm[i], perm[i+1]
perm[i], perm[i+1] = p_2, p_1
bubbled = True
# apply local swaps
current_shape = list(shape)
temp_flat_data = flat_data
new_shape = [shape[i] for i in perm]
swaps = []
bubbled = True
while bubbled:
bubbled = False
for i in range(len(new_shape) - 1):
if perm[i] > perm[i + 1]:
swaps.append([i, i + 1])
p_1, p_2 = perm[i], perm[i + 1]
perm[i], perm[i + 1] = p_2, p_1
bubbled = True
for swap in swaps[::-1]:
ind_1, ind_2 = swap[0], swap[1]
dim_1 = current_shape[ind_1]
dim_2 = current_shape[ind_2]
split_num = 1
block_size = 1
# apply local swaps
current_shape = list(shape)
temp_flat_data = flat_data
for i in range(ind_1):
split_num *= current_shape[i]
for i in range(ind_2+1, len(current_shape)):
block_size *= current_shape[i]
for swap in swaps[::-1]:
ind_1, ind_2 = swap[0], swap[1]
dim_1 = current_shape[ind_1]
dim_2 = current_shape[ind_2]
split_num = 1
block_size = 1
data_blocks = np.reshape(temp_flat_data, [-1, block_size])
flat_data_1 = []
for k in range(split_num):
block = []
for m in range(dim_2):
for n in range(dim_1):
block_pos = k*dim_1*dim_2 + n*dim_2+m
block.extend(data_blocks[block_pos])
flat_data_1.extend(block)
for i in range(ind_1):
split_num *= current_shape[i]
for i in range(ind_2 + 1, len(current_shape)):
block_size *= current_shape[i]
temp_flat_data = flat_data_1
current_shape[ind_1] = dim_2
current_shape[ind_2] = dim_1
data_blocks = np.reshape(temp_flat_data, [-1, block_size])
flat_data_1 = []
for k in range(split_num):
block = []
for m in range(dim_2):
for n in range(dim_1):
block_pos = k * dim_1 * dim_2 + n * dim_2 + m
block.extend(data_blocks[block_pos])
flat_data_1.extend(block)
temp_flat_data = flat_data_1
current_shape[ind_1] = dim_2
current_shape[ind_2] = dim_1
return np.reshape(temp_flat_data, current_shape)
return np.reshape(temp_flat_data, current_shape)
def subtract(data_set_1, data_set_2):
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2)
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(
data_set_1, data_set_2
)
shape = get_shape(broadcasted_data_1)
flat_data_1 = flatten_to_list(broadcasted_data_1)
flat_data_2 = flatten_to_list(broadcasted_data_2)
substracted_data = [flat_data_1[i] - flat_data_2[i] for i in range(len(flat_data_1))]
substracted_data = [
flat_data_1[i] - flat_data_2[i] for i in range(len(flat_data_1))
]
new_data = np.reshape(substracted_data, shape)
return new_data

View File

@ -1,7 +1,7 @@
"""This module contains helper functions that do graph modifications.
"""
This module contains helper functions that do graph modifications.
"""
import onnx
from . import helper
@ -10,9 +10,10 @@ def replace_node_input(node, old_input, new_input):
if input_name == old_input:
node.input[i] = new_input
def delete_nodes(g, node_list):
node_to_delete = []
#Find target nodes
# Find target nodes
for node in g.node:
if node.name not in node_list:
continue
@ -23,16 +24,28 @@ def delete_nodes(g, node_list):
for node in node_to_delete:
# Check the node whether if it is valid to delete
if len(node.input) == 0:
print("Deleting an Constant node. Please make sure you also delete all its following nodes")
print(
"Deleting an Constant node. "
"Please make sure you also delete all its following nodes"
)
elif len(node.input) > 1:
print("Warning: Node {} has more than one input. This script cannot delete merge nodes.".format(node.name))
print(
f"Warning: Node {node.name} has more than one input. "
"This script cannot delete merge nodes."
)
# Connect the nodes around the target node.
# Set the following node input as the previous node output.
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
if len(node.input) == 0:
for following_node in following_nodes:
following_node.input.remove(node.output[0])
elif len(following_nodes) > 0 and len(node.input) == 1 and helper.find_input_by_name(g, node.input[0]) is not None:
elif (
len(following_nodes) > 0
and len(node.input) == 1
and helper.find_input_by_name(g, node.input[0]) is not None
):
# The node input is an input
new_input = helper.find_value_by_name(g, node.output[0])
g.input.append(new_input)
@ -40,9 +53,11 @@ def delete_nodes(g, node_list):
g.value_info.remove(new_input)
elif len(following_nodes) > 0:
for following_node in following_nodes:
replace_node_input(following_node, node.output[0], node.input[0])
replace_node_input(
following_node, node.output[0], node.input[0]
)
else:
# If the node is the output, replace the output with the previous input.
# If the node is the output, replace it with previous input.
value = helper.find_value_by_name(g, node.input[0])
output_values = []
while len(g.output):
@ -56,6 +71,7 @@ def delete_nodes(g, node_list):
# Remove the node and value info.
g.node.remove(node)
def delete_input(g, target_list):
for name in target_list:
input_value = helper.find_input_by_name(g, name)
@ -64,6 +80,7 @@ def delete_input(g, target_list):
continue
g.input.remove(input_value)
def delete_output(g, target_list):
for name in target_list:
output_value = helper.find_output_by_name(g, name)
@ -72,6 +89,7 @@ def delete_output(g, target_list):
continue
g.output.remove(output_value)
def delete_value_with_name_if_exists(g, name):
value = helper.find_value_by_name(g, name)
if value is not None:

File diff suppressed because it is too large Load Diff

View File

@ -1,317 +1,368 @@
from . import helper
from . import other
from . import modhelper
from . import fusing
import numpy as np
import onnx
import onnx.utils
def eliminate_transposes(m):
g = m.graph
keep_eliminating = True
while keep_eliminating:
while swap_transpose_with_single_next_node(g):
pass
splitted = split_transpose_for_multiple_next_nodes(g)
annihilated = annihilate_transposes(g)
multiple_trans_swapped = swap_multiple_transposes_with_node(g)
keep_eliminating = splitted or annihilated or multiple_trans_swapped
if keep_eliminating:
m = other.polish_model(m)
g = m.graph
return m
def eliminate_transposes(m):
g = m.graph
keep_eliminating = True
while keep_eliminating:
while swap_transpose_with_single_next_node(g):
pass
splitted = split_transpose_for_multiple_next_nodes(g)
annihilated = annihilate_transposes(g)
multiple_trans_swapped = swap_multiple_transposes_with_node(g)
keep_eliminating = splitted or annihilated or multiple_trans_swapped
if keep_eliminating:
m = other.polish_model(m)
g = m.graph
return m
def swap_transpose_with_single_next_node(g):
swapped = False
passable_nodes = set(['Relu', 'Neg', 'LeakyRelu', 'Sqrt', 'Reciprocal', 'Add', 'Mul', 'Tanh'])
for node in g.node:
trans_node = node
# Check for transpose node
if trans_node.op_type != 'Transpose':
continue
next_nodes = helper.find_nodes_by_input_name(g, trans_node.output[0])
if len(next_nodes) != 1:
continue
next_node = next_nodes[0]
# Check if the next node is the type can be swapped
if next_node.op_type not in passable_nodes:
continue
swapped = False
passable_nodes = set(
[
"Relu",
"Neg",
"LeakyRelu",
"Sqrt",
"Reciprocal",
"Add",
"Mul",
"Tanh",
]
)
for node in g.node:
trans_node = node
# Check for transpose node
if trans_node.op_type != "Transpose":
continue
next_nodes = helper.find_nodes_by_input_name(g, trans_node.output[0])
if len(next_nodes) != 1:
continue
next_node = next_nodes[0]
# Check if the next node is the type can be swapped
if next_node.op_type not in passable_nodes:
continue
input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in next_node.input]
input_nodes = [
helper.find_node_by_output_name(g, input_name)
for input_name in next_node.input
]
# Check if the node has nonconstant input other than the Transpose node itself
nonconstant_input = False
for input_node in input_nodes:
if input_node == None:
nonconstant_input = True
break
if input_node.name == trans_node.name:
continue
elif input_node.op_type == 'Constant':
continue
else:
nonconstant_input = True
break
if nonconstant_input:
continue
# Check if the node has nonconstant input
# other than the Transpose node itself
nonconstant_input = False
for input_node in input_nodes:
if input_node is None:
nonconstant_input = True
break
if input_node.name == trans_node.name:
continue
elif input_node.op_type == "Constant":
continue
else:
nonconstant_input = True
break
if nonconstant_input:
continue
for input_node in input_nodes:
if input_node.name == trans_node.name:
# if the input is just the transpose node
next_value_info = helper.find_value_by_name(g, next_node.output[0])
mid_value_info = helper.find_value_by_name(g, trans_node.output[0])
for input_node in input_nodes:
if input_node.name == trans_node.name:
# if the input is just the transpose node
next_value_info = helper.find_value_by_name(
g, next_node.output[0]
)
mid_value_info = helper.find_value_by_name(
g, trans_node.output[0]
)
output_nodes = helper.find_nodes_by_input_name(g, next_node.output[0])
for out_node in output_nodes:
modhelper.replace_node_input(out_node, next_node.output[0], trans_node.name)
output_nodes = helper.find_nodes_by_input_name(
g, next_node.output[0]
)
for out_node in output_nodes:
modhelper.replace_node_input(
out_node, next_node.output[0], trans_node.name
)
next_node.input[0] = trans_node.input[0]
next_node.output[0] = next_node.name
trans_node.input[0] = next_node.name
trans_node.output[0] = trans_node.name
next_node.input[0] = trans_node.input[0]
next_node.output[0] = next_node.name
trans_node.input[0] = next_node.name
trans_node.output[0] = trans_node.name
if next_value_info:
next_value_info.name = trans_node.name
if mid_value_info:
g.value_info.remove(mid_value_info)
else:
# if the input is a constant node
old_tensor = input_node.attribute[0].t
old_shape, data = helper.constant_to_list(input_node)
# If the constant node is a scaler, no action is needed
if type(old_shape) == int:
old_shape = [old_shape]
permutation = list(trans_node.attribute[0].ints)
while len(old_shape) < len(permutation):
old_shape.insert(0, 1)
np_data = np.reshape(data, old_shape)
reverse_perm = []
for i in range(len(permutation)):
reverse_perm.append(permutation.index(i))
np_data = np.transpose(np_data, reverse_perm)
new_shape = np_data.shape
new_tensor = onnx.helper.make_tensor(
name=old_tensor.name,
data_type=old_tensor.data_type,
dims=new_shape,
vals=np_data.flatten().tolist()
)
new_node = onnx.helper.make_node(
'Constant',
[],
[input_node.output[0]],
name=input_node.name,
value=new_tensor
)
g.node.extend([new_node])
if next_value_info:
next_value_info.name = trans_node.name
if mid_value_info:
g.value_info.remove(mid_value_info)
else:
# if the input is a constant node
old_tensor = input_node.attribute[0].t
old_shape, data = helper.constant_to_list(input_node)
# If the constant node is a scaler, no action is needed
if type(old_shape) == int:
old_shape = [old_shape]
permutation = list(trans_node.attribute[0].ints)
while len(old_shape) < len(permutation):
old_shape.insert(0, 1)
np_data = np.reshape(data, old_shape)
reverse_perm = []
for i in range(len(permutation)):
reverse_perm.append(permutation.index(i))
np_data = np.transpose(np_data, reverse_perm)
new_shape = np_data.shape
new_tensor = onnx.helper.make_tensor(
name=old_tensor.name,
data_type=old_tensor.data_type,
dims=new_shape,
vals=np_data.flatten().tolist(),
)
new_node = onnx.helper.make_node(
"Constant",
[],
[input_node.output[0]],
name=input_node.name,
value=new_tensor,
)
g.node.extend([new_node])
g.value_info.remove(helper.find_value_by_name(g, input_node.output[0]))
g.node.remove(input_node)
g.value_info.remove(
helper.find_value_by_name(g, input_node.output[0])
)
g.node.remove(input_node)
swapped = True
swapped = True
other.topological_sort(g)
return swapped
other.topological_sort(g)
return swapped
def swap_multiple_transposes_with_node(g):
# here only consider same input transposes
swapped = False
passable_nodes = set(['Add', 'Mul'])
node_to_del = []
for node in g.node:
if node.op_type not in passable_nodes:
continue
input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in node.input]
if any([input_node == None for input_node in input_nodes]):
continue
if any([input_node.op_type != 'Transpose' for input_node in input_nodes]):
continue
# here only consider same input transposes
swapped = False
passable_nodes = set(["Add", "Mul"])
node_to_del = []
for node in g.node:
if node.op_type not in passable_nodes:
continue
input_nodes = [
helper.find_node_by_output_name(g, input_name)
for input_name in node.input
]
if any([input_node is None for input_node in input_nodes]):
continue
if any(
[input_node.op_type != "Transpose" for input_node in input_nodes]
):
continue
permutation = list(input_nodes[0].attribute[0].ints)
if any([list(input_node.attribute[0].ints) != permutation for input_node in input_nodes]):
continue
for input_name in node.input:
input_node = helper.find_node_by_output_name(g, input_name)
modhelper.replace_node_input(node, input_name, input_node.input[0])
permutation = list(input_nodes[0].attribute[0].ints)
if any(
[
list(input_node.attribute[0].ints) != permutation
for input_node in input_nodes
]
):
continue
node_to_del.extend(input_nodes)
for input_node in input_nodes:
input_val_info = helper.find_value_by_name(g, input_node.output[0])
if input_val_info is not None:
g.value_info.remove(input_val_info)
output_val_info = helper.find_value_by_name(g, node.output[0])
if output_val_info is not None:
g.value_info.remove(output_val_info)
for input_name in node.input:
input_node = helper.find_node_by_output_name(g, input_name)
modhelper.replace_node_input(node, input_name, input_node.input[0])
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
for i in range(len(output_nodes)):
new_trans_node_name = node.name+'_trans_'+str(i)
new_trans_node = onnx.helper.make_node(
'Transpose',
[node.output[0]],
[new_trans_node_name],
name=new_trans_node_name,
perm=permutation
)
modhelper.replace_node_input(output_nodes[i], node.output[0], new_trans_node_name)
g.node.extend([new_trans_node])
swapped = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)
return swapped
node_to_del.extend(input_nodes)
for input_node in input_nodes:
input_val_info = helper.find_value_by_name(g, input_node.output[0])
if input_val_info is not None:
g.value_info.remove(input_val_info)
output_val_info = helper.find_value_by_name(g, node.output[0])
if output_val_info is not None:
g.value_info.remove(output_val_info)
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
for i in range(len(output_nodes)):
new_trans_node_name = node.name + "_trans_" + str(i)
new_trans_node = onnx.helper.make_node(
"Transpose",
[node.output[0]],
[new_trans_node_name],
name=new_trans_node_name,
perm=permutation,
)
modhelper.replace_node_input(
output_nodes[i], node.output[0], new_trans_node_name
)
g.node.extend([new_trans_node])
swapped = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)
return swapped
def annihilate_transposes(g):
node_to_del = []
annihilated = False
for node in g.node:
if node.op_type != 'Transpose':
continue
pre_node = helper.find_node_by_output_name(g, node.input[0])
if not pre_node or pre_node.op_type != 'Transpose':
continue
nodes_from_top_transpose = helper.find_nodes_by_input_name(g, pre_node.output[0])
if len(nodes_from_top_transpose) > 1:
continue
perm_1 = list(pre_node.attribute[0].ints)
perm_2 = list(node.attribute[0].ints)
if perm_1 != perm_2:
continue
node_to_del = []
annihilated = False
for node in g.node:
if node.op_type != "Transpose":
continue
pre_node = helper.find_node_by_output_name(g, node.input[0])
if not pre_node or pre_node.op_type != "Transpose":
continue
nodes_from_top_transpose = helper.find_nodes_by_input_name(
g, pre_node.output[0]
)
if len(nodes_from_top_transpose) > 1:
continue
out_nodes = helper.find_nodes_by_input_name(g, node.output[0])
for out_node in out_nodes:
modhelper.replace_node_input(out_node, node.output[0], pre_node.input[0])
node_to_del.extend([node, pre_node])
mid_value_info = helper.find_value_by_name(g, pre_node.output[0])
out_value_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(mid_value_info)
g.value_info.remove(out_value_info)
perm_1 = list(pre_node.attribute[0].ints)
perm_2 = list(node.attribute[0].ints)
if perm_1 != perm_2:
continue
annihilated = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return annihilated
out_nodes = helper.find_nodes_by_input_name(g, node.output[0])
for out_node in out_nodes:
modhelper.replace_node_input(
out_node, node.output[0], pre_node.input[0]
)
node_to_del.extend([node, pre_node])
mid_value_info = helper.find_value_by_name(g, pre_node.output[0])
out_value_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(mid_value_info)
g.value_info.remove(out_value_info)
annihilated = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return annihilated
def split_transpose_for_multiple_next_nodes(g):
splitted = False
node_to_del = []
for node in g.node:
if node.op_type != 'Transpose':
continue
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
if len(output_nodes) < 2:
continue
for i in range(len(output_nodes)):
output_node = output_nodes[i]
new_trans_node_name = node.name + '_' + str(i)
new_trans_node = onnx.helper.make_node(
'Transpose',
[node.input[0]],
[new_trans_node_name],
name=new_trans_node_name,
perm=list(node.attribute[0].ints)
)
modhelper.replace_node_input(output_node, node.output[0], new_trans_node.output[0])
g.node.extend([new_trans_node])
node_to_del.append(node)
val_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(val_info)
splitted = False
node_to_del = []
for node in g.node:
if node.op_type != "Transpose":
continue
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
if len(output_nodes) < 2:
continue
for i in range(len(output_nodes)):
output_node = output_nodes[i]
new_trans_node_name = node.name + "_" + str(i)
new_trans_node = onnx.helper.make_node(
"Transpose",
[node.input[0]],
[new_trans_node_name],
name=new_trans_node_name,
perm=list(node.attribute[0].ints),
)
modhelper.replace_node_input(
output_node, node.output[0], new_trans_node.output[0]
)
g.node.extend([new_trans_node])
node_to_del.append(node)
val_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(val_info)
splitted = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)
return splitted
splitted = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)
return splitted
def remove_trivial_transpose(g):
node_to_del = []
for node in g.node:
if node.op_type != 'Transpose':
continue
permutation = list(node.attribute[0].ints)
if permutation != list(range(len(permutation))):
continue
next_nodes = helper.find_nodes_by_input_name(g, node.output[0])
if not next_nodes:
input_val_info = helper.find_value_by_name(g, node.input[0])
out_val_info = helper.find_output_by_name(g, node.output[0])
if not input_val_info:
input_val_info = helper.find_input_by_name(g, node.input[0])
g.output.remove(out_val_info)
g.output.extend([input_val_info])
else:
out_val_info = helper.find_value_by_name(g, node.output[0])
for next_node in next_nodes:
modhelper.replace_node_input(next_node, node.output[0], node.input[0])
g.value_info.remove(out_val_info)
node_to_del.append(node)
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)
node_to_del = []
for node in g.node:
if node.op_type != "Transpose":
continue
permutation = list(node.attribute[0].ints)
if permutation != list(range(len(permutation))):
continue
next_nodes = helper.find_nodes_by_input_name(g, node.output[0])
if not next_nodes:
input_val_info = helper.find_value_by_name(g, node.input[0])
out_val_info = helper.find_output_by_name(g, node.output[0])
if not input_val_info:
input_val_info = helper.find_input_by_name(g, node.input[0])
g.output.remove(out_val_info)
g.output.extend([input_val_info])
else:
out_val_info = helper.find_value_by_name(g, node.output[0])
for next_node in next_nodes:
modhelper.replace_node_input(
next_node, node.output[0], node.input[0]
)
g.value_info.remove(out_val_info)
node_to_del.append(node)
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)
def fuse_Transpose_into_Gemm_weight(g):
node_to_del = []
for node in g.node:
# Check pattern
if node.op_type != 'Gemm':
continue
prev_node = helper.find_node_by_output_name(g, node.input[0])
if prev_node is None or prev_node.op_type != 'Flatten':
continue
transpose_node = helper.find_node_by_output_name(g, prev_node.input[0])
if transpose_node.op_type != 'Transpose':
continue
# Check attribute
perm = helper.get_list_attribute_by_name(transpose_node, 'perm', 'int')
if perm != [0, 2, 3, 1]:
continue
transB = helper.get_var_attribute_by_name(node, 'transB', 'int')
if transB is not None and transB == 1:
continue
# Get the original weight
origin_weight = helper.find_node_by_output_name(g, node.input[1])
origin_np = helper.constant_to_numpy(origin_weight)
# Calculate a new weight
shape = helper.get_shape_from_value_info(helper.find_value_by_name(g, prev_node.input[0]))
shape.append(-1)
new_np = np.reshape(origin_np, shape)
new_np = np.transpose(new_np, [0, 3, 1, 2, 4])
new_np = np.reshape(new_np, [-1, new_np.shape[-1]])
new_weight = helper.numpy_to_constant(origin_weight.output[0], new_np)
# Replace and eliminate
prev_node.input[0] = transpose_node.input[0]
node_to_del.append(transpose_node)
node_to_del.append(origin_weight)
g.value_info.remove(helper.find_value_by_name(g, transpose_node.output[0]))
g.node.extend([new_weight])
node_to_del = []
for node in g.node:
# Check pattern
if node.op_type != "Gemm":
continue
prev_node = helper.find_node_by_output_name(g, node.input[0])
if prev_node is None or prev_node.op_type != "Flatten":
continue
transpose_node = helper.find_node_by_output_name(g, prev_node.input[0])
if transpose_node.op_type != "Transpose":
continue
# Check attribute
perm = helper.get_list_attribute_by_name(transpose_node, "perm", "int")
if perm != [0, 2, 3, 1]:
continue
transB = helper.get_var_attribute_by_name(node, "transB", "int")
if transB is not None and transB == 1:
continue
# Get the original weight
origin_weight = helper.find_node_by_output_name(g, node.input[1])
origin_np = helper.constant_to_numpy(origin_weight)
# Calculate a new weight
shape = helper.get_shape_from_value_info(
helper.find_value_by_name(g, prev_node.input[0])
)
shape.append(-1)
new_np = np.reshape(origin_np, shape)
new_np = np.transpose(new_np, [0, 3, 1, 2, 4])
new_np = np.reshape(new_np, [-1, new_np.shape[-1]])
new_weight = helper.numpy_to_constant(origin_weight.output[0], new_np)
# Replace and eliminate
prev_node.input[0] = transpose_node.input[0]
node_to_del.append(transpose_node)
node_to_del.append(origin_weight)
g.value_info.remove(
helper.find_value_by_name(g, transpose_node.output[0])
)
g.node.extend([new_weight])
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)
other.topological_sort(g)

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,10 @@
"""Special operations on model.
"""
import logging
import onnx.helper
import numpy as np
from . import helper
from . import other
from . import modhelper
def change_first_conv_from_bgr_to_rgb(m):
"""For input channel format BGR model, use this function to change the first
@ -16,12 +15,14 @@ def change_first_conv_from_bgr_to_rgb(m):
# Check for first node.
g = m.graph
input_name = g.input[0].name
first_nodes = helper.find_following_nodes_by_input_value_name(g, input_name)
first_nodes = helper.find_following_nodes_by_input_value_name(
g, input_name
)
if len(first_nodes) > 1:
return False
first_node = first_nodes[0]
# Now we have the first node. Check this first node.
if first_node.op_type != 'Conv':
if first_node.op_type != "Conv":
return False
weight_value = helper.find_value_by_name(g, first_node.input[1])
weight_shape = helper.get_shape_from_value_info(weight_value)
@ -41,10 +42,12 @@ def change_first_conv_from_bgr_to_rgb(m):
other.topological_sort(g)
return True
def change_input_from_bgr_to_rgb(m):
"""For input channel format BGR model, use this function to modify the model
to accepct RGB image.If the first node is a non-group Conv. Modify weight to
adapt the input into RGB. Otherwise create a new node.
"""
For input channel format BGR model, use this function to modify the model
to accepct RGB image.If the first node is a non-group Conv.
Modify weight to adapt the input into RGB. Otherwise create a new node.
:param m: the model proto
"""
@ -61,34 +64,33 @@ def change_input_from_bgr_to_rgb(m):
return
# Otherwise, create a special conv node and replace the input
# Construct weight
weight_np = np.zeros((3, 3, 3, 3)).astype('float32')
weight_np = np.zeros((3, 3, 3, 3)).astype("float32")
weight_np[0, 2, 1, 1] = 1.0
weight_np[1, 1, 1, 1] = 1.0
weight_np[2, 0, 1, 1] = 1.0
new_weight = helper.numpy_to_constant("bgr_shuffle_weight", weight_np)
# Construct Conv
new_conv = onnx.helper.make_node(
'Conv',
['rgb_input', "bgr_shuffle_weight"],
"Conv",
["rgb_input", "bgr_shuffle_weight"],
[g.input[0].name],
name='bgr_shuffle',
name="bgr_shuffle",
dilations=[1, 1],
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1]
strides=[1, 1],
)
# Connect the graph
old_input_value = g.input.pop()
new_input_value = onnx.helper.make_tensor_value_info(
'rgb_input',
old_input_value.type.tensor_type.elem_type,
input_shape
"rgb_input", old_input_value.type.tensor_type.elem_type, input_shape
)
g.input.extend([new_input_value])
g.node.extend([new_weight, new_conv])
# topological sort
other.topological_sort(g)
def add_0_5_to_normalized_input(m):
"""For normalized input between -0.5 ~ 0.5, add 0.5 to the input to keep it
between 0 ~ 1.
@ -105,41 +107,37 @@ def add_0_5_to_normalized_input(m):
return
# Construct weight
ch = input_shape[1]
weight_np = np.zeros((ch, ch, 3, 3)).astype('float32')
weight_np = np.zeros((ch, ch, 3, 3)).astype("float32")
for i in range(ch):
weight_np[i, i, 1, 1] = 1.0
new_weight = helper.numpy_to_constant("input_norm_weight", weight_np)
# Construct bias
bias_np = np.array([0.5] * ch).astype('float32')
bias_np = np.array([0.5] * ch).astype("float32")
new_bias = helper.numpy_to_constant("input_norm_bias", bias_np)
# Construct Conv
new_conv = onnx.helper.make_node(
'Conv',
['origin_input', "input_norm_weight", "input_norm_bias"],
"Conv",
["origin_input", "input_norm_weight", "input_norm_bias"],
[g.input[0].name],
name='input_norm',
name="input_norm",
dilations=[1, 1],
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1]
strides=[1, 1],
)
# Construct value_infos
old_input_value = g.input.pop()
weight_value = onnx.helper.make_tensor_value_info(
'input_norm_weight',
"input_norm_weight",
old_input_value.type.tensor_type.elem_type,
[3, 3, 3, 3]
[3, 3, 3, 3],
)
bias_value = onnx.helper.make_tensor_value_info(
'input_norm_bias',
old_input_value.type.tensor_type.elem_type,
[3]
"input_norm_bias", old_input_value.type.tensor_type.elem_type, [3]
)
# Connect the graph
new_input_value = onnx.helper.make_tensor_value_info(
'origin_input',
old_input_value.type.tensor_type.elem_type,
input_shape
"origin_input", old_input_value.type.tensor_type.elem_type, input_shape
)
g.input.extend([new_input_value])
g.node.extend([new_weight, new_bias, new_conv])
@ -147,9 +145,9 @@ def add_0_5_to_normalized_input(m):
# topological sort
other.topological_sort(g)
def add_rgb2yynn_node(m):
"""Add a conv layer which can convert rgb to yynn input.
"""
"""Add a conv layer which can convert rgb to yynn input."""
g = m.graph
if len(g.input) > 1:
print("This model has multiple inputs. Cannot change to rgb input.")
@ -159,37 +157,32 @@ def add_rgb2yynn_node(m):
print("The input shape is not BCHW. Cannot normalize input.")
return
# Construct weight
ch = input_shape[1]
weight_np = np.zeros((3, 3, 4, 4)).astype('float32')
weight_np[1, 1, :3, :2] = np.array([[[[0.299],
[0.587],
[0.114]]]])
weight_np[1, 1, 3, 2:] = 1.
weight_np = np.zeros((3, 3, 4, 4)).astype("float32")
weight_np[1, 1, :3, :2] = np.array([[[[0.299], [0.587], [0.114]]]])
weight_np[1, 1, 3, 2:] = 1.0
weight_np = np.transpose(weight_np, (3, 2, 0, 1))
new_weight = helper.numpy_to_constant("input_rgb2yynn_weight", weight_np)
# Construct conv node
new_conv = onnx.helper.make_node(
'Conv',
['new_input', "input_rgb2yynn_weight"],
"Conv",
["new_input", "input_rgb2yynn_weight"],
[g.input[0].name],
name='input_rgba2yynn',
name="input_rgba2yynn",
dilations=[1, 1],
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1]
strides=[1, 1],
)
# Construct value_infos
old_input_value = g.input.pop()
weight_value = onnx.helper.make_tensor_value_info(
'input_rgb2yynn_weight',
"input_rgb2yynn_weight",
old_input_value.type.tensor_type.elem_type,
[4, 4, 3, 3]
[4, 4, 3, 3],
)
# Connect the graph
new_input_value = onnx.helper.make_tensor_value_info(
'new_input',
old_input_value.type.tensor_type.elem_type,
input_shape
"new_input", old_input_value.type.tensor_type.elem_type, input_shape
)
g.input.extend([new_input_value])
g.node.extend([new_weight, new_conv])
@ -197,6 +190,7 @@ def add_rgb2yynn_node(m):
# topological sort
other.topological_sort(g)
def swap_MatMul_inputs(g, original_matmul_node):
# Create Transpose nodes
input_a_value = helper.find_value_by_name(g, original_matmul_node.input[0])
@ -206,11 +200,12 @@ def swap_MatMul_inputs(g, original_matmul_node):
else:
perm = [0, 2, 1]
new_input_b_node = onnx.helper.make_node(
'Transpose',
inputs = [input_a_value.name],
outputs = [input_a_value.name + '_transposed'],
name = f"{input_a_value.name}_transposed_for_{original_matmul_node.name}",
perm = perm
"Transpose",
inputs=[input_a_value.name],
outputs=[input_a_value.name + "_transposed"],
name=f"{input_a_value.name}_transposed_for_"
f"{original_matmul_node.name}",
perm=perm,
)
input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1])
input_b_shape = helper.get_shape_from_value_info(input_b_value)
@ -219,18 +214,19 @@ def swap_MatMul_inputs(g, original_matmul_node):
else:
perm = [0, 1, 3, 2]
new_input_a_node = onnx.helper.make_node(
'Transpose',
inputs = [input_b_value.name],
outputs = [input_b_value.name + '_transposed'],
name = f'{input_b_value.name}_transposed_for_{original_matmul_node.name}',
perm = perm
"Transpose",
inputs=[input_b_value.name],
outputs=[input_b_value.name + "_transposed"],
name=f"{input_b_value.name}_transposed_for_"
f"{original_matmul_node.name}",
perm=perm,
)
# Create new MatMul node
new_matmul_node = onnx.helper.make_node(
'MatMul',
inputs = [new_input_a_node.output[0], new_input_b_node.output[0]],
outputs = [original_matmul_node.output[0] + '_transposed'],
name = original_matmul_node.name + '_transposed'
"MatMul",
inputs=[new_input_a_node.output[0], new_input_b_node.output[0]],
outputs=[original_matmul_node.output[0] + "_transposed"],
name=original_matmul_node.name + "_transposed",
)
# Create final Transpose node
output_value = helper.find_value_by_name(g, original_matmul_node.output[0])
@ -240,17 +236,25 @@ def swap_MatMul_inputs(g, original_matmul_node):
else:
perm = [0, 1, 3, 2]
new_final_transpose_node = onnx.helper.make_node(
'Transpose',
inputs = [new_matmul_node.output[0]],
outputs = [original_matmul_node.output[0]],
name = original_matmul_node.name + '_final_transpose',
perm = perm
"Transpose",
inputs=[new_matmul_node.output[0]],
outputs=[original_matmul_node.output[0]],
name=original_matmul_node.name + "_final_transpose",
perm=perm,
)
# Add new nodes
g.node.extend([new_input_a_node, new_input_b_node, new_matmul_node, new_final_transpose_node])
g.node.extend(
[
new_input_a_node,
new_input_b_node,
new_matmul_node,
new_final_transpose_node,
]
)
# Delete original nodes
g.node.remove(original_matmul_node)
def split_MatMul_batch_then_concat(g, original_matmul_node):
new_nodes = []
final_concat_inputs = []
@ -265,49 +269,85 @@ def split_MatMul_batch_then_concat(g, original_matmul_node):
batch_count = input_a_shape[1]
for i in range(batch_count):
# Create Split nodes for input A
starts_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_starts", (1, ), [i])
ends_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_ends", (1, ), [i+1])
axes_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_axes", (1, ), [len(input_a_shape) - 3])
starts_node = helper.list_to_constant(
f"{input_a_value.name}_sliced_{i}_starts", (1,), [i]
)
ends_node = helper.list_to_constant(
f"{input_a_value.name}_sliced_{i}_ends", (1,), [i + 1]
)
axes_node = helper.list_to_constant(
f"{input_a_value.name}_sliced_{i}_axes",
(1,),
[len(input_a_shape) - 3],
)
new_sliced_a_node = onnx.helper.make_node(
'Slice',
inputs = [input_a_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]],
outputs = [f"{input_a_value.name}_sliced_{i}"],
name = f"{input_a_value.name}_sliced_{i}_for_{original_matmul_node.name}"
"Slice",
inputs=[
input_a_value.name,
starts_node.output[0],
ends_node.output[0],
axes_node.output[0],
],
outputs=[f"{input_a_value.name}_sliced_{i}"],
name=f"{input_a_value.name}_sliced_{i}_for_"
f"{original_matmul_node.name}",
)
new_nodes.extend(
[starts_node, ends_node, axes_node, new_sliced_a_node]
)
new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_a_node])
# Create Split nodes for input B
starts_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_starts", (1, ), [i])
ends_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_ends", (1, ), [i+1])
axes_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_axes", (1, ), [len(input_b_shape) - 3])
new_sliced_b_node = onnx.helper.make_node(
'Slice',
inputs = [input_b_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]],
outputs = [f"{input_b_value.name}_sliced_{i}"],
name = f"{input_b_value.name}_sliced_{i}_for_{original_matmul_node.name}"
starts_node = helper.list_to_constant(
f"{input_b_value.name}_sliced_{i}_starts", (1,), [i]
)
ends_node = helper.list_to_constant(
f"{input_b_value.name}_sliced_{i}_ends", (1,), [i + 1]
)
axes_node = helper.list_to_constant(
f"{input_b_value.name}_sliced_{i}_axes",
(1,),
[len(input_b_shape) - 3],
)
new_sliced_b_node = onnx.helper.make_node(
"Slice",
inputs=[
input_b_value.name,
starts_node.output[0],
ends_node.output[0],
axes_node.output[0],
],
outputs=[f"{input_b_value.name}_sliced_{i}"],
name=f"{input_b_value.name}_sliced_{i}_for_"
f"{original_matmul_node.name}",
)
new_nodes.extend(
[starts_node, ends_node, axes_node, new_sliced_b_node]
)
new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_b_node])
# Create MatMul nodes
new_matmul_node = onnx.helper.make_node(
'MatMul',
inputs = [new_sliced_a_node.output[0], new_sliced_b_node.output[0]],
outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"],
name = f"{original_matmul_node.name}_sliced_{i}"
"MatMul",
inputs=[new_sliced_a_node.output[0], new_sliced_b_node.output[0]],
outputs=[f"{original_matmul_node.output[0]}_sliced_{i}"],
name=f"{original_matmul_node.name}_sliced_{i}",
)
new_nodes.append(new_matmul_node)
final_concat_inputs.append(new_matmul_node.output[0])
# Create Concat nodes
output_value = helper.find_value_by_name(g, original_matmul_node.output[0])
if output_value is None:
output_value = helper.find_output_by_name(g, original_matmul_node.output[0])
output_value = helper.find_output_by_name(
g, original_matmul_node.output[0]
)
if output_value is None:
helper.logger.error(f"Cannot find value_info for {original_matmul_node.output[0]}")
helper.logger.error(
f"Cannot find value_info for {original_matmul_node.output[0]}"
)
output_shape = helper.get_shape_from_value_info(output_value)
new_concat_node = onnx.helper.make_node(
"Concat",
inputs = final_concat_inputs,
outputs = [original_matmul_node.output[0]],
name = f"{original_matmul_node.name}_final_concat",
axis = len(output_shape) - 3
inputs=final_concat_inputs,
outputs=[original_matmul_node.output[0]],
name=f"{original_matmul_node.name}_final_concat",
axis=len(output_shape) - 3,
)
new_nodes.append(new_concat_node)
# Add new nodes
@ -320,7 +360,9 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
new_nodes = []
final_concat_inputs = []
# Get the batch count
input_b_node = helper.find_node_by_output_name(g, original_matmul_node.input[1])
input_b_node = helper.find_node_by_output_name(
g, original_matmul_node.input[1]
)
input_b_np = helper.constant_to_numpy(input_b_node)
if len(input_b_np.shape) == 3:
batch_count = input_b_np.shape[0]
@ -329,17 +371,19 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
for i in range(batch_count):
# Create new constant node
if len(input_b_np.shape) == 3:
new_np = input_b_np[i:i+1, ...]
new_np = input_b_np[i:i + 1, ...]
else:
new_np = input_b_np[:, i:i+1, ...]
new_weight = helper.numpy_to_constant(f"{input_b_node.name}_sliced_{i}", new_np)
new_np = input_b_np[:, i:i + 1, ...]
new_weight = helper.numpy_to_constant(
f"{input_b_node.name}_sliced_{i}", new_np
)
new_nodes.append(new_weight)
# Create MatMul nodes
new_matmul_node = onnx.helper.make_node(
'MatMul',
inputs = [original_matmul_node.input[0], new_weight.output[0]],
outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"],
name = f"{original_matmul_node.name}_sliced_{i}"
"MatMul",
inputs=[original_matmul_node.input[0], new_weight.output[0]],
outputs=[f"{original_matmul_node.output[0]}_sliced_{i}"],
name=f"{original_matmul_node.name}_sliced_{i}",
)
new_nodes.append(new_matmul_node)
final_concat_inputs.append(new_matmul_node.output[0])
@ -348,10 +392,10 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
output_shape = helper.get_shape_from_value_info(output_value)
new_concat_node = onnx.helper.make_node(
"Concat",
inputs = final_concat_inputs,
outputs = [original_matmul_node.output[0]],
name = f"{original_matmul_node.name}_final_concat",
axis = len(output_shape) - 3
inputs=final_concat_inputs,
outputs=[original_matmul_node.output[0]],
name=f"{original_matmul_node.name}_final_concat",
axis=len(output_shape) - 3,
)
new_nodes.append(new_concat_node)
# Add new nodes
@ -367,7 +411,7 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
def special_MatMul_process(g):
for node in g.node:
if node.op_type != 'MatMul':
if node.op_type != "MatMul":
continue
input_a_name = node.input[0]
input_a_value = helper.find_value_by_name(g, input_a_name)
@ -383,19 +427,30 @@ def special_MatMul_process(g):
continue
# Too many dimensions or too few dimensions. Not supported. Skip
if len(input_a_shape) > 4 or len(input_b_shape) > 4:
helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have too many dimensions.")
helper.logger.warning(
f"Cannot optimize MatMul {node.name}: "
"inputs have too many dimensions."
)
continue
if len(input_a_shape) < 2 or len(input_b_shape) < 2:
helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have two few dimensions.")
helper.logger.warning(
f"Cannot optimize MatMul {node.name}: "
"inputs have two few dimensions."
)
continue
# For 4 dimension, check the first dimension (should be 1) and treated as 3 dimensions.
# For 4 dimension, check the first dimension (should be 1)
# and treated as 3 dimensions.
extra_dim = None
if len(input_a_shape) == 4:
extra_dim = input_a_shape[0]
input_a_shape = input_a_shape[1:]
if len(input_b_shape) == 4:
if input_b_shape[0] != extra_dim:
helper.logger.warning(f"Cannot optimize MatMul {node.name}: input dimension batch sizes does not match ({extra_dim} vs {input_b_shape[0]}).")
helper.logger.warning(
f"Cannot optimize MatMul {node.name}: "
"input dimension batch sizes does not match "
f"({extra_dim} vs {input_b_shape[0]})."
)
continue
input_b_shape = input_b_shape[1:]
# Check input B dimension
@ -404,20 +459,31 @@ def special_MatMul_process(g):
continue
# If B is B x W x V, but B is a constant.
input_b_node = helper.find_node_by_output_name(g, input_b_name)
if input_b_node is not None and input_b_node.op_type == 'Constant':
if input_b_node is not None and input_b_node.op_type == "Constant":
# Constant input
helper.logger.debug(f"Optimizing MatMul node {node.name}: split constant input.")
helper.logger.debug(
f"Optimizing MatMul node {node.name}: split constant input."
)
split_MatMul_Constant_input_then_concat(g, node)
# If B is B x W x V and A is 1 x H x W, do the swap.
elif len(input_a_shape) == 2 or (input_a_shape[0] == 1 and (extra_dim is None or extra_dim == 1)):
helper.logger.debug(f"Optimizing MatMul node {node.name}: swap input.")
elif len(input_a_shape) == 2 or (
input_a_shape[0] == 1 and (extra_dim is None or extra_dim == 1)
):
helper.logger.debug(
f"Optimizing MatMul node {node.name}: swap input."
)
swap_MatMul_inputs(g, node)
# If B is B x W x V and A is B x H x W, do the split.
elif input_b_shape[0] == input_a_shape[0]:
helper.logger.debug(f"Optimizing MatMul node {node.name}: split input batch.")
helper.logger.debug(
f"Optimizing MatMul node {node.name}: split input batch."
)
split_MatMul_batch_then_concat(g, node)
# Other cases are not supported: If B is B x W x V but A is X x H x W.
else:
helper.logger.warning(f"Cannot optimize MatMul {node.name}: unknown reason. Might be shape mismatch.")
helper.logger.warning(
f"Cannot optimize MatMul {node.name}: "
"unknown reason. Might be shape mismatch."
)
continue
other.topological_sort(g)
other.topological_sort(g)