style: fix format so pep8 is satisfied

This commit is contained in:
chingning.chen 2022-04-12 14:26:54 +08:00
parent a783220efa
commit 0136a5b2bd
27 changed files with 3329 additions and 2104 deletions

View File

@ -316,7 +316,9 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
session_options.register_custom_ops_library(ort_custom_op_path) session_options.register_custom_ops_library(ort_custom_op_path)
providers = ['CPUExecutionProvider'] providers = ['CPUExecutionProvider']
provider_options = [{}] provider_options = [{}]
is_cuda_available = ort.get_device() == 'GPU' and torch.cuda.is_available() is_cuda_available = (
ort.get_device() == 'GPU' and torch.cuda.is_available()
)
if is_cuda_available: if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider') providers.insert(0, 'CUDAExecutionProvider')
device_id = device_id or 0 device_id = device_id or 0
@ -334,7 +336,7 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
self.output_name_list = [sess_outputs[0].name] self.output_name_list = [sess_outputs[0].name]
self.cfg = cfg # TODO: necessary? self.cfg = cfg # TODO: necessary?
self.test_cfg = cfg.model.test_cfg self.test_cfg = cfg.model.test_cfg
self.test_mode = self.test_cfg.mode # NOTE: should be 'whole' or 'slide' self.test_mode = self.test_cfg.mode # NOTE: either 'whole' or 'slide'
self.is_cuda_available = is_cuda_available self.is_cuda_available = is_cuda_available
self.count_mat = None self.count_mat = None
try: try:

View File

@ -171,7 +171,8 @@ if __name__ == '__main__':
setup( setup(
name='mmsegmentation', name='mmsegmentation',
version=get_version(), version=get_version(),
description='Open MMLab Semantic Segmentation Toolbox and Benchmark (Kneron Edition)', description='Open MMLab Semantic Segmentation Toolbox '
'and Benchmark (Kneron Edition)',
long_description=readme(), long_description=readme(),
long_description_content_type='text/markdown', long_description_content_type='text/markdown',
author='MMSegmentation Contributors and Kneron', author='MMSegmentation Contributors and Kneron',

View File

@ -163,9 +163,9 @@ def main():
efficient_test = eval_kwargs.get('efficient_test', False) efficient_test = eval_kwargs.get('efficient_test', False)
if efficient_test: if efficient_test:
warnings.warn( warnings.warn(
'``efficient_test=True`` does not have effect in tools/test_kneron.py, ' '"efficient_test=True" does not have effect in '
'the evaluation and format results are CPU memory efficient by ' 'tools/test_kneron.py, the evaluation and format '
'default') 'results are CPU memory efficient by default')
eval_on_format_results = ( eval_on_format_results = (
args.eval is not None and 'cityscapes' in args.eval) args.eval is not None and 'cityscapes' in args.eval)

View File

@ -5,55 +5,81 @@ import sys
from tools.other import topological_sort from tools.other import topological_sort
from tools import helper from tools import helper
def fuse_bias_in_consecutive_1x1_conv(g): def fuse_bias_in_consecutive_1x1_conv(g):
for second in g.node: for second in g.node:
# Find two conv # Find two conv
if second.op_type != 'Conv': if second.op_type != "Conv":
continue continue
first = helper.find_node_by_output_name(g, second.input[0]) first = helper.find_node_by_output_name(g, second.input[0])
if first is None or first.op_type != 'Conv': if first is None or first.op_type != "Conv":
continue continue
# Check if the first one has only one folloing node # Check if the first one has only one folloing node
if len(helper.find_following_nodes_by_input_value_name(g, first.output[0])) != 1: if (
len(
helper.find_following_nodes_by_input_value_name(
g, first.output[0]
)
)
!= 1
):
continue continue
# If first node has no bias, continue # If first node has no bias, continue
if len(first.input) == 2: if len(first.input) == 2:
continue continue
# Check their kernel size # Check their kernel size
first_kernel_shape = helper.get_list_attribute_by_name(first, 'kernel_shape', 'int') first_kernel_shape = helper.get_list_attribute_by_name(
second_kernel_shape = helper.get_list_attribute_by_name(second, 'kernel_shape', 'int') first, "kernel_shape", "int"
prod = first_kernel_shape[0] * first_kernel_shape[1] * second_kernel_shape[0] * second_kernel_shape[1] )
second_kernel_shape = helper.get_list_attribute_by_name(
second, "kernel_shape", "int"
)
prod = (
first_kernel_shape[0]
* first_kernel_shape[1]
* second_kernel_shape[0]
* second_kernel_shape[1]
)
if prod != 1: if prod != 1:
continue continue
print('Found: ', first.name, ' ', second.name) print("Found: ", first.name, " ", second.name)
# Get bias of the nodes # Get bias of the nodes
first_bias_node = helper.find_node_by_output_name(g, first.input[2]) first_bias_node = helper.find_node_by_output_name(g, first.input[2])
second_weight_node = helper.find_node_by_output_name(g, second.input[1]) second_weight_node = helper.find_node_by_output_name(
g, second.input[1]
)
second_bias_node = helper.find_node_by_output_name(g, second.input[2]) second_bias_node = helper.find_node_by_output_name(g, second.input[2])
first_bias = helper.constant_to_numpy(first_bias_node) first_bias = helper.constant_to_numpy(first_bias_node)
second_weight = helper.constant_to_numpy(second_weight_node) second_weight = helper.constant_to_numpy(second_weight_node)
second_bias = helper.constant_to_numpy(second_bias_node) second_bias = helper.constant_to_numpy(second_bias_node)
# Calculate the weight for second node # Calculate the weight for second node
first_bias = np.reshape(first_bias, (1, first_bias.size)) first_bias = np.reshape(first_bias, (1, first_bias.size))
second_weight = np.reshape(second_weight, (second_weight.shape[0], second_weight.shape[1])) second_weight = np.reshape(
second_weight, (second_weight.shape[0], second_weight.shape[1])
)
second_weight = np.transpose(second_weight) second_weight = np.transpose(second_weight)
new_second_bias = second_bias + np.matmul(first_bias, second_weight) new_second_bias = second_bias + np.matmul(first_bias, second_weight)
new_second_bias = np.reshape(new_second_bias, (new_second_bias.size,)) new_second_bias = np.reshape(new_second_bias, (new_second_bias.size,))
# Generate new weight # Generate new weight
new_first_bias = np.reshape(first_bias, (first_bias.size, )) new_first_bias = np.reshape(first_bias, (first_bias.size,))
for i in range(new_first_bias.shape[0]): for i in range(new_first_bias.shape[0]):
new_first_bias[i] = 0.0 new_first_bias[i] = 0.0
new_first_bias_node = helper.numpy_to_constant(first_bias_node.output[0], new_first_bias) new_first_bias_node = helper.numpy_to_constant(
new_second_bias_node = helper.numpy_to_constant(second_bias_node.output[0], new_second_bias) first_bias_node.output[0], new_first_bias
)
new_second_bias_node = helper.numpy_to_constant(
second_bias_node.output[0], new_second_bias
)
# Delete old weight and add new weights # Delete old weight and add new weights
g.node.remove(first_bias_node) g.node.remove(first_bias_node)
g.node.remove(second_bias_node) g.node.remove(second_bias_node)
g.node.extend([new_first_bias_node, new_second_bias_node]) g.node.extend([new_first_bias_node, new_second_bias_node])
topological_sort(g) topological_sort(g)
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) != 3: if len(sys.argv) != 3:
exit(1) exit(1)
m = onnx.load(sys.argv[1]) m = onnx.load(sys.argv[1])
fuse_bias_in_consecutive_1x1_conv(m.graph) fuse_bias_in_consecutive_1x1_conv(m.graph)
onnx.save(m, sys.argv[2]) onnx.save(m, sys.argv[2])

View File

@ -1,5 +1,6 @@
import onnx import onnx
import onnx.utils import onnx.utils
try: try:
from onnx import optimizer from onnx import optimizer
except ImportError: except ImportError:
@ -9,23 +10,107 @@ import argparse
import tools.modhelper as helper import tools.modhelper as helper
import tools.other as other import tools.other as other
import tools.replacing as replacing import tools.replacing as replacing
# Main process # Main process
# Argument parser # Argument parser
parser = argparse.ArgumentParser(description="Edit an ONNX model.\nThe processing sequense is 'delete nodes/values' -> 'add nodes' -> 'change shapes'.\nCutting cannot be done with other operations together") parser = argparse.ArgumentParser(
parser.add_argument('in_file', type=str, help='input ONNX FILE') description="Edit an ONNX model.\nThe processing sequense is 'delete "
parser.add_argument('out_file', type=str, help="ouput ONNX FILE") "nodes/values' -> 'add nodes' -> 'change shapes'.\nCutting "
parser.add_argument('-c', '--cut', dest='cut_node', type=str, nargs='+', help="remove nodes from the given nodes(inclusive)") "cannot be done with other operations together"
parser.add_argument('--cut-type', dest='cut_type', type=str, nargs='+', help="remove nodes by type from the given nodes(inclusive)") )
parser.add_argument('-d', '--delete', dest='delete_node', type=str, nargs='+', help="delete nodes by names and only those nodes") parser.add_argument("in_file", type=str, help="input ONNX FILE")
parser.add_argument('--delete-input', dest='delete_input', type=str, nargs='+', help="delete inputs by names") parser.add_argument("out_file", type=str, help="ouput ONNX FILE")
parser.add_argument('--delete-output', dest='delete_output', type=str, nargs='+', help="delete outputs by names") parser.add_argument(
parser.add_argument('-i', '--input', dest='input_change', type=str, nargs='+', help="change input shape (e.g. -i 'input_0 1 3 224 224')") "-c",
parser.add_argument('-o', '--output', dest='output_change', type=str, nargs='+', help="change output shape (e.g. -o 'input_0 1 3 224 224')") "--cut",
parser.add_argument('--add-conv', dest='add_conv', type=str, nargs='+', help='add nop conv using specific input') dest="cut_node",
parser.add_argument('--add-bn', dest='add_bn', type=str, nargs='+', help='add nop bn using specific input') type=str,
parser.add_argument('--rename-output', dest='rename_output', type=str, nargs='+', help='Rename the specific output(e.g. --rename-output old_name new_name)') nargs="+",
parser.add_argument('--pixel-bias-value', dest='pixel_bias_value', type=str, nargs='+', help='(per channel) set pixel value bias bn layer at model front for normalization( e.g. --pixel_bias_value "[104.0, 117.0, 123.0]" )') help="remove nodes from the given nodes(inclusive)",
parser.add_argument('--pixel-scale-value', dest='pixel_scale_value', type=str, nargs='+', help='(per channel) set pixel value scale bn layer at model front for normalization( e.g. --pixel_scale_value "[0.0078125, 0.0078125, 0.0078125]" )') )
parser.add_argument(
"--cut-type",
dest="cut_type",
type=str,
nargs="+",
help="remove nodes by type from the given nodes(inclusive)",
)
parser.add_argument(
"-d",
"--delete",
dest="delete_node",
type=str,
nargs="+",
help="delete nodes by names and only those nodes",
)
parser.add_argument(
"--delete-input",
dest="delete_input",
type=str,
nargs="+",
help="delete inputs by names",
)
parser.add_argument(
"--delete-output",
dest="delete_output",
type=str,
nargs="+",
help="delete outputs by names",
)
parser.add_argument(
"-i",
"--input",
dest="input_change",
type=str,
nargs="+",
help="change input shape (e.g. -i 'input_0 1 3 224 224')",
)
parser.add_argument(
"-o",
"--output",
dest="output_change",
type=str,
nargs="+",
help="change output shape (e.g. -o 'input_0 1 3 224 224')",
)
parser.add_argument(
"--add-conv",
dest="add_conv",
type=str,
nargs="+",
help="add nop conv using specific input",
)
parser.add_argument(
"--add-bn",
dest="add_bn",
type=str,
nargs="+",
help="add nop bn using specific input",
)
parser.add_argument(
"--rename-output",
dest="rename_output",
type=str,
nargs="+",
help="Rename the specific output(e.g. --rename-output old_name new_name)",
)
parser.add_argument(
"--pixel-bias-value",
dest="pixel_bias_value",
type=str,
nargs="+",
help='(per channel) set pixel value bias bn layer at model front for '
'normalization( e.g. --pixel_bias_value "[104.0, 117.0, 123.0]" )',
)
parser.add_argument(
"--pixel-scale-value",
dest="pixel_scale_value",
type=str,
nargs="+",
help='(per channel) set pixel value scale bn layer at model front for '
'normalization( e.g. --pixel_scale_value '
'"[0.0078125, 0.0078125, 0.0078125]" )',
)
args = parser.parse_args() args = parser.parse_args()
@ -60,23 +145,48 @@ if args.add_bn is not None:
if args.pixel_bias_value is not None or args.pixel_scale_value is not None: if args.pixel_bias_value is not None or args.pixel_scale_value is not None:
if len(g.input) > 1: if len(g.input) > 1:
raise ValueError(" '--pixel-bias-value' and '--pixel-scale-value' only support one input node model currently") raise ValueError(
" '--pixel-bias-value' and '--pixel-scale-value' "
"only support one input node model currently"
)
i_n = g.input[0] i_n = g.input[0]
pixel_bias_value = [0] * i_n.type.tensor_type.shape.dim[1].dim_value pixel_bias_value = [0] * i_n.type.tensor_type.shape.dim[1].dim_value
pixel_scale_value = [1] * i_n.type.tensor_type.shape.dim[1].dim_value pixel_scale_value = [1] * i_n.type.tensor_type.shape.dim[1].dim_value
if args.pixel_bias_value is not None and len(args.pixel_bias_value) == 1: if args.pixel_bias_value is not None and len(args.pixel_bias_value) == 1:
pixel_bias_value = [float(n) for n in args.pixel_bias_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')] pixel_bias_value = [
float(n)
for n in args.pixel_bias_value[0]
.replace("[", "")
.replace("]", "")
.split(",")
]
if args.pixel_scale_value is not None and len(args.pixel_scale_value) == 1: if args.pixel_scale_value is not None and len(args.pixel_scale_value) == 1:
pixel_scale_value = [float(n) for n in args.pixel_scale_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')] pixel_scale_value = [
float(n)
for n in args.pixel_scale_value[0]
.replace("[", "")
.replace("]", "")
.split(",")
]
if i_n.type.tensor_type.shape.dim[1].dim_value != len(
if i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_bias_value) or i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_scale_value): pixel_bias_value
raise ValueError("--pixel-bias-value (" + str(pixel_bias_value) + ") and --pixel-scale-value (" + str(pixel_scale_value) + ") should be same as input dimension:" + str(i_n.type.tensor_type.shape.dim[1].dim_value) ) ) or i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_scale_value):
other.add_bias_scale_bn_after(g, i_n.name, pixel_bias_value, pixel_scale_value) raise ValueError(
"--pixel-bias-value ("
+ str(pixel_bias_value)
+ ") and --pixel-scale-value ("
+ str(pixel_scale_value)
+ ") should be same as input dimension:"
+ str(i_n.type.tensor_type.shape.dim[1].dim_value)
)
other.add_bias_scale_bn_after(
g, i_n.name, pixel_bias_value, pixel_scale_value
)
# Change input and output shapes as requested # Change input and output shapes as requested
if args.input_change is not None: if args.input_change is not None:
@ -100,14 +210,21 @@ if args.rename_output:
print("Rename output should be paires of names.") print("Rename output should be paires of names.")
else: else:
for i in range(0, len(args.rename_output), 2): for i in range(0, len(args.rename_output), 2):
other.rename_output_name(g, args.rename_output[i], args.rename_output[i + 1]) other.rename_output_name(
g, args.rename_output[i], args.rename_output[i + 1]
)
# Remove useless nodes # Remove useless nodes
if args.delete_node or args.delete_input or args.input_change or args.output_change: if (
args.delete_node
or args.delete_input
or args.input_change
or args.output_change
):
# If shape changed during the modification, redo shape inference. # If shape changed during the modification, redo shape inference.
while(len(g.value_info) > 0): while len(g.value_info) > 0:
g.value_info.pop() g.value_info.pop()
passes = ['extract_constant_to_initializer'] passes = ["extract_constant_to_initializer"]
m = optimizer.optimize(m, passes) m = optimizer.optimize(m, passes)
g = m.graph g = m.graph
replacing.replace_initializer_with_Constant(g) replacing.replace_initializer_with_Constant(g)
@ -115,4 +232,4 @@ other.topological_sort(g)
# Polish and output # Polish and output
m = other.polish_model(m) m = other.polish_model(m)
other.add_output_to_value_info(m.graph) other.add_output_to_value_info(m.graph)
onnx.save(m, args.out_file) onnx.save(m, args.out_file)

View File

@ -11,42 +11,44 @@ if len(sys.argv) != 3:
# Modify onnx # Modify onnx
m = onnx.load(sys.argv[1]) m = onnx.load(sys.argv[1])
special.add_0_5_to_normalized_input(m) special.add_0_5_to_normalized_input(m)
onnx.save(m, sys.argv[1][:-4] + 'norm.onnx') onnx.save(m, sys.argv[1][:-4] + "norm.onnx")
# Change input node # Change input node
origin_file = open(sys.argv[2], 'r') origin_file = open(sys.argv[2], "r")
origin_json = json.load(origin_file) origin_json = json.load(origin_file)
origin_json["input_node"]["output_datapath_radix"] = [8] origin_json["input_node"]["output_datapath_radix"] = [8]
new_json_str = json.dumps(origin_json) new_json_str = json.dumps(origin_json)
# Modify json # Modify json
file = open(sys.argv[1][:-4] + 'norm.onnx' + '.json', 'w') file = open(sys.argv[1][:-4] + "norm.onnx" + ".json", "w")
s = """{{ s = """{{
\"{0}\" : \"{0}\" :
{{ {{
\"bias_bitwidth\" : 16, \"bias_bitwidth\" : 16,
\"{0}_bias\" : [15], \"{0}_bias\" : [15],
\"{0}_weight\" : [3,3,3], \"{0}_weight\" : [3,3,3],
\"conv_coarse_shift\" : [-4,-4,-4], \"conv_coarse_shift\" : [-4,-4,-4],
\"conv_fine_shift\" : [0,0,0], \"conv_fine_shift\" : [0,0,0],
\"conv_total_shift\" : [-4,-4,-4], \"conv_total_shift\" : [-4,-4,-4],
\"cpu_mode\" : false, \"cpu_mode\" : false,
\"delta_input_bitwidth\" : [0], \"delta_input_bitwidth\" : [0],
\"delta_output_bitwidth\" : 8, \"delta_output_bitwidth\" : 8,
\"flag_radix_bias_eq_output\" : true, \"flag_radix_bias_eq_output\" : true,
\"input_scale\" : [[1.0,1.0,1.0]], \"input_scale\" : [[1.0,1.0,1.0]],
\"output_scale\" : [1.0, 1.0, 1.0], \"output_scale\" : [1.0, 1.0, 1.0],
\"psum_bitwidth\" : 16, \"psum_bitwidth\" : 16,
\"weight_bitwidth\" : 8, \"weight_bitwidth\" : 8,
\"input_datapath_bitwidth\" : [8], \"input_datapath_bitwidth\" : [8],
\"input_datapath_radix\" : [8], \"input_datapath_radix\" : [8],
\"working_input_bitwidth\" : 8, \"working_input_bitwidth\" : 8,
\"working_input_radix\" : [8], \"working_input_radix\" : [8],
\"working_output_bitwidth\" : 16, \"working_output_bitwidth\" : 16,
\"working_output_radix\" : 15, \"working_output_radix\" : 15,
\"output_datapath_bitwidth\" : 8, \"output_datapath_bitwidth\" : 8,
\"output_datapath_radix\" : 7 \"output_datapath_radix\" : 7
}},\n""".format('input_norm') }},\n""".format(
"input_norm"
)
file.write(s + new_json_str[1:]) file.write(s + new_json_str[1:])
file.close() file.close()
origin_file.close() origin_file.close()

View File

@ -2,33 +2,33 @@
import sys import sys
import onnx import onnx
import numpy as np
from onnx import numpy_helper
from tools import other, helper from tools import other, helper
""" """
Change onnx model from version 1.3 to version 1.4. Change onnx model from version 1.3 to version 1.4.
Modify the BN node by removing the spatial attribute - Modify the BN node by removing the spatial attribute
Modify the Upsample node by removing the 'scales' attribute, and adding a constant node instead. - Modify the Upsample node by removing the 'scales' attribute,
Model's ir_version and opset_import are updated. and adding a constant node instead.
- Model's ir_version and opset_import are updated.
""" """
def remove_BN_spatial(g): def remove_BN_spatial(g):
for node in g.node: for node in g.node:
if node.op_type != 'BatchNormalization': if node.op_type != "BatchNormalization":
continue continue
for att in node.attribute: for att in node.attribute:
if att.name == 'spatial': if att.name == "spatial":
node.attribute.remove(att) node.attribute.remove(att)
def upsample_attribute_to_const(g): def upsample_attribute_to_const(g):
for node in g.node: for node in g.node:
if node.op_type != 'Upsample': if node.op_type != "Upsample":
continue continue
scales_exist = False scales_exist = False
for att in node.attribute: for att in node.attribute:
if att.name == 'scales': if att.name == "scales":
scales_exist = True scales_exist = True
break break
if not scales_exist: if not scales_exist:
@ -36,18 +36,23 @@ def upsample_attribute_to_const(g):
shape = [len(att.floats)] shape = [len(att.floats)]
node.attribute.remove(att) node.attribute.remove(att)
new_node = helper.list_to_constant(node.name+'_input', shape, att.floats) new_node = helper.list_to_constant(
node.name + "_input", shape, att.floats
)
g.node.extend([new_node]) g.node.extend([new_node])
value_info = onnx.helper.make_tensor_value_info(node.name+'_input', onnx.TensorProto.FLOAT, shape) value_info = onnx.helper.make_tensor_value_info(
node.input.extend([node.name+'_input']) node.name + "_input", onnx.TensorProto.FLOAT, shape
)
node.input.extend([node.name + "_input"])
g.value_info.extend([value_info]) g.value_info.extend([value_info])
def relu6_to_clip(g): def relu6_to_clip(g):
for node in g.node: for node in g.node:
if node.op_type != 'Relu': if node.op_type != "Relu":
continue continue
max_val = helper.get_var_attribute_by_name(node, 'max', 'float') max_val = helper.get_var_attribute_by_name(node, "max", "float")
if max_val is None: if max_val is None:
continue continue
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
@ -56,11 +61,12 @@ def relu6_to_clip(g):
node.output, node.output,
name=node.name, name=node.name,
max=max_val, max=max_val,
min=0.0 min=0.0,
) )
g.node.remove(node) g.node.remove(node)
g.node.extend([new_node]) g.node.extend([new_node])
def PRelu_weight_reshape(g): def PRelu_weight_reshape(g):
# For PRelu with single dimension weight. Expand it to 1, x, 1, 1 # For PRelu with single dimension weight. Expand it to 1, x, 1, 1
for node in g.node: for node in g.node:
@ -91,16 +97,18 @@ def PRelu_weight_reshape(g):
new_input = onnx.helper.make_tensor_value_info( new_input = onnx.helper.make_tensor_value_info(
node.input[1], node.input[1],
input_value.type.tensor_type.elem_type, input_value.type.tensor_type.elem_type,
(1, slope.dims[1], 1, 1)) (1, slope.dims[1], 1, 1),
)
g.input.remove(input_value) g.input.remove(input_value)
g.input.append(new_input) g.input.append(new_input)
value_info = helper.find_value_by_name(g, node.input[1]) value_info = helper.find_value_by_name(g, node.input[1])
if value_info is not None: if value_info is not None:
g.value_info.remove(value_info) g.value_info.remove(value_info)
def do_convert(m): def do_convert(m):
graph = m.graph graph = m.graph
# Modify the nodes. # Modify the nodes.
remove_BN_spatial(graph) remove_BN_spatial(graph)
upsample_attribute_to_const(graph) upsample_attribute_to_const(graph)
@ -113,6 +121,7 @@ def do_convert(m):
m.opset_import[0].version = 9 m.opset_import[0].version = 9
return m return m
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) != 3: if len(sys.argv) != 3:
print("Usage:{} file_in file_out".format(sys.argv[0])) print("Usage:{} file_in file_out".format(sys.argv[0]))

View File

@ -3,31 +3,38 @@
import sys import sys
import onnx import onnx
import onnx.utils import onnx.utils
import numpy as np
from onnx import numpy_helper
from tools import other, helper, replacing from tools import other, helper, replacing
""" """
Change onnx model from version 1.4 to version 1.6. Change onnx model from version 1.4 to version 1.6.
""" """
def replace_all_attribute_to_const_node_in_pad_node(g): def replace_all_attribute_to_const_node_in_pad_node(g):
node_to_remove = [] node_to_remove = []
node_to_extend = [] node_to_extend = []
for node in g.node: for node in g.node:
if node.op_type != 'Pad': if node.op_type != "Pad":
continue continue
pad_loc_node = None # must have pad_loc_node = None # must have
pad_mode = 'constant' pad_mode = "constant"
pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [0.0]) # need scalar pad_value_node = helper.list_to_constant(
node.name + "_pad_value", [], [0.0]
) # need scalar
for att in node.attribute: for att in node.attribute:
if att.name == 'mode': if att.name == "mode":
pad_mode = helper.get_var_attribute_by_name(node, 'mode', 'string') pad_mode = helper.get_var_attribute_by_name(
if att.name == 'pads': node, "mode", "string"
pad_loc_node = helper.list_to_constant(node.name+'_pad_loc', [len(att.ints)], att.ints) )
if att.name == 'value': if att.name == "pads":
pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [att.f]) pad_loc_node = helper.list_to_constant(
node.name + "_pad_loc", [len(att.ints)], att.ints
)
if att.name == "value":
pad_value_node = helper.list_to_constant(
node.name + "_pad_value", [], [att.f]
)
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
"Pad", "Pad",
@ -40,24 +47,30 @@ def replace_all_attribute_to_const_node_in_pad_node(g):
node_to_extend.append(new_node) node_to_extend.append(new_node)
node_to_extend.append(pad_loc_node) node_to_extend.append(pad_loc_node)
node_to_extend.append(pad_value_node) node_to_extend.append(pad_value_node)
for node in node_to_remove: for node in node_to_remove:
g.node.remove(node) g.node.remove(node)
for node in node_to_extend: for node in node_to_extend:
g.node.extend([node]) g.node.extend([node])
def upsampling_to_resize(g): def upsampling_to_resize(g):
for node in g.node: for node in g.node:
if node.op_type != 'Upsample': if node.op_type != "Upsample":
continue continue
upsampling_mode = helper.get_var_attribute_by_name(node, 'mode', 'string') upsampling_mode = helper.get_var_attribute_by_name(
node, "mode", "string"
)
scale_value_node = helper.find_node_by_output_name(g, node.input[1]) scale_value_node = helper.find_node_by_output_name(g, node.input[1])
if scale_value_node.op_type != "Constant": if scale_value_node.op_type != "Constant":
raise TypeError('seems there is a dynamic "scales" param in Upsampling node: ' + node.name + ' , you might need to do constant folding first') raise TypeError(
'seems there is a dynamic "scales" param in Upsampling node: '
+ node.name
+ " , you might need to do constant folding first"
)
roi_node = helper.list_to_constant(node.name+'_roi_value', [0], []) roi_node = helper.list_to_constant(node.name + "_roi_value", [0], [])
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
"Resize", "Resize",
@ -65,7 +78,7 @@ def upsampling_to_resize(g):
[node.output[0]], [node.output[0]],
name=node.output[0], name=node.output[0],
mode=upsampling_mode, mode=upsampling_mode,
coordinate_transformation_mode = 'asymmetric' coordinate_transformation_mode="asymmetric",
) )
g.node.remove(node) g.node.remove(node)
@ -75,7 +88,7 @@ def upsampling_to_resize(g):
def replace_all_attribute_to_const_node_in_slice_node(g): def replace_all_attribute_to_const_node_in_slice_node(g):
for node in g.node: for node in g.node:
if node.op_type != 'Slice': if node.op_type != "Slice":
continue continue
axes_const_node = None axes_const_node = None
@ -83,62 +96,75 @@ def replace_all_attribute_to_const_node_in_slice_node(g):
starts_const_node = None starts_const_node = None
steps_const_node = None steps_const_node = None
for att in node.attribute: for att in node.attribute:
if att.name == 'axes': if att.name == "axes":
axes_const_node = helper.list_to_constant(node.name+'_axes_value', [len(att.ints)], att.ints) axes_const_node = helper.list_to_constant(
node.name + "_axes_value", [len(att.ints)], att.ints
if att.name == 'ends': )
ends_const_node = helper.list_to_constant(node.name+'_ends_value', [len(att.ints)], att.ints)
if att.name == 'starts': if att.name == "ends":
starts_const_node = helper.list_to_constant(node.name+'_starts_value', [len(att.ints)], att.ints) ends_const_node = helper.list_to_constant(
node.name + "_ends_value", [len(att.ints)], att.ints
)
if att.name == 'steps': if att.name == "starts":
steps_const_node = helper.list_to_constant(node.name+'_steps_value',[ len(att.ints)], att.ints) starts_const_node = helper.list_to_constant(
node.name + "_starts_value", [len(att.ints)], att.ints
)
## pop out from back if att.name == "steps":
steps_const_node = helper.list_to_constant(
node.name + "_steps_value", [len(att.ints)], att.ints
)
# pop out from back
attr_len = len(node.attribute) attr_len = len(node.attribute)
for i in range(attr_len): for i in range(attr_len):
node.attribute.remove(node.attribute[ attr_len -1 - i ]) node.attribute.remove(node.attribute[attr_len - 1 - i])
## according the spec, we need to add node in specific order # according the spec, we need to add node in specific order
if starts_const_node != None: if starts_const_node is not None:
g.node.extend([starts_const_node]) g.node.extend([starts_const_node])
node.input.extend([starts_const_node.name]) node.input.extend([starts_const_node.name])
if ends_const_node != None: if ends_const_node is not None:
g.node.extend([ends_const_node]) g.node.extend([ends_const_node])
node.input.extend([ends_const_node.name]) node.input.extend([ends_const_node.name])
if axes_const_node != None: if axes_const_node is not None:
g.node.extend([axes_const_node]) g.node.extend([axes_const_node])
node.input.extend([axes_const_node.name]) node.input.extend([axes_const_node.name])
if steps_const_node != None: if steps_const_node is not None:
g.node.extend([steps_const_node]) g.node.extend([steps_const_node])
node.input.extend([steps_const_node.name]) node.input.extend([steps_const_node.name])
def replace_min_max_attribute_to_const_node_in_clip_node(g): def replace_min_max_attribute_to_const_node_in_clip_node(g):
for node in g.node: for node in g.node:
if node.op_type != 'Clip': if node.op_type != "Clip":
continue continue
max_const_node = None max_const_node = None
min_const_node = None min_const_node = None
for att in node.attribute: for att in node.attribute:
if att.name == 'max': if att.name == "max":
max_const_node = helper.list_to_constant(node.name+'_max_value', [], [att.f]) max_const_node = helper.list_to_constant(
node.name + "_max_value", [], [att.f]
if att.name == 'min': )
min_const_node = helper.list_to_constant(node.name+'_min_value', [], [att.f])
## pop out from back if att.name == "min":
node.attribute.remove(node.attribute[1]) min_const_node = helper.list_to_constant(
node.attribute.remove(node.attribute[0]) node.name + "_min_value", [], [att.f]
)
## according the spec, we need to add node in specific order
# pop out from back
node.attribute.remove(node.attribute[1])
node.attribute.remove(node.attribute[0])
# according the spec, we need to add node in specific order
g.node.extend([min_const_node]) g.node.extend([min_const_node])
g.node.extend([max_const_node]) g.node.extend([max_const_node])
node.input.extend([min_const_node.name]) node.input.extend([min_const_node.name])
node.input.extend([max_const_node.name]) node.input.extend([max_const_node.name])
def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto: def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto:
"""Update ir_version from 4 to 6 and update opset from 9 to 11. """Update ir_version from 4 to 6 and update opset from 9 to 11.
@ -173,6 +199,7 @@ def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto:
model = other.polish_model(model) model = other.polish_model(model)
return model return model
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) != 3: if len(sys.argv) != 3:
print("Usage:{} file_in file_out".format(sys.argv[0])) print("Usage:{} file_in file_out".format(sys.argv[0]))

View File

@ -1,45 +1,51 @@
import onnx import onnx
import onnx.utils import onnx.utils
try:
from onnx import optimizer
except ImportError:
import onnxoptimizer as optimizer
import sys
import argparse import argparse
import logging import logging
from tools import eliminating from tools import eliminating
from tools import fusing
from tools import replacing
from tools import other from tools import other
from tools import special from tools import special
from tools import combo from tools import combo
from tools.helper import logger
# from tools import temp # from tools import temp
def onnx2onnx_flow(m: onnx.ModelProto,
disable_fuse_bn=False, def onnx2onnx_flow(
bn_on_skip=False, m: onnx.ModelProto,
bn_before_add=False, disable_fuse_bn=False,
bgr=False, bn_on_skip=False,
norm=False, bn_before_add=False,
rgba2yynn=False, bgr=False,
eliminate_tail=False, norm=False,
opt_matmul=False, rgba2yynn=False,
duplicate_shared_weights=True) -> onnx.ModelProto: eliminate_tail=False,
opt_matmul=False,
duplicate_shared_weights=True,
) -> onnx.ModelProto:
"""Optimize the onnx. """Optimize the onnx.
Args: Args:
m (ModelProto): the input onnx ModelProto m (ModelProto): the input onnx ModelProto
disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False. disable_fuse_bn (bool, optional): do not fuse BN into Conv.
bn_on_skip (bool, optional): add BN operator on skip branches. Defaults to False. Defaults to False.
bn_before_add (bool, optional): add BN before Add node on every branches. Defaults to False. bn_on_skip (bool, optional): add BN operator on skip branches.
bgr (bool, optional): add an Conv layer to convert rgb input to bgr. Defaults to False. Defaults to False.
norm (bool, optional): add an Conv layer to add 0.5 tp the input. Defaults to False. bn_before_add (bool, optional): add BN before Add node on every branch.
rgba2yynn (bool, optional): add an Conv layer to convert rgb input to yynn . Defaults to False. Defaults to False.
eliminate_tail (bool, optional): remove the trailing NPU unsupported nodes. Defaults to False. bgr (bool, optional): add an Conv layer to convert rgb input to bgr.
opt_matmul(bool, optional): optimize the MatMul layers according to the NPU limit. Defaults to False. Defaults to False.
duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True. norm (bool, optional): add an Conv layer to add 0.5 tp the input.
Defaults to False.
rgba2yynn (bool, optional): add an Conv layer to convert rgb to yynn.
Defaults to False.
eliminate_tail (bool, optional): remove trailing NPU unsupported nodes.
Defaults to False.
opt_matmul(bool, optional): optimize MatMul layers due to NPU limit.
Defaults to False.
duplicate_shared_weights(bool, optional): duplicate shared weights.
Defaults to True.
Returns: Returns:
ModelProto: the optimized onnx model object. ModelProto: the optimized onnx model object.
@ -79,28 +85,83 @@ def onnx2onnx_flow(m: onnx.ModelProto,
return m return m
# Main process # Main process
if __name__ == "__main__": if __name__ == "__main__":
# Argument parser # Argument parser
parser = argparse.ArgumentParser(description="Optimize an ONNX model for Kneron compiler") parser = argparse.ArgumentParser(
parser.add_argument('in_file', help='input ONNX FILE') description="Optimize an ONNX model for Kneron compiler"
parser.add_argument('-o', '--output', dest='out_file', type=str, help="ouput ONNX FILE") )
parser.add_argument('--log', default='i', type=str, help="set log level") parser.add_argument("in_file", help="input ONNX FILE")
parser.add_argument('--bgr', action='store_true', default=False, help="set if the model is trained in BGR mode") parser.add_argument(
parser.add_argument('--norm', action='store_true', default=False, help="set if you have the input -0.5~0.5") "-o", "--output", dest="out_file", type=str, help="ouput ONNX FILE"
parser.add_argument('--rgba2yynn', action='store_true', default=False, help="set if the model has yynn input but you want to take rgba images") )
parser.add_argument('--add-bn-on-skip', dest='bn_on_skip', action='store_true', default=False, parser.add_argument("--log", default="i", type=str, help="set log level")
help="set if you only want to add BN on skip branches") parser.add_argument(
parser.add_argument('--add-bn', dest='bn_before_add', action='store_true', default=False, "--bgr",
help="set if you want to add BN before Add") action="store_true",
parser.add_argument('-t', '--eliminate-tail-unsupported', dest='eliminate_tail', action='store_true', default=False, default=False,
help='whether remove the last unsupported node for hardware') help="set if the model is trained in BGR mode",
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False, )
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.") parser.add_argument(
parser.add_argument('--opt-matmul', dest='opt_matmul', action='store_true', default=False, "--norm",
help="set if you want to optimize the MatMul operations for the kneron hardware.") action="store_true",
parser.add_argument('--no-duplicate-shared-weights', dest='no_duplicate_shared_weights', action='store_true', default=False, default=False,
help='do not duplicate shared weights. Defaults to False.') help="set if you have the input -0.5~0.5",
)
parser.add_argument(
"--rgba2yynn",
action="store_true",
default=False,
help="set if the model has yynn input but you want "
"to take rgba images",
)
parser.add_argument(
"--add-bn-on-skip",
dest="bn_on_skip",
action="store_true",
default=False,
help="set if you only want to add BN on skip branches",
)
parser.add_argument(
"--add-bn",
dest="bn_before_add",
action="store_true",
default=False,
help="set if you want to add BN before Add",
)
parser.add_argument(
"-t",
"--eliminate-tail-unsupported",
dest="eliminate_tail",
action="store_true",
default=False,
help="whether remove the last unsupported node for hardware",
)
parser.add_argument(
"--no-bn-fusion",
dest="disable_fuse_bn",
action="store_true",
default=False,
help="set if you have met errors which related to inferenced "
"shape mismatch. This option will prevent fusing "
"BatchNormalization into Conv.",
)
parser.add_argument(
"--opt-matmul",
dest="opt_matmul",
action="store_true",
default=False,
help="set if you want to optimize MatMul operations "
"for kneron hardware.",
)
parser.add_argument(
"--no-duplicate-shared-weights",
dest="no_duplicate_shared_weights",
action="store_true",
default=False,
help="do not duplicate shared weights. Defaults to False.",
)
args = parser.parse_args() args = parser.parse_args()
if args.out_file is None: if args.out_file is None:
@ -108,11 +169,11 @@ if __name__ == "__main__":
else: else:
outfile = args.out_file outfile = args.out_file
if args.log == 'w': if args.log == "w":
logging.basicConfig(level=logging.WARN) logging.basicConfig(level=logging.WARN)
elif args.log == 'd': elif args.log == "d":
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
elif args.log == 'e': elif args.log == "e":
logging.basicConfig(level=logging.ERROR) logging.basicConfig(level=logging.ERROR)
else: else:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -131,6 +192,17 @@ if __name__ == "__main__":
# Basic model organize # Basic model organize
m = onnx.load(args.in_file) m = onnx.load(args.in_file)
m = onnx2onnx_flow(m, args.disable_fuse_bn, args.bn_on_skip, args.bn_before_add, args.bgr, args.norm, args.rgba2yynn, args.eliminate_tail, args.opt_matmul, not args.no_duplicate_shared_weights) m = onnx2onnx_flow(
m,
args.disable_fuse_bn,
args.bn_on_skip,
args.bn_before_add,
args.bgr,
args.norm,
args.rgba2yynn,
args.eliminate_tail,
args.opt_matmul,
not args.no_duplicate_shared_weights,
)
onnx.save(m, outfile) onnx.save(m, outfile)

View File

@ -5,12 +5,30 @@ import numpy as np
from tools import helper from tools import helper
onnx2np_dtype = {0: 'float', 1: 'float32', 2: 'uint8', 3: 'int8', 4: 'uint16', 5: 'int16', 6: 'int32', 7: 'int64', 8: 'str', 9: 'bool', 10: 'float16', 11: 'double', 12: 'uint32', 13: 'uint64', 14: 'complex64', 15: 'complex128', 16: 'float'} onnx2np_dtype = {
0: "float",
1: "float32",
2: "uint8",
3: "int8",
4: "uint16",
5: "int16",
6: "int32",
7: "int64",
8: "str",
9: "bool",
10: "float16",
11: "double",
12: "uint32",
13: "uint64",
14: "complex64",
15: "complex128",
16: "float",
}
def onnx_model_results(path_a, path_b, total_times=10): def onnx_model_results(path_a, path_b, total_times=10):
""" using onnxruntime to inference two onnx models' ouputs """using onnxruntime to inference two onnx models' ouputs
:onnx model paths: two model paths :onnx model paths: two model paths
:total_times: inference times, default to be 10 :total_times: inference times, default to be 10
:returns: inference results of two models :returns: inference results of two models
@ -22,13 +40,20 @@ def onnx_model_results(path_a, path_b, total_times=10):
outputs_b = session_b.get_outputs() outputs_b = session_b.get_outputs()
# check outputs # check outputs
assert len(outputs_a) == len(outputs_b), 'Two models have different output numbers.' assert len(outputs_a) == len(
outputs_b
), "Two models have different output numbers."
for i in range(len(outputs_a)): for i in range(len(outputs_a)):
out_shape_a, out_shape_b = outputs_a[i].shape, outputs_b[i].shape out_shape_a, out_shape_b = outputs_a[i].shape, outputs_b[i].shape
out_shape_a = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_a)) out_shape_a = list(
out_shape_b = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_b)) map(lambda x: x if isinstance(x, int) else 1, out_shape_a)
assert out_shape_a == out_shape_b, 'Output {} has unmatched shapes'.format(i) )
out_shape_b = list(
map(lambda x: x if isinstance(x, int) else 1, out_shape_b)
)
assert (
out_shape_a == out_shape_b
), "Output {} has unmatched shapes".format(i)
# load onnx graph_a and graph_b, to find the initializer and inputs # load onnx graph_a and graph_b, to find the initializer and inputs
# then compare to remove the items in the inputs which will be initialized # then compare to remove the items in the inputs which will be initialized
@ -38,9 +63,16 @@ def onnx_model_results(path_a, path_b, total_times=10):
init_a, init_b = graph_a.initializer, graph_b.initializer init_a, init_b = graph_a.initializer, graph_b.initializer
# remove initializer from raw inputs # remove initializer from raw inputs
input_names_a, input_names_b = set([ele.name for ele in inputs_a]), set([ele.name for ele in inputs_b]) input_names_a, input_names_b = set([ele.name for ele in inputs_a]), set(
init_names_a, init_names_b = set([ele.name for ele in init_a]), set([ele.name for ele in init_b]) [ele.name for ele in inputs_b]
real_inputs_names_a, real_inputs_names_b = input_names_a - init_names_a, input_names_b - init_names_b )
init_names_a, init_names_b = set([ele.name for ele in init_a]), set(
[ele.name for ele in init_b]
)
real_inputs_names_a, real_inputs_names_b = (
input_names_a - init_names_a,
input_names_b - init_names_b,
)
# prepare and figure out matching of real inputs a and real inputs b # prepare and figure out matching of real inputs a and real inputs b
# try to keep original orders of each inputs # try to keep original orders of each inputs
@ -61,17 +93,20 @@ def onnx_model_results(path_a, path_b, total_times=10):
for item_a in real_inputs_a: for item_a in real_inputs_a:
size, shape = helper.find_size_shape_from_value(item_a) size, shape = helper.find_size_shape_from_value(item_a)
if size: if size:
assert real_single_input_a is None, 'Multiple inputs of first model, single input expected.' assert (
real_single_input_a is None
), "Multiple inputs of first model, single input expected."
real_single_input_a = item_a real_single_input_a = item_a
size_a, shape_a = size, shape size_a, shape_a = size, shape
for item_b in real_inputs_b: for item_b in real_inputs_b:
size, shape = helper.find_size_shape_from_value(item_b) size, shape = helper.find_size_shape_from_value(item_b)
if size: if size:
assert real_single_input_b is None, 'Multiple inputs of second model, single input expected.' assert (
real_single_input_b is None
), "Multiple inputs of second model, single input expected."
real_single_input_b = item_b real_single_input_b = item_b
size_b, shape_b = size, shape size_b, shape_b = size, shape
assert size_a == size_b, 'Sizes of two models do not match.' assert size_a == size_b, "Sizes of two models do not match."
# construct inputs tensors # construct inputs tensors
input_data_type_a = real_single_input_a.type.tensor_type.elem_type input_data_type_a = real_single_input_a.type.tensor_type.elem_type
@ -84,7 +119,7 @@ def onnx_model_results(path_a, path_b, total_times=10):
results_a = [[] for i in range(len(outputs_a))] results_a = [[] for i in range(len(outputs_a))]
results_b = [[] for i in range(len(outputs_b))] results_b = [[] for i in range(len(outputs_b))]
while times < total_times: while times < total_times:
# initialize inputs by random data, default to be uniform # initialize inputs by random data, default to be uniform
data = np.random.random(size_a) data = np.random.random(size_a)
input_a = np.reshape(data, shape_a).astype(input_data_type_a) input_a = np.reshape(data, shape_a).astype(input_data_type_a)
input_b = np.reshape(data, shape_b).astype(input_data_type_b) input_b = np.reshape(data, shape_b).astype(input_data_type_b)
@ -93,12 +128,18 @@ def onnx_model_results(path_a, path_b, total_times=10):
input_dict_b = {} input_dict_b = {}
for item_a in real_inputs_a: for item_a in real_inputs_a:
item_type_a = onnx2np_dtype[item_a.type.tensor_type.elem_type] item_type_a = onnx2np_dtype[item_a.type.tensor_type.elem_type]
input_dict_a[item_a.name] = np.array([]).astype(item_type_a) \ input_dict_a[item_a.name] = (
if item_a.name != real_single_input_a.name else input_a np.array([]).astype(item_type_a)
if item_a.name != real_single_input_a.name
else input_a
)
for item_b in real_inputs_b: for item_b in real_inputs_b:
item_type_b = onnx2np_dtype[item_b.type.tensor_type.elem_type] item_type_b = onnx2np_dtype[item_b.type.tensor_type.elem_type]
input_dict_b[item_b.name] = np.array([]).astype(item_type_b) \ input_dict_b[item_b.name] = (
if item_b.name != real_single_input_b.name else input_b np.array([]).astype(item_type_b)
if item_b.name != real_single_input_b.name
else input_b
)
ra = session_a.run([], input_dict_a) ra = session_a.run([], input_dict_a)
rb = session_b.run([], input_dict_b) rb = session_b.run([], input_dict_b)
@ -109,26 +150,32 @@ def onnx_model_results(path_a, path_b, total_times=10):
return results_a, results_b return results_a, results_b
if __name__ == '__main__':
if __name__ == "__main__":
# Argument parser. # Argument parser.
parser = argparse.ArgumentParser(description="Compare two ONNX models to check if they have the same output.") parser = argparse.ArgumentParser(
parser.add_argument('in_file_a', help='input ONNX file a') description="Compare two ONNX models to check if "
parser.add_argument('in_file_b', help='input ONNX file b') "they have the same output."
)
parser.add_argument("in_file_a", help="input ONNX file a")
parser.add_argument("in_file_b", help="input ONNX file b")
args = parser.parse_args() args = parser.parse_args()
results_a, results_b = onnx_model_results(args.in_file_a, args.in_file_b, total_times=10) results_a, results_b = onnx_model_results(
args.in_file_a, args.in_file_b, total_times=10
)
ra_flat = helper.flatten_with_depth(results_a, 0) ra_flat = helper.flatten_with_depth(results_a, 0)
rb_flat = helper.flatten_with_depth(results_b, 0) rb_flat = helper.flatten_with_depth(results_b, 0)
shape_a = [item[1] for item in ra_flat] shape_a = [item[1] for item in ra_flat]
shape_b = [item[1] for item in rb_flat] shape_b = [item[1] for item in rb_flat]
assert shape_a == shape_b, 'two results data shape doesn\'t match' assert shape_a == shape_b, "two results data shape doesn't match"
ra_raw = [item[0] for item in ra_flat] ra_raw = [item[0] for item in ra_flat]
rb_raw = [item[0] for item in rb_flat] rb_raw = [item[0] for item in rb_flat]
try: try:
np.testing.assert_almost_equal(ra_raw, rb_raw, 4) np.testing.assert_almost_equal(ra_raw, rb_raw, 4)
print('Two models have the same behaviour.') print("Two models have the same behaviour.")
except Exception as mismatch: except Exception as mismatch:
print(mismatch) print(mismatch)
exit(1) exit(1)

View File

@ -1,4 +1,3 @@
import onnx
import argparse import argparse
import glob import glob
import csv import csv
@ -8,214 +7,242 @@ import matplotlib.pyplot as plt
from tools import helper from tools import helper
import onnx_vs_onnx as onnx_tester import onnx_vs_onnx as onnx_tester
def compare_results(results_a, results_b): def compare_results(results_a, results_b):
""" compare onnx model inference results """compare onnx model inference results
calculate basic statistical values calculate basic statistical values
results: results from inference multiple times results: results from inference multiple times
returns: list of basic statistical values returns: list of basic statistical values
""" """
# input results data can be of nonuniform shape # input results data can be of nonuniform shape
# get flatten data to compare # get flatten data to compare
ra_flat = helper.flatten_with_depth(results_a, 0) ra_flat = helper.flatten_with_depth(results_a, 0)
rb_flat = helper.flatten_with_depth(results_b, 0) rb_flat = helper.flatten_with_depth(results_b, 0)
shape_a = [item[1] for item in ra_flat] shape_a = [item[1] for item in ra_flat]
shape_b = [item[1] for item in rb_flat] shape_b = [item[1] for item in rb_flat]
assert shape_a == shape_b, 'two results data shape doesn\'t match' assert shape_a == shape_b, "two results data shape doesn't match"
ra_raw = [item[0] for item in ra_flat] ra_raw = [item[0] for item in ra_flat]
rb_raw = [item[0] for item in rb_flat] rb_raw = [item[0] for item in rb_flat]
# the statistical values # the statistical values
max_rel_diff = 0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } ) max_rel_diff = (
max_abs_diff = 0 # defined to be max( { abs(ra-rb) } ) 0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } )
mean_rel_diff = 0 )
mean_abs_diff = 0 max_abs_diff = 0 # defined to be max( { abs(ra-rb) } )
std_rel_diff = 0 mean_rel_diff = 0
std_abs_diff = 0 mean_abs_diff = 0
acc_with_diff_precision = [] std_rel_diff = 0
rel_diff = [] std_abs_diff = 0
abs_diff_percentiles = [] # rel_diff percentiles acc_with_diff_precision = []
rel_diff_percentiles = [] # abs_diff precentiles rel_diff = []
abs_diff_percentiles = [] # rel_diff percentiles
rel_diff_percentiles = [] # abs_diff precentiles
raw_diff = [ra_raw[i]-rb_raw[i] for i in range(len(ra_raw))] raw_diff = [ra_raw[i] - rb_raw[i] for i in range(len(ra_raw))]
abs_diff = [abs(num) for num in raw_diff] abs_diff = [abs(num) for num in raw_diff]
for i in range(len(ra_raw)):
divider = max([abs(ra_raw[i]), abs(rb_raw[i])])
val = abs_diff[i]/divider if divider != 0 else 0
rel_diff.append(val)
max_rel_diff = max(rel_diff)
max_abs_diff = max(abs_diff)
mean_rel_diff = np.average(rel_diff)
mean_abs_diff = np.average(abs_diff)
std_rel_diff = np.std(rel_diff)
std_abs_diff = np.std(abs_diff)
# calculate accuracy with different precison
for digit in range(8):
correct = 0
for i in range(len(ra_raw)): for i in range(len(ra_raw)):
if format(ra_raw[i], '.'+str(digit)+'f')\ divider = max([abs(ra_raw[i]), abs(rb_raw[i])])
== format(rb_raw[i], '.'+str(digit)+'f'): val = abs_diff[i] / divider if divider != 0 else 0
correct += 1 rel_diff.append(val)
acc_with_diff_precision.append([digit, float(format(correct/len(ra_raw), '.3f'))])
# analyze rel_diff distribution max_rel_diff = max(rel_diff)
rel_diff.sort() max_abs_diff = max(abs_diff)
abs_diff.sort() mean_rel_diff = np.average(rel_diff)
for i in range(20): mean_abs_diff = np.average(abs_diff)
rel_diff_percentiles.append(['{}%'.format(i*5), rel_diff[int((i/20)*len(rel_diff))]]) std_rel_diff = np.std(rel_diff)
abs_diff_percentiles.append(['{}%'.format(i*5), abs_diff[int((i/20)*len(abs_diff))]]) std_abs_diff = np.std(abs_diff)
results = [ # calculate accuracy with different precison
['max_rel_diff', max_rel_diff], for digit in range(8):
['max_abs_diff', max_abs_diff], correct = 0
['mean_rel_diff', mean_rel_diff], for i in range(len(ra_raw)):
['mean_abs_diff', mean_abs_diff], if format(ra_raw[i], "." + str(digit) + "f") == format(
['std_rel_diff', std_rel_diff], rb_raw[i], "." + str(digit) + "f"
['std_abs_diff', std_abs_diff], ):
['acc_with_diff_precision', acc_with_diff_precision], correct += 1
['rel_diff_percentiles', rel_diff_percentiles], acc_with_diff_precision.append(
['abs_diff_percentiles', abs_diff_percentiles] [digit, float(format(correct / len(ra_raw), ".3f"))]
] )
return results
if __name__ == '__main__': # analyze rel_diff distribution
parser = argparse.ArgumentParser(description='test model optimization results') rel_diff.sort()
abs_diff.sort()
parser.add_argument('dir', type=str, help='the directory that stores onnx models') for i in range(20):
parser.add_argument('ending1', type=str, help='model file name ending(eg, .onnx)') rel_diff_percentiles.append(
parser.add_argument('ending2', type=str, help='opt model file name ending(eg. _opt.onnx)') ["{}%".format(i * 5), rel_diff[int((i / 20) * len(rel_diff))]]
parser.add_argument('out_file', type=str, help='output csv file name') )
parser.add_argument('-p', '--plot', default='N', help='get plots (Y/N)') abs_diff_percentiles.append(
parser.add_argument('-i', '--iter_times', default=10, type=int, help='inference times') ["{}%".format(i * 5), abs_diff[int((i / 20) * len(abs_diff))]]
)
args = parser.parse_args() results = [
["max_rel_diff", max_rel_diff],
["max_abs_diff", max_abs_diff],
["mean_rel_diff", mean_rel_diff],
["mean_abs_diff", mean_abs_diff],
["std_rel_diff", std_rel_diff],
["std_abs_diff", std_abs_diff],
["acc_with_diff_precision", acc_with_diff_precision],
["rel_diff_percentiles", rel_diff_percentiles],
["abs_diff_percentiles", abs_diff_percentiles],
]
old_models_paths = glob.glob(args.dir+'*'+args.ending1) return results
new_models_paths = glob.glob(args.dir+'*'+args.ending2)
stats_table = [[
'Model',
'max_rel_diff',
'max_abs_diff',
'mean_rel_diff',
'mean_abs_diff',
'std_rel_diff',
'std_abs_diff',
'acc_with_diff_precision',
'rel_diff_percentiles',
'abs_diff_percentiles'
]]
for new_model_path in new_models_paths:
old_model_path = new_model_path[:-len(args.ending2)] + args.ending1
if old_model_path not in old_models_paths:
continue
# run inference
results_a, results_b = onnx_tester.onnx_model_results(old_model_path, new_model_path, total_times=args.iter_times)
# compare inference results
comparision = compare_results(results_a, results_b)
new_line = [old_model_path.split('/')[-1]]
for item in comparision:
new_line.append(item[1])
stats_table.append(new_line)
# try to read existing file
old_stats_table = []
try:
old_file = open(args.out_file, 'r')
reader = csv.reader(old_file)
old_header = reader.__next__()
for row in reader:
old_stats_table.append(row)
old_file.close()
except:
pass
# compare and merge possible old stat data file with new stat data file
header = stats_table[0]
stats_table = stats_table[1:]
new_model_names = set([item[0] for item in stats_table])
for row in old_stats_table:
if row[0] not in new_model_names:
stats_table.append(row)
stats_table.insert(0, header)
# write a new stat data file, overwrite old file
new_file = open(args.out_file, 'w', newline='')
writer = csv.writer(new_file)
for row in stats_table:
writer.writerow(row)
new_file.close()
# make some plots
if args.plot == 'Y':
if len(stats_table) < 2:
exit(0)
sample_table = stats_table[1:] if len(stats_table) < 6 else stats_table[1:6]
max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]]
plt.hist(max_rel_diffs, bins=15)
plt.title('Max Relavtive Difference Histogram')
plt.xlabel('Max Relative Difference')
plt.ylabel('Counts')
plt.savefig('max_rel_diff_hist.png')
plt.close()
max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]]
plt.hist(max_abs_diffs, bins=15)
plt.title('Max Absolute Difference Histogram')
plt.xlabel('Max Absolute Difference')
plt.ylabel('Counts')
plt.savefig('max_abs_diff_hist.png')
plt.close()
for line in sample_table:
model_name = line[0]
percentiles = line[-2]
x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))]
y = [ele[1] for ele in percentiles]
plt.plot(x, y, label=model_name)
plt.title('Rel_diff Percentiles of Raw and Optimized Models')
plt.xlabel('percentage')
plt.ylabel('relative difference')
plt.legend()
plt.savefig('rel_diff_percentiles.png')
plt.close()
for line in sample_table:
model_name = line[0]
percentiles = line[-1]
x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))]
y = [ele[1] for ele in percentiles]
plt.plot(x, y, label=model_name)
plt.title('Abs_diff Percentiles of Raw and Optimized Models')
plt.xlabel('percentage')
plt.ylabel('absolute difference')
plt.legend()
plt.savefig('abs_diff_percentiles.png')
plt.close()
for line in sample_table:
model_name = line[0]
accuracies = line[-3]
x = [acc[0] for acc in accuracies]
y = [acc[1] for acc in accuracies]
plt.plot(x, y, label=model_name)
plt.title('Accuracies with Different Precisions')
plt.xlabel('Decimals')
plt.ylabel('Precision')
plt.legend()
plt.savefig('precisions.png')
plt.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="test model optimization results"
)
parser.add_argument(
"dir", type=str, help="the directory that stores onnx models"
)
parser.add_argument(
"ending1", type=str, help="model file name ending(eg, .onnx)"
)
parser.add_argument(
"ending2", type=str, help="opt model file name ending(eg. _opt.onnx)"
)
parser.add_argument("out_file", type=str, help="output csv file name")
parser.add_argument("-p", "--plot", default="N", help="get plots (Y/N)")
parser.add_argument(
"-i", "--iter_times", default=10, type=int, help="inference times"
)
args = parser.parse_args()
old_models_paths = glob.glob(args.dir + "*" + args.ending1)
new_models_paths = glob.glob(args.dir + "*" + args.ending2)
stats_table = [
[
"Model",
"max_rel_diff",
"max_abs_diff",
"mean_rel_diff",
"mean_abs_diff",
"std_rel_diff",
"std_abs_diff",
"acc_with_diff_precision",
"rel_diff_percentiles",
"abs_diff_percentiles",
]
]
for new_model_path in new_models_paths:
old_model_path = new_model_path[: -len(args.ending2)] + args.ending1
if old_model_path not in old_models_paths:
continue
# run inference
results_a, results_b = onnx_tester.onnx_model_results(
old_model_path, new_model_path, total_times=args.iter_times
)
# compare inference results
comparision = compare_results(results_a, results_b)
new_line = [old_model_path.split("/")[-1]]
for item in comparision:
new_line.append(item[1])
stats_table.append(new_line)
# try to read existing file
old_stats_table = []
try:
old_file = open(args.out_file, "r")
reader = csv.reader(old_file)
old_header = reader.__next__()
for row in reader:
old_stats_table.append(row)
old_file.close()
except Exception:
pass
# compare and merge possible old stat data file with new stat data file
header = stats_table[0]
stats_table = stats_table[1:]
new_model_names = set([item[0] for item in stats_table])
for row in old_stats_table:
if row[0] not in new_model_names:
stats_table.append(row)
stats_table.insert(0, header)
# write a new stat data file, overwrite old file
new_file = open(args.out_file, "w", newline="")
writer = csv.writer(new_file)
for row in stats_table:
writer.writerow(row)
new_file.close()
# make some plots
if args.plot == "Y":
if len(stats_table) < 2:
exit(0)
sample_table = (
stats_table[1:] if len(stats_table) < 6 else stats_table[1:6]
)
max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]]
plt.hist(max_rel_diffs, bins=15)
plt.title("Max Relavtive Difference Histogram")
plt.xlabel("Max Relative Difference")
plt.ylabel("Counts")
plt.savefig("max_rel_diff_hist.png")
plt.close()
max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]]
plt.hist(max_abs_diffs, bins=15)
plt.title("Max Absolute Difference Histogram")
plt.xlabel("Max Absolute Difference")
plt.ylabel("Counts")
plt.savefig("max_abs_diff_hist.png")
plt.close()
for line in sample_table:
model_name = line[0]
percentiles = line[-2]
x = [
round(i * (1 / len(percentiles)), 2)
for i in range(len(percentiles))
]
y = [ele[1] for ele in percentiles]
plt.plot(x, y, label=model_name)
plt.title("Rel_diff Percentiles of Raw and Optimized Models")
plt.xlabel("percentage")
plt.ylabel("relative difference")
plt.legend()
plt.savefig("rel_diff_percentiles.png")
plt.close()
for line in sample_table:
model_name = line[0]
percentiles = line[-1]
x = [
round(i * (1 / len(percentiles)), 2)
for i in range(len(percentiles))
]
y = [ele[1] for ele in percentiles]
plt.plot(x, y, label=model_name)
plt.title("Abs_diff Percentiles of Raw and Optimized Models")
plt.xlabel("percentage")
plt.ylabel("absolute difference")
plt.legend()
plt.savefig("abs_diff_percentiles.png")
plt.close()
for line in sample_table:
model_name = line[0]
accuracies = line[-3]
x = [acc[0] for acc in accuracies]
y = [acc[1] for acc in accuracies]
plt.plot(x, y, label=model_name)
plt.title("Accuracies with Different Precisions")
plt.xlabel("Decimals")
plt.ylabel("Precision")
plt.legend()
plt.savefig("precisions.png")
plt.close()

View File

@ -1,21 +1,10 @@
import onnx import onnx
import onnx.utils import onnx.utils
try:
from onnx import optimizer
except ImportError:
import onnxoptimizer as optimizer
import sys import sys
import numpy as np
import struct
import logging import logging
import argparse import argparse
from tools import eliminating
from tools import fusing
from tools import replacing
from tools import other
from tools import combo
from tools import special
from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow
# Debug use # Debug use
@ -25,13 +14,28 @@ from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow
# Generate a prototype onnx # # Generate a prototype onnx #
###################################### ######################################
parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler") parser = argparse.ArgumentParser(
parser.add_argument('in_file', help='input ONNX or PTH FILE') description="Optimize a Pytorch generated model for Kneron compiler"
parser.add_argument('out_file', help="ouput ONNX FILE") )
parser.add_argument('--input-size', dest='input_size', nargs=3, parser.add_argument("in_file", help="input ONNX or PTH FILE")
help='if you using pth, please use this argument to set up the input size of the model. It should be in \'CH H W\' format, e.g. \'--input-size 3 256 512\'.') parser.add_argument("out_file", help="ouput ONNX FILE")
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False, parser.add_argument(
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.") "--input-size",
dest="input_size",
nargs=3,
help="if you using pth, please use this argument to set up the input "
"size of the model. It should be in 'CH H W' format, "
"e.g. '--input-size 3 256 512'.",
)
parser.add_argument(
"--no-bn-fusion",
dest="disable_fuse_bn",
action="store_true",
default=False,
help="set if you have met errors which related to inferenced shape "
"mismatch. This option will prevent fusing BatchNormalization "
"into Conv.",
)
args = parser.parse_args() args = parser.parse_args()
@ -39,7 +43,7 @@ if len(args.in_file) <= 4:
# When the filename is too short. # When the filename is too short.
logging.error("Invalid input file: {}".format(args.in_file)) logging.error("Invalid input file: {}".format(args.in_file))
exit(1) exit(1)
elif args.in_file[-4:] == '.pth': elif args.in_file[-4:] == ".pth":
# Pytorch pth case # Pytorch pth case
logging.warning("Converting from pth to onnx is not recommended.") logging.warning("Converting from pth to onnx is not recommended.")
onnx_in = args.out_file onnx_in = args.out_file
@ -47,21 +51,29 @@ elif args.in_file[-4:] == '.pth':
from torch.autograd import Variable from torch.autograd import Variable
import torch import torch
import torch.onnx import torch.onnx
# import torchvision # import torchvision
# Standard ImageNet input - 3 channels, 224x224. # Standard ImageNet input - 3 channels, 224x224.
# Values don't matter as we care about network structure. # Values don't matter as we care about network structure.
# But they can also be real inputs. # But they can also be real inputs.
if args.input_size is None: if args.input_size is None:
logging.error("\'--input-size\' is required for the pth input file.") logging.error("'--input-size' is required for the pth input file.")
exit(1) exit(1)
dummy_input = Variable(torch.randn(1, int(args.input_size[0]), int(args.input_size[1]), int(args.input_size[2]))) dummy_input = Variable(
torch.randn(
1,
int(args.input_size[0]),
int(args.input_size[1]),
int(args.input_size[2]),
)
)
# Obtain your model, it can be also constructed in your script explicitly. # Obtain your model, it can be also constructed in your script explicitly.
model = torch.load(sys.argv[1], map_location='cpu') model = torch.load(sys.argv[1], map_location="cpu")
# model = torchvision.models.resnet34(pretrained=True) # model = torchvision.models.resnet34(pretrained=True)
# Invoke export. # Invoke export.
# torch.save(model, "resnet34.pth") # torch.save(model, "resnet34.pth")
torch.onnx.export(model, dummy_input, args.out_file, opset_version=11) torch.onnx.export(model, dummy_input, args.out_file, opset_version=11)
elif args.in_file[-4:] == 'onnx': elif args.in_file[-4:] == "onnx":
onnx_in = args.in_file onnx_in = args.in_file
else: else:
# When the file is neither an onnx or a pytorch pth. # When the file is neither an onnx or a pytorch pth.

View File

@ -1,29 +1,22 @@
import onnx import onnx
import onnx.utils import onnx.utils
try:
from onnx import optimizer
except ImportError:
import onnxoptimizer as optimizer
import sys
import numpy as np
import struct
import logging import logging
import argparse import argparse
from .tools import eliminating
from .tools import fusing
from .tools import replacing
from .tools import other
from .tools import combo from .tools import combo
from .tools import special
# Define general pytorch exported onnx optimize process # Define general pytorch exported onnx optimize process
def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx.ModelProto: def torch_exported_onnx_flow(
m: onnx.ModelProto, disable_fuse_bn=False
) -> onnx.ModelProto:
"""Optimize the Pytorch exported onnx. """Optimize the Pytorch exported onnx.
Args: Args:
m (ModelProto): the input onnx model m (ModelProto): the input onnx model
disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False. disable_fuse_bn (bool, optional): do not fuse BN into Conv.
Defaults to False.
Returns: Returns:
ModelProto: the optimized onnx model ModelProto: the optimized onnx model
@ -38,20 +31,29 @@ def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx.
# Main Process # Main Process
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler") parser = argparse.ArgumentParser(
parser.add_argument('in_file', help='input ONNX') description="Optimize a Pytorch generated model for Kneron compiler"
parser.add_argument('out_file', help="ouput ONNX FILE") )
parser.add_argument('--log', default='i', type=str, help="set log level") parser.add_argument("in_file", help="input ONNX")
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False, parser.add_argument("out_file", help="ouput ONNX FILE")
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.") parser.add_argument("--log", default="i", type=str, help="set log level")
parser.add_argument(
"--no-bn-fusion",
dest="disable_fuse_bn",
action="store_true",
default=False,
help="set if you have met errors which related to inferenced shape "
"mismatch. This option will prevent fusing BatchNormalization "
"into Conv.",
)
args = parser.parse_args() args = parser.parse_args()
if args.log == 'w': if args.log == "w":
logging.basicConfig(level=logging.WARN) logging.basicConfig(level=logging.WARN)
elif args.log == 'd': elif args.log == "d":
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
elif args.log == 'e': elif args.log == "e":
logging.basicConfig(level=logging.ERROR) logging.basicConfig(level=logging.ERROR)
else: else:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -60,7 +62,7 @@ if __name__ == "__main__":
# When the filename is too short. # When the filename is too short.
logging.error("Invalid input file: {}".format(args.in_file)) logging.error("Invalid input file: {}".format(args.in_file))
exit(1) exit(1)
elif args.in_file[-4:] == 'onnx': elif args.in_file[-4:] == "onnx":
onnx_in = args.in_file onnx_in = args.in_file
else: else:
# When the file is not an onnx file. # When the file is not an onnx file.

View File

@ -8,7 +8,8 @@ import onnx.utils
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tools import combo, eliminating, replacing, other from tools import combo, eliminating, replacing, other
def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
def tf2onnx_flow(pb_path: str, test_mode=False) -> onnx.ModelProto:
"""Convert frozen graph pb file into onnx """Convert frozen graph pb file into onnx
Args: Args:
@ -21,34 +22,45 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
Returns: Returns:
onnx.ModelProto: converted onnx onnx.ModelProto: converted onnx
""" """
TF2ONNX_VERSION = int(tf2onnx.version.version.replace('.', '')) TF2ONNX_VERSION = int(tf2onnx.version.version.replace(".", ""))
if 160 <= TF2ONNX_VERSION: if 160 <= TF2ONNX_VERSION:
from tf2onnx import tf_loader from tf2onnx import tf_loader
else: else:
from tf2onnx import loader as tf_loader from tf2onnx import loader as tf_loader
if pb_path[-3:] == '.pb':
model_name = pb_path.split('/')[-1][:-3]
# always reset tensorflow session at begin if pb_path[-3:] == ".pb":
model_name = pb_path.split("/")[-1][:-3]
# always reset tensorflow session at begin
tf.reset_default_graph() tf.reset_default_graph()
with tf.Session() as sess: with tf.Session() as sess:
with gfile.FastGFile(pb_path, 'rb') as f: with gfile.FastGFile(pb_path, "rb") as f:
graph_def = tf.GraphDef() graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read()) graph_def.ParseFromString(f.read())
sess.graph.as_default() sess.graph.as_default()
tf.import_graph_def(graph_def, name='') tf.import_graph_def(graph_def, name="")
if 160 <= int(tf2onnx.version.version.replace('.', '')): if 160 <= int(tf2onnx.version.version.replace(".", "")):
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions = tf2onnx.tf_utils.tflist_to_onnx( (
sess.graph, onnx_nodes,
{}) op_cnt,
attr_cnt,
output_shapes,
dtypes,
functions,
) = tf2onnx.tf_utils.tflist_to_onnx(sess.graph, {})
else: else:
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tf2onnx.tfonnx.tflist_to_onnx( (
sess.graph.get_operations(), onnx_nodes,
{}) op_cnt,
attr_cnt,
output_shapes,
dtypes,
) = tf2onnx.tfonnx.tflist_to_onnx(
sess.graph.get_operations(), {}
)
for n in onnx_nodes: for n in onnx_nodes:
if len(n.output) == 0: if len(n.output) == 0:
@ -59,12 +71,12 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
nodes_outputs = set() nodes_outputs = set()
for n in onnx_nodes: for n in onnx_nodes:
if n.op_type == 'Placeholder': if n.op_type == "Placeholder":
continue continue
for input in n.input: for input in n.input:
nodes_inputs.add(input) nodes_inputs.add(input)
for output in n.output: for output in n.output:
nodes_outputs.add(output) nodes_outputs.add(output)
graph_input_names = set() graph_input_names = set()
for input_name in nodes_inputs: for input_name in nodes_inputs:
@ -76,35 +88,43 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
if n.input and n.input[0] not in nodes_outputs: if n.input and n.input[0] not in nodes_outputs:
continue continue
if len(n.output) == 0: if len(n.output) == 0:
n.output.append(n.name + ':0') n.output.append(n.name + ":0")
graph_output_names.add(n.output[0]) graph_output_names.add(n.output[0])
else: else:
output_name = n.output[0] output_name = n.output[0]
if (output_name not in nodes_inputs) and (0 < len(n.input)): if (output_name not in nodes_inputs) and (
0 < len(n.input)
):
graph_output_names.add(output_name) graph_output_names.add(output_name)
logging.info('Model Inputs: %s', str(list(graph_input_names))) logging.info("Model Inputs: %s", str(list(graph_input_names)))
logging.info('Model Outputs: %s', str(list(graph_output_names))) logging.info("Model Outputs: %s", str(list(graph_output_names)))
graph_def, inputs, outputs = tf_loader.from_graphdef(model_path=pb_path, graph_def, inputs, outputs = tf_loader.from_graphdef(
input_names=list(graph_input_names), model_path=pb_path,
output_names=list(graph_output_names)) input_names=list(graph_input_names),
output_names=list(graph_output_names),
)
with tf.Graph().as_default() as tf_graph: with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(graph_def, name='') tf.import_graph_def(graph_def, name="")
if 160 <= TF2ONNX_VERSION: if 160 <= TF2ONNX_VERSION:
with tf_loader.tf_session(graph=tf_graph): with tf_loader.tf_session(graph=tf_graph):
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph, onnx_graph = tf2onnx.tfonnx.process_tf_graph(
input_names=inputs, tf_graph=tf_graph,
output_names=outputs, input_names=inputs,
opset=11) output_names=outputs,
opset=11,
)
else: else:
with tf.Session(graph=tf_graph): with tf.Session(graph=tf_graph):
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph, onnx_graph = tf2onnx.tfonnx.process_tf_graph(
input_names=inputs, tf_graph=tf_graph,
output_names=outputs, input_names=inputs,
opset=11) output_names=outputs,
opset=11,
)
# Optimize with tf2onnx.optimizer # Optimize with tf2onnx.optimizer
onnx_graph = tf2onnx.optimizer.optimize_graph(onnx_graph) onnx_graph = tf2onnx.optimizer.optimize_graph(onnx_graph)
@ -115,7 +135,9 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
model_proto = other.polish_model(model_proto) model_proto = other.polish_model(model_proto)
else: else:
raise Exception('expect .pb file as input, but got "' + str(pb_path) + '"') raise Exception(
'expect .pb file as input, but got "' + str(pb_path) + '"'
)
# rename # rename
m = model_proto m = model_proto
@ -133,15 +155,26 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
return m return m
if __name__ == "__main__":
if __name__ == '__main__': parser = argparse.ArgumentParser(
parser = argparse.ArgumentParser(description='Convert tensorflow pb file to onnx file and optimized onnx file. Or just optimize tensorflow onnx file.') description="Convert tensorflow pb file to onnx file and optimized "
parser.add_argument('in_file', help='input file') "onnx file. Or just optimize tensorflow onnx file."
parser.add_argument('out_file', help='output optimized model file') )
parser.add_argument('-t', '--test_mode', default=False, help='test mode will not eliminate shape changes after input') parser.add_argument("in_file", help="input file")
parser.add_argument("out_file", help="output optimized model file")
parser.add_argument(
"-t",
"--test_mode",
default=False,
help="test mode will not eliminate shape changes after input",
)
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.INFO) logging.basicConfig(
stream=sys.stdout,
format="[%(asctime)s] %(levelname)s: %(message)s",
level=logging.INFO,
)
m = tf2onnx_flow(args.in_file, args.test_mode) m = tf2onnx_flow(args.in_file, args.test_mode)
onnx.save(m, args.out_file) onnx.save(m, args.out_file)
logging.info('Save Optimized ONNX: %s', args.out_file) logging.info("Save Optimized ONNX: %s", args.out_file)

View File

@ -6,6 +6,7 @@ import onnxruntime
from tools import helper from tools import helper
def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10): def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
# Setup onnx session and get meta data # Setup onnx session and get meta data
onnx_session = onnxruntime.InferenceSession(onnx_file, None) onnx_session = onnxruntime.InferenceSession(onnx_file, None)
@ -21,21 +22,32 @@ def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
tflite_session.allocate_tensors() tflite_session.allocate_tensors()
tflite_inputs = tflite_session.get_input_details() tflite_inputs = tflite_session.get_input_details()
tflite_outputs = tflite_session.get_output_details() tflite_outputs = tflite_session.get_output_details()
tflite_input_shape = tflite_inputs[0]['shape'] tflite_input_shape = tflite_inputs[0]["shape"]
# Compare input shape # Compare input shape
assert(len(onnx_input_shape) == len(tflite_input_shape)), "TFLite and ONNX shape unmatch." assert len(onnx_input_shape) == len(
assert(onnx_input_shape == [tflite_input_shape[0], tflite_input_shape[3], tflite_input_shape[1], tflite_input_shape[2]]), "TFLite and ONNX shape unmatch." tflite_input_shape
), "TFLite and ONNX shape unmatch."
assert onnx_input_shape == [
tflite_input_shape[0],
tflite_input_shape[3],
tflite_input_shape[1],
tflite_input_shape[2],
], "TFLite and ONNX shape unmatch."
# Generate random number and run # Generate random number and run
tflite_results = [] tflite_results = []
onnx_results = [] onnx_results = []
for _ in range(total_times): for _ in range(total_times):
# Generate input # Generate input
tflite_input_data = np.array(np.random.random_sample(tflite_input_shape), dtype=np.float32) tflite_input_data = np.array(
np.random.random_sample(tflite_input_shape), dtype=np.float32
)
onnx_input_data = np.transpose(tflite_input_data, [0, 3, 1, 2]) onnx_input_data = np.transpose(tflite_input_data, [0, 3, 1, 2])
# Run tflite # Run tflite
tflite_session.set_tensor(tflite_inputs[0]['index'], tflite_input_data) tflite_session.set_tensor(tflite_inputs[0]["index"], tflite_input_data)
tflite_session.invoke() tflite_session.invoke()
tflite_results.append(tflite_session.get_tensor(tflite_outputs[0]['index'])) tflite_results.append(
tflite_session.get_tensor(tflite_outputs[0]["index"])
)
# Run onnx # Run onnx
onnx_input_dict = {onnx_inputs[0].name: onnx_input_data} onnx_input_dict = {onnx_inputs[0].name: onnx_input_data}
onnx_results.append(onnx_session.run([], onnx_input_dict)[0]) onnx_results.append(onnx_session.run([], onnx_input_dict)[0])
@ -43,26 +55,31 @@ def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
return tflite_results, onnx_results return tflite_results, onnx_results
if __name__ == '__main__': if __name__ == "__main__":
# Argument parser. # Argument parser.
parser = argparse.ArgumentParser(description="Compare a TFLite model and an ONNX model to check if they have the same output.") parser = argparse.ArgumentParser(
parser.add_argument('tflite_file', help='input tflite file') description="Compare a TFLite model and an ONNX model to check "
parser.add_argument('onnx_file', help='input ONNX file') "if they have the same output."
)
parser.add_argument("tflite_file", help="input tflite file")
parser.add_argument("onnx_file", help="input ONNX file")
args = parser.parse_args() args = parser.parse_args()
results_a, results_b = compare_tflite_and_onnx(args.tflite_file, args.onnx_file, total_times=10) results_a, results_b = compare_tflite_and_onnx(
args.tflite_file, args.onnx_file, total_times=10
)
ra_flat = helper.flatten_with_depth(results_a, 0) ra_flat = helper.flatten_with_depth(results_a, 0)
rb_flat = helper.flatten_with_depth(results_b, 0) rb_flat = helper.flatten_with_depth(results_b, 0)
shape_a = [item[1] for item in ra_flat] shape_a = [item[1] for item in ra_flat]
shape_b = [item[1] for item in rb_flat] shape_b = [item[1] for item in rb_flat]
assert shape_a == shape_b, 'two results data shape doesn\'t match' assert shape_a == shape_b, "two results data shape doesn't match"
ra_raw = [item[0] for item in ra_flat] ra_raw = [item[0] for item in ra_flat]
rb_raw = [item[0] for item in rb_flat] rb_raw = [item[0] for item in rb_flat]
try: try:
np.testing.assert_almost_equal(ra_raw, rb_raw, 8) np.testing.assert_almost_equal(ra_raw, rb_raw, 8)
print('Two models have the same behaviour.') print("Two models have the same behaviour.")
except Exception as mismatch: except Exception as mismatch:
print(mismatch) print(mismatch)
exit(1) exit(1)

View File

@ -2,7 +2,7 @@
""" """
import logging import logging
import onnx.utils
try: try:
from onnx import optimizer from onnx import optimizer
except ImportError: except ImportError:
@ -15,16 +15,19 @@ from . import eliminating
from . import fusing from . import fusing
from . import constant_folding from . import constant_folding
from . import removing_transpose from . import removing_transpose
from . import modhelper
from .common_pattern import torch_pattern_match, tf_pattern_match from .common_pattern import torch_pattern_match, tf_pattern_match
from .helper import logger from .helper import logger
def preprocess(model_proto, disable_fuse_bn=False, duplicate_shared_weights=True):
def preprocess(
model_proto, disable_fuse_bn=False, duplicate_shared_weights=True
):
"""The most common used functions before other processing. """The most common used functions before other processing.
Args: Args:
model_proto: the original model input model_proto: the original model input
duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True. duplicate_shared_weights(bool, optional): duplicate shared weights.
Defaults to True.
Return: Return:
the new model after preprocessing the new model after preprocessing
@ -65,22 +68,28 @@ def preprocess(model_proto, disable_fuse_bn=False, duplicate_shared_weights=True
replacing.replace_initializer_with_Constant(model_proto.graph) replacing.replace_initializer_with_Constant(model_proto.graph)
other.topological_sort(model_proto.graph) other.topological_sort(model_proto.graph)
m = other.polish_model(model_proto) m = other.polish_model(model_proto)
passes = ['extract_constant_to_initializer', passes = [
'eliminate_nop_dropout', "extract_constant_to_initializer",
'eliminate_deadend', "eliminate_nop_dropout",
'fuse_matmul_add_bias_into_gemm', "eliminate_deadend",
'fuse_pad_into_conv'] "fuse_matmul_add_bias_into_gemm",
"fuse_pad_into_conv",
]
if not disable_fuse_bn: if not disable_fuse_bn:
passes.append('fuse_bn_into_conv') passes.append("fuse_bn_into_conv")
m = optimizer.optimize(m, passes) m = optimizer.optimize(m, passes)
g = m.graph g = m.graph
# Add name again since onnx optimizer higher than 1.7 may remove node names. # Add name again since onnx optimizer higher than 1.7 may remove node names
other.add_name_to_node(g) other.add_name_to_node(g)
if duplicate_shared_weights: if duplicate_shared_weights:
replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=True) replacing.replace_initializer_with_Constant(
g, duplicate_shared_weights=True
)
other.duplicate_param_shared_constant(g) other.duplicate_param_shared_constant(g)
else: else:
replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=False) replacing.replace_initializer_with_Constant(
g, duplicate_shared_weights=False
)
other.topological_sort(g) other.topological_sort(g)
m = other.polish_model(m) m = other.polish_model(m)
g = m.graph g = m.graph
@ -161,12 +170,12 @@ def pytorch_constant_folding(m):
other.topological_sort(m.graph) other.topological_sort(m.graph)
while len(m.graph.value_info) != 0: while len(m.graph.value_info) != 0:
m.graph.value_info.pop() m.graph.value_info.pop()
m = other.inference_shapes(m) m = other.inference_shapes(m)
replacing.replace_shape_with_constant(m.graph) replacing.replace_shape_with_constant(m.graph)
other.topological_sort(m.graph) other.topological_sort(m.graph)
m = torch_pattern_match(m) m = torch_pattern_match(m)
m = optimizer.optimize(m, ['eliminate_deadend']) m = optimizer.optimize(m, ["eliminate_deadend"])
return m return m
@ -206,7 +215,7 @@ def tensorflow_optimization(m):
replacing.replace_shape_with_constant(m.graph) replacing.replace_shape_with_constant(m.graph)
other.topological_sort(m.graph) other.topological_sort(m.graph)
m = tf_pattern_match(m) m = tf_pattern_match(m)
m = optimizer.optimize(m, ['eliminate_deadend']) m = optimizer.optimize(m, ["eliminate_deadend"])
eliminating.eliminate_consecutive_reshape(m.graph) eliminating.eliminate_consecutive_reshape(m.graph)
eliminating.eliminate_Squeeze_before_Reshape(m.graph) eliminating.eliminate_Squeeze_before_Reshape(m.graph)
@ -253,6 +262,6 @@ def postprocess(m):
m = other.polish_model(m) m = other.polish_model(m)
other.add_output_to_value_info(m.graph) other.add_output_to_value_info(m.graph)
m = optimizer.optimize(m, ['eliminate_deadend']) m = optimizer.optimize(m, ["eliminate_deadend"])
m.producer_name = 'kneron_formatter' m.producer_name = "kneron_formatter"
return m return m

View File

@ -3,19 +3,20 @@ import numpy as np
import onnx.helper import onnx.helper
import onnx.utils import onnx.utils
from . import modhelper
from . import helper from . import helper
from . import other from . import other
def torch_pattern_match(m): def torch_pattern_match(m):
# Create a map from optype to the nodes. # Create a map from optype to the nodes.
optype2node = defaultdict(list) optype2node = defaultdict(list)
for node in m.graph.node: for node in m.graph.node:
optype2node[node.op_type].append(node) optype2node[node.op_type].append(node)
for matmul_node in optype2node['MatMul']: for matmul_node in optype2node["MatMul"]:
pattern_matmul_mul_add(m.graph, matmul_node) pattern_matmul_mul_add(m.graph, matmul_node)
for resize_node in optype2node['Resize']: for resize_node in optype2node["Resize"]:
# torch nn.UpsamplingBilinear2d will be given us 4 input: "X, roi, scales, sizes" # torch nn.UpsamplingBilinear2d will be given us 4 input:
# "X, roi, scales, sizes"
if len(resize_node.input) != 4: if len(resize_node.input) != 4:
continue continue
make_UpsamplingBilinear2d_value_info(m.graph, resize_node.name) make_UpsamplingBilinear2d_value_info(m.graph, resize_node.name)
@ -24,15 +25,17 @@ def torch_pattern_match(m):
m = other.polish_model(m) m = other.polish_model(m)
return m return m
def tf_pattern_match(m): def tf_pattern_match(m):
# Create a map from optype to the nodes. # Create a map from optype to the nodes.
optype2node = defaultdict(list) optype2node = defaultdict(list)
for node in m.graph.node: for node in m.graph.node:
optype2node[node.op_type].append(node) optype2node[node.op_type].append(node)
for matmul_node in optype2node['MatMul']: for matmul_node in optype2node["MatMul"]:
pattern_matmul_mul_add(m.graph, matmul_node) pattern_matmul_mul_add(m.graph, matmul_node)
for resize_node in optype2node['Resize']: for resize_node in optype2node["Resize"]:
# In tensorflow2onnx, ReizeXXX will be given us 4 input: "X, roi, scales, sizes" # In tensorflow2onnx, ReizeXXX will be given us 4 input:
# "X, roi, scales, sizes"
# and node output name will be given the "node name + :0" # and node output name will be given the "node name + :0"
if len(resize_node.input) != 4: if len(resize_node.input) != 4:
continue continue
@ -42,24 +45,25 @@ def tf_pattern_match(m):
m = other.polish_model(m) m = other.polish_model(m)
return m return m
def pattern_matmul_mul_add(g, matmul_node): def pattern_matmul_mul_add(g, matmul_node):
# Check node match - Mul node # Check node match - Mul node
next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0]) next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0])
if len(next_nodes) != 1: if len(next_nodes) != 1:
return return
if next_nodes[0].op_type != 'Mul': if next_nodes[0].op_type != "Mul":
return return
mul_node = next_nodes[0] mul_node = next_nodes[0]
# Check node match - Add node # Check node match - Add node
next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0]) next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0])
if len(next_nodes) != 1: if len(next_nodes) != 1:
return return
if next_nodes[0].op_type != 'Add': if next_nodes[0].op_type != "Add":
return return
add_node = next_nodes[0] add_node = next_nodes[0]
# Check Mul weight # Check Mul weight
mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1]) mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1])
if mul_weight_node.op_type != 'Constant': if mul_weight_node.op_type != "Constant":
return return
weight_size, mul_weight = helper.constant_to_list(mul_weight_node) weight_size, mul_weight = helper.constant_to_list(mul_weight_node)
for i in mul_weight: for i in mul_weight:
@ -68,15 +72,19 @@ def pattern_matmul_mul_add(g, matmul_node):
channel = weight_size[0] channel = weight_size[0]
# Check Add weight # Check Add weight
add_weight_node = helper.find_node_by_output_name(g, add_node.input[1]) add_weight_node = helper.find_node_by_output_name(g, add_node.input[1])
if add_weight_node.op_type != 'Constant': if add_weight_node.op_type != "Constant":
return return
# Check MatMul weight to see if it need weight broadcast # Check MatMul weight to see if it need weight broadcast
matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1]) matmul_weight_node = helper.find_node_by_output_name(
g, matmul_node.input[1]
)
matmul_weight = helper.constant_to_numpy(matmul_weight_node) matmul_weight = helper.constant_to_numpy(matmul_weight_node)
if matmul_weight.shape[1] == 1: if matmul_weight.shape[1] == 1:
# Weight broadcast # Weight broadcast
new_matmul_weight = np.tile(matmul_weight, channel) new_matmul_weight = np.tile(matmul_weight, channel)
new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight) new_matmul_weight_node = helper.numpy_to_constant(
matmul_weight_node.name, new_matmul_weight
)
g.node.remove(matmul_weight_node) g.node.remove(matmul_weight_node)
g.node.extend([new_matmul_weight_node]) g.node.extend([new_matmul_weight_node])
value = helper.find_value_by_name(g, matmul_weight_node.output[0]) value = helper.find_value_by_name(g, matmul_weight_node.output[0])
@ -93,14 +101,14 @@ def pattern_matmul_mul_add(g, matmul_node):
g.value_info.remove(value) g.value_info.remove(value)
# Fuse Matmul and Add # Fuse Matmul and Add
gemm_node = onnx.helper.make_node( gemm_node = onnx.helper.make_node(
'Gemm', "Gemm",
[matmul_node.input[0], matmul_node.input[1], add_node.input[1]], [matmul_node.input[0], matmul_node.input[1], add_node.input[1]],
[add_node.output[0]], [add_node.output[0]],
name = matmul_node.name, name=matmul_node.name,
alpha = 1.0, alpha=1.0,
beta = 1.0, beta=1.0,
transA = 0, transA=0,
transB = 0 transB=0,
) )
g.node.extend([gemm_node]) g.node.extend([gemm_node])
# Clean up # Clean up
@ -111,6 +119,7 @@ def pattern_matmul_mul_add(g, matmul_node):
g.value_info.remove(value) g.value_info.remove(value)
other.topological_sort(g) other.topological_sort(g)
def make_UpsamplingBilinear2d_value_info(g, resize_node_name): def make_UpsamplingBilinear2d_value_info(g, resize_node_name):
resize_node = helper.find_node_by_node_name(g, resize_node_name) resize_node = helper.find_node_by_node_name(g, resize_node_name)
@ -124,34 +133,45 @@ def make_UpsamplingBilinear2d_value_info(g, resize_node_name):
new_output_value_info = onnx.helper.make_tensor_value_info( new_output_value_info = onnx.helper.make_tensor_value_info(
resize_node.output[0], resize_node.output[0],
onnx.helper.TensorProto.FLOAT, onnx.helper.TensorProto.FLOAT,
shape_data.tolist() shape_data.tolist(),
) )
g.value_info.extend([new_output_value_info]) g.value_info.extend([new_output_value_info])
def polish_RESIZE_input_param_node(g, resize_node_name): def polish_RESIZE_input_param_node(g, resize_node_name):
resize_node = helper.find_node_by_node_name(g, resize_node_name) resize_node = helper.find_node_by_node_name(g, resize_node_name)
shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3]) shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3])
shape_data = helper.constant_to_numpy(shape_data_node).astype(int) shape_data = helper.constant_to_numpy(shape_data_node).astype(int)
# handle 0 batch size which is invalid # handle 0 batch size which is invalid
if shape_data[0] == 0: if shape_data[0] == 0:
shape_data[0] = 1 shape_data[0] = 1
pre_node_output_value_info = helper.find_value_by_name(g, resize_node.input[0]) pre_node_output_value_info = helper.find_value_by_name(
ori_shape = np.array([pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value, g, resize_node.input[0]
pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value, )
pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value, ori_shape = np.array(
pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value]) [
pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value,
resize_node.input.remove(resize_node.input[3]) pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value,
pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value,
pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value,
]
)
resize_scales = np.array(shape_data/ori_shape).astype(float) resize_node.input.remove(resize_node.input[3])
resize_scale_node = helper.list_to_constant('resize_scales_node_' + resize_node.name, resize_scales.shape, resize_scales, data_type=onnx.helper.TensorProto.FLOAT)
resize_scales = np.array(shape_data / ori_shape).astype(float)
resize_scale_node = helper.list_to_constant(
"resize_scales_node_" + resize_node.name,
resize_scales.shape,
resize_scales,
data_type=onnx.helper.TensorProto.FLOAT,
)
resize_node.input[2] = resize_scale_node.name resize_node.input[2] = resize_scale_node.name
g.node.extend([resize_scale_node]) g.node.extend([resize_scale_node])
other.topological_sort(g) other.topological_sort(g)

View File

@ -5,15 +5,14 @@ import logging
import traceback import traceback
from . import helper from . import helper
from .general_graph import Graph, Node
from .other import topological_sort from .other import topological_sort
from .replacing import replace_shape_with_constant
from .helper import logger from .helper import logger
def are_all_inputs_Constant_with_one_child(g, node): def are_all_inputs_Constant_with_one_child(g, node):
for input_name in node.input: for input_name in node.input:
input_node = helper.find_node_by_output_name(g, input_name) input_node = helper.find_node_by_output_name(g, input_name)
if input_node is None or input_node.op_type != 'Constant': if input_node is None or input_node.op_type != "Constant":
return False return False
relative_outputs = helper.find_nodes_by_input_name(g, input_name) relative_outputs = helper.find_nodes_by_input_name(g, input_name)
if len(relative_outputs) > 1: if len(relative_outputs) > 1:
@ -28,7 +27,7 @@ def constant_folding(g):
:return: If any node is folded, return True. Otherwise, return False. :return: If any node is folded, return True. Otherwise, return False.
""" """
keep_folding = True # Keep the while loop keep_folding = True # Keep the while loop
folded = False # Return value folded = False # Return value
try: try:
# Before constant folding, duplicate the constant nodes. # Before constant folding, duplicate the constant nodes.
duplicate_constant_node(g) duplicate_constant_node(g)
@ -38,37 +37,47 @@ def constant_folding(g):
# Check if the node is foldable # Check if the node is foldable
if node.op_type not in constant_folding_nodes.keys(): if node.op_type not in constant_folding_nodes.keys():
continue continue
# Check if the parents of the node are all single follower constant node. # Check if parents of the node are all
# single follower constant node.
if not are_all_inputs_Constant_with_one_child(g, node): if not are_all_inputs_Constant_with_one_child(g, node):
continue continue
# Constant folding for the specific node # Constant folding for the specific node
if constant_folding_nodes[node.op_type](g, node): if constant_folding_nodes[node.op_type](g, node):
logging.debug("Constant nodes and %s %s are folded.", logging.debug(
node.op_type, node.name) "Constant nodes and %s %s are folded.",
node.op_type,
node.name,
)
folded = True folded = True
keep_folding = True keep_folding = True
else: else:
logging.debug( logging.debug(
"Constant nodes and %s %s are skipped.", node.op_type, node.name) "Constant nodes and %s %s are skipped.",
except Exception as e: node.op_type,
node.name,
)
except Exception:
logger.error("An exception is raised while constant folding.") logger.error("An exception is raised while constant folding.")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return folded return folded
def duplicate_constant_node(g): def duplicate_constant_node(g):
""" Duplicate the constant node if its following nodes contain constant folding """
nodes. Create and link the new constant nodes to the constant folding nodes. Duplicate the constant node if its following nodes contain
constant folding nodes. Create and link the new constant nodes
to the constant folding nodes.
""" """
for node in g.node: for node in g.node:
# Find a valid constant node # Find a valid constant node
if node.op_type != 'Constant': if node.op_type != "Constant":
continue continue
output_val_info = helper.find_value_by_name(g, node.output[0]) output_val_info = helper.find_value_by_name(g, node.output[0])
if output_val_info is None: if output_val_info is None:
print("Cannot inference the shape of Const node output: " + print(
node.output[0]) "Cannot inference the shape of Const node output: "
+ node.output[0]
)
exit(1) exit(1)
data_shape = helper.get_shape_from_value_info(output_val_info) data_shape = helper.get_shape_from_value_info(output_val_info)
output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
@ -78,30 +87,37 @@ def duplicate_constant_node(g):
continue continue
# Check if its following nodes are foldable # Check if its following nodes are foldable
foldable_output_nodes = list(filter(lambda n: n.op_type in foldable_output_nodes = list(
constant_folding_nodes.keys(), output_nodes)) filter(
lambda n: n.op_type in constant_folding_nodes.keys(),
output_nodes,
)
)
if not foldable_output_nodes: if not foldable_output_nodes:
continue continue
# Duplicate the node needed by foldable nodes # Duplicate the node needed by foldable nodes
for i in range(len(foldable_output_nodes)): for i in range(len(foldable_output_nodes)):
logging.debug("Found constant %s and %s %s are availble for folding. Duplicate constant.", logging.debug(
node.name, foldable_output_nodes[i].op_type, foldable_output_nodes[i].name) f"Found constant {node.name} and "
output_name = node.output[0] + '_dup_' + str(i) f"{foldable_output_nodes[i].op_type} "
f"{foldable_output_nodes[i].name} are availble for folding. "
"Duplicate constant.",
)
output_name = node.output[0] + "_dup_" + str(i)
new_constant_node = onnx.helper.make_node( new_constant_node = onnx.helper.make_node(
'Constant', "Constant",
[], [],
[output_name], [output_name],
name=output_name, name=output_name,
value=node.attribute[0].t value=node.attribute[0].t,
) )
new_val_info = onnx.helper.make_tensor_value_info( new_val_info = onnx.helper.make_tensor_value_info(
output_name, output_name, node.attribute[0].t.data_type, data_shape
node.attribute[0].t.data_type,
data_shape
) )
input_ind = list(foldable_output_nodes[i].input).index( input_ind = list(foldable_output_nodes[i].input).index(
node.output[0]) node.output[0]
)
foldable_output_nodes[i].input[input_ind] = output_name foldable_output_nodes[i].input[input_ind] = output_name
g.node.extend([new_constant_node]) g.node.extend([new_constant_node])
@ -116,6 +132,7 @@ def duplicate_constant_node(g):
return return
def slice_constant_folding(g, node): def slice_constant_folding(g, node):
op_version = helper.get_current_opset_version() op_version = helper.get_current_opset_version()
# only support opset 9 & 11 # only support opset 9 & 11
@ -124,9 +141,9 @@ def slice_constant_folding(g, node):
elif op_version == 9: elif op_version == 9:
return slice_constant_folding_Opset_9(g, node) return slice_constant_folding_Opset_9(g, node)
def slice_constant_folding_Opset_11(g, node): def slice_constant_folding_Opset_11(g, node):
""" Fold constant and slice nodes to a single constant node. """Fold constant and slice nodes to a single constant node."""
"""
pre_node = helper.find_node_by_output_name(g, node.input[0]) pre_node = helper.find_node_by_output_name(g, node.input[0])
pre_shape, data_list = helper.constant_to_list(pre_node) pre_shape, data_list = helper.constant_to_list(pre_node)
@ -136,20 +153,26 @@ def slice_constant_folding_Opset_11(g, node):
ends_node = helper.find_node_by_output_name(g, node.input[2]) ends_node = helper.find_node_by_output_name(g, node.input[2])
_, ends = helper.constant_to_list(ends_node) _, ends = helper.constant_to_list(ends_node)
axes_node = (
axes_node = None if len(node.input) <= 3 else helper.find_node_by_output_name(g, node.input[3]) None
if len(node.input) <= 3
else helper.find_node_by_output_name(g, node.input[3])
)
if not axes_node: if not axes_node:
axes = list(range(len(helper.get_shape(data_list)))) axes = list(range(len(helper.get_shape(data_list))))
else: else:
_, axes = helper.constant_to_list(axes_node) _, axes = helper.constant_to_list(axes_node)
steps_node = None if len(node.input) <= 4 else helper.find_node_by_output_name(g, node.input[4]) steps_node = (
None
if len(node.input) <= 4
else helper.find_node_by_output_name(g, node.input[4])
)
if not steps_node: if not steps_node:
steps = [1]*len(helper.get_shape(data_list)) steps = [1] * len(helper.get_shape(data_list))
else: else:
_, steps = helper.constant_to_list(steps_node) _, steps = helper.constant_to_list(steps_node)
data_list = list(map(int, data_list)) data_list = list(map(int, data_list))
starts = list(map(int, starts)) starts = list(map(int, starts))
ends = list(map(int, ends)) ends = list(map(int, ends))
@ -160,10 +183,15 @@ def slice_constant_folding_Opset_11(g, node):
new_data = None new_data = None
for idx, _ in enumerate(axes): for idx, _ in enumerate(axes):
new_data = np.apply_along_axis( lambda x: x[starts[idx] : ends[idx] : steps[idx]], idx, data_list ) new_data = np.apply_along_axis(
lambda x: x[starts[idx]:ends[idx]:steps[idx]], idx, data_list
)
new_node = helper.list_to_constant(node.output[0], helper.get_shape( new_node = helper.list_to_constant(
new_data), helper.flatten_to_list(new_data)) node.output[0],
helper.get_shape(new_data),
helper.flatten_to_list(new_data),
)
g.node.extend([new_node]) g.node.extend([new_node])
value_info = helper.find_value_by_name(g, pre_node.output[0]) value_info = helper.find_value_by_name(g, pre_node.output[0])
if value_info is not None: if value_info is not None:
@ -173,16 +201,16 @@ def slice_constant_folding_Opset_11(g, node):
return True return True
def slice_constant_folding_Opset_9(g, node): def slice_constant_folding_Opset_9(g, node):
""" Fold constant and slice nodes to a single constant node. """Fold constant and slice nodes to a single constant node."""
"""
pre_node = helper.find_node_by_output_name(g, node.input[0]) pre_node = helper.find_node_by_output_name(g, node.input[0])
pre_shape, data_list = helper.constant_to_list(pre_node) pre_shape, data_list = helper.constant_to_list(pre_node)
data_list = np.reshape(data_list, pre_shape) data_list = np.reshape(data_list, pre_shape)
axes = helper.get_attribute_by_name(node, 'axes') axes = helper.get_attribute_by_name(node, "axes")
ends = list(helper.get_attribute_by_name(node, 'ends').ints) ends = list(helper.get_attribute_by_name(node, "ends").ints)
starts = list(helper.get_attribute_by_name(node, 'starts').ints) starts = list(helper.get_attribute_by_name(node, "starts").ints)
if not axes: if not axes:
axes = list(range(len(helper.get_shape(data_list)))) axes = list(range(len(helper.get_shape(data_list))))
@ -190,8 +218,11 @@ def slice_constant_folding_Opset_9(g, node):
axes = list(axes.ints) axes = list(axes.ints)
new_data = helper.slice_data(data_list, starts, ends, axes) new_data = helper.slice_data(data_list, starts, ends, axes)
new_node = helper.list_to_constant(node.output[0], helper.get_shape( new_node = helper.list_to_constant(
new_data), helper.flatten_to_list(new_data)) node.output[0],
helper.get_shape(new_data),
helper.flatten_to_list(new_data),
)
g.node.extend([new_node]) g.node.extend([new_node])
value_info = helper.find_value_by_name(g, pre_node.output[0]) value_info = helper.find_value_by_name(g, pre_node.output[0])
if value_info is not None: if value_info is not None:
@ -201,9 +232,9 @@ def slice_constant_folding_Opset_9(g, node):
return True return True
def cast_constant_folding(g, node): def cast_constant_folding(g, node):
""" Fold constant and cast node to a single constant node. """Fold constant and cast node to a single constant node."""
"""
pre_node = helper.find_node_by_output_name(g, node.input[0]) pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node) shape, data = helper.constant_to_list(pre_node)
data_type = node.attribute[0].i data_type = node.attribute[0].i
@ -212,28 +243,24 @@ def cast_constant_folding(g, node):
elif data_type == onnx.helper.TensorProto.FLOAT: elif data_type == onnx.helper.TensorProto.FLOAT:
data = list(map(float, data)) data = list(map(float, data))
else: else:
raise RuntimeError('data type not supported') raise RuntimeError("data type not supported")
if shape == 1: if shape == 1:
tensor = onnx.helper.make_tensor( tensor = onnx.helper.make_tensor(
name=pre_node.attribute[0].name, name=pre_node.attribute[0].name,
data_type=data_type, data_type=data_type,
dims=[], dims=[],
vals=data vals=data,
) )
else: else:
tensor = onnx.helper.make_tensor( tensor = onnx.helper.make_tensor(
name=pre_node.attribute[0].name, name=pre_node.attribute[0].name,
data_type=data_type, data_type=data_type,
dims=shape, dims=shape,
vals=helper.flatten_to_list(data) vals=helper.flatten_to_list(data),
) )
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
'Constant', "Constant", [], [node.output[0]], name=node.output[0], value=tensor
[],
[node.output[0]],
name=node.output[0],
value=tensor
) )
g.node.extend([new_node]) g.node.extend([new_node])
@ -250,15 +277,14 @@ def cast_constant_folding(g, node):
def reduceprod_constant_folding(g, node): def reduceprod_constant_folding(g, node):
""" Fold constant and reduceprod nodes to a single constant node. """Fold constant and reduceprod nodes to a single constant node."""
"""
pre_node = helper.find_node_by_output_name(g, node.input[0]) pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data_set = helper.constant_to_list(pre_node) shape, data_set = helper.constant_to_list(pre_node)
tensor = pre_node.attribute[0].t tensor = pre_node.attribute[0].t
data_set = np.reshape(data_set, shape) data_set = np.reshape(data_set, shape)
for att in node.attribute: for att in node.attribute:
if att.name == 'axes': if att.name == "axes":
axes = list(att.ints) axes = list(att.ints)
else: else:
keepdims = int(att.i) keepdims = int(att.i)
@ -270,14 +296,10 @@ def reduceprod_constant_folding(g, node):
name=node.output[0], name=node.output[0],
data_type=tensor.data_type, data_type=tensor.data_type,
dims=new_shape, dims=new_shape,
vals=new_flat_data vals=new_flat_data,
) )
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
'Constant', "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
) )
g.node.extend([new_node]) g.node.extend([new_node])
@ -294,8 +316,7 @@ def reduceprod_constant_folding(g, node):
def reshape_constant_input_folding(g, node): def reshape_constant_input_folding(g, node):
""" Fold constant and reshape nodes to a single constant node. """Fold constant and reshape nodes to a single constant node."""
"""
pre_data_node = helper.find_node_by_output_name(g, node.input[0]) pre_data_node = helper.find_node_by_output_name(g, node.input[0])
pre_shape_node = helper.find_node_by_output_name(g, node.input[1]) pre_shape_node = helper.find_node_by_output_name(g, node.input[1])
@ -307,14 +328,10 @@ def reshape_constant_input_folding(g, node):
name=node.output[0], name=node.output[0],
data_type=pre_data_node.attribute[0].t.data_type, data_type=pre_data_node.attribute[0].t.data_type,
dims=new_data.shape, dims=new_data.shape,
vals=helper.flatten_to_list(new_data) vals=helper.flatten_to_list(new_data),
) )
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
'Constant', "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
) )
g.node.extend([new_node]) g.node.extend([new_node])
@ -332,8 +349,7 @@ def reshape_constant_input_folding(g, node):
def concat_constant_folding(g, node): def concat_constant_folding(g, node):
""" Fold constant and concat nodes to a single constant node. """Fold constant and concat nodes to a single constant node."""
"""
node_to_del = [] node_to_del = []
valid_inputs = True valid_inputs = True
for input_name in node.input: for input_name in node.input:
@ -342,7 +358,7 @@ def concat_constant_folding(g, node):
if len(input_node_output) > 1: if len(input_node_output) > 1:
valid_inputs = False valid_inputs = False
break break
if input_node.op_type != 'Constant': if input_node.op_type != "Constant":
valid_inputs = False valid_inputs = False
break break
@ -370,7 +386,7 @@ def concat_constant_folding(g, node):
node.output[0], node.output[0],
helper.get_shape(concat_data), helper.get_shape(concat_data),
helper.flatten_to_list(concat_data), helper.flatten_to_list(concat_data),
data_type=node_data_type data_type=node_data_type,
) )
g.node.extend([new_node]) g.node.extend([new_node])
node_to_del.append(node) node_to_del.append(node)
@ -388,8 +404,7 @@ def concat_constant_folding(g, node):
def transpose_constant_folding(g, node): def transpose_constant_folding(g, node):
"""Fold constant and transpose nodes to a single constant node. """Fold constant and transpose nodes to a single constant node."""
"""
node_to_del = [] node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0]) pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node) shape, data = helper.constant_to_list(pre_node)
@ -402,7 +417,7 @@ def transpose_constant_folding(g, node):
node.output[0], node.output[0],
new_shape, new_shape,
new_data.flatten().tolist(), new_data.flatten().tolist(),
data_type=pre_node.attribute[0].t.data_type data_type=pre_node.attribute[0].t.data_type,
) )
g.node.extend([new_node]) g.node.extend([new_node])
@ -415,9 +430,7 @@ def transpose_constant_folding(g, node):
g.value_info.remove(next_val_info) g.value_info.remove(next_val_info)
new_val_info = onnx.helper.make_tensor_value_info( new_val_info = onnx.helper.make_tensor_value_info(
node.output[0], node.output[0], pre_node.attribute[0].t.data_type, new_shape
pre_node.attribute[0].t.data_type,
new_shape
) )
g.value_info.extend([new_val_info]) g.value_info.extend([new_val_info])
@ -430,8 +443,7 @@ def transpose_constant_folding(g, node):
def unsqueeze_constant_folding(g, node): def unsqueeze_constant_folding(g, node):
"""Fold constant and unsqueeze nodes to a single constant node. """Fold constant and unsqueeze nodes to a single constant node."""
"""
node_to_del = [] node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0]) pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node) shape, data = helper.constant_to_list(pre_node)
@ -449,7 +461,7 @@ def unsqueeze_constant_folding(g, node):
node.output[0], node.output[0],
new_shape, new_shape,
np_data.flatten().tolist(), np_data.flatten().tolist(),
data_type=pre_node.attribute[0].t.data_type data_type=pre_node.attribute[0].t.data_type,
) )
g.node.extend([new_node]) g.node.extend([new_node])
node_to_del.extend([node, pre_node]) node_to_del.extend([node, pre_node])
@ -464,9 +476,7 @@ def unsqueeze_constant_folding(g, node):
g.value_info.remove(next_val_info) g.value_info.remove(next_val_info)
new_val_info = onnx.helper.make_tensor_value_info( new_val_info = onnx.helper.make_tensor_value_info(
node.output[0], node.output[0], pre_node.attribute[0].t.data_type, new_shape
pre_node.attribute[0].t.data_type,
new_shape
) )
g.value_info.extend([new_val_info]) g.value_info.extend([new_val_info])
@ -478,8 +488,7 @@ def unsqueeze_constant_folding(g, node):
def gather_constant_folding(g, node): def gather_constant_folding(g, node):
"""Fold constant and gather nodes to a single constant node. """Fold constant and gather nodes to a single constant node."""
"""
node_to_del = [] node_to_del = []
pre_data_node = helper.find_node_by_output_name(g, node.input[0]) pre_data_node = helper.find_node_by_output_name(g, node.input[0])
@ -502,7 +511,7 @@ def gather_constant_folding(g, node):
node.output[0], node.output[0],
new_shape, new_shape,
new_data.flatten().tolist(), new_data.flatten().tolist(),
data_type=pre_data_node.attribute[0].t.data_type data_type=pre_data_node.attribute[0].t.data_type,
) )
node_to_del.extend([node, pre_data_node, pre_indices_node]) node_to_del.extend([node, pre_data_node, pre_indices_node])
@ -512,9 +521,7 @@ def gather_constant_folding(g, node):
val_info_2 = helper.find_value_by_name(g, node.input[1]) val_info_2 = helper.find_value_by_name(g, node.input[1])
val_info_3 = helper.find_value_by_name(g, node.output[0]) val_info_3 = helper.find_value_by_name(g, node.output[0])
new_val_info = onnx.helper.make_tensor_value_info( new_val_info = onnx.helper.make_tensor_value_info(
new_node.output[0], new_node.output[0], pre_data_node.attribute[0].t.data_type, new_shape
pre_data_node.attribute[0].t.data_type,
new_shape
) )
if val_info_1 is not None: if val_info_1 is not None:
@ -533,8 +540,7 @@ def gather_constant_folding(g, node):
def add_constant_folding(g, node): def add_constant_folding(g, node):
"""Fold constant and add nodes to a single constant node. """Fold constant and add nodes to a single constant node."""
"""
node_to_del = [] node_to_del = []
pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
@ -547,14 +553,14 @@ def add_constant_folding(g, node):
np_data2 = np.reshape(data2, shape2) np_data2 = np.reshape(data2, shape2)
try: try:
new_data = np.add(np_data1, np_data2) new_data = np.add(np_data1, np_data2)
except: except Exception:
raise RuntimeError('can\'t broadcast and add two data sets') raise RuntimeError("can't broadcast and add two data sets")
new_node = helper.list_to_constant( new_node = helper.list_to_constant(
node.output[0], node.output[0],
new_data.shape, new_data.shape,
new_data.flatten().tolist(), new_data.flatten().tolist(),
data_type=pre_node_1.attribute[0].t.data_type data_type=pre_node_1.attribute[0].t.data_type,
) )
g.node.extend([new_node]) g.node.extend([new_node])
@ -571,8 +577,7 @@ def add_constant_folding(g, node):
def sqrt_constant_folding(g, node): def sqrt_constant_folding(g, node):
""" Fold constant and sqrt nodes to a single node. """Fold constant and sqrt nodes to a single node."""
"""
node_to_del = [] node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0]) pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node) shape, data = helper.constant_to_list(pre_node)
@ -582,17 +587,13 @@ def sqrt_constant_folding(g, node):
data_type = output_val_info.type.tensor_type.elem_type data_type = output_val_info.type.tensor_type.elem_type
new_tensor = onnx.helper.make_tensor( new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data', name=node.output[0] + "_data",
data_type=data_type, data_type=data_type,
dims=shape, dims=shape,
vals=np_data.flatten().tolist() vals=np_data.flatten().tolist(),
) )
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
'Constant', "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
) )
g.value_info.remove(input_val_info) g.value_info.remove(input_val_info)
@ -607,13 +608,12 @@ def sqrt_constant_folding(g, node):
def reciprocal_constant_folding(g, node): def reciprocal_constant_folding(g, node):
""" Fold constant and reciprocal nodes to a single constant node. """Fold constant and reciprocal nodes to a single constant node."""
"""
node_to_del = [] node_to_del = []
pre_node = helper.find_node_by_output_name(g, node.input[0]) pre_node = helper.find_node_by_output_name(g, node.input[0])
shape, data = helper.constant_to_list(pre_node) shape, data = helper.constant_to_list(pre_node)
data = list(map(lambda x: x if abs(x) > 1.e-8 else 1.e-8, data)) data = list(map(lambda x: x if abs(x) > 1.0e-8 else 1.0e-8, data))
np_data = np.reshape(data, shape) np_data = np.reshape(data, shape)
np_data = np.reciprocal(np_data) np_data = np.reciprocal(np_data)
@ -622,17 +622,13 @@ def reciprocal_constant_folding(g, node):
data_type = output_val_info.type.tensor_type.elem_type data_type = output_val_info.type.tensor_type.elem_type
new_tensor = onnx.helper.make_tensor( new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data', name=node.output[0] + "_data",
data_type=data_type, data_type=data_type,
dims=shape, dims=shape,
vals=np_data.flatten().tolist() vals=np_data.flatten().tolist(),
) )
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
'Constant', "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
) )
node_to_del.extend([node, pre_node]) node_to_del.extend([node, pre_node])
@ -648,8 +644,7 @@ def reciprocal_constant_folding(g, node):
def mul_constant_folding(g, node): def mul_constant_folding(g, node):
""" Fold constant and mul nodes to a single constant node. """Fold constant and mul nodes to a single constant node."""
"""
node_to_del = [] node_to_del = []
pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
@ -666,8 +661,8 @@ def mul_constant_folding(g, node):
try: try:
new_data = np.multiply(np_data1, np_data2) new_data = np.multiply(np_data1, np_data2)
except: except Exception:
raise RuntimeError('can not broadcast and multiply two data sets') raise RuntimeError("can not broadcast and multiply two data sets")
# Special shape for single element. # Special shape for single element.
if shape1 == 1 and shape2 == 1: if shape1 == 1 and shape2 == 1:
@ -676,17 +671,13 @@ def mul_constant_folding(g, node):
new_shape = new_data.shape new_shape = new_data.shape
new_tensor = onnx.helper.make_tensor( new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data', name=node.output[0] + "_data",
data_type=pre_node_1.attribute[0].t.data_type, data_type=pre_node_1.attribute[0].t.data_type,
dims=new_shape, dims=new_shape,
vals=new_data.flatten().tolist() vals=new_data.flatten().tolist(),
) )
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
'Constant', "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
) )
node_to_del.extend([node, pre_node_1, pre_node_2]) node_to_del.extend([node, pre_node_1, pre_node_2])
@ -703,8 +694,7 @@ def mul_constant_folding(g, node):
def div_constant_folding(g, node): def div_constant_folding(g, node):
""" Fold constant and mul nodes to a single constant node. """Fold constant and mul nodes to a single constant node."""
"""
node_to_del = [] node_to_del = []
pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
@ -721,8 +711,8 @@ def div_constant_folding(g, node):
try: try:
new_data = np.divide(np_data1, np_data2) new_data = np.divide(np_data1, np_data2)
except: except Exception:
raise RuntimeError('can not broadcast and multiply two data sets') raise RuntimeError("can not broadcast and multiply two data sets")
# Special shape for single element. # Special shape for single element.
if shape1 == 1 and shape2 == 1: if shape1 == 1 and shape2 == 1:
@ -732,20 +722,16 @@ def div_constant_folding(g, node):
# Check data type if it is int # Check data type if it is int
if pre_node_1.attribute[0].t.data_type == 7: if pre_node_1.attribute[0].t.data_type == 7:
new_data = new_data.astype('int64') new_data = new_data.astype("int64")
new_tensor = onnx.helper.make_tensor( new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data', name=node.output[0] + "_data",
data_type=pre_node_1.attribute[0].t.data_type, data_type=pre_node_1.attribute[0].t.data_type,
dims=new_shape, dims=new_shape,
vals=new_data.flatten().tolist() vals=new_data.flatten().tolist(),
) )
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
'Constant', "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
) )
node_to_del.extend([node, pre_node_1, pre_node_2]) node_to_del.extend([node, pre_node_1, pre_node_2])
@ -762,8 +748,7 @@ def div_constant_folding(g, node):
def sub_constant_folding(g, node): def sub_constant_folding(g, node):
""" Fold constant and sub nodes to a single node. """Fold constant and sub nodes to a single node."""
"""
node_to_del = [] node_to_del = []
pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
@ -781,17 +766,13 @@ def sub_constant_folding(g, node):
new_shape = new_data.shape new_shape = new_data.shape
new_tensor = onnx.helper.make_tensor( new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data', name=node.output[0] + "_data",
data_type=pre_node_1.attribute[0].t.data_type, data_type=pre_node_1.attribute[0].t.data_type,
dims=new_shape, dims=new_shape,
vals=helper.flatten_to_list(new_data) vals=helper.flatten_to_list(new_data),
) )
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
'Constant', "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
) )
g.node.extend([new_node]) g.node.extend([new_node])
@ -815,17 +796,13 @@ def neg_constant_folding(g, node):
new_data_list = [-num for num in data_list] new_data_list = [-num for num in data_list]
new_tensor = onnx.helper.make_tensor( new_tensor = onnx.helper.make_tensor(
name=pre_node.name+'_neg_tensor', name=pre_node.name + "_neg_tensor",
data_type=pre_node.attribute[0].t.data_type, data_type=pre_node.attribute[0].t.data_type,
dims=shape, dims=shape,
vals=new_data_list vals=new_data_list,
) )
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
'Constant', "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
) )
g.node.extend([new_node]) g.node.extend([new_node])
@ -851,17 +828,13 @@ def floor_constant_folding(g, node):
new_shape = shape new_shape = shape
new_tensor = onnx.helper.make_tensor( new_tensor = onnx.helper.make_tensor(
name=node.output[0]+'_data', name=node.output[0] + "_data",
data_type=pre_node.attribute[0].t.data_type, data_type=pre_node.attribute[0].t.data_type,
dims=new_shape, dims=new_shape,
vals=helper.flatten_to_list(new_data) vals=helper.flatten_to_list(new_data),
) )
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
'Constant', "Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
[],
[node.output[0]],
name=node.output[0],
value=new_tensor
) )
g.node.extend([new_node]) g.node.extend([new_node])
@ -877,8 +850,7 @@ def floor_constant_folding(g, node):
def bn_constant_folding(g, node): def bn_constant_folding(g, node):
""" Fold constant and mul nodes to a single constant node. """Fold constant and mul nodes to a single constant node."""
"""
# Prepare data # Prepare data
node_to_del = [] node_to_del = []
input_node = helper.find_node_by_output_name(g, node.input[0]) input_node = helper.find_node_by_output_name(g, node.input[0])
@ -900,17 +872,22 @@ def bn_constant_folding(g, node):
mean_data = helper.constant_to_numpy(mean_node) mean_data = helper.constant_to_numpy(mean_node)
var_data = helper.constant_to_numpy(var_node) var_data = helper.constant_to_numpy(var_node)
epsilon = helper.get_var_attribute_by_name(node, 'epsilon', 'float') epsilon = helper.get_var_attribute_by_name(node, "epsilon", "float")
if epsilon is None: if epsilon is None:
epsilon = 0.00001 epsilon = 0.00001
# Calculate new node # Calculate new node
new_data = scale_data * (input_data - mean_data) / np.sqrt(var_data + epsilon) + bias_data new_data = (
scale_data * (input_data - mean_data) / np.sqrt(var_data + epsilon)
+ bias_data
)
new_node = helper.numpy_to_constant(node.output[0], new_data) new_node = helper.numpy_to_constant(node.output[0], new_data)
# Reconnect the graph # Reconnect the graph
node_to_del.extend([node, input_node, scale_node, bias_node, mean_node, var_node]) node_to_del.extend(
[node, input_node, scale_node, bias_node, mean_node, var_node]
)
g.node.extend([new_node]) g.node.extend([new_node])
for value in input_value_info: for value in input_value_info:
@ -925,8 +902,7 @@ def bn_constant_folding(g, node):
def DequantizeLinear_constant_folding(g, node): def DequantizeLinear_constant_folding(g, node):
""" Fold constant and mul nodes to a single constant node. """Fold constant and mul nodes to a single constant node."""
"""
# Prepare data # Prepare data
node_to_del = [] node_to_del = []
x_node = helper.find_node_by_output_name(g, node.input[0]) x_node = helper.find_node_by_output_name(g, node.input[0])
@ -951,7 +927,9 @@ def DequantizeLinear_constant_folding(g, node):
x_zero_point_data = np.array([0.0]) x_zero_point_data = np.array([0.0])
# Calculate new node # Calculate new node
new_data = (x_data.astype(np.float32) - x_zero_point_data.astype(np.float32)) * x_scale_data new_data = (
x_data.astype(np.float32) - x_zero_point_data.astype(np.float32)
) * x_scale_data
new_node = helper.numpy_to_constant(node.output[0], new_data) new_node = helper.numpy_to_constant(node.output[0], new_data)
@ -974,22 +952,22 @@ def DequantizeLinear_constant_folding(g, node):
# Available constant folding names to function map. # Available constant folding names to function map.
constant_folding_nodes = { constant_folding_nodes = {
'Add': add_constant_folding, "Add": add_constant_folding,
'BatchNormalization': bn_constant_folding, "BatchNormalization": bn_constant_folding,
'Cast': cast_constant_folding, "Cast": cast_constant_folding,
'Concat': concat_constant_folding, "Concat": concat_constant_folding,
'DequantizeLinear': DequantizeLinear_constant_folding, "DequantizeLinear": DequantizeLinear_constant_folding,
'Div': div_constant_folding, "Div": div_constant_folding,
'Floor': floor_constant_folding, "Floor": floor_constant_folding,
'Gather': gather_constant_folding, "Gather": gather_constant_folding,
'Mul': mul_constant_folding, "Mul": mul_constant_folding,
'Reciprocal': reciprocal_constant_folding, "Reciprocal": reciprocal_constant_folding,
'ReduceProd': reduceprod_constant_folding, "ReduceProd": reduceprod_constant_folding,
'Reshape': reshape_constant_input_folding, "Reshape": reshape_constant_input_folding,
'Slice': slice_constant_folding, "Slice": slice_constant_folding,
'Sqrt': sqrt_constant_folding, "Sqrt": sqrt_constant_folding,
'Transpose': transpose_constant_folding, "Transpose": transpose_constant_folding,
'Unsqueeze': unsqueeze_constant_folding, "Unsqueeze": unsqueeze_constant_folding,
'Sub': sub_constant_folding, "Sub": sub_constant_folding,
'Neg': neg_constant_folding "Neg": neg_constant_folding,
} }

View File

@ -7,6 +7,7 @@ from . import helper
from . import modhelper from . import modhelper
from .general_graph import Graph from .general_graph import Graph
def eliminate_Identify_and_Dropout(g): def eliminate_Identify_and_Dropout(g):
""" """
Eliminate Identify layers Eliminate Identify layers
@ -15,31 +16,46 @@ def eliminate_Identify_and_Dropout(g):
""" """
node_to_remove = [] node_to_remove = []
for node in g.node: for node in g.node:
if node.op_type != 'Identity' and node.op_type != 'Dropout': if node.op_type != "Identity" and node.op_type != "Dropout":
continue continue
# If this node is the last node, leave it to `eliminate_useless_last node` # If this node is the last, leave it to `eliminate_useless_last node`
if helper.find_output_by_name(g, node.output[0]) is not None: if helper.find_output_by_name(g, node.output[0]) is not None:
continue continue
# Replace the parents in all the following nodes # Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
for following_node in following_nodes: for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0]) modhelper.replace_node_input(
following_node, node.output[0], node.input[0]
)
# Delete value info # Delete value info
value_between = helper.find_value_by_name(g, node.output[0]) value_between = helper.find_value_by_name(g, node.output[0])
try: try:
g.value_info.remove(value_between) g.value_info.remove(value_between)
except: except Exception:
print("No value info to delete while eliminating identity layers.") print("No value info to delete while eliminating identity layers.")
# Node is waiting for elimination # Node is waiting for elimination
node_to_remove.append(node) node_to_remove.append(node)
for node in node_to_remove: for node in node_to_remove:
g.node.remove(node) g.node.remove(node)
# Remove last useless nodes # Remove last useless nodes
def remove_useless_last_nodes(g): def remove_useless_last_nodes(g):
"""Remove useless nodes from the tail of the graph """Remove useless nodes from the tail of the graph"""
""" USELESS = [
USELESS = ["Reshape", "Identity", "Transpose", "Flatten", "Dropout", "Mystery", "Constant", "Squeeze", "Unsqueeze", 'Softmax'] "Reshape",
"Identity",
"Transpose",
"Flatten",
"Dropout",
"Mystery",
"Constant",
"Squeeze",
"Unsqueeze",
"Softmax",
]
graph = Graph(g) graph = Graph(g)
todo = collections.deque() todo = collections.deque()
for node in graph.output_nodes: for node in graph.output_nodes:
@ -54,19 +70,30 @@ def remove_useless_last_nodes(g):
if cur_node.proto.op_type not in USELESS: if cur_node.proto.op_type not in USELESS:
continue continue
# Find the output # Find the output
cur_node_output = helper.find_output_by_name(g, cur_node.proto.output[0]) cur_node_output = helper.find_output_by_name(
g, cur_node.proto.output[0]
)
for cur_input in cur_node.parents: for cur_input in cur_node.parents:
cur_input.children.remove(cur_node) cur_input.children.remove(cur_node)
if len(cur_input.children) == 0: if len(cur_input.children) == 0:
todo.append(cur_input) todo.append(cur_input)
if cur_node_output is not None: if cur_node_output is not None:
cur_input_output = helper.find_value_by_name(g, cur_input.proto.output[0]) cur_input_output = helper.find_value_by_name(
cur_input_output_in_output = helper.find_output_by_name(g, cur_input.proto.output[0]) g, cur_input.proto.output[0]
if cur_input_output is not None and cur_input_output_in_output is None: )
cur_input_output_in_output = helper.find_output_by_name(
g, cur_input.proto.output[0]
)
if (
cur_input_output is not None
and cur_input_output_in_output is None
):
g.output.extend([cur_input_output]) g.output.extend([cur_input_output])
node_to_remove.append(cur_node.proto) node_to_remove.append(cur_node.proto)
try: try:
g.value_info.remove(helper.find_value_by_name(g, cur_node.proto.output[0])) g.value_info.remove(
helper.find_value_by_name(g, cur_node.proto.output[0])
)
except ValueError: except ValueError:
pass pass
if cur_node_output is not None: if cur_node_output is not None:
@ -76,10 +103,12 @@ def remove_useless_last_nodes(g):
for node in node_to_remove: for node in node_to_remove:
g.node.remove(node) g.node.remove(node)
###################################### ######################################
# TF only optimization passes # # TF only optimization passes #
###################################### ######################################
def eliminate_shape_changing_after_input(g): def eliminate_shape_changing_after_input(g):
""" """
Eliminate the Reshape node after input and reshape the input Eliminate the Reshape node after input and reshape the input
@ -87,7 +116,14 @@ def eliminate_shape_changing_after_input(g):
:param g: the onnx graph :param g: the onnx graph
""" """
node_to_remove = [] node_to_remove = []
REMOVE_LIST = ["Reshape", "Transpose", "Flatten", "Dropout", "Squeeze", "Unsqueeze"] REMOVE_LIST = [
"Reshape",
"Transpose",
"Flatten",
"Dropout",
"Squeeze",
"Unsqueeze",
]
for node in g.node: for node in g.node:
# Find an input and the shape node # Find an input and the shape node
if node.op_type not in REMOVE_LIST: if node.op_type not in REMOVE_LIST:
@ -105,9 +141,9 @@ def eliminate_shape_changing_after_input(g):
# Remove Weight if any. # Remove Weight if any.
output_val_info = helper.find_value_by_name(g, node.output[0]) output_val_info = helper.find_value_by_name(g, node.output[0])
if node.op_type == 'Reshape': if node.op_type == "Reshape":
shape_node = helper.find_node_by_output_name(g, node.input[1]) shape_node = helper.find_node_by_output_name(g, node.input[1])
if shape_node.op_type != 'Constant': if shape_node.op_type != "Constant":
continue continue
# manuelly set the input shape # manuelly set the input shape
@ -117,25 +153,29 @@ def eliminate_shape_changing_after_input(g):
_, new_shape = helper.constant_to_list(shape_node) _, new_shape = helper.constant_to_list(shape_node)
for i in range(len(new_shape)): for i in range(len(new_shape)):
if new_shape[i] == -1: if new_shape[i] == -1:
dim = int(old_size//np.prod(new_shape)*(-1)) dim = int(old_size // np.prod(new_shape) * (-1))
new_shape[i] = dim new_shape[i] = dim
new_input = onnx.helper.make_tensor_value_info( new_input = onnx.helper.make_tensor_value_info(
output_val_info.name, output_val_info.name,
output_val_info.type.tensor_type.elem_type, output_val_info.type.tensor_type.elem_type,
new_shape new_shape,
) )
node_to_remove.append(node) node_to_remove.append(node)
shape_outputs = helper.find_nodes_by_input_name(g, shape_node.output[0]) shape_outputs = helper.find_nodes_by_input_name(
g, shape_node.output[0]
)
if len(shape_outputs) == 1: if len(shape_outputs) == 1:
node_to_remove.append(shape_node) node_to_remove.append(shape_node)
g.value_info.remove(helper.find_value_by_name(g, shape_node.output[0])) g.value_info.remove(
helper.find_value_by_name(g, shape_node.output[0])
)
g.input.remove(old_input) g.input.remove(old_input)
g.input.extend([new_input]) g.input.extend([new_input])
g.value_info.remove(output_val_info) g.value_info.remove(output_val_info)
elif node.op_type == 'Transpose': elif node.op_type == "Transpose":
permutation = list(node.attribute[0].ints) permutation = list(node.attribute[0].ints)
pre_shape = helper.get_shape_from_value_info(old_input) pre_shape = helper.get_shape_from_value_info(old_input)
new_shape = [pre_shape[i] for i in permutation] new_shape = [pre_shape[i] for i in permutation]
@ -143,7 +183,7 @@ def eliminate_shape_changing_after_input(g):
new_input = onnx.helper.make_tensor_value_info( new_input = onnx.helper.make_tensor_value_info(
output_val_info.name, output_val_info.name,
output_val_info.type.tensor_type.elem_type, output_val_info.type.tensor_type.elem_type,
new_shape new_shape,
) )
node_to_remove.append(node) node_to_remove.append(node)
@ -151,7 +191,7 @@ def eliminate_shape_changing_after_input(g):
g.input.remove(old_input) g.input.remove(old_input)
g.input.extend([new_input]) g.input.extend([new_input])
g.value_info.remove(output_val_info) g.value_info.remove(output_val_info)
elif node.op_type == 'Flatten': elif node.op_type == "Flatten":
axis = node.attribute[0].int axis = node.attribute[0].int
pre_shape = helper.get_shape_from_value_info(old_input) pre_shape = helper.get_shape_from_value_info(old_input)
dim_1, dim_2 = 1, 1 dim_1, dim_2 = 1, 1
@ -166,7 +206,7 @@ def eliminate_shape_changing_after_input(g):
new_input = onnx.helper.make_tensor_value_info( new_input = onnx.helper.make_tensor_value_info(
output_val_info.name, output_val_info.name,
output_val_info.type.tensor_type.elem_type, output_val_info.type.tensor_type.elem_type,
new_shape new_shape,
) )
node_to_remove.append(node) node_to_remove.append(node)
@ -174,18 +214,18 @@ def eliminate_shape_changing_after_input(g):
g.input.remove(old_input) g.input.remove(old_input)
g.input.extend([new_input]) g.input.extend([new_input])
g.value_info.remove(output_val_info) g.value_info.remove(output_val_info)
elif node.op_type == 'Dropout': elif node.op_type == "Dropout":
g.input.remove(old_input) g.input.remove(old_input)
g.input.extend([output_val_info]) g.input.extend([output_val_info])
g.value_info.remove(output_val_info) g.value_info.remove(output_val_info)
node_to_remove.append(node) node_to_remove.append(node)
elif node.op_type == 'Squeeze': elif node.op_type == "Squeeze":
axis = list(node.attribute[0].ints) axis = list(node.attribute[0].ints)
pre_shape = helper.get_shape_from_value_info(old_input) pre_shape = helper.get_shape_from_value_info(old_input)
for pos in sorted(axis)[::-1]: for pos in sorted(axis)[::-1]:
if pre_shape[pos] != 1: if pre_shape[pos] != 1:
raise RuntimeError('invalid axis for squeeze') raise RuntimeError("invalid axis for squeeze")
else: else:
pre_shape.pop(pos) pre_shape.pop(pos)
new_shape = pre_shape new_shape = pre_shape
@ -193,7 +233,7 @@ def eliminate_shape_changing_after_input(g):
new_input = onnx.helper.make_tensor_value_info( new_input = onnx.helper.make_tensor_value_info(
output_val_info.name, output_val_info.name,
output_val_info.type.tensor_type.elem_type, output_val_info.type.tensor_type.elem_type,
new_shape new_shape,
) )
node_to_remove.append(node) node_to_remove.append(node)
@ -201,7 +241,7 @@ def eliminate_shape_changing_after_input(g):
g.input.remove(old_input) g.input.remove(old_input)
g.input.extend([new_input]) g.input.extend([new_input])
g.value_info.remove(output_val_info) g.value_info.remove(output_val_info)
elif node.op_type == 'Unsqueeze': elif node.op_type == "Unsqueeze":
axis = list(node.attribute[0].ints) axis = list(node.attribute[0].ints)
pre_shape = helper.get_shape_from_value_info(old_input) pre_shape = helper.get_shape_from_value_info(old_input)
new_shape = pre_shape new_shape = pre_shape
@ -210,7 +250,7 @@ def eliminate_shape_changing_after_input(g):
new_input = onnx.helper.make_tensor_value_info( new_input = onnx.helper.make_tensor_value_info(
output_val_info.name, output_val_info.name,
output_val_info.type.tensor_type.elem_type, output_val_info.type.tensor_type.elem_type,
new_shape new_shape,
) )
node_to_remove.append(node) node_to_remove.append(node)
@ -222,7 +262,7 @@ def eliminate_shape_changing_after_input(g):
for node in node_to_remove: for node in node_to_remove:
g.node.remove(node) g.node.remove(node)
other.topological_sort(g) other.topological_sort(g)
@ -231,15 +271,13 @@ def eliminate_Reshape_Cast(g):
:param g: the onnx graph :param g: the onnx graph
""" """
#Find all reshape layers # Find all reshape layers
node_to_remove = []
for node in g.node: for node in g.node:
if node.op_type != 'Reshape': if node.op_type != "Reshape":
continue continue
prev_node = helper.find_node_by_output_name(g, node.input[1]) prev_node = helper.find_node_by_output_name(g, node.input[1])
if prev_node.op_type != 'Cast': if prev_node.op_type != "Cast":
continue continue
# Now we find the cast weight pattern. Cast the weight, delete the cast.
reshape_node = node reshape_node = node
cast_node = prev_node cast_node = prev_node
weight_node = helper.find_node_by_output_name(g, cast_node.input[0]) weight_node = helper.find_node_by_output_name(g, cast_node.input[0])
@ -248,10 +286,12 @@ def eliminate_Reshape_Cast(g):
weight_node.attribute[0].t.data_type = 7 weight_node.attribute[0].t.data_type = 7
if weight_node.attribute[0].t.raw_data: if weight_node.attribute[0].t.raw_data:
raw_data = weight_node.attribute[0].t.raw_data raw_data = weight_node.attribute[0].t.raw_data
int_data = [i[0] for i in struct.iter_unpack('i', raw_data)] int_data = [i[0] for i in struct.iter_unpack("i", raw_data)]
raw_data = struct.pack('q' * len(int_data), *int_data) raw_data = struct.pack("q" * len(int_data), *int_data)
elif len(weight_node.attribute[0].t.int64_data) > 0\ elif (
or len(weight_node.attribute[0].t.int32_data) > 0: len(weight_node.attribute[0].t.int64_data) > 0
or len(weight_node.attribute[0].t.int32_data) > 0
):
# It's already int. Do nothing # It's already int. Do nothing
pass pass
else: else:
@ -264,6 +304,7 @@ def eliminate_Reshape_Cast(g):
g.value_info.remove(origin_weight_out) g.value_info.remove(origin_weight_out)
g.node.remove(cast_node) g.node.remove(cast_node)
def eliminate_Cast_after_input(g): def eliminate_Cast_after_input(g):
"""Eliminate the cast layer right after the input """Eliminate the cast layer right after the input
@ -271,7 +312,7 @@ def eliminate_Cast_after_input(g):
""" """
node_to_remove = [] node_to_remove = []
for node in g.node: for node in g.node:
if node.op_type != 'Cast': if node.op_type != "Cast":
continue continue
old_input = helper.find_input_by_name(g, node.input[0]) old_input = helper.find_input_by_name(g, node.input[0])
if old_input is None: if old_input is None:
@ -279,9 +320,7 @@ def eliminate_Cast_after_input(g):
next_val_info = helper.find_value_by_name(g, node.output[0]) next_val_info = helper.find_value_by_name(g, node.output[0])
shape = helper.get_shape_from_value_info(next_val_info) shape = helper.get_shape_from_value_info(next_val_info)
new_val_info = onnx.helper.make_tensor_value_info( new_val_info = onnx.helper.make_tensor_value_info(
next_val_info.name, next_val_info.name, node.attribute[0].i, shape
node.attribute[0].i,
shape
) )
# Delete old value_info # Delete old value_info
g.input.remove(old_input) g.input.remove(old_input)
@ -293,6 +332,7 @@ def eliminate_Cast_after_input(g):
for node in node_to_remove: for node in node_to_remove:
g.node.remove(node) g.node.remove(node)
def eliminate_consecutive_Cast(g): def eliminate_consecutive_Cast(g):
"""If two cast is next to each other, remove the first cast """If two cast is next to each other, remove the first cast
@ -300,10 +340,10 @@ def eliminate_consecutive_Cast(g):
""" """
node_to_remove = [] node_to_remove = []
for node in g.node: for node in g.node:
if node.op_type != 'Cast': if node.op_type != "Cast":
continue continue
first_node = helper.find_node_by_output_name(g, node.input[0]) first_node = helper.find_node_by_output_name(g, node.input[0])
if first_node is None or first_node.op_type != 'Cast': if first_node is None or first_node.op_type != "Cast":
continue continue
# Here we have two consecutive Cast Node # Here we have two consecutive Cast Node
# Reset the input of the later node # Reset the input of the later node
@ -315,6 +355,7 @@ def eliminate_consecutive_Cast(g):
for node in node_to_remove: for node in node_to_remove:
g.node.remove(node) g.node.remove(node)
def eliminate_Squeeze_before_Reshape(g): def eliminate_Squeeze_before_Reshape(g):
"""If Squeeze and Reshape is next to each other, remove the first node """If Squeeze and Reshape is next to each other, remove the first node
@ -322,12 +363,12 @@ def eliminate_Squeeze_before_Reshape(g):
""" """
node_to_remove = [] node_to_remove = []
for node in g.node: for node in g.node:
if node.op_type != 'Reshape': if node.op_type != "Reshape":
continue continue
first_node = helper.find_node_by_output_name(g, node.input[0]) first_node = helper.find_node_by_output_name(g, node.input[0])
if not first_node: if not first_node:
continue continue
if first_node.op_type != 'Squeeze': if first_node.op_type != "Squeeze":
continue continue
# Here we have two consecutive Cast Node # Here we have two consecutive Cast Node
# Reset the input of the later node # Reset the input of the later node
@ -339,9 +380,9 @@ def eliminate_Squeeze_before_Reshape(g):
for node in node_to_remove: for node in node_to_remove:
g.node.remove(node) g.node.remove(node)
def eliminate_no_children_input(g): def eliminate_no_children_input(g):
"""Eliminate inputs with no children at all. """Eliminate inputs with no children at all."""
"""
# Create a set of input names # Create a set of input names
input_names = set([i.name for i in g.input]) input_names = set([i.name for i in g.input])
# If a name is used in any node, remove this name from the set. # If a name is used in any node, remove this name from the set.
@ -353,31 +394,33 @@ def eliminate_no_children_input(g):
info = helper.find_input_by_name(g, i) info = helper.find_input_by_name(g, i)
g.input.remove(info) g.input.remove(info)
def eliminate_consecutive_reshape(g): def eliminate_consecutive_reshape(g):
"""Replace consecutive reshape nodes by a single node. """Replace consecutive reshape nodes by a single node."""
"""
node_to_del = [] node_to_del = []
for node in g.node: for node in g.node:
if node.op_type != 'Reshape': if node.op_type != "Reshape":
continue continue
pre_data_node = helper.find_node_by_output_name(g, node.input[0]) pre_data_node = helper.find_node_by_output_name(g, node.input[0])
pre_shape_node = helper.find_node_by_output_name(g, node.input[1]) pre_shape_node = helper.find_node_by_output_name(g, node.input[1])
if not pre_data_node or not pre_shape_node: if not pre_data_node or not pre_shape_node:
continue continue
if pre_shape_node.op_type != 'Constant': if pre_shape_node.op_type != "Constant":
continue continue
if pre_data_node.op_type != 'Reshape': if pre_data_node.op_type != "Reshape":
continue continue
pre_pre_shape_node = helper.find_node_by_output_name(g, pre_data_node.input[1]) pre_pre_shape_node = helper.find_node_by_output_name(
if pre_pre_shape_node.op_type != 'Constant': g, pre_data_node.input[1]
)
if pre_pre_shape_node.op_type != "Constant":
continue continue
new_reshape_node = onnx.helper.make_node( new_reshape_node = onnx.helper.make_node(
'Reshape', "Reshape",
[pre_data_node.input[0], node.input[1]], [pre_data_node.input[0], node.input[1]],
[node.output[0]], [node.output[0]],
name = node.output[0] name=node.output[0],
) )
g.node.extend([new_reshape_node]) g.node.extend([new_reshape_node])
@ -394,6 +437,7 @@ def eliminate_consecutive_reshape(g):
node = node_to_del.pop() node = node_to_del.pop()
g.node.remove(node) g.node.remove(node)
def eliminate_single_input_Concat(g): def eliminate_single_input_Concat(g):
""" """
Eliminate single input Concat layers Eliminate single input Concat layers
@ -402,12 +446,12 @@ def eliminate_single_input_Concat(g):
""" """
node_to_remove = [] node_to_remove = []
for node in g.node: for node in g.node:
if node.op_type != 'Concat': if node.op_type != "Concat":
continue continue
# If this node has more than 1 input, continue. # If this node has more than 1 input, continue.
if len(node.input) > 1: if len(node.input) > 1:
continue continue
# If this node is the output node, set its previous node as output nodes. # If this node is output node, set its previous node as output nodes.
if helper.find_output_by_name(g, node.output[0]) is not None: if helper.find_output_by_name(g, node.output[0]) is not None:
todel_output = helper.find_output_by_name(g, node.output[0]) todel_output = helper.find_output_by_name(g, node.output[0])
the_input_value = helper.find_value_by_name(g, node.input[0]) the_input_value = helper.find_value_by_name(g, node.input[0])
@ -416,20 +460,25 @@ def eliminate_single_input_Concat(g):
node_to_remove.append(node) node_to_remove.append(node)
continue continue
# Replace the parents in all the following nodes # Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
for following_node in following_nodes: for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0]) modhelper.replace_node_input(
following_node, node.output[0], node.input[0]
)
# Delete value info # Delete value info
value_between = helper.find_value_by_name(g, node.output[0]) value_between = helper.find_value_by_name(g, node.output[0])
try: try:
g.value_info.remove(value_between) g.value_info.remove(value_between)
except: except Exception:
print("No value info to delete while eliminating identity layers.") print("No value info to delete while eliminating identity layers.")
# Node is waiting for elimination # Node is waiting for elimination
node_to_remove.append(node) node_to_remove.append(node)
for node in node_to_remove: for node in node_to_remove:
g.node.remove(node) g.node.remove(node)
def eliminate_nop_Maxpool_and_AveragePool(g): def eliminate_nop_Maxpool_and_AveragePool(g):
""" """
Eliminate do nothing MaxPool and AveragePool layers. Eliminate do nothing MaxPool and AveragePool layers.
@ -439,7 +488,7 @@ def eliminate_nop_Maxpool_and_AveragePool(g):
""" """
node_to_remove = [] node_to_remove = []
for node in g.node: for node in g.node:
if node.op_type != 'MaxPool' and node.op_type != 'AveragePool': if node.op_type != "MaxPool" and node.op_type != "AveragePool":
continue continue
# If this node is actually working, continue. # If this node is actually working, continue.
kernel = helper.get_list_attribute_by_name(node, "kernel_shape", "int") kernel = helper.get_list_attribute_by_name(node, "kernel_shape", "int")
@ -447,7 +496,7 @@ def eliminate_nop_Maxpool_and_AveragePool(g):
strides = helper.get_list_attribute_by_name(node, "strides", "int") strides = helper.get_list_attribute_by_name(node, "strides", "int")
if kernel != [1, 1] or pads != [0, 0, 0, 0] or strides != [1, 1]: if kernel != [1, 1] or pads != [0, 0, 0, 0] or strides != [1, 1]:
continue continue
# If this node is the output node, set its previous node as output nodes. # If this node is the output, set its previous node as output nodes.
if helper.find_output_by_name(g, node.output[0]) is not None: if helper.find_output_by_name(g, node.output[0]) is not None:
todel_output = helper.find_output_by_name(g, node.output[0]) todel_output = helper.find_output_by_name(g, node.output[0])
the_input_value = helper.find_value_by_name(g, node.input[0]) the_input_value = helper.find_value_by_name(g, node.input[0])
@ -456,14 +505,18 @@ def eliminate_nop_Maxpool_and_AveragePool(g):
node_to_remove.append(node) node_to_remove.append(node)
continue continue
# Replace the parents in all the following nodes # Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
for following_node in following_nodes: for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0]) modhelper.replace_node_input(
following_node, node.output[0], node.input[0]
)
# Delete value info # Delete value info
value_between = helper.find_value_by_name(g, node.output[0]) value_between = helper.find_value_by_name(g, node.output[0])
try: try:
g.value_info.remove(value_between) g.value_info.remove(value_between)
except: except Exception:
print("No value info to delete while eliminating identity layers.") print("No value info to delete while eliminating identity layers.")
# Node is waiting for elimination # Node is waiting for elimination
node_to_remove.append(node) node_to_remove.append(node)
@ -474,20 +527,20 @@ def eliminate_nop_Maxpool_and_AveragePool(g):
def eliminate_trivial_maxpool(g): def eliminate_trivial_maxpool(g):
node_to_del = [] node_to_del = []
for node in g.node: for node in g.node:
if node.op_type != 'MaxPool': if node.op_type != "MaxPool":
continue continue
pads = None pads = None
strides = None strides = None
dilation = None dilation = None
kernel_shape = None kernel_shape = None
for att in node.attribute: for att in node.attribute:
if att.name == 'pads': if att.name == "pads":
pads = list(att.ints) pads = list(att.ints)
elif att.name == 'strides': elif att.name == "strides":
strides = list(att.ints) strides = list(att.ints)
elif att.name == 'kernel_shape': elif att.name == "kernel_shape":
kernel_shape = list(att.ints) kernel_shape = list(att.ints)
elif att.name == 'dilation': elif att.name == "dilation":
dilation = list(att.ints) dilation = list(att.ints)
else: else:
pass pass
@ -504,7 +557,7 @@ def eliminate_trivial_maxpool(g):
next_nodes = helper.find_nodes_by_input_name(g, node.output[0]) next_nodes = helper.find_nodes_by_input_name(g, node.output[0])
if next_nodes[0] == None: if next_nodes[0] is None:
output_value = helper.find_output_by_name(g, node.output[0]) output_value = helper.find_output_by_name(g, node.output[0])
if not output_value: if not output_value:
continue continue
@ -512,18 +565,21 @@ def eliminate_trivial_maxpool(g):
pre_val_info = helper.find_value_by_name(g, node.input[0]) pre_val_info = helper.find_value_by_name(g, node.input[0])
g.output.extend([pre_val_info]) g.output.extend([pre_val_info])
g.output.remove(output_value) g.output.remove(output_value)
for next_node in next_nodes: for next_node in next_nodes:
modhelper.replace_node_input(next_node, node.output[0], node.input[0]) modhelper.replace_node_input(
next_node, node.output[0], node.input[0]
)
next_val_info = helper.find_value_by_name(g, node.output[0]) next_val_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(next_val_info) g.value_info.remove(next_val_info)
while node_to_del: while node_to_del:
g.node.remove(node_to_del.pop()) g.node.remove(node_to_del.pop())
other.topological_sort(g) other.topological_sort(g)
def eliminate_empty_value_infos(g): def eliminate_empty_value_infos(g):
to_remove = [] to_remove = []
for value_info in g.value_info: for value_info in g.value_info:
@ -532,10 +588,11 @@ def eliminate_empty_value_infos(g):
for value_info in to_remove: for value_info in to_remove:
g.value_info.remove(value_info) g.value_info.remove(value_info)
def eliminate_nop_pads(g): def eliminate_nop_pads(g):
node_to_remove = [] node_to_remove = []
for node in g.node: for node in g.node:
if node.op_type != 'Pad': if node.op_type != "Pad":
continue continue
# Check if the Pad is empty or not # Check if the Pad is empty or not
pads_node = helper.find_node_by_output_name(g, node.input[1]) pads_node = helper.find_node_by_output_name(g, node.input[1])
@ -546,11 +603,7 @@ def eliminate_nop_pads(g):
all_zero = False all_zero = False
if not all_zero: if not all_zero:
continue continue
# Check if it has the constant_value_node # If this node is the output, set its previous node as output nodes.
constant_value_node = None
if len(node.input) > 2:
constant_value_node = helper.find_node_by_output_name(g, node.input[2])
# If this node is the output node, set its previous node as output nodes.
if helper.find_output_by_name(g, node.output[0]) is not None: if helper.find_output_by_name(g, node.output[0]) is not None:
todel_output = helper.find_output_by_name(g, node.output[0]) todel_output = helper.find_output_by_name(g, node.output[0])
g.output.remove(todel_output) g.output.remove(todel_output)
@ -559,38 +612,44 @@ def eliminate_nop_pads(g):
if the_input_value is not None: if the_input_value is not None:
g.output.extend([the_input_value]) g.output.extend([the_input_value])
# Replace the parents in all the following nodes # Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
for following_node in following_nodes: for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0]) modhelper.replace_node_input(
following_node, node.output[0], node.input[0]
)
# Delete value info # Delete value info
value_between = helper.find_value_by_name(g, node.output[0]) value_between = helper.find_value_by_name(g, node.output[0])
try: try:
g.value_info.remove(value_between) g.value_info.remove(value_between)
except: except Exception:
helper.logger.info("No value info to delete while eliminating identity layers.") helper.logger.info(
"No value info to delete while eliminating identity layers."
)
# Node is waiting for elimination # Node is waiting for elimination
node_to_remove.append(node) node_to_remove.append(node)
for node in node_to_remove: for node in node_to_remove:
g.node.remove(node) g.node.remove(node)
def eliminate_trivial_elementwise_calculation(g): def eliminate_trivial_elementwise_calculation(g):
"""Eliminate Add, Sub, Mul, Sub nodes which do nothing. """Eliminate Add, Sub, Mul, Sub nodes which do nothing."""
"""
node_to_remove = [] node_to_remove = []
for node in g.node: for node in g.node:
weight_node = None weight_node = None
if node.op_type == 'Add' or node.op_type == 'Sub': if node.op_type == "Add" or node.op_type == "Sub":
# For add and sub, check if the weights are 0s. # For add and sub, check if the weights are 0s.
weight_node = helper.find_node_by_output_name(g, node.input[1]) weight_node = helper.find_node_by_output_name(g, node.input[1])
if weight_node is None or weight_node.op_type != 'Constant': if weight_node is None or weight_node.op_type != "Constant":
continue continue
weight_np = helper.constant_to_numpy(weight_node) weight_np = helper.constant_to_numpy(weight_node)
if np.any(weight_np): if np.any(weight_np):
continue continue
elif node.op_type == 'Mul' or node.op_type == 'Div': elif node.op_type == "Mul" or node.op_type == "Div":
# For Mul and Div, check if the weights are 1s. # For Mul and Div, check if the weights are 1s.
weight_node = helper.find_node_by_output_name(g, node.input[1]) weight_node = helper.find_node_by_output_name(g, node.input[1])
if weight_node is None or weight_node.op_type != 'Constant': if weight_node is None or weight_node.op_type != "Constant":
continue continue
weight_np = helper.constant_to_numpy(weight_node) weight_np = helper.constant_to_numpy(weight_node)
weight_np = weight_np - 1 weight_np = weight_np - 1
@ -605,9 +664,13 @@ def eliminate_trivial_elementwise_calculation(g):
if output_value_info is not None: if output_value_info is not None:
g.value_info.remove(output_value_info) g.value_info.remove(output_value_info)
# Replace next node input if any. # Replace next node input if any.
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
for following_node in following_nodes: for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0]) modhelper.replace_node_input(
following_node, node.output[0], node.input[0]
)
todel_output = helper.find_output_by_name(g, node.output[0]) todel_output = helper.find_output_by_name(g, node.output[0])
if todel_output is not None: if todel_output is not None:
g.output.remove(todel_output) g.output.remove(todel_output)
@ -616,38 +679,53 @@ def eliminate_trivial_elementwise_calculation(g):
the_input_value = helper.find_value_by_name(g, node.input[0]) the_input_value = helper.find_value_by_name(g, node.input[0])
g.output.extend([the_input_value]) g.output.extend([the_input_value])
# Delete the constant node if it is not used by other nodes # Delete the constant node if it is not used by other nodes
constant_following_nodes = helper.find_following_nodes_by_input_value_name(g, weight_node.output[0]) constant_following_nodes = (
helper.find_following_nodes_by_input_value_name(
g, weight_node.output[0]
)
)
if len(constant_following_nodes) == 1: if len(constant_following_nodes) == 1:
node_to_remove.append(weight_node) node_to_remove.append(weight_node)
output_value_info = helper.find_value_by_name(g, weight_node.output[0]) output_value_info = helper.find_value_by_name(
g, weight_node.output[0]
)
if output_value_info is not None: if output_value_info is not None:
g.value_info.remove(output_value_info) g.value_info.remove(output_value_info)
for node in node_to_remove: for node in node_to_remove:
g.node.remove(node) g.node.remove(node)
def eliminate_nop_cast(g): def eliminate_nop_cast(g):
"""Eliminate do nothing Cast nodes. """Eliminate do nothing Cast nodes."""
"""
node_to_remove = [] node_to_remove = []
for node in g.node: for node in g.node:
if node.op_type != 'Cast': if node.op_type != "Cast":
continue continue
# Get input value_info # Get input value_info
input_value = helper.find_value_by_name(g, node.input[0]) input_value = helper.find_value_by_name(g, node.input[0])
if input_value is None: if input_value is None:
helper.logger.debug(f"Cannot find the input value_info for Cast node {node.name}. Skip elimination check.") helper.logger.debug(
f"Cannot find the input value_info for Cast node {node.name}. "
"Skip elimination check."
)
continue continue
# Get output value_info # Get output value_info
output_value = helper.find_value_by_name(g, node.output[0]) output_value = helper.find_value_by_name(g, node.output[0])
if output_value is None: if output_value is None:
output_value = helper.find_output_by_name(g, node.output[0]) output_value = helper.find_output_by_name(g, node.output[0])
if output_value is None: if output_value is None:
helper.logger.debug(f"Cannot find the output value_info for Cast node {node.name}. Skip elimination check.") helper.logger.debug(
f"Cannot find the output value_info for Cast node {node.name}."
" Skip elimination check."
)
continue continue
# Compare the type. # Compare the type.
if input_value.type.tensor_type.elem_type != output_value.type.tensor_type.elem_type: if (
input_value.type.tensor_type.elem_type
!= output_value.type.tensor_type.elem_type
):
continue continue
# If this node is the output node, set its previous node as output nodes. # If this node is the output, set its previous node as output nodes.
if helper.find_output_by_name(g, node.output[0]) is not None: if helper.find_output_by_name(g, node.output[0]) is not None:
todel_output = helper.find_output_by_name(g, node.output[0]) todel_output = helper.find_output_by_name(g, node.output[0])
g.output.remove(todel_output) g.output.remove(todel_output)
@ -656,9 +734,13 @@ def eliminate_nop_cast(g):
if the_input_value is not None: if the_input_value is not None:
g.output.extend([the_input_value]) g.output.extend([the_input_value])
# Replace the parents in all the following nodes # Replace the parents in all the following nodes
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
for following_node in following_nodes: for following_node in following_nodes:
modhelper.replace_node_input(following_node, node.output[0], node.input[0]) modhelper.replace_node_input(
following_node, node.output[0], node.input[0]
)
# Delete value info # Delete value info
value_between = helper.find_value_by_name(g, node.output[0]) value_between = helper.find_value_by_name(g, node.output[0])
if value_between is not None: if value_between is not None:

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,11 @@
from collections import deque from collections import deque
class Node: class Node:
"""A Node which maps a node proto. It has pointers to its parents and """A Node which maps a node proto. It has pointers to its parents and
children. children.
""" """
def __init__(self, onnx_node): def __init__(self, onnx_node):
"""Initialize a node. This initialization only set up the mapping to """Initialize a node. This initialization only set up the mapping to
node proto. The pointers should be set up by outside. node proto. The pointers should be set up by outside.
@ -17,12 +19,12 @@ class Node:
self.name = onnx_node.name self.name = onnx_node.name
self.proto = onnx_node self.proto = onnx_node
class Graph: class Graph:
"""A graph which is constructed from the onnx proto. """A graph which is constructed from the onnx proto."""
"""
def __init__(self, onnx_graph): def __init__(self, onnx_graph):
"""Construct the graph from onnx. """Construct the graph from onnx."""
"""
self.input_nodes = [] self.input_nodes = []
self.output_nodes = [] self.output_nodes = []
self.name2node = {} self.name2node = {}
@ -51,9 +53,9 @@ class Graph:
for value in onnx_graph.value_info: for value in onnx_graph.value_info:
node = self.output2node[value.name] node = self.output2node[value.name]
node.output_value = value node.output_value = value
def get_sorted_node_list(self): def get_sorted_node_list(self):
"""Return a node list in topological order. """Return a node list in topological order."""
"""
visited = set() visited = set()
todo = deque() todo = deque()
result = [] result = []

View File

@ -6,21 +6,26 @@ import struct
import numpy as np import numpy as np
import logging import logging
__ONNX_VERSION__ = -1 __ONNX_VERSION__ = -1
logger = logging.getLogger("optimizer_scripts") logger = logging.getLogger("optimizer_scripts")
def setup_current_opset_version(m): def setup_current_opset_version(m):
global __ONNX_VERSION__ global __ONNX_VERSION__
__ONNX_VERSION__ = m.opset_import[0].version __ONNX_VERSION__ = m.opset_import[0].version
if __ONNX_VERSION__ not in [11]: if __ONNX_VERSION__ not in [11]:
raise RuntimeError('Only support opset 11, but got ' + str(__ONNX_VERSION__)) raise RuntimeError(
"Only support opset 11, but got " + str(__ONNX_VERSION__)
)
def get_current_opset_version(): def get_current_opset_version():
if __ONNX_VERSION__ == -1: if __ONNX_VERSION__ == -1:
raise RuntimeError('do setup_current_opset_version first please') raise RuntimeError("do setup_current_opset_version first please")
return __ONNX_VERSION__ return __ONNX_VERSION__
def find_nodes_by_input_name(g, name): def find_nodes_by_input_name(g, name):
nodes = [] nodes = []
for node in g.node: for node in g.node:
@ -28,6 +33,7 @@ def find_nodes_by_input_name(g, name):
nodes.append(node) nodes.append(node)
return nodes return nodes
def find_node_by_output_name(g, name): def find_node_by_output_name(g, name):
""" """
Find a node in the graph by its output name Find a node in the graph by its output name
@ -41,6 +47,7 @@ def find_node_by_output_name(g, name):
return i return i
return None return None
def find_node_by_node_name(g, name): def find_node_by_node_name(g, name):
""" """
Find a node in the graph by its output name Find a node in the graph by its output name
@ -54,6 +61,7 @@ def find_node_by_node_name(g, name):
return i return i
return None return None
def find_following_nodes_by_input_value_name(g, name): def find_following_nodes_by_input_value_name(g, name):
""" Find the following nodes of a specific value. """ Find the following nodes of a specific value.
@ -63,6 +71,7 @@ def find_following_nodes_by_input_value_name(g, name):
""" """
return find_nodes_by_input_name(g, name) return find_nodes_by_input_name(g, name)
def find_value_by_name(g, name): def find_value_by_name(g, name):
""" """
Find a value_info in the graph by name Find a value_info in the graph by name
@ -76,6 +85,7 @@ def find_value_by_name(g, name):
return i return i
return None return None
def find_output_by_name(g, name): def find_output_by_name(g, name):
""" """
Find a value_info in the graph by name Find a value_info in the graph by name
@ -89,6 +99,7 @@ def find_output_by_name(g, name):
return i return i
return None return None
def find_input_by_name(g, name): def find_input_by_name(g, name):
""" """
Find a input in the graph by name Find a input in the graph by name
@ -102,6 +113,7 @@ def find_input_by_name(g, name):
return i return i
return None return None
def list_to_constant(name, shape, data, data_type=None): def list_to_constant(name, shape, data, data_type=None):
"""Generate a constant node using the given infomation. """Generate a constant node using the given infomation.
@ -119,18 +131,9 @@ def list_to_constant(name, shape, data, data_type=None):
data_type = onnx.helper.TensorProto.INT64 data_type = onnx.helper.TensorProto.INT64
else: else:
data_type = onnx.helper.TensorProto.FLOAT data_type = onnx.helper.TensorProto.FLOAT
tensor = onnx.helper.make_tensor( tensor = onnx.helper.make_tensor(name, data_type, shape, data)
name,
data_type,
shape,
data
)
new_w_node = onnx.helper.make_node( new_w_node = onnx.helper.make_node(
"Constant", "Constant", [], [name], name=name, value=tensor
[],
[name],
name = name,
value = tensor
) )
return new_w_node return new_w_node
@ -151,18 +154,9 @@ def scaler_to_constant(name, data, data_type=None):
else: else:
logger.error("Cannot create scaler constant with a list.") logger.error("Cannot create scaler constant with a list.")
exit(1) exit(1)
tensor = onnx.helper.make_tensor( tensor = onnx.helper.make_tensor(name, data_type, None, [data])
name,
data_type,
None,
[data]
)
new_w_node = onnx.helper.make_node( new_w_node = onnx.helper.make_node(
"Constant", "Constant", [], [name], name=name, value=tensor
[],
[name],
name = name,
value = tensor
) )
return new_w_node return new_w_node
@ -170,6 +164,7 @@ def scaler_to_constant(name, data, data_type=None):
def numpy_to_constant(name, np_array): def numpy_to_constant(name, np_array):
return list_to_constant(name, np_array.shape, np_array.flatten().tolist()) return list_to_constant(name, np_array.shape, np_array.flatten().tolist())
def constant_to_list(node): def constant_to_list(node):
"""Generate a list from the constant node """Generate a list from the constant node
@ -184,27 +179,27 @@ def constant_to_list(node):
if len(tensor.int32_data) != 0: if len(tensor.int32_data) != 0:
data = list(tensor.int32_data) data = list(tensor.int32_data)
else: else:
data = [i[0] for i in struct.iter_unpack('i', tensor.raw_data)] data = [i[0] for i in struct.iter_unpack("i", tensor.raw_data)]
elif tensor.data_type == onnx.helper.TensorProto.INT64: elif tensor.data_type == onnx.helper.TensorProto.INT64:
if len(tensor.int64_data) != 0: if len(tensor.int64_data) != 0:
data = list(tensor.int64_data) data = list(tensor.int64_data)
else: else:
data = [i[0] for i in struct.iter_unpack('q', tensor.raw_data)] data = [i[0] for i in struct.iter_unpack("q", tensor.raw_data)]
elif tensor.data_type == onnx.helper.TensorProto.INT8: elif tensor.data_type == onnx.helper.TensorProto.INT8:
if len(tensor.int32_data) != 0: if len(tensor.int32_data) != 0:
data = list(tensor.int32_data) data = list(tensor.int32_data)
else: else:
data = [i[0] for i in struct.iter_unpack('b', tensor.raw_data)] data = [i[0] for i in struct.iter_unpack("b", tensor.raw_data)]
elif tensor.data_type == onnx.helper.TensorProto.FLOAT: elif tensor.data_type == onnx.helper.TensorProto.FLOAT:
if len(tensor.float_data) != 0: if len(tensor.float_data) != 0:
data = list(tensor.float_data) data = list(tensor.float_data)
else: else:
data = [i[0] for i in struct.iter_unpack('f', tensor.raw_data)] data = [i[0] for i in struct.iter_unpack("f", tensor.raw_data)]
elif tensor.data_type == onnx.helper.TensorProto.DOUBLE: elif tensor.data_type == onnx.helper.TensorProto.DOUBLE:
if len(tensor.double_data) != 0: if len(tensor.double_data) != 0:
data = list(tensor.double_data) data = list(tensor.double_data)
else: else:
data = [i[0] for i in struct.iter_unpack('d', tensor.raw_data)] data = [i[0] for i in struct.iter_unpack("d", tensor.raw_data)]
else: else:
print("Not supported data type {}".format(tensor.data_type)) print("Not supported data type {}".format(tensor.data_type))
raise RuntimeError raise RuntimeError
@ -214,6 +209,7 @@ def constant_to_list(node):
shape = list(tensor.dims) shape = list(tensor.dims)
return shape, data return shape, data
def constant_to_numpy(node): def constant_to_numpy(node):
"""Generate a numpy array from the constant node """Generate a numpy array from the constant node
@ -223,6 +219,7 @@ def constant_to_numpy(node):
shape, data = constant_to_list(node) shape, data = constant_to_list(node)
return np.array(data).reshape(shape) return np.array(data).reshape(shape)
def all_constant_input(node): def all_constant_input(node):
"""Find the inputs of the given node. If the inputs of this node are all\\ """Find the inputs of the given node. If the inputs of this node are all\\
constant nodes, return True. Otherwise, return False. constant nodes, return True. Otherwise, return False.
@ -234,24 +231,26 @@ def all_constant_input(node):
return False return False
isConstant = True isConstant = True
for parent in node.parents: for parent in node.parents:
if parent.proto is None or parent.proto.op_type != 'Constant': if parent.proto is None or parent.proto.op_type != "Constant":
isConstant = False isConstant = False
break break
return isConstant return isConstant
def get_padding(size, kernel_size, strides): def get_padding(size, kernel_size, strides):
""" Calculate the padding array for same padding in the Tensorflow fashion.\\ """ Calculate the padding array for same padding in the Tensorflow fashion.\\
See https://www.tensorflow.org/api_guides/python/nn#Convolution for more. See https://www.tensorflow.org/api_guides/python/nn#Convolution for more.
""" """
if size[0] % strides[0] == 0: if size[0] % strides[0] == 0:
pad_h = max(kernel_size[0] - strides[0], 0) pad_h = max(kernel_size[0] - strides[0], 0)
else: else:
pad_h = max(kernel_size[0] - (size[0] % strides[0]), 0) pad_h = max(kernel_size[0] - (size[0] % strides[0]), 0)
if size[1] % strides[1] == 0: if size[1] % strides[1] == 0:
pad_w = max(kernel_size[1] - strides[1], 0) pad_w = max(kernel_size[1] - strides[1], 0)
else: else:
pad_w = max(kernel_size[1] - (size[1] % strides[1]), 0) pad_w = max(kernel_size[1] - (size[1] % strides[1]), 0)
return [pad_h//2, pad_w//2, pad_h-pad_h//2, pad_w-pad_w//2] return [pad_h // 2, pad_w // 2, pad_h - pad_h // 2, pad_w - pad_w // 2]
def get_shape_from_value_info(value): def get_shape_from_value_info(value):
"""Get shape from a value info. """Get shape from a value info.
@ -261,12 +260,13 @@ def get_shape_from_value_info(value):
""" """
return [d.dim_value for d in value.type.tensor_type.shape.dim] return [d.dim_value for d in value.type.tensor_type.shape.dim]
def find_size_shape_from_value(value): def find_size_shape_from_value(value):
''' """
Find the size of data within the value_info object. Find the size of data within the value_info object.
:param value: value_info :param value: value_info
:return: int size and list shape of the data in the value_info :return: int size and list shape of the data in the value_info
''' """
if not value: if not value:
return None, None return None, None
if not value.type.tensor_type.shape.dim: if not value.type.tensor_type.shape.dim:
@ -292,6 +292,7 @@ def get_attribute_by_name(node, attr_name):
return attr return attr
return None return None
def get_list_attribute_by_name(node, attr_name: str, attr_type: str): def get_list_attribute_by_name(node, attr_name: str, attr_type: str):
"""Get list attribute with specific name in the given node proto. """Get list attribute with specific name in the given node proto.
@ -317,12 +318,13 @@ def get_list_attribute_by_name(node, attr_name: str, attr_type: str):
print("Warning: undefined type for list attribute extraction") print("Warning: undefined type for list attribute extraction")
return None return None
def get_var_attribute_by_name(node, attr_name: str, attr_type: str): def get_var_attribute_by_name(node, attr_name: str, attr_type: str):
"""Get variable attribute with specific name in the given node proto. """Get variable attribute with specific name in the given node proto.
:param node: the node proto.\\ :param node: the node proto.
:param attr_name: str for the name of the target.\\ :param attr_name: str for the name of the target.
:param attr_type: str which should be "float", "int", "string" or "tensor".\\ :param attr_type: str which should be "float", "int", "string" or "tensor".
:return: if found, return the variable. Else, return None. :return: if found, return the variable. Else, return None.
""" """
attr_proto = get_attribute_by_name(node, attr_name) attr_proto = get_attribute_by_name(node, attr_name)
@ -333,7 +335,7 @@ def get_var_attribute_by_name(node, attr_name: str, attr_type: str):
elif attr_type == "float": elif attr_type == "float":
return attr_proto.f return attr_proto.f
elif attr_type == "string": elif attr_type == "string":
if type(attr_proto.s) == type(b'abc'): if isinstance(attr_proto.s, bytes):
return attr_proto.s.decode("utf-8") return attr_proto.s.decode("utf-8")
else: else:
return attr_proto.s return attr_proto.s
@ -343,22 +345,25 @@ def get_var_attribute_by_name(node, attr_name: str, attr_type: str):
print("Warning: undefined type for variable attribute extraction") print("Warning: undefined type for variable attribute extraction")
return None return None
def flatten_with_depth(data, depth): def flatten_with_depth(data, depth):
output = [] output = []
if type(data) not in [type(np.array([1])), type([1])]: if type(data) not in [type(np.array([1])), type([1])]:
return [[data, 0]] return [[data, 0]]
for item in data: for item in data:
if type(item) not in [type(np.array([1])), type([1])]: if type(item) not in [type(np.array([1])), type([1])]:
output.append([item, depth+1]) output.append([item, depth + 1])
else: else:
output += flatten_with_depth(item, depth+1) output += flatten_with_depth(item, depth + 1)
return output return output
def flatten_to_list(data): def flatten_to_list(data):
flatten_depth = flatten_with_depth(data, 0) flatten_depth = flatten_with_depth(data, 0)
flat_data = [item[0] for item in flatten_depth] flat_data = [item[0] for item in flatten_depth]
return flat_data return flat_data
def get_shape(data): def get_shape(data):
shape = [] shape = []
if type(data) not in [type(np.array([1])), type([1])]: if type(data) not in [type(np.array([1])), type([1])]:
@ -378,7 +383,7 @@ def slice_data(data, starts, ends, axes):
starts_updated = [] starts_updated = []
ends_updated = [] ends_updated = []
for i in range(len(starts)): for i in range(len(starts)):
start_updated = min(starts[i], shape[i]-1) % shape[i] start_updated = min(starts[i], shape[i] - 1) % shape[i]
starts_updated.append(start_updated) starts_updated.append(start_updated)
for j in range(len(starts)): for j in range(len(starts)):
if ends[j] >= shape[j]: if ends[j] >= shape[j]:
@ -393,19 +398,21 @@ def slice_data(data, starts, ends, axes):
index_slices.append(list(range(shape[i]))) index_slices.append(list(range(shape[i])))
else: else:
axe_ind = axes.index(i) axe_ind = axes.index(i)
index_slices.append(list(range(starts_updated[axe_ind], ends_updated[axe_ind]))) index_slices.append(
list(range(starts_updated[axe_ind], ends_updated[axe_ind]))
)
indices = [1] indices = [1]
for i in range(len(shape)-1, -1, -1): for i in range(len(shape) - 1, -1, -1):
step = np.prod(shape[i+1:]) step = np.prod(shape[i + 1:])
temp_pos = indices temp_pos = indices
new_indices = [] new_indices = []
for n in index_slices[i]: for n in index_slices[i]:
for pos in temp_pos: for pos in temp_pos:
new_indices.append(int(n*step+pos)) new_indices.append(int(n * step + pos))
indices = new_indices indices = new_indices
sliced_data = [flat_data[k-1] for k in indices] sliced_data = [flat_data[k - 1] for k in indices]
# reshape to correct shape. # reshape to correct shape.
new_shape = [] new_shape = []
@ -414,48 +421,51 @@ def slice_data(data, starts, ends, axes):
new_shape.append(shape[i]) new_shape.append(shape[i])
else: else:
axe_ind = axes.index(i) axe_ind = axes.index(i)
new_shape.append(ends_updated[axe_ind]-starts_updated[axe_ind]) new_shape.append(ends_updated[axe_ind] - starts_updated[axe_ind])
if any([dim < 1 for dim in new_shape]): if any([dim < 1 for dim in new_shape]):
raise RuntimeError('Invalid starts ends.') raise RuntimeError("Invalid starts ends.")
sliced_data = np.reshape(sliced_data, new_shape) sliced_data = np.reshape(sliced_data, new_shape)
return sliced_data return sliced_data
def concatenate(data_sets, axis): def concatenate(data_sets, axis):
# check shapes # check shapes
shapes = [] shapes = []
shapes_ = [] shapes_ = []
for data_set in data_sets: for data_set in data_sets:
shape = get_shape(data_set) shape = get_shape(data_set)
shapes.append(list(shape)) shapes.append(list(shape))
shape.pop(axis) shape.pop(axis)
shapes_.append(shape) shapes_.append(shape)
if not all([s == shapes_[0] for s in shapes_]): if not all([s == shapes_[0] for s in shapes_]):
raise RuntimeError('data sets shapes do not match') raise RuntimeError("data sets shapes do not match")
new_dim = sum([s[axis] for s in shapes]) new_dim = sum([s[axis] for s in shapes])
new_shape = list(shapes[0]) new_shape = list(shapes[0])
new_shape[axis] = new_dim new_shape[axis] = new_dim
flat_data_sets = [] flat_data_sets = []
for data_set in data_sets: for data_set in data_sets:
flat_data_sets.append(flatten_to_list(data_set)) flat_data_sets.append(flatten_to_list(data_set))
sub_block_size = 1 sub_block_size = 1
for i in range(axis+1, len(shapes[0])): for i in range(axis + 1, len(shapes[0])):
sub_block_size *= shapes[0][i] sub_block_size *= shapes[0][i]
split_num = 1 split_num = 1
for i in range(axis): for i in range(axis):
split_num *= shapes[0][i] split_num *= shapes[0][i]
total_flat_data = [] total_flat_data = []
for i in range(split_num): for i in range(split_num):
for j in range(len(shapes)): for j in range(len(shapes)):
block_size = sub_block_size*shapes[j][axis] block_size = sub_block_size * shapes[j][axis]
total_flat_data.extend(flat_data_sets[j][i*block_size:(i+1)*block_size]) total_flat_data.extend(
flat_data_sets[j][i * block_size:(i + 1) * block_size]
)
new_data = np.reshape(total_flat_data, new_shape) new_data = np.reshape(total_flat_data, new_shape)
return new_data return new_data
@ -464,158 +474,169 @@ def concatenate(data_sets, axis):
def broadcast_data_sets(data_set_1, data_set_2): def broadcast_data_sets(data_set_1, data_set_2):
shape1 = get_shape(data_set_1) shape1 = get_shape(data_set_1)
shape2 = get_shape(data_set_2) shape2 = get_shape(data_set_2)
# compare shapes and get broadcasted shape # compare shapes and get broadcasted shape
list_a, list_b = (shape1, shape2) if len(shape1) > len(shape2) else (shape2, shape1) list_a, list_b = (
(shape1, shape2) if len(shape1) > len(shape2) else (shape2, shape1)
)
while len(list_a) > len(list_b): while len(list_a) > len(list_b):
list_b.insert(0, 0) list_b.insert(0, 0)
broadcasted_shape = [] broadcasted_shape = []
for i in range(len(list_a)): for i in range(len(list_a)):
if list_b[i] == 0: if list_b[i] == 0:
broadcasted_shape.append(list_a[i]) broadcasted_shape.append(list_a[i])
elif list_b[i] == 1: elif list_b[i] == 1:
broadcasted_shape.append(list_a[i]) broadcasted_shape.append(list_a[i])
elif list_a[i] == 1: elif list_a[i] == 1:
broadcasted_shape.append(list_b[i]) broadcasted_shape.append(list_b[i])
elif list_a[i] == list_b[i]: elif list_a[i] == list_b[i]:
broadcasted_shape.append(list_a[i]) broadcasted_shape.append(list_a[i])
else: else:
raise RuntimeError('Can not broadcast two data sets') raise RuntimeError("Can not broadcast two data sets")
# prepare data for broadcasting. # prepare data for broadcasting.
shape1 = list(map(lambda x:x if x != 0 else 1, shape1)) shape1 = list(map(lambda x: x if x != 0 else 1, shape1))
shape2 = list(map(lambda x:x if x != 0 else 1, shape2)) shape2 = list(map(lambda x: x if x != 0 else 1, shape2))
data_1 = np.reshape(data_set_1, shape1) data_1 = np.reshape(data_set_1, shape1)
data_2 = np.reshape(data_set_2, shape2) data_2 = np.reshape(data_set_2, shape2)
for i in range(len(shape1)): for i in range(len(shape1)):
if shape1[i] != broadcasted_shape[i]: if shape1[i] != broadcasted_shape[i]:
new_data_total = [list(data_1) for _ in range(broadcasted_shape[i])] new_data_total = [
data_1 = concatenate(new_data_total, axis=i) list(data_1) for _ in range(broadcasted_shape[i])
]
data_1 = concatenate(new_data_total, axis=i)
for i in range(len(shape2)): for i in range(len(shape2)):
if shape2[i] != broadcasted_shape[i]: if shape2[i] != broadcasted_shape[i]:
new_data_total = [list(data_2) for _ in range(broadcasted_shape[i])] new_data_total = [
data_2 = concatenate(new_data_total, axis=i) list(data_2) for _ in range(broadcasted_shape[i])
]
data_2 = concatenate(new_data_total, axis=i)
return data_1, data_2 return data_1, data_2
def add(data_set_1, data_set_2): def add(data_set_1, data_set_2):
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2) broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(
data_set_1, data_set_2
)
flat_data_1 = flatten_to_list(broadcasted_data_1) flat_data_1 = flatten_to_list(broadcasted_data_1)
flat_data_2 = flatten_to_list(broadcasted_data_2) flat_data_2 = flatten_to_list(broadcasted_data_2)
shape = get_shape(broadcasted_data_1) shape = get_shape(broadcasted_data_1)
res = [] res = []
for i in range(len(flat_data_1)): for i in range(len(flat_data_1)):
res.append(flat_data_1[i]+flat_data_2[i]) res.append(flat_data_1[i] + flat_data_2[i])
res = np.reshape(res, shape)
return res res = np.reshape(res, shape)
return res
def reduceprod(data_set, axis, keepdims=1): def reduceprod(data_set, axis, keepdims=1):
flat_data = flatten_to_list(data_set) flat_data = flatten_to_list(data_set)
old_shape = get_shape(data_set) old_shape = get_shape(data_set)
temp_shape = old_shape temp_shape = old_shape
temp_flat_data = flat_data temp_flat_data = flat_data
for ax in axis: for ax in axis:
split_num = 1 split_num = 1
step = 1 step = 1
for i in range(ax): for i in range(ax):
split_num *= temp_shape[i] split_num *= temp_shape[i]
for i in range(ax+1, len(temp_shape)): for i in range(ax + 1, len(temp_shape)):
step *= temp_shape[i] step *= temp_shape[i]
block_size = len(temp_flat_data)//split_num block_size = len(temp_flat_data) // split_num
new_flat_data = [] new_flat_data = []
for j in range(split_num): for j in range(split_num):
block_data = temp_flat_data[j*block_size:(j+1)*block_size] block_data = temp_flat_data[j * block_size:(j + 1) * block_size]
reduced_block_data = [] reduced_block_data = []
for k in range(step): for k in range(step):
val = block_data[k] val = block_data[k]
for l in range(1, block_size//step): for li in range(1, block_size // step):
val *= block_data[k+l*step] val *= block_data[k + li * step]
reduced_block_data.append(val) reduced_block_data.append(val)
new_flat_data.extend(reduced_block_data) new_flat_data.extend(reduced_block_data)
temp_flat_data = new_flat_data temp_flat_data = new_flat_data
temp_shape[ax] = 1 temp_shape[ax] = 1
new_flat_data = temp_flat_data new_flat_data = temp_flat_data
new_shape = temp_shape new_shape = temp_shape
if not keepdims: if not keepdims:
axis = sorted(list(axis)) axis = sorted(list(axis))
for pos in axis[::-1]: for pos in axis[::-1]:
new_shape.pop(pos) new_shape.pop(pos)
return np.reshape(new_flat_data, new_shape) return np.reshape(new_flat_data, new_shape)
def transpose(data_set, permutation): def transpose(data_set, permutation):
# find series of local swaps # find series of local swaps
data_set = list(data_set) data_set = list(data_set)
perm = list(permutation) perm = list(permutation)
shape = get_shape(data_set) shape = get_shape(data_set)
flat_data = flatten_to_list(data_set) flat_data = flatten_to_list(data_set)
assert set(perm) == set(range(len(shape))), 'invalid permutation' assert set(perm) == set(range(len(shape))), "invalid permutation"
new_shape = [shape[i] for i in perm] new_shape = [shape[i] for i in perm]
swaps = [] swaps = []
bubbled = True bubbled = True
while bubbled: while bubbled:
bubbled = False bubbled = False
for i in range(len(new_shape)-1): for i in range(len(new_shape) - 1):
if perm[i] > perm[i+1]: if perm[i] > perm[i + 1]:
swaps.append([i, i+1]) swaps.append([i, i + 1])
p_1, p_2 = perm[i], perm[i+1] p_1, p_2 = perm[i], perm[i + 1]
perm[i], perm[i+1] = p_2, p_1 perm[i], perm[i + 1] = p_2, p_1
bubbled = True bubbled = True
# apply local swaps
current_shape = list(shape)
temp_flat_data = flat_data
for swap in swaps[::-1]: # apply local swaps
ind_1, ind_2 = swap[0], swap[1] current_shape = list(shape)
dim_1 = current_shape[ind_1] temp_flat_data = flat_data
dim_2 = current_shape[ind_2]
split_num = 1
block_size = 1
for i in range(ind_1): for swap in swaps[::-1]:
split_num *= current_shape[i] ind_1, ind_2 = swap[0], swap[1]
for i in range(ind_2+1, len(current_shape)): dim_1 = current_shape[ind_1]
block_size *= current_shape[i] dim_2 = current_shape[ind_2]
split_num = 1
block_size = 1
data_blocks = np.reshape(temp_flat_data, [-1, block_size]) for i in range(ind_1):
flat_data_1 = [] split_num *= current_shape[i]
for k in range(split_num): for i in range(ind_2 + 1, len(current_shape)):
block = [] block_size *= current_shape[i]
for m in range(dim_2):
for n in range(dim_1):
block_pos = k*dim_1*dim_2 + n*dim_2+m
block.extend(data_blocks[block_pos])
flat_data_1.extend(block)
temp_flat_data = flat_data_1 data_blocks = np.reshape(temp_flat_data, [-1, block_size])
current_shape[ind_1] = dim_2 flat_data_1 = []
current_shape[ind_2] = dim_1 for k in range(split_num):
block = []
for m in range(dim_2):
for n in range(dim_1):
block_pos = k * dim_1 * dim_2 + n * dim_2 + m
block.extend(data_blocks[block_pos])
flat_data_1.extend(block)
temp_flat_data = flat_data_1
current_shape[ind_1] = dim_2
current_shape[ind_2] = dim_1
return np.reshape(temp_flat_data, current_shape)
return np.reshape(temp_flat_data, current_shape)
def subtract(data_set_1, data_set_2): def subtract(data_set_1, data_set_2):
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2) broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(
data_set_1, data_set_2
)
shape = get_shape(broadcasted_data_1) shape = get_shape(broadcasted_data_1)
flat_data_1 = flatten_to_list(broadcasted_data_1) flat_data_1 = flatten_to_list(broadcasted_data_1)
flat_data_2 = flatten_to_list(broadcasted_data_2) flat_data_2 = flatten_to_list(broadcasted_data_2)
substracted_data = [flat_data_1[i] - flat_data_2[i] for i in range(len(flat_data_1))] substracted_data = [
flat_data_1[i] - flat_data_2[i] for i in range(len(flat_data_1))
]
new_data = np.reshape(substracted_data, shape) new_data = np.reshape(substracted_data, shape)
return new_data return new_data

View File

@ -1,7 +1,7 @@
"""This module contains helper functions that do graph modifications. """
This module contains helper functions that do graph modifications.
""" """
import onnx
from . import helper from . import helper
@ -10,9 +10,10 @@ def replace_node_input(node, old_input, new_input):
if input_name == old_input: if input_name == old_input:
node.input[i] = new_input node.input[i] = new_input
def delete_nodes(g, node_list): def delete_nodes(g, node_list):
node_to_delete = [] node_to_delete = []
#Find target nodes # Find target nodes
for node in g.node: for node in g.node:
if node.name not in node_list: if node.name not in node_list:
continue continue
@ -23,16 +24,28 @@ def delete_nodes(g, node_list):
for node in node_to_delete: for node in node_to_delete:
# Check the node whether if it is valid to delete # Check the node whether if it is valid to delete
if len(node.input) == 0: if len(node.input) == 0:
print("Deleting an Constant node. Please make sure you also delete all its following nodes") print(
"Deleting an Constant node. "
"Please make sure you also delete all its following nodes"
)
elif len(node.input) > 1: elif len(node.input) > 1:
print("Warning: Node {} has more than one input. This script cannot delete merge nodes.".format(node.name)) print(
f"Warning: Node {node.name} has more than one input. "
"This script cannot delete merge nodes."
)
# Connect the nodes around the target node. # Connect the nodes around the target node.
# Set the following node input as the previous node output. # Set the following node input as the previous node output.
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) following_nodes = helper.find_following_nodes_by_input_value_name(
g, node.output[0]
)
if len(node.input) == 0: if len(node.input) == 0:
for following_node in following_nodes: for following_node in following_nodes:
following_node.input.remove(node.output[0]) following_node.input.remove(node.output[0])
elif len(following_nodes) > 0 and len(node.input) == 1 and helper.find_input_by_name(g, node.input[0]) is not None: elif (
len(following_nodes) > 0
and len(node.input) == 1
and helper.find_input_by_name(g, node.input[0]) is not None
):
# The node input is an input # The node input is an input
new_input = helper.find_value_by_name(g, node.output[0]) new_input = helper.find_value_by_name(g, node.output[0])
g.input.append(new_input) g.input.append(new_input)
@ -40,9 +53,11 @@ def delete_nodes(g, node_list):
g.value_info.remove(new_input) g.value_info.remove(new_input)
elif len(following_nodes) > 0: elif len(following_nodes) > 0:
for following_node in following_nodes: for following_node in following_nodes:
replace_node_input(following_node, node.output[0], node.input[0]) replace_node_input(
following_node, node.output[0], node.input[0]
)
else: else:
# If the node is the output, replace the output with the previous input. # If the node is the output, replace it with previous input.
value = helper.find_value_by_name(g, node.input[0]) value = helper.find_value_by_name(g, node.input[0])
output_values = [] output_values = []
while len(g.output): while len(g.output):
@ -56,6 +71,7 @@ def delete_nodes(g, node_list):
# Remove the node and value info. # Remove the node and value info.
g.node.remove(node) g.node.remove(node)
def delete_input(g, target_list): def delete_input(g, target_list):
for name in target_list: for name in target_list:
input_value = helper.find_input_by_name(g, name) input_value = helper.find_input_by_name(g, name)
@ -64,6 +80,7 @@ def delete_input(g, target_list):
continue continue
g.input.remove(input_value) g.input.remove(input_value)
def delete_output(g, target_list): def delete_output(g, target_list):
for name in target_list: for name in target_list:
output_value = helper.find_output_by_name(g, name) output_value = helper.find_output_by_name(g, name)
@ -72,6 +89,7 @@ def delete_output(g, target_list):
continue continue
g.output.remove(output_value) g.output.remove(output_value)
def delete_value_with_name_if_exists(g, name): def delete_value_with_name_if_exists(g, name):
value = helper.find_value_by_name(g, name) value = helper.find_value_by_name(g, name)
if value is not None: if value is not None:

File diff suppressed because it is too large Load Diff

View File

@ -1,317 +1,368 @@
from . import helper from . import helper
from . import other from . import other
from . import modhelper from . import modhelper
from . import fusing
import numpy as np import numpy as np
import onnx import onnx
import onnx.utils import onnx.utils
def eliminate_transposes(m):
g = m.graph
keep_eliminating = True
while keep_eliminating:
while swap_transpose_with_single_next_node(g):
pass
splitted = split_transpose_for_multiple_next_nodes(g)
annihilated = annihilate_transposes(g)
multiple_trans_swapped = swap_multiple_transposes_with_node(g)
keep_eliminating = splitted or annihilated or multiple_trans_swapped
if keep_eliminating: def eliminate_transposes(m):
m = other.polish_model(m) g = m.graph
g = m.graph keep_eliminating = True
while keep_eliminating:
return m while swap_transpose_with_single_next_node(g):
pass
splitted = split_transpose_for_multiple_next_nodes(g)
annihilated = annihilate_transposes(g)
multiple_trans_swapped = swap_multiple_transposes_with_node(g)
keep_eliminating = splitted or annihilated or multiple_trans_swapped
if keep_eliminating:
m = other.polish_model(m)
g = m.graph
return m
def swap_transpose_with_single_next_node(g): def swap_transpose_with_single_next_node(g):
swapped = False swapped = False
passable_nodes = set(['Relu', 'Neg', 'LeakyRelu', 'Sqrt', 'Reciprocal', 'Add', 'Mul', 'Tanh']) passable_nodes = set(
for node in g.node: [
trans_node = node "Relu",
# Check for transpose node "Neg",
if trans_node.op_type != 'Transpose': "LeakyRelu",
continue "Sqrt",
next_nodes = helper.find_nodes_by_input_name(g, trans_node.output[0]) "Reciprocal",
if len(next_nodes) != 1: "Add",
continue "Mul",
next_node = next_nodes[0] "Tanh",
# Check if the next node is the type can be swapped ]
if next_node.op_type not in passable_nodes: )
continue for node in g.node:
trans_node = node
# Check for transpose node
if trans_node.op_type != "Transpose":
continue
next_nodes = helper.find_nodes_by_input_name(g, trans_node.output[0])
if len(next_nodes) != 1:
continue
next_node = next_nodes[0]
# Check if the next node is the type can be swapped
if next_node.op_type not in passable_nodes:
continue
input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in next_node.input] input_nodes = [
helper.find_node_by_output_name(g, input_name)
for input_name in next_node.input
]
# Check if the node has nonconstant input other than the Transpose node itself # Check if the node has nonconstant input
nonconstant_input = False # other than the Transpose node itself
for input_node in input_nodes: nonconstant_input = False
if input_node == None: for input_node in input_nodes:
nonconstant_input = True if input_node is None:
break nonconstant_input = True
if input_node.name == trans_node.name: break
continue if input_node.name == trans_node.name:
elif input_node.op_type == 'Constant': continue
continue elif input_node.op_type == "Constant":
else: continue
nonconstant_input = True else:
break nonconstant_input = True
if nonconstant_input: break
continue if nonconstant_input:
continue
for input_node in input_nodes: for input_node in input_nodes:
if input_node.name == trans_node.name: if input_node.name == trans_node.name:
# if the input is just the transpose node # if the input is just the transpose node
next_value_info = helper.find_value_by_name(g, next_node.output[0]) next_value_info = helper.find_value_by_name(
mid_value_info = helper.find_value_by_name(g, trans_node.output[0]) g, next_node.output[0]
)
mid_value_info = helper.find_value_by_name(
g, trans_node.output[0]
)
output_nodes = helper.find_nodes_by_input_name(g, next_node.output[0]) output_nodes = helper.find_nodes_by_input_name(
for out_node in output_nodes: g, next_node.output[0]
modhelper.replace_node_input(out_node, next_node.output[0], trans_node.name) )
for out_node in output_nodes:
modhelper.replace_node_input(
out_node, next_node.output[0], trans_node.name
)
next_node.input[0] = trans_node.input[0] next_node.input[0] = trans_node.input[0]
next_node.output[0] = next_node.name next_node.output[0] = next_node.name
trans_node.input[0] = next_node.name trans_node.input[0] = next_node.name
trans_node.output[0] = trans_node.name trans_node.output[0] = trans_node.name
if next_value_info: if next_value_info:
next_value_info.name = trans_node.name next_value_info.name = trans_node.name
if mid_value_info: if mid_value_info:
g.value_info.remove(mid_value_info) g.value_info.remove(mid_value_info)
else: else:
# if the input is a constant node # if the input is a constant node
old_tensor = input_node.attribute[0].t old_tensor = input_node.attribute[0].t
old_shape, data = helper.constant_to_list(input_node) old_shape, data = helper.constant_to_list(input_node)
# If the constant node is a scaler, no action is needed # If the constant node is a scaler, no action is needed
if type(old_shape) == int: if type(old_shape) == int:
old_shape = [old_shape] old_shape = [old_shape]
permutation = list(trans_node.attribute[0].ints) permutation = list(trans_node.attribute[0].ints)
while len(old_shape) < len(permutation): while len(old_shape) < len(permutation):
old_shape.insert(0, 1) old_shape.insert(0, 1)
np_data = np.reshape(data, old_shape) np_data = np.reshape(data, old_shape)
reverse_perm = [] reverse_perm = []
for i in range(len(permutation)): for i in range(len(permutation)):
reverse_perm.append(permutation.index(i)) reverse_perm.append(permutation.index(i))
np_data = np.transpose(np_data, reverse_perm) np_data = np.transpose(np_data, reverse_perm)
new_shape = np_data.shape new_shape = np_data.shape
new_tensor = onnx.helper.make_tensor( new_tensor = onnx.helper.make_tensor(
name=old_tensor.name, name=old_tensor.name,
data_type=old_tensor.data_type, data_type=old_tensor.data_type,
dims=new_shape, dims=new_shape,
vals=np_data.flatten().tolist() vals=np_data.flatten().tolist(),
) )
new_node = onnx.helper.make_node( new_node = onnx.helper.make_node(
'Constant', "Constant",
[], [],
[input_node.output[0]], [input_node.output[0]],
name=input_node.name, name=input_node.name,
value=new_tensor value=new_tensor,
) )
g.node.extend([new_node]) g.node.extend([new_node])
g.value_info.remove(helper.find_value_by_name(g, input_node.output[0])) g.value_info.remove(
g.node.remove(input_node) helper.find_value_by_name(g, input_node.output[0])
)
g.node.remove(input_node)
swapped = True swapped = True
other.topological_sort(g) other.topological_sort(g)
return swapped return swapped
def swap_multiple_transposes_with_node(g): def swap_multiple_transposes_with_node(g):
# here only consider same input transposes # here only consider same input transposes
swapped = False swapped = False
passable_nodes = set(['Add', 'Mul']) passable_nodes = set(["Add", "Mul"])
node_to_del = [] node_to_del = []
for node in g.node: for node in g.node:
if node.op_type not in passable_nodes: if node.op_type not in passable_nodes:
continue continue
input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in node.input] input_nodes = [
if any([input_node == None for input_node in input_nodes]): helper.find_node_by_output_name(g, input_name)
continue for input_name in node.input
if any([input_node.op_type != 'Transpose' for input_node in input_nodes]): ]
continue if any([input_node is None for input_node in input_nodes]):
continue
if any(
[input_node.op_type != "Transpose" for input_node in input_nodes]
):
continue
permutation = list(input_nodes[0].attribute[0].ints) permutation = list(input_nodes[0].attribute[0].ints)
if any([list(input_node.attribute[0].ints) != permutation for input_node in input_nodes]): if any(
continue [
list(input_node.attribute[0].ints) != permutation
for input_name in node.input: for input_node in input_nodes
input_node = helper.find_node_by_output_name(g, input_name) ]
modhelper.replace_node_input(node, input_name, input_node.input[0]) ):
continue
node_to_del.extend(input_nodes) for input_name in node.input:
for input_node in input_nodes: input_node = helper.find_node_by_output_name(g, input_name)
input_val_info = helper.find_value_by_name(g, input_node.output[0]) modhelper.replace_node_input(node, input_name, input_node.input[0])
if input_val_info is not None:
g.value_info.remove(input_val_info)
output_val_info = helper.find_value_by_name(g, node.output[0])
if output_val_info is not None:
g.value_info.remove(output_val_info)
output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) node_to_del.extend(input_nodes)
for i in range(len(output_nodes)): for input_node in input_nodes:
new_trans_node_name = node.name+'_trans_'+str(i) input_val_info = helper.find_value_by_name(g, input_node.output[0])
new_trans_node = onnx.helper.make_node( if input_val_info is not None:
'Transpose', g.value_info.remove(input_val_info)
[node.output[0]], output_val_info = helper.find_value_by_name(g, node.output[0])
[new_trans_node_name], if output_val_info is not None:
name=new_trans_node_name, g.value_info.remove(output_val_info)
perm=permutation
) output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
modhelper.replace_node_input(output_nodes[i], node.output[0], new_trans_node_name) for i in range(len(output_nodes)):
new_trans_node_name = node.name + "_trans_" + str(i)
g.node.extend([new_trans_node]) new_trans_node = onnx.helper.make_node(
"Transpose",
swapped = True [node.output[0]],
[new_trans_node_name],
while node_to_del: name=new_trans_node_name,
node = node_to_del.pop() perm=permutation,
g.node.remove(node) )
modhelper.replace_node_input(
other.topological_sort(g) output_nodes[i], node.output[0], new_trans_node_name
return swapped )
g.node.extend([new_trans_node])
swapped = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)
return swapped
def annihilate_transposes(g): def annihilate_transposes(g):
node_to_del = [] node_to_del = []
annihilated = False annihilated = False
for node in g.node: for node in g.node:
if node.op_type != 'Transpose': if node.op_type != "Transpose":
continue continue
pre_node = helper.find_node_by_output_name(g, node.input[0]) pre_node = helper.find_node_by_output_name(g, node.input[0])
if not pre_node or pre_node.op_type != 'Transpose': if not pre_node or pre_node.op_type != "Transpose":
continue continue
nodes_from_top_transpose = helper.find_nodes_by_input_name(g, pre_node.output[0]) nodes_from_top_transpose = helper.find_nodes_by_input_name(
if len(nodes_from_top_transpose) > 1: g, pre_node.output[0]
continue )
if len(nodes_from_top_transpose) > 1:
perm_1 = list(pre_node.attribute[0].ints) continue
perm_2 = list(node.attribute[0].ints)
if perm_1 != perm_2:
continue
out_nodes = helper.find_nodes_by_input_name(g, node.output[0]) perm_1 = list(pre_node.attribute[0].ints)
for out_node in out_nodes: perm_2 = list(node.attribute[0].ints)
modhelper.replace_node_input(out_node, node.output[0], pre_node.input[0]) if perm_1 != perm_2:
continue
node_to_del.extend([node, pre_node])
mid_value_info = helper.find_value_by_name(g, pre_node.output[0])
out_value_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(mid_value_info)
g.value_info.remove(out_value_info)
annihilated = True out_nodes = helper.find_nodes_by_input_name(g, node.output[0])
while node_to_del: for out_node in out_nodes:
node = node_to_del.pop() modhelper.replace_node_input(
g.node.remove(node) out_node, node.output[0], pre_node.input[0]
)
return annihilated
node_to_del.extend([node, pre_node])
mid_value_info = helper.find_value_by_name(g, pre_node.output[0])
out_value_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(mid_value_info)
g.value_info.remove(out_value_info)
annihilated = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
return annihilated
def split_transpose_for_multiple_next_nodes(g): def split_transpose_for_multiple_next_nodes(g):
splitted = False splitted = False
node_to_del = [] node_to_del = []
for node in g.node: for node in g.node:
if node.op_type != 'Transpose': if node.op_type != "Transpose":
continue continue
output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
if len(output_nodes) < 2: if len(output_nodes) < 2:
continue continue
for i in range(len(output_nodes)): for i in range(len(output_nodes)):
output_node = output_nodes[i] output_node = output_nodes[i]
new_trans_node_name = node.name + '_' + str(i) new_trans_node_name = node.name + "_" + str(i)
new_trans_node = onnx.helper.make_node( new_trans_node = onnx.helper.make_node(
'Transpose', "Transpose",
[node.input[0]], [node.input[0]],
[new_trans_node_name], [new_trans_node_name],
name=new_trans_node_name, name=new_trans_node_name,
perm=list(node.attribute[0].ints) perm=list(node.attribute[0].ints),
) )
modhelper.replace_node_input(output_node, node.output[0], new_trans_node.output[0]) modhelper.replace_node_input(
g.node.extend([new_trans_node]) output_node, node.output[0], new_trans_node.output[0]
)
node_to_del.append(node) g.node.extend([new_trans_node])
val_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(val_info) node_to_del.append(node)
val_info = helper.find_value_by_name(g, node.output[0])
g.value_info.remove(val_info)
splitted = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)
return splitted
splitted = True
while node_to_del:
node = node_to_del.pop()
g.node.remove(node)
other.topological_sort(g)
return splitted
def remove_trivial_transpose(g): def remove_trivial_transpose(g):
node_to_del = [] node_to_del = []
for node in g.node: for node in g.node:
if node.op_type != 'Transpose': if node.op_type != "Transpose":
continue continue
permutation = list(node.attribute[0].ints) permutation = list(node.attribute[0].ints)
if permutation != list(range(len(permutation))): if permutation != list(range(len(permutation))):
continue continue
next_nodes = helper.find_nodes_by_input_name(g, node.output[0]) next_nodes = helper.find_nodes_by_input_name(g, node.output[0])
if not next_nodes: if not next_nodes:
input_val_info = helper.find_value_by_name(g, node.input[0]) input_val_info = helper.find_value_by_name(g, node.input[0])
out_val_info = helper.find_output_by_name(g, node.output[0]) out_val_info = helper.find_output_by_name(g, node.output[0])
if not input_val_info: if not input_val_info:
input_val_info = helper.find_input_by_name(g, node.input[0]) input_val_info = helper.find_input_by_name(g, node.input[0])
g.output.remove(out_val_info) g.output.remove(out_val_info)
g.output.extend([input_val_info]) g.output.extend([input_val_info])
else: else:
out_val_info = helper.find_value_by_name(g, node.output[0]) out_val_info = helper.find_value_by_name(g, node.output[0])
for next_node in next_nodes: for next_node in next_nodes:
modhelper.replace_node_input(next_node, node.output[0], node.input[0]) modhelper.replace_node_input(
g.value_info.remove(out_val_info) next_node, node.output[0], node.input[0]
)
node_to_del.append(node) g.value_info.remove(out_val_info)
while node_to_del: node_to_del.append(node)
node = node_to_del.pop()
g.node.remove(node) while node_to_del:
node = node_to_del.pop()
other.topological_sort(g) g.node.remove(node)
other.topological_sort(g)
def fuse_Transpose_into_Gemm_weight(g): def fuse_Transpose_into_Gemm_weight(g):
node_to_del = [] node_to_del = []
for node in g.node: for node in g.node:
# Check pattern # Check pattern
if node.op_type != 'Gemm': if node.op_type != "Gemm":
continue continue
prev_node = helper.find_node_by_output_name(g, node.input[0]) prev_node = helper.find_node_by_output_name(g, node.input[0])
if prev_node is None or prev_node.op_type != 'Flatten': if prev_node is None or prev_node.op_type != "Flatten":
continue continue
transpose_node = helper.find_node_by_output_name(g, prev_node.input[0]) transpose_node = helper.find_node_by_output_name(g, prev_node.input[0])
if transpose_node.op_type != 'Transpose': if transpose_node.op_type != "Transpose":
continue continue
# Check attribute # Check attribute
perm = helper.get_list_attribute_by_name(transpose_node, 'perm', 'int') perm = helper.get_list_attribute_by_name(transpose_node, "perm", "int")
if perm != [0, 2, 3, 1]: if perm != [0, 2, 3, 1]:
continue continue
transB = helper.get_var_attribute_by_name(node, 'transB', 'int') transB = helper.get_var_attribute_by_name(node, "transB", "int")
if transB is not None and transB == 1: if transB is not None and transB == 1:
continue continue
# Get the original weight # Get the original weight
origin_weight = helper.find_node_by_output_name(g, node.input[1]) origin_weight = helper.find_node_by_output_name(g, node.input[1])
origin_np = helper.constant_to_numpy(origin_weight) origin_np = helper.constant_to_numpy(origin_weight)
# Calculate a new weight # Calculate a new weight
shape = helper.get_shape_from_value_info(helper.find_value_by_name(g, prev_node.input[0])) shape = helper.get_shape_from_value_info(
shape.append(-1) helper.find_value_by_name(g, prev_node.input[0])
new_np = np.reshape(origin_np, shape) )
new_np = np.transpose(new_np, [0, 3, 1, 2, 4]) shape.append(-1)
new_np = np.reshape(new_np, [-1, new_np.shape[-1]]) new_np = np.reshape(origin_np, shape)
new_weight = helper.numpy_to_constant(origin_weight.output[0], new_np) new_np = np.transpose(new_np, [0, 3, 1, 2, 4])
# Replace and eliminate new_np = np.reshape(new_np, [-1, new_np.shape[-1]])
prev_node.input[0] = transpose_node.input[0] new_weight = helper.numpy_to_constant(origin_weight.output[0], new_np)
node_to_del.append(transpose_node) # Replace and eliminate
node_to_del.append(origin_weight) prev_node.input[0] = transpose_node.input[0]
g.value_info.remove(helper.find_value_by_name(g, transpose_node.output[0])) node_to_del.append(transpose_node)
g.node.extend([new_weight]) node_to_del.append(origin_weight)
g.value_info.remove(
helper.find_value_by_name(g, transpose_node.output[0])
)
g.node.extend([new_weight])
while node_to_del: while node_to_del:
node = node_to_del.pop() node = node_to_del.pop()
g.node.remove(node) g.node.remove(node)
other.topological_sort(g) other.topological_sort(g)

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,10 @@
"""Special operations on model. """Special operations on model.
""" """
import logging
import onnx.helper import onnx.helper
import numpy as np import numpy as np
from . import helper from . import helper
from . import other from . import other
from . import modhelper
def change_first_conv_from_bgr_to_rgb(m): def change_first_conv_from_bgr_to_rgb(m):
"""For input channel format BGR model, use this function to change the first """For input channel format BGR model, use this function to change the first
@ -16,12 +15,14 @@ def change_first_conv_from_bgr_to_rgb(m):
# Check for first node. # Check for first node.
g = m.graph g = m.graph
input_name = g.input[0].name input_name = g.input[0].name
first_nodes = helper.find_following_nodes_by_input_value_name(g, input_name) first_nodes = helper.find_following_nodes_by_input_value_name(
g, input_name
)
if len(first_nodes) > 1: if len(first_nodes) > 1:
return False return False
first_node = first_nodes[0] first_node = first_nodes[0]
# Now we have the first node. Check this first node. # Now we have the first node. Check this first node.
if first_node.op_type != 'Conv': if first_node.op_type != "Conv":
return False return False
weight_value = helper.find_value_by_name(g, first_node.input[1]) weight_value = helper.find_value_by_name(g, first_node.input[1])
weight_shape = helper.get_shape_from_value_info(weight_value) weight_shape = helper.get_shape_from_value_info(weight_value)
@ -41,10 +42,12 @@ def change_first_conv_from_bgr_to_rgb(m):
other.topological_sort(g) other.topological_sort(g)
return True return True
def change_input_from_bgr_to_rgb(m): def change_input_from_bgr_to_rgb(m):
"""For input channel format BGR model, use this function to modify the model """
to accepct RGB image.If the first node is a non-group Conv. Modify weight to For input channel format BGR model, use this function to modify the model
adapt the input into RGB. Otherwise create a new node. to accepct RGB image.If the first node is a non-group Conv.
Modify weight to adapt the input into RGB. Otherwise create a new node.
:param m: the model proto :param m: the model proto
""" """
@ -61,34 +64,33 @@ def change_input_from_bgr_to_rgb(m):
return return
# Otherwise, create a special conv node and replace the input # Otherwise, create a special conv node and replace the input
# Construct weight # Construct weight
weight_np = np.zeros((3, 3, 3, 3)).astype('float32') weight_np = np.zeros((3, 3, 3, 3)).astype("float32")
weight_np[0, 2, 1, 1] = 1.0 weight_np[0, 2, 1, 1] = 1.0
weight_np[1, 1, 1, 1] = 1.0 weight_np[1, 1, 1, 1] = 1.0
weight_np[2, 0, 1, 1] = 1.0 weight_np[2, 0, 1, 1] = 1.0
new_weight = helper.numpy_to_constant("bgr_shuffle_weight", weight_np) new_weight = helper.numpy_to_constant("bgr_shuffle_weight", weight_np)
# Construct Conv # Construct Conv
new_conv = onnx.helper.make_node( new_conv = onnx.helper.make_node(
'Conv', "Conv",
['rgb_input', "bgr_shuffle_weight"], ["rgb_input", "bgr_shuffle_weight"],
[g.input[0].name], [g.input[0].name],
name='bgr_shuffle', name="bgr_shuffle",
dilations=[1, 1], dilations=[1, 1],
kernel_shape=[3, 3], kernel_shape=[3, 3],
pads=[1, 1, 1, 1], pads=[1, 1, 1, 1],
strides=[1, 1] strides=[1, 1],
) )
# Connect the graph # Connect the graph
old_input_value = g.input.pop() old_input_value = g.input.pop()
new_input_value = onnx.helper.make_tensor_value_info( new_input_value = onnx.helper.make_tensor_value_info(
'rgb_input', "rgb_input", old_input_value.type.tensor_type.elem_type, input_shape
old_input_value.type.tensor_type.elem_type,
input_shape
) )
g.input.extend([new_input_value]) g.input.extend([new_input_value])
g.node.extend([new_weight, new_conv]) g.node.extend([new_weight, new_conv])
# topological sort # topological sort
other.topological_sort(g) other.topological_sort(g)
def add_0_5_to_normalized_input(m): def add_0_5_to_normalized_input(m):
"""For normalized input between -0.5 ~ 0.5, add 0.5 to the input to keep it """For normalized input between -0.5 ~ 0.5, add 0.5 to the input to keep it
between 0 ~ 1. between 0 ~ 1.
@ -105,41 +107,37 @@ def add_0_5_to_normalized_input(m):
return return
# Construct weight # Construct weight
ch = input_shape[1] ch = input_shape[1]
weight_np = np.zeros((ch, ch, 3, 3)).astype('float32') weight_np = np.zeros((ch, ch, 3, 3)).astype("float32")
for i in range(ch): for i in range(ch):
weight_np[i, i, 1, 1] = 1.0 weight_np[i, i, 1, 1] = 1.0
new_weight = helper.numpy_to_constant("input_norm_weight", weight_np) new_weight = helper.numpy_to_constant("input_norm_weight", weight_np)
# Construct bias # Construct bias
bias_np = np.array([0.5] * ch).astype('float32') bias_np = np.array([0.5] * ch).astype("float32")
new_bias = helper.numpy_to_constant("input_norm_bias", bias_np) new_bias = helper.numpy_to_constant("input_norm_bias", bias_np)
# Construct Conv # Construct Conv
new_conv = onnx.helper.make_node( new_conv = onnx.helper.make_node(
'Conv', "Conv",
['origin_input', "input_norm_weight", "input_norm_bias"], ["origin_input", "input_norm_weight", "input_norm_bias"],
[g.input[0].name], [g.input[0].name],
name='input_norm', name="input_norm",
dilations=[1, 1], dilations=[1, 1],
kernel_shape=[3, 3], kernel_shape=[3, 3],
pads=[1, 1, 1, 1], pads=[1, 1, 1, 1],
strides=[1, 1] strides=[1, 1],
) )
# Construct value_infos # Construct value_infos
old_input_value = g.input.pop() old_input_value = g.input.pop()
weight_value = onnx.helper.make_tensor_value_info( weight_value = onnx.helper.make_tensor_value_info(
'input_norm_weight', "input_norm_weight",
old_input_value.type.tensor_type.elem_type, old_input_value.type.tensor_type.elem_type,
[3, 3, 3, 3] [3, 3, 3, 3],
) )
bias_value = onnx.helper.make_tensor_value_info( bias_value = onnx.helper.make_tensor_value_info(
'input_norm_bias', "input_norm_bias", old_input_value.type.tensor_type.elem_type, [3]
old_input_value.type.tensor_type.elem_type,
[3]
) )
# Connect the graph # Connect the graph
new_input_value = onnx.helper.make_tensor_value_info( new_input_value = onnx.helper.make_tensor_value_info(
'origin_input', "origin_input", old_input_value.type.tensor_type.elem_type, input_shape
old_input_value.type.tensor_type.elem_type,
input_shape
) )
g.input.extend([new_input_value]) g.input.extend([new_input_value])
g.node.extend([new_weight, new_bias, new_conv]) g.node.extend([new_weight, new_bias, new_conv])
@ -147,9 +145,9 @@ def add_0_5_to_normalized_input(m):
# topological sort # topological sort
other.topological_sort(g) other.topological_sort(g)
def add_rgb2yynn_node(m): def add_rgb2yynn_node(m):
"""Add a conv layer which can convert rgb to yynn input. """Add a conv layer which can convert rgb to yynn input."""
"""
g = m.graph g = m.graph
if len(g.input) > 1: if len(g.input) > 1:
print("This model has multiple inputs. Cannot change to rgb input.") print("This model has multiple inputs. Cannot change to rgb input.")
@ -159,37 +157,32 @@ def add_rgb2yynn_node(m):
print("The input shape is not BCHW. Cannot normalize input.") print("The input shape is not BCHW. Cannot normalize input.")
return return
# Construct weight # Construct weight
ch = input_shape[1] weight_np = np.zeros((3, 3, 4, 4)).astype("float32")
weight_np = np.zeros((3, 3, 4, 4)).astype('float32') weight_np[1, 1, :3, :2] = np.array([[[[0.299], [0.587], [0.114]]]])
weight_np[1, 1, :3, :2] = np.array([[[[0.299], weight_np[1, 1, 3, 2:] = 1.0
[0.587],
[0.114]]]])
weight_np[1, 1, 3, 2:] = 1.
weight_np = np.transpose(weight_np, (3, 2, 0, 1)) weight_np = np.transpose(weight_np, (3, 2, 0, 1))
new_weight = helper.numpy_to_constant("input_rgb2yynn_weight", weight_np) new_weight = helper.numpy_to_constant("input_rgb2yynn_weight", weight_np)
# Construct conv node # Construct conv node
new_conv = onnx.helper.make_node( new_conv = onnx.helper.make_node(
'Conv', "Conv",
['new_input', "input_rgb2yynn_weight"], ["new_input", "input_rgb2yynn_weight"],
[g.input[0].name], [g.input[0].name],
name='input_rgba2yynn', name="input_rgba2yynn",
dilations=[1, 1], dilations=[1, 1],
kernel_shape=[3, 3], kernel_shape=[3, 3],
pads=[1, 1, 1, 1], pads=[1, 1, 1, 1],
strides=[1, 1] strides=[1, 1],
) )
# Construct value_infos # Construct value_infos
old_input_value = g.input.pop() old_input_value = g.input.pop()
weight_value = onnx.helper.make_tensor_value_info( weight_value = onnx.helper.make_tensor_value_info(
'input_rgb2yynn_weight', "input_rgb2yynn_weight",
old_input_value.type.tensor_type.elem_type, old_input_value.type.tensor_type.elem_type,
[4, 4, 3, 3] [4, 4, 3, 3],
) )
# Connect the graph # Connect the graph
new_input_value = onnx.helper.make_tensor_value_info( new_input_value = onnx.helper.make_tensor_value_info(
'new_input', "new_input", old_input_value.type.tensor_type.elem_type, input_shape
old_input_value.type.tensor_type.elem_type,
input_shape
) )
g.input.extend([new_input_value]) g.input.extend([new_input_value])
g.node.extend([new_weight, new_conv]) g.node.extend([new_weight, new_conv])
@ -197,6 +190,7 @@ def add_rgb2yynn_node(m):
# topological sort # topological sort
other.topological_sort(g) other.topological_sort(g)
def swap_MatMul_inputs(g, original_matmul_node): def swap_MatMul_inputs(g, original_matmul_node):
# Create Transpose nodes # Create Transpose nodes
input_a_value = helper.find_value_by_name(g, original_matmul_node.input[0]) input_a_value = helper.find_value_by_name(g, original_matmul_node.input[0])
@ -206,11 +200,12 @@ def swap_MatMul_inputs(g, original_matmul_node):
else: else:
perm = [0, 2, 1] perm = [0, 2, 1]
new_input_b_node = onnx.helper.make_node( new_input_b_node = onnx.helper.make_node(
'Transpose', "Transpose",
inputs = [input_a_value.name], inputs=[input_a_value.name],
outputs = [input_a_value.name + '_transposed'], outputs=[input_a_value.name + "_transposed"],
name = f"{input_a_value.name}_transposed_for_{original_matmul_node.name}", name=f"{input_a_value.name}_transposed_for_"
perm = perm f"{original_matmul_node.name}",
perm=perm,
) )
input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1]) input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1])
input_b_shape = helper.get_shape_from_value_info(input_b_value) input_b_shape = helper.get_shape_from_value_info(input_b_value)
@ -219,18 +214,19 @@ def swap_MatMul_inputs(g, original_matmul_node):
else: else:
perm = [0, 1, 3, 2] perm = [0, 1, 3, 2]
new_input_a_node = onnx.helper.make_node( new_input_a_node = onnx.helper.make_node(
'Transpose', "Transpose",
inputs = [input_b_value.name], inputs=[input_b_value.name],
outputs = [input_b_value.name + '_transposed'], outputs=[input_b_value.name + "_transposed"],
name = f'{input_b_value.name}_transposed_for_{original_matmul_node.name}', name=f"{input_b_value.name}_transposed_for_"
perm = perm f"{original_matmul_node.name}",
perm=perm,
) )
# Create new MatMul node # Create new MatMul node
new_matmul_node = onnx.helper.make_node( new_matmul_node = onnx.helper.make_node(
'MatMul', "MatMul",
inputs = [new_input_a_node.output[0], new_input_b_node.output[0]], inputs=[new_input_a_node.output[0], new_input_b_node.output[0]],
outputs = [original_matmul_node.output[0] + '_transposed'], outputs=[original_matmul_node.output[0] + "_transposed"],
name = original_matmul_node.name + '_transposed' name=original_matmul_node.name + "_transposed",
) )
# Create final Transpose node # Create final Transpose node
output_value = helper.find_value_by_name(g, original_matmul_node.output[0]) output_value = helper.find_value_by_name(g, original_matmul_node.output[0])
@ -240,17 +236,25 @@ def swap_MatMul_inputs(g, original_matmul_node):
else: else:
perm = [0, 1, 3, 2] perm = [0, 1, 3, 2]
new_final_transpose_node = onnx.helper.make_node( new_final_transpose_node = onnx.helper.make_node(
'Transpose', "Transpose",
inputs = [new_matmul_node.output[0]], inputs=[new_matmul_node.output[0]],
outputs = [original_matmul_node.output[0]], outputs=[original_matmul_node.output[0]],
name = original_matmul_node.name + '_final_transpose', name=original_matmul_node.name + "_final_transpose",
perm = perm perm=perm,
) )
# Add new nodes # Add new nodes
g.node.extend([new_input_a_node, new_input_b_node, new_matmul_node, new_final_transpose_node]) g.node.extend(
[
new_input_a_node,
new_input_b_node,
new_matmul_node,
new_final_transpose_node,
]
)
# Delete original nodes # Delete original nodes
g.node.remove(original_matmul_node) g.node.remove(original_matmul_node)
def split_MatMul_batch_then_concat(g, original_matmul_node): def split_MatMul_batch_then_concat(g, original_matmul_node):
new_nodes = [] new_nodes = []
final_concat_inputs = [] final_concat_inputs = []
@ -265,49 +269,85 @@ def split_MatMul_batch_then_concat(g, original_matmul_node):
batch_count = input_a_shape[1] batch_count = input_a_shape[1]
for i in range(batch_count): for i in range(batch_count):
# Create Split nodes for input A # Create Split nodes for input A
starts_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_starts", (1, ), [i]) starts_node = helper.list_to_constant(
ends_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_ends", (1, ), [i+1]) f"{input_a_value.name}_sliced_{i}_starts", (1,), [i]
axes_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_axes", (1, ), [len(input_a_shape) - 3]) )
ends_node = helper.list_to_constant(
f"{input_a_value.name}_sliced_{i}_ends", (1,), [i + 1]
)
axes_node = helper.list_to_constant(
f"{input_a_value.name}_sliced_{i}_axes",
(1,),
[len(input_a_shape) - 3],
)
new_sliced_a_node = onnx.helper.make_node( new_sliced_a_node = onnx.helper.make_node(
'Slice', "Slice",
inputs = [input_a_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]], inputs=[
outputs = [f"{input_a_value.name}_sliced_{i}"], input_a_value.name,
name = f"{input_a_value.name}_sliced_{i}_for_{original_matmul_node.name}" starts_node.output[0],
ends_node.output[0],
axes_node.output[0],
],
outputs=[f"{input_a_value.name}_sliced_{i}"],
name=f"{input_a_value.name}_sliced_{i}_for_"
f"{original_matmul_node.name}",
)
new_nodes.extend(
[starts_node, ends_node, axes_node, new_sliced_a_node]
) )
new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_a_node])
# Create Split nodes for input B # Create Split nodes for input B
starts_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_starts", (1, ), [i]) starts_node = helper.list_to_constant(
ends_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_ends", (1, ), [i+1]) f"{input_b_value.name}_sliced_{i}_starts", (1,), [i]
axes_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_axes", (1, ), [len(input_b_shape) - 3]) )
new_sliced_b_node = onnx.helper.make_node( ends_node = helper.list_to_constant(
'Slice', f"{input_b_value.name}_sliced_{i}_ends", (1,), [i + 1]
inputs = [input_b_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]], )
outputs = [f"{input_b_value.name}_sliced_{i}"], axes_node = helper.list_to_constant(
name = f"{input_b_value.name}_sliced_{i}_for_{original_matmul_node.name}" f"{input_b_value.name}_sliced_{i}_axes",
(1,),
[len(input_b_shape) - 3],
)
new_sliced_b_node = onnx.helper.make_node(
"Slice",
inputs=[
input_b_value.name,
starts_node.output[0],
ends_node.output[0],
axes_node.output[0],
],
outputs=[f"{input_b_value.name}_sliced_{i}"],
name=f"{input_b_value.name}_sliced_{i}_for_"
f"{original_matmul_node.name}",
)
new_nodes.extend(
[starts_node, ends_node, axes_node, new_sliced_b_node]
) )
new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_b_node])
# Create MatMul nodes # Create MatMul nodes
new_matmul_node = onnx.helper.make_node( new_matmul_node = onnx.helper.make_node(
'MatMul', "MatMul",
inputs = [new_sliced_a_node.output[0], new_sliced_b_node.output[0]], inputs=[new_sliced_a_node.output[0], new_sliced_b_node.output[0]],
outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"], outputs=[f"{original_matmul_node.output[0]}_sliced_{i}"],
name = f"{original_matmul_node.name}_sliced_{i}" name=f"{original_matmul_node.name}_sliced_{i}",
) )
new_nodes.append(new_matmul_node) new_nodes.append(new_matmul_node)
final_concat_inputs.append(new_matmul_node.output[0]) final_concat_inputs.append(new_matmul_node.output[0])
# Create Concat nodes # Create Concat nodes
output_value = helper.find_value_by_name(g, original_matmul_node.output[0]) output_value = helper.find_value_by_name(g, original_matmul_node.output[0])
if output_value is None: if output_value is None:
output_value = helper.find_output_by_name(g, original_matmul_node.output[0]) output_value = helper.find_output_by_name(
g, original_matmul_node.output[0]
)
if output_value is None: if output_value is None:
helper.logger.error(f"Cannot find value_info for {original_matmul_node.output[0]}") helper.logger.error(
f"Cannot find value_info for {original_matmul_node.output[0]}"
)
output_shape = helper.get_shape_from_value_info(output_value) output_shape = helper.get_shape_from_value_info(output_value)
new_concat_node = onnx.helper.make_node( new_concat_node = onnx.helper.make_node(
"Concat", "Concat",
inputs = final_concat_inputs, inputs=final_concat_inputs,
outputs = [original_matmul_node.output[0]], outputs=[original_matmul_node.output[0]],
name = f"{original_matmul_node.name}_final_concat", name=f"{original_matmul_node.name}_final_concat",
axis = len(output_shape) - 3 axis=len(output_shape) - 3,
) )
new_nodes.append(new_concat_node) new_nodes.append(new_concat_node)
# Add new nodes # Add new nodes
@ -320,7 +360,9 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
new_nodes = [] new_nodes = []
final_concat_inputs = [] final_concat_inputs = []
# Get the batch count # Get the batch count
input_b_node = helper.find_node_by_output_name(g, original_matmul_node.input[1]) input_b_node = helper.find_node_by_output_name(
g, original_matmul_node.input[1]
)
input_b_np = helper.constant_to_numpy(input_b_node) input_b_np = helper.constant_to_numpy(input_b_node)
if len(input_b_np.shape) == 3: if len(input_b_np.shape) == 3:
batch_count = input_b_np.shape[0] batch_count = input_b_np.shape[0]
@ -329,17 +371,19 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
for i in range(batch_count): for i in range(batch_count):
# Create new constant node # Create new constant node
if len(input_b_np.shape) == 3: if len(input_b_np.shape) == 3:
new_np = input_b_np[i:i+1, ...] new_np = input_b_np[i:i + 1, ...]
else: else:
new_np = input_b_np[:, i:i+1, ...] new_np = input_b_np[:, i:i + 1, ...]
new_weight = helper.numpy_to_constant(f"{input_b_node.name}_sliced_{i}", new_np) new_weight = helper.numpy_to_constant(
f"{input_b_node.name}_sliced_{i}", new_np
)
new_nodes.append(new_weight) new_nodes.append(new_weight)
# Create MatMul nodes # Create MatMul nodes
new_matmul_node = onnx.helper.make_node( new_matmul_node = onnx.helper.make_node(
'MatMul', "MatMul",
inputs = [original_matmul_node.input[0], new_weight.output[0]], inputs=[original_matmul_node.input[0], new_weight.output[0]],
outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"], outputs=[f"{original_matmul_node.output[0]}_sliced_{i}"],
name = f"{original_matmul_node.name}_sliced_{i}" name=f"{original_matmul_node.name}_sliced_{i}",
) )
new_nodes.append(new_matmul_node) new_nodes.append(new_matmul_node)
final_concat_inputs.append(new_matmul_node.output[0]) final_concat_inputs.append(new_matmul_node.output[0])
@ -348,10 +392,10 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
output_shape = helper.get_shape_from_value_info(output_value) output_shape = helper.get_shape_from_value_info(output_value)
new_concat_node = onnx.helper.make_node( new_concat_node = onnx.helper.make_node(
"Concat", "Concat",
inputs = final_concat_inputs, inputs=final_concat_inputs,
outputs = [original_matmul_node.output[0]], outputs=[original_matmul_node.output[0]],
name = f"{original_matmul_node.name}_final_concat", name=f"{original_matmul_node.name}_final_concat",
axis = len(output_shape) - 3 axis=len(output_shape) - 3,
) )
new_nodes.append(new_concat_node) new_nodes.append(new_concat_node)
# Add new nodes # Add new nodes
@ -367,7 +411,7 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
def special_MatMul_process(g): def special_MatMul_process(g):
for node in g.node: for node in g.node:
if node.op_type != 'MatMul': if node.op_type != "MatMul":
continue continue
input_a_name = node.input[0] input_a_name = node.input[0]
input_a_value = helper.find_value_by_name(g, input_a_name) input_a_value = helper.find_value_by_name(g, input_a_name)
@ -383,19 +427,30 @@ def special_MatMul_process(g):
continue continue
# Too many dimensions or too few dimensions. Not supported. Skip # Too many dimensions or too few dimensions. Not supported. Skip
if len(input_a_shape) > 4 or len(input_b_shape) > 4: if len(input_a_shape) > 4 or len(input_b_shape) > 4:
helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have too many dimensions.") helper.logger.warning(
f"Cannot optimize MatMul {node.name}: "
"inputs have too many dimensions."
)
continue continue
if len(input_a_shape) < 2 or len(input_b_shape) < 2: if len(input_a_shape) < 2 or len(input_b_shape) < 2:
helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have two few dimensions.") helper.logger.warning(
f"Cannot optimize MatMul {node.name}: "
"inputs have two few dimensions."
)
continue continue
# For 4 dimension, check the first dimension (should be 1) and treated as 3 dimensions. # For 4 dimension, check the first dimension (should be 1)
# and treated as 3 dimensions.
extra_dim = None extra_dim = None
if len(input_a_shape) == 4: if len(input_a_shape) == 4:
extra_dim = input_a_shape[0] extra_dim = input_a_shape[0]
input_a_shape = input_a_shape[1:] input_a_shape = input_a_shape[1:]
if len(input_b_shape) == 4: if len(input_b_shape) == 4:
if input_b_shape[0] != extra_dim: if input_b_shape[0] != extra_dim:
helper.logger.warning(f"Cannot optimize MatMul {node.name}: input dimension batch sizes does not match ({extra_dim} vs {input_b_shape[0]}).") helper.logger.warning(
f"Cannot optimize MatMul {node.name}: "
"input dimension batch sizes does not match "
f"({extra_dim} vs {input_b_shape[0]})."
)
continue continue
input_b_shape = input_b_shape[1:] input_b_shape = input_b_shape[1:]
# Check input B dimension # Check input B dimension
@ -404,20 +459,31 @@ def special_MatMul_process(g):
continue continue
# If B is B x W x V, but B is a constant. # If B is B x W x V, but B is a constant.
input_b_node = helper.find_node_by_output_name(g, input_b_name) input_b_node = helper.find_node_by_output_name(g, input_b_name)
if input_b_node is not None and input_b_node.op_type == 'Constant': if input_b_node is not None and input_b_node.op_type == "Constant":
# Constant input # Constant input
helper.logger.debug(f"Optimizing MatMul node {node.name}: split constant input.") helper.logger.debug(
f"Optimizing MatMul node {node.name}: split constant input."
)
split_MatMul_Constant_input_then_concat(g, node) split_MatMul_Constant_input_then_concat(g, node)
# If B is B x W x V and A is 1 x H x W, do the swap. # If B is B x W x V and A is 1 x H x W, do the swap.
elif len(input_a_shape) == 2 or (input_a_shape[0] == 1 and (extra_dim is None or extra_dim == 1)): elif len(input_a_shape) == 2 or (
helper.logger.debug(f"Optimizing MatMul node {node.name}: swap input.") input_a_shape[0] == 1 and (extra_dim is None or extra_dim == 1)
):
helper.logger.debug(
f"Optimizing MatMul node {node.name}: swap input."
)
swap_MatMul_inputs(g, node) swap_MatMul_inputs(g, node)
# If B is B x W x V and A is B x H x W, do the split. # If B is B x W x V and A is B x H x W, do the split.
elif input_b_shape[0] == input_a_shape[0]: elif input_b_shape[0] == input_a_shape[0]:
helper.logger.debug(f"Optimizing MatMul node {node.name}: split input batch.") helper.logger.debug(
f"Optimizing MatMul node {node.name}: split input batch."
)
split_MatMul_batch_then_concat(g, node) split_MatMul_batch_then_concat(g, node)
# Other cases are not supported: If B is B x W x V but A is X x H x W. # Other cases are not supported: If B is B x W x V but A is X x H x W.
else: else:
helper.logger.warning(f"Cannot optimize MatMul {node.name}: unknown reason. Might be shape mismatch.") helper.logger.warning(
f"Cannot optimize MatMul {node.name}: "
"unknown reason. Might be shape mismatch."
)
continue continue
other.topological_sort(g) other.topological_sort(g)