style: fix format so pep8 is satisfied
This commit is contained in:
parent
a783220efa
commit
0136a5b2bd
@ -316,7 +316,9 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||
providers = ['CPUExecutionProvider']
|
||||
provider_options = [{}]
|
||||
is_cuda_available = ort.get_device() == 'GPU' and torch.cuda.is_available()
|
||||
is_cuda_available = (
|
||||
ort.get_device() == 'GPU' and torch.cuda.is_available()
|
||||
)
|
||||
if is_cuda_available:
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
device_id = device_id or 0
|
||||
@ -334,7 +336,7 @@ class ONNXRuntimeSegmentorKN(BaseSegmentor):
|
||||
self.output_name_list = [sess_outputs[0].name]
|
||||
self.cfg = cfg # TODO: necessary?
|
||||
self.test_cfg = cfg.model.test_cfg
|
||||
self.test_mode = self.test_cfg.mode # NOTE: should be 'whole' or 'slide'
|
||||
self.test_mode = self.test_cfg.mode # NOTE: either 'whole' or 'slide'
|
||||
self.is_cuda_available = is_cuda_available
|
||||
self.count_mat = None
|
||||
try:
|
||||
|
||||
3
setup.py
3
setup.py
@ -171,7 +171,8 @@ if __name__ == '__main__':
|
||||
setup(
|
||||
name='mmsegmentation',
|
||||
version=get_version(),
|
||||
description='Open MMLab Semantic Segmentation Toolbox and Benchmark (Kneron Edition)',
|
||||
description='Open MMLab Semantic Segmentation Toolbox '
|
||||
'and Benchmark (Kneron Edition)',
|
||||
long_description=readme(),
|
||||
long_description_content_type='text/markdown',
|
||||
author='MMSegmentation Contributors and Kneron',
|
||||
|
||||
@ -163,9 +163,9 @@ def main():
|
||||
efficient_test = eval_kwargs.get('efficient_test', False)
|
||||
if efficient_test:
|
||||
warnings.warn(
|
||||
'``efficient_test=True`` does not have effect in tools/test_kneron.py, '
|
||||
'the evaluation and format results are CPU memory efficient by '
|
||||
'default')
|
||||
'"efficient_test=True" does not have effect in '
|
||||
'tools/test_kneron.py, the evaluation and format '
|
||||
'results are CPU memory efficient by default')
|
||||
|
||||
eval_on_format_results = (
|
||||
args.eval is not None and 'cityscapes' in args.eval)
|
||||
|
||||
@ -5,55 +5,81 @@ import sys
|
||||
from tools.other import topological_sort
|
||||
from tools import helper
|
||||
|
||||
|
||||
def fuse_bias_in_consecutive_1x1_conv(g):
|
||||
for second in g.node:
|
||||
# Find two conv
|
||||
if second.op_type != 'Conv':
|
||||
if second.op_type != "Conv":
|
||||
continue
|
||||
first = helper.find_node_by_output_name(g, second.input[0])
|
||||
if first is None or first.op_type != 'Conv':
|
||||
if first is None or first.op_type != "Conv":
|
||||
continue
|
||||
# Check if the first one has only one folloing node
|
||||
if len(helper.find_following_nodes_by_input_value_name(g, first.output[0])) != 1:
|
||||
if (
|
||||
len(
|
||||
helper.find_following_nodes_by_input_value_name(
|
||||
g, first.output[0]
|
||||
)
|
||||
)
|
||||
!= 1
|
||||
):
|
||||
continue
|
||||
# If first node has no bias, continue
|
||||
if len(first.input) == 2:
|
||||
continue
|
||||
# Check their kernel size
|
||||
first_kernel_shape = helper.get_list_attribute_by_name(first, 'kernel_shape', 'int')
|
||||
second_kernel_shape = helper.get_list_attribute_by_name(second, 'kernel_shape', 'int')
|
||||
prod = first_kernel_shape[0] * first_kernel_shape[1] * second_kernel_shape[0] * second_kernel_shape[1]
|
||||
first_kernel_shape = helper.get_list_attribute_by_name(
|
||||
first, "kernel_shape", "int"
|
||||
)
|
||||
second_kernel_shape = helper.get_list_attribute_by_name(
|
||||
second, "kernel_shape", "int"
|
||||
)
|
||||
prod = (
|
||||
first_kernel_shape[0]
|
||||
* first_kernel_shape[1]
|
||||
* second_kernel_shape[0]
|
||||
* second_kernel_shape[1]
|
||||
)
|
||||
if prod != 1:
|
||||
continue
|
||||
print('Found: ', first.name, ' ', second.name)
|
||||
print("Found: ", first.name, " ", second.name)
|
||||
# Get bias of the nodes
|
||||
first_bias_node = helper.find_node_by_output_name(g, first.input[2])
|
||||
second_weight_node = helper.find_node_by_output_name(g, second.input[1])
|
||||
second_weight_node = helper.find_node_by_output_name(
|
||||
g, second.input[1]
|
||||
)
|
||||
second_bias_node = helper.find_node_by_output_name(g, second.input[2])
|
||||
first_bias = helper.constant_to_numpy(first_bias_node)
|
||||
second_weight = helper.constant_to_numpy(second_weight_node)
|
||||
second_bias = helper.constant_to_numpy(second_bias_node)
|
||||
# Calculate the weight for second node
|
||||
first_bias = np.reshape(first_bias, (1, first_bias.size))
|
||||
second_weight = np.reshape(second_weight, (second_weight.shape[0], second_weight.shape[1]))
|
||||
second_weight = np.reshape(
|
||||
second_weight, (second_weight.shape[0], second_weight.shape[1])
|
||||
)
|
||||
second_weight = np.transpose(second_weight)
|
||||
new_second_bias = second_bias + np.matmul(first_bias, second_weight)
|
||||
new_second_bias = np.reshape(new_second_bias, (new_second_bias.size,))
|
||||
# Generate new weight
|
||||
new_first_bias = np.reshape(first_bias, (first_bias.size, ))
|
||||
new_first_bias = np.reshape(first_bias, (first_bias.size,))
|
||||
for i in range(new_first_bias.shape[0]):
|
||||
new_first_bias[i] = 0.0
|
||||
new_first_bias_node = helper.numpy_to_constant(first_bias_node.output[0], new_first_bias)
|
||||
new_second_bias_node = helper.numpy_to_constant(second_bias_node.output[0], new_second_bias)
|
||||
new_first_bias_node = helper.numpy_to_constant(
|
||||
first_bias_node.output[0], new_first_bias
|
||||
)
|
||||
new_second_bias_node = helper.numpy_to_constant(
|
||||
second_bias_node.output[0], new_second_bias
|
||||
)
|
||||
# Delete old weight and add new weights
|
||||
g.node.remove(first_bias_node)
|
||||
g.node.remove(second_bias_node)
|
||||
g.node.extend([new_first_bias_node, new_second_bias_node])
|
||||
topological_sort(g)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
exit(1)
|
||||
m = onnx.load(sys.argv[1])
|
||||
fuse_bias_in_consecutive_1x1_conv(m.graph)
|
||||
onnx.save(m, sys.argv[2])
|
||||
onnx.save(m, sys.argv[2])
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import onnx
|
||||
import onnx.utils
|
||||
|
||||
try:
|
||||
from onnx import optimizer
|
||||
except ImportError:
|
||||
@ -9,23 +10,107 @@ import argparse
|
||||
import tools.modhelper as helper
|
||||
import tools.other as other
|
||||
import tools.replacing as replacing
|
||||
|
||||
# Main process
|
||||
# Argument parser
|
||||
parser = argparse.ArgumentParser(description="Edit an ONNX model.\nThe processing sequense is 'delete nodes/values' -> 'add nodes' -> 'change shapes'.\nCutting cannot be done with other operations together")
|
||||
parser.add_argument('in_file', type=str, help='input ONNX FILE')
|
||||
parser.add_argument('out_file', type=str, help="ouput ONNX FILE")
|
||||
parser.add_argument('-c', '--cut', dest='cut_node', type=str, nargs='+', help="remove nodes from the given nodes(inclusive)")
|
||||
parser.add_argument('--cut-type', dest='cut_type', type=str, nargs='+', help="remove nodes by type from the given nodes(inclusive)")
|
||||
parser.add_argument('-d', '--delete', dest='delete_node', type=str, nargs='+', help="delete nodes by names and only those nodes")
|
||||
parser.add_argument('--delete-input', dest='delete_input', type=str, nargs='+', help="delete inputs by names")
|
||||
parser.add_argument('--delete-output', dest='delete_output', type=str, nargs='+', help="delete outputs by names")
|
||||
parser.add_argument('-i', '--input', dest='input_change', type=str, nargs='+', help="change input shape (e.g. -i 'input_0 1 3 224 224')")
|
||||
parser.add_argument('-o', '--output', dest='output_change', type=str, nargs='+', help="change output shape (e.g. -o 'input_0 1 3 224 224')")
|
||||
parser.add_argument('--add-conv', dest='add_conv', type=str, nargs='+', help='add nop conv using specific input')
|
||||
parser.add_argument('--add-bn', dest='add_bn', type=str, nargs='+', help='add nop bn using specific input')
|
||||
parser.add_argument('--rename-output', dest='rename_output', type=str, nargs='+', help='Rename the specific output(e.g. --rename-output old_name new_name)')
|
||||
parser.add_argument('--pixel-bias-value', dest='pixel_bias_value', type=str, nargs='+', help='(per channel) set pixel value bias bn layer at model front for normalization( e.g. --pixel_bias_value "[104.0, 117.0, 123.0]" )')
|
||||
parser.add_argument('--pixel-scale-value', dest='pixel_scale_value', type=str, nargs='+', help='(per channel) set pixel value scale bn layer at model front for normalization( e.g. --pixel_scale_value "[0.0078125, 0.0078125, 0.0078125]" )')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Edit an ONNX model.\nThe processing sequense is 'delete "
|
||||
"nodes/values' -> 'add nodes' -> 'change shapes'.\nCutting "
|
||||
"cannot be done with other operations together"
|
||||
)
|
||||
parser.add_argument("in_file", type=str, help="input ONNX FILE")
|
||||
parser.add_argument("out_file", type=str, help="ouput ONNX FILE")
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--cut",
|
||||
dest="cut_node",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="remove nodes from the given nodes(inclusive)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cut-type",
|
||||
dest="cut_type",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="remove nodes by type from the given nodes(inclusive)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--delete",
|
||||
dest="delete_node",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="delete nodes by names and only those nodes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete-input",
|
||||
dest="delete_input",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="delete inputs by names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete-output",
|
||||
dest="delete_output",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="delete outputs by names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input",
|
||||
dest="input_change",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="change input shape (e.g. -i 'input_0 1 3 224 224')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
dest="output_change",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="change output shape (e.g. -o 'input_0 1 3 224 224')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add-conv",
|
||||
dest="add_conv",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="add nop conv using specific input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add-bn",
|
||||
dest="add_bn",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="add nop bn using specific input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rename-output",
|
||||
dest="rename_output",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Rename the specific output(e.g. --rename-output old_name new_name)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pixel-bias-value",
|
||||
dest="pixel_bias_value",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help='(per channel) set pixel value bias bn layer at model front for '
|
||||
'normalization( e.g. --pixel_bias_value "[104.0, 117.0, 123.0]" )',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pixel-scale-value",
|
||||
dest="pixel_scale_value",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help='(per channel) set pixel value scale bn layer at model front for '
|
||||
'normalization( e.g. --pixel_scale_value '
|
||||
'"[0.0078125, 0.0078125, 0.0078125]" )',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -60,23 +145,48 @@ if args.add_bn is not None:
|
||||
if args.pixel_bias_value is not None or args.pixel_scale_value is not None:
|
||||
|
||||
if len(g.input) > 1:
|
||||
raise ValueError(" '--pixel-bias-value' and '--pixel-scale-value' only support one input node model currently")
|
||||
|
||||
raise ValueError(
|
||||
" '--pixel-bias-value' and '--pixel-scale-value' "
|
||||
"only support one input node model currently"
|
||||
)
|
||||
|
||||
i_n = g.input[0]
|
||||
|
||||
pixel_bias_value = [0] * i_n.type.tensor_type.shape.dim[1].dim_value
|
||||
pixel_scale_value = [1] * i_n.type.tensor_type.shape.dim[1].dim_value
|
||||
|
||||
if args.pixel_bias_value is not None and len(args.pixel_bias_value) == 1:
|
||||
pixel_bias_value = [float(n) for n in args.pixel_bias_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')]
|
||||
pixel_bias_value = [
|
||||
float(n)
|
||||
for n in args.pixel_bias_value[0]
|
||||
.replace("[", "")
|
||||
.replace("]", "")
|
||||
.split(",")
|
||||
]
|
||||
|
||||
if args.pixel_scale_value is not None and len(args.pixel_scale_value) == 1:
|
||||
pixel_scale_value = [float(n) for n in args.pixel_scale_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')]
|
||||
pixel_scale_value = [
|
||||
float(n)
|
||||
for n in args.pixel_scale_value[0]
|
||||
.replace("[", "")
|
||||
.replace("]", "")
|
||||
.split(",")
|
||||
]
|
||||
|
||||
|
||||
if i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_bias_value) or i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_scale_value):
|
||||
raise ValueError("--pixel-bias-value (" + str(pixel_bias_value) + ") and --pixel-scale-value (" + str(pixel_scale_value) + ") should be same as input dimension:" + str(i_n.type.tensor_type.shape.dim[1].dim_value) )
|
||||
other.add_bias_scale_bn_after(g, i_n.name, pixel_bias_value, pixel_scale_value)
|
||||
if i_n.type.tensor_type.shape.dim[1].dim_value != len(
|
||||
pixel_bias_value
|
||||
) or i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_scale_value):
|
||||
raise ValueError(
|
||||
"--pixel-bias-value ("
|
||||
+ str(pixel_bias_value)
|
||||
+ ") and --pixel-scale-value ("
|
||||
+ str(pixel_scale_value)
|
||||
+ ") should be same as input dimension:"
|
||||
+ str(i_n.type.tensor_type.shape.dim[1].dim_value)
|
||||
)
|
||||
other.add_bias_scale_bn_after(
|
||||
g, i_n.name, pixel_bias_value, pixel_scale_value
|
||||
)
|
||||
|
||||
# Change input and output shapes as requested
|
||||
if args.input_change is not None:
|
||||
@ -100,14 +210,21 @@ if args.rename_output:
|
||||
print("Rename output should be paires of names.")
|
||||
else:
|
||||
for i in range(0, len(args.rename_output), 2):
|
||||
other.rename_output_name(g, args.rename_output[i], args.rename_output[i + 1])
|
||||
other.rename_output_name(
|
||||
g, args.rename_output[i], args.rename_output[i + 1]
|
||||
)
|
||||
|
||||
# Remove useless nodes
|
||||
if args.delete_node or args.delete_input or args.input_change or args.output_change:
|
||||
if (
|
||||
args.delete_node
|
||||
or args.delete_input
|
||||
or args.input_change
|
||||
or args.output_change
|
||||
):
|
||||
# If shape changed during the modification, redo shape inference.
|
||||
while(len(g.value_info) > 0):
|
||||
while len(g.value_info) > 0:
|
||||
g.value_info.pop()
|
||||
passes = ['extract_constant_to_initializer']
|
||||
passes = ["extract_constant_to_initializer"]
|
||||
m = optimizer.optimize(m, passes)
|
||||
g = m.graph
|
||||
replacing.replace_initializer_with_Constant(g)
|
||||
@ -115,4 +232,4 @@ other.topological_sort(g)
|
||||
# Polish and output
|
||||
m = other.polish_model(m)
|
||||
other.add_output_to_value_info(m.graph)
|
||||
onnx.save(m, args.out_file)
|
||||
onnx.save(m, args.out_file)
|
||||
|
||||
@ -11,42 +11,44 @@ if len(sys.argv) != 3:
|
||||
# Modify onnx
|
||||
m = onnx.load(sys.argv[1])
|
||||
special.add_0_5_to_normalized_input(m)
|
||||
onnx.save(m, sys.argv[1][:-4] + 'norm.onnx')
|
||||
onnx.save(m, sys.argv[1][:-4] + "norm.onnx")
|
||||
|
||||
# Change input node
|
||||
origin_file = open(sys.argv[2], 'r')
|
||||
origin_file = open(sys.argv[2], "r")
|
||||
origin_json = json.load(origin_file)
|
||||
origin_json["input_node"]["output_datapath_radix"] = [8]
|
||||
new_json_str = json.dumps(origin_json)
|
||||
|
||||
# Modify json
|
||||
file = open(sys.argv[1][:-4] + 'norm.onnx' + '.json', 'w')
|
||||
file = open(sys.argv[1][:-4] + "norm.onnx" + ".json", "w")
|
||||
s = """{{
|
||||
\"{0}\" :
|
||||
{{
|
||||
\"bias_bitwidth\" : 16,
|
||||
\"{0}_bias\" : [15],
|
||||
\"{0}_weight\" : [3,3,3],
|
||||
\"conv_coarse_shift\" : [-4,-4,-4],
|
||||
\"conv_fine_shift\" : [0,0,0],
|
||||
\"conv_total_shift\" : [-4,-4,-4],
|
||||
\"cpu_mode\" : false,
|
||||
\"delta_input_bitwidth\" : [0],
|
||||
\"delta_output_bitwidth\" : 8,
|
||||
\"flag_radix_bias_eq_output\" : true,
|
||||
\"input_scale\" : [[1.0,1.0,1.0]],
|
||||
\"output_scale\" : [1.0, 1.0, 1.0],
|
||||
\"psum_bitwidth\" : 16,
|
||||
\"weight_bitwidth\" : 8,
|
||||
\"input_datapath_bitwidth\" : [8],
|
||||
\"input_datapath_radix\" : [8],
|
||||
\"working_input_bitwidth\" : 8,
|
||||
\"working_input_radix\" : [8],
|
||||
\"working_output_bitwidth\" : 16,
|
||||
\"working_output_radix\" : 15,
|
||||
\"output_datapath_bitwidth\" : 8,
|
||||
\"output_datapath_radix\" : 7
|
||||
}},\n""".format('input_norm')
|
||||
\"{0}\" :
|
||||
{{
|
||||
\"bias_bitwidth\" : 16,
|
||||
\"{0}_bias\" : [15],
|
||||
\"{0}_weight\" : [3,3,3],
|
||||
\"conv_coarse_shift\" : [-4,-4,-4],
|
||||
\"conv_fine_shift\" : [0,0,0],
|
||||
\"conv_total_shift\" : [-4,-4,-4],
|
||||
\"cpu_mode\" : false,
|
||||
\"delta_input_bitwidth\" : [0],
|
||||
\"delta_output_bitwidth\" : 8,
|
||||
\"flag_radix_bias_eq_output\" : true,
|
||||
\"input_scale\" : [[1.0,1.0,1.0]],
|
||||
\"output_scale\" : [1.0, 1.0, 1.0],
|
||||
\"psum_bitwidth\" : 16,
|
||||
\"weight_bitwidth\" : 8,
|
||||
\"input_datapath_bitwidth\" : [8],
|
||||
\"input_datapath_radix\" : [8],
|
||||
\"working_input_bitwidth\" : 8,
|
||||
\"working_input_radix\" : [8],
|
||||
\"working_output_bitwidth\" : 16,
|
||||
\"working_output_radix\" : 15,
|
||||
\"output_datapath_bitwidth\" : 8,
|
||||
\"output_datapath_radix\" : 7
|
||||
}},\n""".format(
|
||||
"input_norm"
|
||||
)
|
||||
file.write(s + new_json_str[1:])
|
||||
file.close()
|
||||
origin_file.close()
|
||||
|
||||
@ -2,33 +2,33 @@
|
||||
|
||||
import sys
|
||||
import onnx
|
||||
import numpy as np
|
||||
from onnx import numpy_helper
|
||||
from tools import other, helper
|
||||
|
||||
"""
|
||||
Change onnx model from version 1.3 to version 1.4.
|
||||
Modify the BN node by removing the spatial attribute
|
||||
Modify the Upsample node by removing the 'scales' attribute, and adding a constant node instead.
|
||||
Model's ir_version and opset_import are updated.
|
||||
Change onnx model from version 1.3 to version 1.4.
|
||||
- Modify the BN node by removing the spatial attribute
|
||||
- Modify the Upsample node by removing the 'scales' attribute,
|
||||
and adding a constant node instead.
|
||||
- Model's ir_version and opset_import are updated.
|
||||
"""
|
||||
|
||||
|
||||
def remove_BN_spatial(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'BatchNormalization':
|
||||
if node.op_type != "BatchNormalization":
|
||||
continue
|
||||
for att in node.attribute:
|
||||
if att.name == 'spatial':
|
||||
if att.name == "spatial":
|
||||
node.attribute.remove(att)
|
||||
|
||||
|
||||
def upsample_attribute_to_const(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'Upsample':
|
||||
if node.op_type != "Upsample":
|
||||
continue
|
||||
scales_exist = False
|
||||
for att in node.attribute:
|
||||
if att.name == 'scales':
|
||||
if att.name == "scales":
|
||||
scales_exist = True
|
||||
break
|
||||
if not scales_exist:
|
||||
@ -36,18 +36,23 @@ def upsample_attribute_to_const(g):
|
||||
|
||||
shape = [len(att.floats)]
|
||||
node.attribute.remove(att)
|
||||
new_node = helper.list_to_constant(node.name+'_input', shape, att.floats)
|
||||
new_node = helper.list_to_constant(
|
||||
node.name + "_input", shape, att.floats
|
||||
)
|
||||
|
||||
g.node.extend([new_node])
|
||||
value_info = onnx.helper.make_tensor_value_info(node.name+'_input', onnx.TensorProto.FLOAT, shape)
|
||||
node.input.extend([node.name+'_input'])
|
||||
value_info = onnx.helper.make_tensor_value_info(
|
||||
node.name + "_input", onnx.TensorProto.FLOAT, shape
|
||||
)
|
||||
node.input.extend([node.name + "_input"])
|
||||
g.value_info.extend([value_info])
|
||||
|
||||
|
||||
def relu6_to_clip(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'Relu':
|
||||
if node.op_type != "Relu":
|
||||
continue
|
||||
max_val = helper.get_var_attribute_by_name(node, 'max', 'float')
|
||||
max_val = helper.get_var_attribute_by_name(node, "max", "float")
|
||||
if max_val is None:
|
||||
continue
|
||||
new_node = onnx.helper.make_node(
|
||||
@ -56,11 +61,12 @@ def relu6_to_clip(g):
|
||||
node.output,
|
||||
name=node.name,
|
||||
max=max_val,
|
||||
min=0.0
|
||||
min=0.0,
|
||||
)
|
||||
g.node.remove(node)
|
||||
g.node.extend([new_node])
|
||||
|
||||
|
||||
def PRelu_weight_reshape(g):
|
||||
# For PRelu with single dimension weight. Expand it to 1, x, 1, 1
|
||||
for node in g.node:
|
||||
@ -91,16 +97,18 @@ def PRelu_weight_reshape(g):
|
||||
new_input = onnx.helper.make_tensor_value_info(
|
||||
node.input[1],
|
||||
input_value.type.tensor_type.elem_type,
|
||||
(1, slope.dims[1], 1, 1))
|
||||
(1, slope.dims[1], 1, 1),
|
||||
)
|
||||
g.input.remove(input_value)
|
||||
g.input.append(new_input)
|
||||
value_info = helper.find_value_by_name(g, node.input[1])
|
||||
if value_info is not None:
|
||||
g.value_info.remove(value_info)
|
||||
|
||||
|
||||
def do_convert(m):
|
||||
graph = m.graph
|
||||
|
||||
|
||||
# Modify the nodes.
|
||||
remove_BN_spatial(graph)
|
||||
upsample_attribute_to_const(graph)
|
||||
@ -113,6 +121,7 @@ def do_convert(m):
|
||||
m.opset_import[0].version = 9
|
||||
return m
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print("Usage:{} file_in file_out".format(sys.argv[0]))
|
||||
|
||||
@ -3,31 +3,38 @@
|
||||
import sys
|
||||
import onnx
|
||||
import onnx.utils
|
||||
import numpy as np
|
||||
from onnx import numpy_helper
|
||||
from tools import other, helper, replacing
|
||||
|
||||
"""
|
||||
Change onnx model from version 1.4 to version 1.6.
|
||||
"""
|
||||
|
||||
|
||||
def replace_all_attribute_to_const_node_in_pad_node(g):
|
||||
node_to_remove = []
|
||||
node_to_extend = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Pad':
|
||||
if node.op_type != "Pad":
|
||||
continue
|
||||
|
||||
pad_loc_node = None # must have
|
||||
pad_mode = 'constant'
|
||||
pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [0.0]) # need scalar
|
||||
pad_loc_node = None # must have
|
||||
pad_mode = "constant"
|
||||
pad_value_node = helper.list_to_constant(
|
||||
node.name + "_pad_value", [], [0.0]
|
||||
) # need scalar
|
||||
for att in node.attribute:
|
||||
if att.name == 'mode':
|
||||
pad_mode = helper.get_var_attribute_by_name(node, 'mode', 'string')
|
||||
if att.name == 'pads':
|
||||
pad_loc_node = helper.list_to_constant(node.name+'_pad_loc', [len(att.ints)], att.ints)
|
||||
if att.name == 'value':
|
||||
pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [att.f])
|
||||
if att.name == "mode":
|
||||
pad_mode = helper.get_var_attribute_by_name(
|
||||
node, "mode", "string"
|
||||
)
|
||||
if att.name == "pads":
|
||||
pad_loc_node = helper.list_to_constant(
|
||||
node.name + "_pad_loc", [len(att.ints)], att.ints
|
||||
)
|
||||
if att.name == "value":
|
||||
pad_value_node = helper.list_to_constant(
|
||||
node.name + "_pad_value", [], [att.f]
|
||||
)
|
||||
|
||||
new_node = onnx.helper.make_node(
|
||||
"Pad",
|
||||
@ -40,24 +47,30 @@ def replace_all_attribute_to_const_node_in_pad_node(g):
|
||||
node_to_extend.append(new_node)
|
||||
node_to_extend.append(pad_loc_node)
|
||||
node_to_extend.append(pad_value_node)
|
||||
|
||||
for node in node_to_remove:
|
||||
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
for node in node_to_extend:
|
||||
for node in node_to_extend:
|
||||
g.node.extend([node])
|
||||
|
||||
|
||||
def upsampling_to_resize(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'Upsample':
|
||||
if node.op_type != "Upsample":
|
||||
continue
|
||||
upsampling_mode = helper.get_var_attribute_by_name(node, 'mode', 'string')
|
||||
upsampling_mode = helper.get_var_attribute_by_name(
|
||||
node, "mode", "string"
|
||||
)
|
||||
|
||||
scale_value_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
if scale_value_node.op_type != "Constant":
|
||||
raise TypeError('seems there is a dynamic "scales" param in Upsampling node: ' + node.name + ' , you might need to do constant folding first')
|
||||
raise TypeError(
|
||||
'seems there is a dynamic "scales" param in Upsampling node: '
|
||||
+ node.name
|
||||
+ " , you might need to do constant folding first"
|
||||
)
|
||||
|
||||
roi_node = helper.list_to_constant(node.name+'_roi_value', [0], [])
|
||||
roi_node = helper.list_to_constant(node.name + "_roi_value", [0], [])
|
||||
|
||||
new_node = onnx.helper.make_node(
|
||||
"Resize",
|
||||
@ -65,7 +78,7 @@ def upsampling_to_resize(g):
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
mode=upsampling_mode,
|
||||
coordinate_transformation_mode = 'asymmetric'
|
||||
coordinate_transformation_mode="asymmetric",
|
||||
)
|
||||
|
||||
g.node.remove(node)
|
||||
@ -75,7 +88,7 @@ def upsampling_to_resize(g):
|
||||
|
||||
def replace_all_attribute_to_const_node_in_slice_node(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'Slice':
|
||||
if node.op_type != "Slice":
|
||||
continue
|
||||
|
||||
axes_const_node = None
|
||||
@ -83,62 +96,75 @@ def replace_all_attribute_to_const_node_in_slice_node(g):
|
||||
starts_const_node = None
|
||||
steps_const_node = None
|
||||
for att in node.attribute:
|
||||
if att.name == 'axes':
|
||||
axes_const_node = helper.list_to_constant(node.name+'_axes_value', [len(att.ints)], att.ints)
|
||||
|
||||
if att.name == 'ends':
|
||||
ends_const_node = helper.list_to_constant(node.name+'_ends_value', [len(att.ints)], att.ints)
|
||||
if att.name == "axes":
|
||||
axes_const_node = helper.list_to_constant(
|
||||
node.name + "_axes_value", [len(att.ints)], att.ints
|
||||
)
|
||||
|
||||
if att.name == 'starts':
|
||||
starts_const_node = helper.list_to_constant(node.name+'_starts_value', [len(att.ints)], att.ints)
|
||||
if att.name == "ends":
|
||||
ends_const_node = helper.list_to_constant(
|
||||
node.name + "_ends_value", [len(att.ints)], att.ints
|
||||
)
|
||||
|
||||
if att.name == 'steps':
|
||||
steps_const_node = helper.list_to_constant(node.name+'_steps_value',[ len(att.ints)], att.ints)
|
||||
if att.name == "starts":
|
||||
starts_const_node = helper.list_to_constant(
|
||||
node.name + "_starts_value", [len(att.ints)], att.ints
|
||||
)
|
||||
|
||||
## pop out from back
|
||||
if att.name == "steps":
|
||||
steps_const_node = helper.list_to_constant(
|
||||
node.name + "_steps_value", [len(att.ints)], att.ints
|
||||
)
|
||||
|
||||
# pop out from back
|
||||
attr_len = len(node.attribute)
|
||||
for i in range(attr_len):
|
||||
node.attribute.remove(node.attribute[ attr_len -1 - i ])
|
||||
node.attribute.remove(node.attribute[attr_len - 1 - i])
|
||||
|
||||
## according the spec, we need to add node in specific order
|
||||
if starts_const_node != None:
|
||||
# according the spec, we need to add node in specific order
|
||||
if starts_const_node is not None:
|
||||
g.node.extend([starts_const_node])
|
||||
node.input.extend([starts_const_node.name])
|
||||
if ends_const_node != None:
|
||||
if ends_const_node is not None:
|
||||
g.node.extend([ends_const_node])
|
||||
node.input.extend([ends_const_node.name])
|
||||
if axes_const_node != None:
|
||||
node.input.extend([ends_const_node.name])
|
||||
if axes_const_node is not None:
|
||||
g.node.extend([axes_const_node])
|
||||
node.input.extend([axes_const_node.name])
|
||||
if steps_const_node != None:
|
||||
if steps_const_node is not None:
|
||||
g.node.extend([steps_const_node])
|
||||
node.input.extend([steps_const_node.name])
|
||||
|
||||
|
||||
|
||||
def replace_min_max_attribute_to_const_node_in_clip_node(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'Clip':
|
||||
if node.op_type != "Clip":
|
||||
continue
|
||||
|
||||
max_const_node = None
|
||||
min_const_node = None
|
||||
for att in node.attribute:
|
||||
if att.name == 'max':
|
||||
max_const_node = helper.list_to_constant(node.name+'_max_value', [], [att.f])
|
||||
|
||||
if att.name == 'min':
|
||||
min_const_node = helper.list_to_constant(node.name+'_min_value', [], [att.f])
|
||||
if att.name == "max":
|
||||
max_const_node = helper.list_to_constant(
|
||||
node.name + "_max_value", [], [att.f]
|
||||
)
|
||||
|
||||
## pop out from back
|
||||
node.attribute.remove(node.attribute[1])
|
||||
node.attribute.remove(node.attribute[0])
|
||||
|
||||
## according the spec, we need to add node in specific order
|
||||
if att.name == "min":
|
||||
min_const_node = helper.list_to_constant(
|
||||
node.name + "_min_value", [], [att.f]
|
||||
)
|
||||
|
||||
# pop out from back
|
||||
node.attribute.remove(node.attribute[1])
|
||||
node.attribute.remove(node.attribute[0])
|
||||
|
||||
# according the spec, we need to add node in specific order
|
||||
g.node.extend([min_const_node])
|
||||
g.node.extend([max_const_node])
|
||||
node.input.extend([min_const_node.name])
|
||||
node.input.extend([max_const_node.name])
|
||||
|
||||
|
||||
def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto:
|
||||
"""Update ir_version from 4 to 6 and update opset from 9 to 11.
|
||||
|
||||
@ -173,6 +199,7 @@ def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto:
|
||||
model = other.polish_model(model)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print("Usage:{} file_in file_out".format(sys.argv[0]))
|
||||
|
||||
@ -1,45 +1,51 @@
|
||||
import onnx
|
||||
import onnx.utils
|
||||
try:
|
||||
from onnx import optimizer
|
||||
except ImportError:
|
||||
import onnxoptimizer as optimizer
|
||||
import sys
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from tools import eliminating
|
||||
from tools import fusing
|
||||
from tools import replacing
|
||||
from tools import other
|
||||
from tools import special
|
||||
from tools import combo
|
||||
from tools.helper import logger
|
||||
|
||||
# from tools import temp
|
||||
|
||||
def onnx2onnx_flow(m: onnx.ModelProto,
|
||||
disable_fuse_bn=False,
|
||||
bn_on_skip=False,
|
||||
bn_before_add=False,
|
||||
bgr=False,
|
||||
norm=False,
|
||||
rgba2yynn=False,
|
||||
eliminate_tail=False,
|
||||
opt_matmul=False,
|
||||
duplicate_shared_weights=True) -> onnx.ModelProto:
|
||||
|
||||
def onnx2onnx_flow(
|
||||
m: onnx.ModelProto,
|
||||
disable_fuse_bn=False,
|
||||
bn_on_skip=False,
|
||||
bn_before_add=False,
|
||||
bgr=False,
|
||||
norm=False,
|
||||
rgba2yynn=False,
|
||||
eliminate_tail=False,
|
||||
opt_matmul=False,
|
||||
duplicate_shared_weights=True,
|
||||
) -> onnx.ModelProto:
|
||||
"""Optimize the onnx.
|
||||
|
||||
Args:
|
||||
m (ModelProto): the input onnx ModelProto
|
||||
disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False.
|
||||
bn_on_skip (bool, optional): add BN operator on skip branches. Defaults to False.
|
||||
bn_before_add (bool, optional): add BN before Add node on every branches. Defaults to False.
|
||||
bgr (bool, optional): add an Conv layer to convert rgb input to bgr. Defaults to False.
|
||||
norm (bool, optional): add an Conv layer to add 0.5 tp the input. Defaults to False.
|
||||
rgba2yynn (bool, optional): add an Conv layer to convert rgb input to yynn . Defaults to False.
|
||||
eliminate_tail (bool, optional): remove the trailing NPU unsupported nodes. Defaults to False.
|
||||
opt_matmul(bool, optional): optimize the MatMul layers according to the NPU limit. Defaults to False.
|
||||
duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True.
|
||||
disable_fuse_bn (bool, optional): do not fuse BN into Conv.
|
||||
Defaults to False.
|
||||
bn_on_skip (bool, optional): add BN operator on skip branches.
|
||||
Defaults to False.
|
||||
bn_before_add (bool, optional): add BN before Add node on every branch.
|
||||
Defaults to False.
|
||||
bgr (bool, optional): add an Conv layer to convert rgb input to bgr.
|
||||
Defaults to False.
|
||||
norm (bool, optional): add an Conv layer to add 0.5 tp the input.
|
||||
Defaults to False.
|
||||
rgba2yynn (bool, optional): add an Conv layer to convert rgb to yynn.
|
||||
Defaults to False.
|
||||
eliminate_tail (bool, optional): remove trailing NPU unsupported nodes.
|
||||
Defaults to False.
|
||||
opt_matmul(bool, optional): optimize MatMul layers due to NPU limit.
|
||||
Defaults to False.
|
||||
duplicate_shared_weights(bool, optional): duplicate shared weights.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
ModelProto: the optimized onnx model object.
|
||||
@ -79,28 +85,83 @@ def onnx2onnx_flow(m: onnx.ModelProto,
|
||||
|
||||
return m
|
||||
|
||||
|
||||
# Main process
|
||||
if __name__ == "__main__":
|
||||
# Argument parser
|
||||
parser = argparse.ArgumentParser(description="Optimize an ONNX model for Kneron compiler")
|
||||
parser.add_argument('in_file', help='input ONNX FILE')
|
||||
parser.add_argument('-o', '--output', dest='out_file', type=str, help="ouput ONNX FILE")
|
||||
parser.add_argument('--log', default='i', type=str, help="set log level")
|
||||
parser.add_argument('--bgr', action='store_true', default=False, help="set if the model is trained in BGR mode")
|
||||
parser.add_argument('--norm', action='store_true', default=False, help="set if you have the input -0.5~0.5")
|
||||
parser.add_argument('--rgba2yynn', action='store_true', default=False, help="set if the model has yynn input but you want to take rgba images")
|
||||
parser.add_argument('--add-bn-on-skip', dest='bn_on_skip', action='store_true', default=False,
|
||||
help="set if you only want to add BN on skip branches")
|
||||
parser.add_argument('--add-bn', dest='bn_before_add', action='store_true', default=False,
|
||||
help="set if you want to add BN before Add")
|
||||
parser.add_argument('-t', '--eliminate-tail-unsupported', dest='eliminate_tail', action='store_true', default=False,
|
||||
help='whether remove the last unsupported node for hardware')
|
||||
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
|
||||
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
|
||||
parser.add_argument('--opt-matmul', dest='opt_matmul', action='store_true', default=False,
|
||||
help="set if you want to optimize the MatMul operations for the kneron hardware.")
|
||||
parser.add_argument('--no-duplicate-shared-weights', dest='no_duplicate_shared_weights', action='store_true', default=False,
|
||||
help='do not duplicate shared weights. Defaults to False.')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Optimize an ONNX model for Kneron compiler"
|
||||
)
|
||||
parser.add_argument("in_file", help="input ONNX FILE")
|
||||
parser.add_argument(
|
||||
"-o", "--output", dest="out_file", type=str, help="ouput ONNX FILE"
|
||||
)
|
||||
parser.add_argument("--log", default="i", type=str, help="set log level")
|
||||
parser.add_argument(
|
||||
"--bgr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="set if the model is trained in BGR mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--norm",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="set if you have the input -0.5~0.5",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rgba2yynn",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="set if the model has yynn input but you want "
|
||||
"to take rgba images",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add-bn-on-skip",
|
||||
dest="bn_on_skip",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="set if you only want to add BN on skip branches",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add-bn",
|
||||
dest="bn_before_add",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="set if you want to add BN before Add",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--eliminate-tail-unsupported",
|
||||
dest="eliminate_tail",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="whether remove the last unsupported node for hardware",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-bn-fusion",
|
||||
dest="disable_fuse_bn",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="set if you have met errors which related to inferenced "
|
||||
"shape mismatch. This option will prevent fusing "
|
||||
"BatchNormalization into Conv.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opt-matmul",
|
||||
dest="opt_matmul",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="set if you want to optimize MatMul operations "
|
||||
"for kneron hardware.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-duplicate-shared-weights",
|
||||
dest="no_duplicate_shared_weights",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="do not duplicate shared weights. Defaults to False.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.out_file is None:
|
||||
@ -108,11 +169,11 @@ if __name__ == "__main__":
|
||||
else:
|
||||
outfile = args.out_file
|
||||
|
||||
if args.log == 'w':
|
||||
if args.log == "w":
|
||||
logging.basicConfig(level=logging.WARN)
|
||||
elif args.log == 'd':
|
||||
elif args.log == "d":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
elif args.log == 'e':
|
||||
elif args.log == "e":
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -131,6 +192,17 @@ if __name__ == "__main__":
|
||||
# Basic model organize
|
||||
m = onnx.load(args.in_file)
|
||||
|
||||
m = onnx2onnx_flow(m, args.disable_fuse_bn, args.bn_on_skip, args.bn_before_add, args.bgr, args.norm, args.rgba2yynn, args.eliminate_tail, args.opt_matmul, not args.no_duplicate_shared_weights)
|
||||
m = onnx2onnx_flow(
|
||||
m,
|
||||
args.disable_fuse_bn,
|
||||
args.bn_on_skip,
|
||||
args.bn_before_add,
|
||||
args.bgr,
|
||||
args.norm,
|
||||
args.rgba2yynn,
|
||||
args.eliminate_tail,
|
||||
args.opt_matmul,
|
||||
not args.no_duplicate_shared_weights,
|
||||
)
|
||||
|
||||
onnx.save(m, outfile)
|
||||
|
||||
@ -5,12 +5,30 @@ import numpy as np
|
||||
from tools import helper
|
||||
|
||||
|
||||
onnx2np_dtype = {0: 'float', 1: 'float32', 2: 'uint8', 3: 'int8', 4: 'uint16', 5: 'int16', 6: 'int32', 7: 'int64', 8: 'str', 9: 'bool', 10: 'float16', 11: 'double', 12: 'uint32', 13: 'uint64', 14: 'complex64', 15: 'complex128', 16: 'float'}
|
||||
onnx2np_dtype = {
|
||||
0: "float",
|
||||
1: "float32",
|
||||
2: "uint8",
|
||||
3: "int8",
|
||||
4: "uint16",
|
||||
5: "int16",
|
||||
6: "int32",
|
||||
7: "int64",
|
||||
8: "str",
|
||||
9: "bool",
|
||||
10: "float16",
|
||||
11: "double",
|
||||
12: "uint32",
|
||||
13: "uint64",
|
||||
14: "complex64",
|
||||
15: "complex128",
|
||||
16: "float",
|
||||
}
|
||||
|
||||
|
||||
def onnx_model_results(path_a, path_b, total_times=10):
|
||||
""" using onnxruntime to inference two onnx models' ouputs
|
||||
|
||||
"""using onnxruntime to inference two onnx models' ouputs
|
||||
|
||||
:onnx model paths: two model paths
|
||||
:total_times: inference times, default to be 10
|
||||
:returns: inference results of two models
|
||||
@ -22,13 +40,20 @@ def onnx_model_results(path_a, path_b, total_times=10):
|
||||
outputs_b = session_b.get_outputs()
|
||||
|
||||
# check outputs
|
||||
assert len(outputs_a) == len(outputs_b), 'Two models have different output numbers.'
|
||||
assert len(outputs_a) == len(
|
||||
outputs_b
|
||||
), "Two models have different output numbers."
|
||||
for i in range(len(outputs_a)):
|
||||
out_shape_a, out_shape_b = outputs_a[i].shape, outputs_b[i].shape
|
||||
out_shape_a = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_a))
|
||||
out_shape_b = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_b))
|
||||
assert out_shape_a == out_shape_b, 'Output {} has unmatched shapes'.format(i)
|
||||
|
||||
out_shape_a = list(
|
||||
map(lambda x: x if isinstance(x, int) else 1, out_shape_a)
|
||||
)
|
||||
out_shape_b = list(
|
||||
map(lambda x: x if isinstance(x, int) else 1, out_shape_b)
|
||||
)
|
||||
assert (
|
||||
out_shape_a == out_shape_b
|
||||
), "Output {} has unmatched shapes".format(i)
|
||||
|
||||
# load onnx graph_a and graph_b, to find the initializer and inputs
|
||||
# then compare to remove the items in the inputs which will be initialized
|
||||
@ -38,9 +63,16 @@ def onnx_model_results(path_a, path_b, total_times=10):
|
||||
init_a, init_b = graph_a.initializer, graph_b.initializer
|
||||
|
||||
# remove initializer from raw inputs
|
||||
input_names_a, input_names_b = set([ele.name for ele in inputs_a]), set([ele.name for ele in inputs_b])
|
||||
init_names_a, init_names_b = set([ele.name for ele in init_a]), set([ele.name for ele in init_b])
|
||||
real_inputs_names_a, real_inputs_names_b = input_names_a - init_names_a, input_names_b - init_names_b
|
||||
input_names_a, input_names_b = set([ele.name for ele in inputs_a]), set(
|
||||
[ele.name for ele in inputs_b]
|
||||
)
|
||||
init_names_a, init_names_b = set([ele.name for ele in init_a]), set(
|
||||
[ele.name for ele in init_b]
|
||||
)
|
||||
real_inputs_names_a, real_inputs_names_b = (
|
||||
input_names_a - init_names_a,
|
||||
input_names_b - init_names_b,
|
||||
)
|
||||
|
||||
# prepare and figure out matching of real inputs a and real inputs b
|
||||
# try to keep original orders of each inputs
|
||||
@ -61,17 +93,20 @@ def onnx_model_results(path_a, path_b, total_times=10):
|
||||
for item_a in real_inputs_a:
|
||||
size, shape = helper.find_size_shape_from_value(item_a)
|
||||
if size:
|
||||
assert real_single_input_a is None, 'Multiple inputs of first model, single input expected.'
|
||||
assert (
|
||||
real_single_input_a is None
|
||||
), "Multiple inputs of first model, single input expected."
|
||||
real_single_input_a = item_a
|
||||
size_a, shape_a = size, shape
|
||||
for item_b in real_inputs_b:
|
||||
size, shape = helper.find_size_shape_from_value(item_b)
|
||||
if size:
|
||||
assert real_single_input_b is None, 'Multiple inputs of second model, single input expected.'
|
||||
assert (
|
||||
real_single_input_b is None
|
||||
), "Multiple inputs of second model, single input expected."
|
||||
real_single_input_b = item_b
|
||||
size_b, shape_b = size, shape
|
||||
assert size_a == size_b, 'Sizes of two models do not match.'
|
||||
|
||||
assert size_a == size_b, "Sizes of two models do not match."
|
||||
|
||||
# construct inputs tensors
|
||||
input_data_type_a = real_single_input_a.type.tensor_type.elem_type
|
||||
@ -84,7 +119,7 @@ def onnx_model_results(path_a, path_b, total_times=10):
|
||||
results_a = [[] for i in range(len(outputs_a))]
|
||||
results_b = [[] for i in range(len(outputs_b))]
|
||||
while times < total_times:
|
||||
# initialize inputs by random data, default to be uniform
|
||||
# initialize inputs by random data, default to be uniform
|
||||
data = np.random.random(size_a)
|
||||
input_a = np.reshape(data, shape_a).astype(input_data_type_a)
|
||||
input_b = np.reshape(data, shape_b).astype(input_data_type_b)
|
||||
@ -93,12 +128,18 @@ def onnx_model_results(path_a, path_b, total_times=10):
|
||||
input_dict_b = {}
|
||||
for item_a in real_inputs_a:
|
||||
item_type_a = onnx2np_dtype[item_a.type.tensor_type.elem_type]
|
||||
input_dict_a[item_a.name] = np.array([]).astype(item_type_a) \
|
||||
if item_a.name != real_single_input_a.name else input_a
|
||||
input_dict_a[item_a.name] = (
|
||||
np.array([]).astype(item_type_a)
|
||||
if item_a.name != real_single_input_a.name
|
||||
else input_a
|
||||
)
|
||||
for item_b in real_inputs_b:
|
||||
item_type_b = onnx2np_dtype[item_b.type.tensor_type.elem_type]
|
||||
input_dict_b[item_b.name] = np.array([]).astype(item_type_b) \
|
||||
if item_b.name != real_single_input_b.name else input_b
|
||||
input_dict_b[item_b.name] = (
|
||||
np.array([]).astype(item_type_b)
|
||||
if item_b.name != real_single_input_b.name
|
||||
else input_b
|
||||
)
|
||||
|
||||
ra = session_a.run([], input_dict_a)
|
||||
rb = session_b.run([], input_dict_b)
|
||||
@ -109,26 +150,32 @@ def onnx_model_results(path_a, path_b, total_times=10):
|
||||
|
||||
return results_a, results_b
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Argument parser.
|
||||
parser = argparse.ArgumentParser(description="Compare two ONNX models to check if they have the same output.")
|
||||
parser.add_argument('in_file_a', help='input ONNX file a')
|
||||
parser.add_argument('in_file_b', help='input ONNX file b')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compare two ONNX models to check if "
|
||||
"they have the same output."
|
||||
)
|
||||
parser.add_argument("in_file_a", help="input ONNX file a")
|
||||
parser.add_argument("in_file_b", help="input ONNX file b")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
results_a, results_b = onnx_model_results(args.in_file_a, args.in_file_b, total_times=10)
|
||||
results_a, results_b = onnx_model_results(
|
||||
args.in_file_a, args.in_file_b, total_times=10
|
||||
)
|
||||
ra_flat = helper.flatten_with_depth(results_a, 0)
|
||||
rb_flat = helper.flatten_with_depth(results_b, 0)
|
||||
shape_a = [item[1] for item in ra_flat]
|
||||
shape_b = [item[1] for item in rb_flat]
|
||||
assert shape_a == shape_b, 'two results data shape doesn\'t match'
|
||||
assert shape_a == shape_b, "two results data shape doesn't match"
|
||||
ra_raw = [item[0] for item in ra_flat]
|
||||
rb_raw = [item[0] for item in rb_flat]
|
||||
|
||||
try:
|
||||
np.testing.assert_almost_equal(ra_raw, rb_raw, 4)
|
||||
print('Two models have the same behaviour.')
|
||||
print("Two models have the same behaviour.")
|
||||
except Exception as mismatch:
|
||||
print(mismatch)
|
||||
exit(1)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import onnx
|
||||
import argparse
|
||||
import glob
|
||||
import csv
|
||||
@ -8,214 +7,242 @@ import matplotlib.pyplot as plt
|
||||
from tools import helper
|
||||
import onnx_vs_onnx as onnx_tester
|
||||
|
||||
|
||||
def compare_results(results_a, results_b):
|
||||
""" compare onnx model inference results
|
||||
calculate basic statistical values
|
||||
results: results from inference multiple times
|
||||
returns: list of basic statistical values
|
||||
"""
|
||||
# input results data can be of nonuniform shape
|
||||
# get flatten data to compare
|
||||
ra_flat = helper.flatten_with_depth(results_a, 0)
|
||||
rb_flat = helper.flatten_with_depth(results_b, 0)
|
||||
shape_a = [item[1] for item in ra_flat]
|
||||
shape_b = [item[1] for item in rb_flat]
|
||||
assert shape_a == shape_b, 'two results data shape doesn\'t match'
|
||||
ra_raw = [item[0] for item in ra_flat]
|
||||
rb_raw = [item[0] for item in rb_flat]
|
||||
"""compare onnx model inference results
|
||||
calculate basic statistical values
|
||||
results: results from inference multiple times
|
||||
returns: list of basic statistical values
|
||||
"""
|
||||
# input results data can be of nonuniform shape
|
||||
# get flatten data to compare
|
||||
ra_flat = helper.flatten_with_depth(results_a, 0)
|
||||
rb_flat = helper.flatten_with_depth(results_b, 0)
|
||||
shape_a = [item[1] for item in ra_flat]
|
||||
shape_b = [item[1] for item in rb_flat]
|
||||
assert shape_a == shape_b, "two results data shape doesn't match"
|
||||
ra_raw = [item[0] for item in ra_flat]
|
||||
rb_raw = [item[0] for item in rb_flat]
|
||||
|
||||
# the statistical values
|
||||
max_rel_diff = 0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } )
|
||||
max_abs_diff = 0 # defined to be max( { abs(ra-rb) } )
|
||||
mean_rel_diff = 0
|
||||
mean_abs_diff = 0
|
||||
std_rel_diff = 0
|
||||
std_abs_diff = 0
|
||||
acc_with_diff_precision = []
|
||||
rel_diff = []
|
||||
abs_diff_percentiles = [] # rel_diff percentiles
|
||||
rel_diff_percentiles = [] # abs_diff precentiles
|
||||
# the statistical values
|
||||
max_rel_diff = (
|
||||
0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } )
|
||||
)
|
||||
max_abs_diff = 0 # defined to be max( { abs(ra-rb) } )
|
||||
mean_rel_diff = 0
|
||||
mean_abs_diff = 0
|
||||
std_rel_diff = 0
|
||||
std_abs_diff = 0
|
||||
acc_with_diff_precision = []
|
||||
rel_diff = []
|
||||
abs_diff_percentiles = [] # rel_diff percentiles
|
||||
rel_diff_percentiles = [] # abs_diff precentiles
|
||||
|
||||
raw_diff = [ra_raw[i]-rb_raw[i] for i in range(len(ra_raw))]
|
||||
abs_diff = [abs(num) for num in raw_diff]
|
||||
for i in range(len(ra_raw)):
|
||||
divider = max([abs(ra_raw[i]), abs(rb_raw[i])])
|
||||
val = abs_diff[i]/divider if divider != 0 else 0
|
||||
rel_diff.append(val)
|
||||
|
||||
max_rel_diff = max(rel_diff)
|
||||
max_abs_diff = max(abs_diff)
|
||||
mean_rel_diff = np.average(rel_diff)
|
||||
mean_abs_diff = np.average(abs_diff)
|
||||
std_rel_diff = np.std(rel_diff)
|
||||
std_abs_diff = np.std(abs_diff)
|
||||
|
||||
# calculate accuracy with different precison
|
||||
for digit in range(8):
|
||||
correct = 0
|
||||
raw_diff = [ra_raw[i] - rb_raw[i] for i in range(len(ra_raw))]
|
||||
abs_diff = [abs(num) for num in raw_diff]
|
||||
for i in range(len(ra_raw)):
|
||||
if format(ra_raw[i], '.'+str(digit)+'f')\
|
||||
== format(rb_raw[i], '.'+str(digit)+'f'):
|
||||
correct += 1
|
||||
acc_with_diff_precision.append([digit, float(format(correct/len(ra_raw), '.3f'))])
|
||||
divider = max([abs(ra_raw[i]), abs(rb_raw[i])])
|
||||
val = abs_diff[i] / divider if divider != 0 else 0
|
||||
rel_diff.append(val)
|
||||
|
||||
# analyze rel_diff distribution
|
||||
rel_diff.sort()
|
||||
abs_diff.sort()
|
||||
for i in range(20):
|
||||
rel_diff_percentiles.append(['{}%'.format(i*5), rel_diff[int((i/20)*len(rel_diff))]])
|
||||
abs_diff_percentiles.append(['{}%'.format(i*5), abs_diff[int((i/20)*len(abs_diff))]])
|
||||
max_rel_diff = max(rel_diff)
|
||||
max_abs_diff = max(abs_diff)
|
||||
mean_rel_diff = np.average(rel_diff)
|
||||
mean_abs_diff = np.average(abs_diff)
|
||||
std_rel_diff = np.std(rel_diff)
|
||||
std_abs_diff = np.std(abs_diff)
|
||||
|
||||
results = [
|
||||
['max_rel_diff', max_rel_diff],
|
||||
['max_abs_diff', max_abs_diff],
|
||||
['mean_rel_diff', mean_rel_diff],
|
||||
['mean_abs_diff', mean_abs_diff],
|
||||
['std_rel_diff', std_rel_diff],
|
||||
['std_abs_diff', std_abs_diff],
|
||||
['acc_with_diff_precision', acc_with_diff_precision],
|
||||
['rel_diff_percentiles', rel_diff_percentiles],
|
||||
['abs_diff_percentiles', abs_diff_percentiles]
|
||||
]
|
||||
|
||||
return results
|
||||
# calculate accuracy with different precison
|
||||
for digit in range(8):
|
||||
correct = 0
|
||||
for i in range(len(ra_raw)):
|
||||
if format(ra_raw[i], "." + str(digit) + "f") == format(
|
||||
rb_raw[i], "." + str(digit) + "f"
|
||||
):
|
||||
correct += 1
|
||||
acc_with_diff_precision.append(
|
||||
[digit, float(format(correct / len(ra_raw), ".3f"))]
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='test model optimization results')
|
||||
|
||||
parser.add_argument('dir', type=str, help='the directory that stores onnx models')
|
||||
parser.add_argument('ending1', type=str, help='model file name ending(eg, .onnx)')
|
||||
parser.add_argument('ending2', type=str, help='opt model file name ending(eg. _opt.onnx)')
|
||||
parser.add_argument('out_file', type=str, help='output csv file name')
|
||||
parser.add_argument('-p', '--plot', default='N', help='get plots (Y/N)')
|
||||
parser.add_argument('-i', '--iter_times', default=10, type=int, help='inference times')
|
||||
# analyze rel_diff distribution
|
||||
rel_diff.sort()
|
||||
abs_diff.sort()
|
||||
for i in range(20):
|
||||
rel_diff_percentiles.append(
|
||||
["{}%".format(i * 5), rel_diff[int((i / 20) * len(rel_diff))]]
|
||||
)
|
||||
abs_diff_percentiles.append(
|
||||
["{}%".format(i * 5), abs_diff[int((i / 20) * len(abs_diff))]]
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
results = [
|
||||
["max_rel_diff", max_rel_diff],
|
||||
["max_abs_diff", max_abs_diff],
|
||||
["mean_rel_diff", mean_rel_diff],
|
||||
["mean_abs_diff", mean_abs_diff],
|
||||
["std_rel_diff", std_rel_diff],
|
||||
["std_abs_diff", std_abs_diff],
|
||||
["acc_with_diff_precision", acc_with_diff_precision],
|
||||
["rel_diff_percentiles", rel_diff_percentiles],
|
||||
["abs_diff_percentiles", abs_diff_percentiles],
|
||||
]
|
||||
|
||||
old_models_paths = glob.glob(args.dir+'*'+args.ending1)
|
||||
new_models_paths = glob.glob(args.dir+'*'+args.ending2)
|
||||
|
||||
stats_table = [[
|
||||
'Model',
|
||||
'max_rel_diff',
|
||||
'max_abs_diff',
|
||||
'mean_rel_diff',
|
||||
'mean_abs_diff',
|
||||
'std_rel_diff',
|
||||
'std_abs_diff',
|
||||
'acc_with_diff_precision',
|
||||
'rel_diff_percentiles',
|
||||
'abs_diff_percentiles'
|
||||
]]
|
||||
|
||||
for new_model_path in new_models_paths:
|
||||
old_model_path = new_model_path[:-len(args.ending2)] + args.ending1
|
||||
if old_model_path not in old_models_paths:
|
||||
continue
|
||||
|
||||
# run inference
|
||||
results_a, results_b = onnx_tester.onnx_model_results(old_model_path, new_model_path, total_times=args.iter_times)
|
||||
|
||||
# compare inference results
|
||||
comparision = compare_results(results_a, results_b)
|
||||
|
||||
new_line = [old_model_path.split('/')[-1]]
|
||||
for item in comparision:
|
||||
new_line.append(item[1])
|
||||
|
||||
stats_table.append(new_line)
|
||||
|
||||
# try to read existing file
|
||||
old_stats_table = []
|
||||
try:
|
||||
old_file = open(args.out_file, 'r')
|
||||
reader = csv.reader(old_file)
|
||||
old_header = reader.__next__()
|
||||
for row in reader:
|
||||
old_stats_table.append(row)
|
||||
old_file.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
# compare and merge possible old stat data file with new stat data file
|
||||
header = stats_table[0]
|
||||
stats_table = stats_table[1:]
|
||||
new_model_names = set([item[0] for item in stats_table])
|
||||
for row in old_stats_table:
|
||||
if row[0] not in new_model_names:
|
||||
stats_table.append(row)
|
||||
stats_table.insert(0, header)
|
||||
|
||||
# write a new stat data file, overwrite old file
|
||||
new_file = open(args.out_file, 'w', newline='')
|
||||
writer = csv.writer(new_file)
|
||||
for row in stats_table:
|
||||
writer.writerow(row)
|
||||
new_file.close()
|
||||
|
||||
# make some plots
|
||||
if args.plot == 'Y':
|
||||
if len(stats_table) < 2:
|
||||
exit(0)
|
||||
|
||||
sample_table = stats_table[1:] if len(stats_table) < 6 else stats_table[1:6]
|
||||
|
||||
max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]]
|
||||
plt.hist(max_rel_diffs, bins=15)
|
||||
plt.title('Max Relavtive Difference Histogram')
|
||||
plt.xlabel('Max Relative Difference')
|
||||
plt.ylabel('Counts')
|
||||
plt.savefig('max_rel_diff_hist.png')
|
||||
plt.close()
|
||||
|
||||
max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]]
|
||||
plt.hist(max_abs_diffs, bins=15)
|
||||
plt.title('Max Absolute Difference Histogram')
|
||||
plt.xlabel('Max Absolute Difference')
|
||||
plt.ylabel('Counts')
|
||||
plt.savefig('max_abs_diff_hist.png')
|
||||
plt.close()
|
||||
|
||||
for line in sample_table:
|
||||
model_name = line[0]
|
||||
percentiles = line[-2]
|
||||
x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))]
|
||||
y = [ele[1] for ele in percentiles]
|
||||
plt.plot(x, y, label=model_name)
|
||||
plt.title('Rel_diff Percentiles of Raw and Optimized Models')
|
||||
plt.xlabel('percentage')
|
||||
plt.ylabel('relative difference')
|
||||
plt.legend()
|
||||
plt.savefig('rel_diff_percentiles.png')
|
||||
plt.close()
|
||||
|
||||
for line in sample_table:
|
||||
model_name = line[0]
|
||||
percentiles = line[-1]
|
||||
x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))]
|
||||
y = [ele[1] for ele in percentiles]
|
||||
plt.plot(x, y, label=model_name)
|
||||
plt.title('Abs_diff Percentiles of Raw and Optimized Models')
|
||||
plt.xlabel('percentage')
|
||||
plt.ylabel('absolute difference')
|
||||
plt.legend()
|
||||
plt.savefig('abs_diff_percentiles.png')
|
||||
plt.close()
|
||||
|
||||
for line in sample_table:
|
||||
model_name = line[0]
|
||||
accuracies = line[-3]
|
||||
x = [acc[0] for acc in accuracies]
|
||||
y = [acc[1] for acc in accuracies]
|
||||
plt.plot(x, y, label=model_name)
|
||||
plt.title('Accuracies with Different Precisions')
|
||||
plt.xlabel('Decimals')
|
||||
plt.ylabel('Precision')
|
||||
plt.legend()
|
||||
plt.savefig('precisions.png')
|
||||
plt.close()
|
||||
return results
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="test model optimization results"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"dir", type=str, help="the directory that stores onnx models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"ending1", type=str, help="model file name ending(eg, .onnx)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"ending2", type=str, help="opt model file name ending(eg. _opt.onnx)"
|
||||
)
|
||||
parser.add_argument("out_file", type=str, help="output csv file name")
|
||||
parser.add_argument("-p", "--plot", default="N", help="get plots (Y/N)")
|
||||
parser.add_argument(
|
||||
"-i", "--iter_times", default=10, type=int, help="inference times"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
old_models_paths = glob.glob(args.dir + "*" + args.ending1)
|
||||
new_models_paths = glob.glob(args.dir + "*" + args.ending2)
|
||||
|
||||
stats_table = [
|
||||
[
|
||||
"Model",
|
||||
"max_rel_diff",
|
||||
"max_abs_diff",
|
||||
"mean_rel_diff",
|
||||
"mean_abs_diff",
|
||||
"std_rel_diff",
|
||||
"std_abs_diff",
|
||||
"acc_with_diff_precision",
|
||||
"rel_diff_percentiles",
|
||||
"abs_diff_percentiles",
|
||||
]
|
||||
]
|
||||
|
||||
for new_model_path in new_models_paths:
|
||||
old_model_path = new_model_path[: -len(args.ending2)] + args.ending1
|
||||
if old_model_path not in old_models_paths:
|
||||
continue
|
||||
|
||||
# run inference
|
||||
results_a, results_b = onnx_tester.onnx_model_results(
|
||||
old_model_path, new_model_path, total_times=args.iter_times
|
||||
)
|
||||
|
||||
# compare inference results
|
||||
comparision = compare_results(results_a, results_b)
|
||||
|
||||
new_line = [old_model_path.split("/")[-1]]
|
||||
for item in comparision:
|
||||
new_line.append(item[1])
|
||||
|
||||
stats_table.append(new_line)
|
||||
|
||||
# try to read existing file
|
||||
old_stats_table = []
|
||||
try:
|
||||
old_file = open(args.out_file, "r")
|
||||
reader = csv.reader(old_file)
|
||||
old_header = reader.__next__()
|
||||
for row in reader:
|
||||
old_stats_table.append(row)
|
||||
old_file.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# compare and merge possible old stat data file with new stat data file
|
||||
header = stats_table[0]
|
||||
stats_table = stats_table[1:]
|
||||
new_model_names = set([item[0] for item in stats_table])
|
||||
for row in old_stats_table:
|
||||
if row[0] not in new_model_names:
|
||||
stats_table.append(row)
|
||||
stats_table.insert(0, header)
|
||||
|
||||
# write a new stat data file, overwrite old file
|
||||
new_file = open(args.out_file, "w", newline="")
|
||||
writer = csv.writer(new_file)
|
||||
for row in stats_table:
|
||||
writer.writerow(row)
|
||||
new_file.close()
|
||||
|
||||
# make some plots
|
||||
if args.plot == "Y":
|
||||
if len(stats_table) < 2:
|
||||
exit(0)
|
||||
|
||||
sample_table = (
|
||||
stats_table[1:] if len(stats_table) < 6 else stats_table[1:6]
|
||||
)
|
||||
|
||||
max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]]
|
||||
plt.hist(max_rel_diffs, bins=15)
|
||||
plt.title("Max Relavtive Difference Histogram")
|
||||
plt.xlabel("Max Relative Difference")
|
||||
plt.ylabel("Counts")
|
||||
plt.savefig("max_rel_diff_hist.png")
|
||||
plt.close()
|
||||
|
||||
max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]]
|
||||
plt.hist(max_abs_diffs, bins=15)
|
||||
plt.title("Max Absolute Difference Histogram")
|
||||
plt.xlabel("Max Absolute Difference")
|
||||
plt.ylabel("Counts")
|
||||
plt.savefig("max_abs_diff_hist.png")
|
||||
plt.close()
|
||||
|
||||
for line in sample_table:
|
||||
model_name = line[0]
|
||||
percentiles = line[-2]
|
||||
x = [
|
||||
round(i * (1 / len(percentiles)), 2)
|
||||
for i in range(len(percentiles))
|
||||
]
|
||||
y = [ele[1] for ele in percentiles]
|
||||
plt.plot(x, y, label=model_name)
|
||||
plt.title("Rel_diff Percentiles of Raw and Optimized Models")
|
||||
plt.xlabel("percentage")
|
||||
plt.ylabel("relative difference")
|
||||
plt.legend()
|
||||
plt.savefig("rel_diff_percentiles.png")
|
||||
plt.close()
|
||||
|
||||
for line in sample_table:
|
||||
model_name = line[0]
|
||||
percentiles = line[-1]
|
||||
x = [
|
||||
round(i * (1 / len(percentiles)), 2)
|
||||
for i in range(len(percentiles))
|
||||
]
|
||||
y = [ele[1] for ele in percentiles]
|
||||
plt.plot(x, y, label=model_name)
|
||||
plt.title("Abs_diff Percentiles of Raw and Optimized Models")
|
||||
plt.xlabel("percentage")
|
||||
plt.ylabel("absolute difference")
|
||||
plt.legend()
|
||||
plt.savefig("abs_diff_percentiles.png")
|
||||
plt.close()
|
||||
|
||||
for line in sample_table:
|
||||
model_name = line[0]
|
||||
accuracies = line[-3]
|
||||
x = [acc[0] for acc in accuracies]
|
||||
y = [acc[1] for acc in accuracies]
|
||||
plt.plot(x, y, label=model_name)
|
||||
plt.title("Accuracies with Different Precisions")
|
||||
plt.xlabel("Decimals")
|
||||
plt.ylabel("Precision")
|
||||
plt.legend()
|
||||
plt.savefig("precisions.png")
|
||||
plt.close()
|
||||
|
||||
@ -1,21 +1,10 @@
|
||||
import onnx
|
||||
import onnx.utils
|
||||
try:
|
||||
from onnx import optimizer
|
||||
except ImportError:
|
||||
import onnxoptimizer as optimizer
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import struct
|
||||
import logging
|
||||
import argparse
|
||||
|
||||
from tools import eliminating
|
||||
from tools import fusing
|
||||
from tools import replacing
|
||||
from tools import other
|
||||
from tools import combo
|
||||
from tools import special
|
||||
from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow
|
||||
|
||||
# Debug use
|
||||
@ -25,13 +14,28 @@ from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow
|
||||
# Generate a prototype onnx #
|
||||
######################################
|
||||
|
||||
parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler")
|
||||
parser.add_argument('in_file', help='input ONNX or PTH FILE')
|
||||
parser.add_argument('out_file', help="ouput ONNX FILE")
|
||||
parser.add_argument('--input-size', dest='input_size', nargs=3,
|
||||
help='if you using pth, please use this argument to set up the input size of the model. It should be in \'CH H W\' format, e.g. \'--input-size 3 256 512\'.')
|
||||
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
|
||||
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Optimize a Pytorch generated model for Kneron compiler"
|
||||
)
|
||||
parser.add_argument("in_file", help="input ONNX or PTH FILE")
|
||||
parser.add_argument("out_file", help="ouput ONNX FILE")
|
||||
parser.add_argument(
|
||||
"--input-size",
|
||||
dest="input_size",
|
||||
nargs=3,
|
||||
help="if you using pth, please use this argument to set up the input "
|
||||
"size of the model. It should be in 'CH H W' format, "
|
||||
"e.g. '--input-size 3 256 512'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-bn-fusion",
|
||||
dest="disable_fuse_bn",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="set if you have met errors which related to inferenced shape "
|
||||
"mismatch. This option will prevent fusing BatchNormalization "
|
||||
"into Conv.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -39,7 +43,7 @@ if len(args.in_file) <= 4:
|
||||
# When the filename is too short.
|
||||
logging.error("Invalid input file: {}".format(args.in_file))
|
||||
exit(1)
|
||||
elif args.in_file[-4:] == '.pth':
|
||||
elif args.in_file[-4:] == ".pth":
|
||||
# Pytorch pth case
|
||||
logging.warning("Converting from pth to onnx is not recommended.")
|
||||
onnx_in = args.out_file
|
||||
@ -47,21 +51,29 @@ elif args.in_file[-4:] == '.pth':
|
||||
from torch.autograd import Variable
|
||||
import torch
|
||||
import torch.onnx
|
||||
|
||||
# import torchvision
|
||||
# Standard ImageNet input - 3 channels, 224x224.
|
||||
# Values don't matter as we care about network structure.
|
||||
# But they can also be real inputs.
|
||||
if args.input_size is None:
|
||||
logging.error("\'--input-size\' is required for the pth input file.")
|
||||
logging.error("'--input-size' is required for the pth input file.")
|
||||
exit(1)
|
||||
dummy_input = Variable(torch.randn(1, int(args.input_size[0]), int(args.input_size[1]), int(args.input_size[2])))
|
||||
dummy_input = Variable(
|
||||
torch.randn(
|
||||
1,
|
||||
int(args.input_size[0]),
|
||||
int(args.input_size[1]),
|
||||
int(args.input_size[2]),
|
||||
)
|
||||
)
|
||||
# Obtain your model, it can be also constructed in your script explicitly.
|
||||
model = torch.load(sys.argv[1], map_location='cpu')
|
||||
model = torch.load(sys.argv[1], map_location="cpu")
|
||||
# model = torchvision.models.resnet34(pretrained=True)
|
||||
# Invoke export.
|
||||
# torch.save(model, "resnet34.pth")
|
||||
torch.onnx.export(model, dummy_input, args.out_file, opset_version=11)
|
||||
elif args.in_file[-4:] == 'onnx':
|
||||
elif args.in_file[-4:] == "onnx":
|
||||
onnx_in = args.in_file
|
||||
else:
|
||||
# When the file is neither an onnx or a pytorch pth.
|
||||
|
||||
@ -1,29 +1,22 @@
|
||||
import onnx
|
||||
import onnx.utils
|
||||
try:
|
||||
from onnx import optimizer
|
||||
except ImportError:
|
||||
import onnxoptimizer as optimizer
|
||||
import sys
|
||||
import numpy as np
|
||||
import struct
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
|
||||
from .tools import eliminating
|
||||
from .tools import fusing
|
||||
from .tools import replacing
|
||||
from .tools import other
|
||||
from .tools import combo
|
||||
from .tools import special
|
||||
|
||||
|
||||
# Define general pytorch exported onnx optimize process
|
||||
def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx.ModelProto:
|
||||
def torch_exported_onnx_flow(
|
||||
m: onnx.ModelProto, disable_fuse_bn=False
|
||||
) -> onnx.ModelProto:
|
||||
"""Optimize the Pytorch exported onnx.
|
||||
|
||||
Args:
|
||||
m (ModelProto): the input onnx model
|
||||
disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False.
|
||||
disable_fuse_bn (bool, optional): do not fuse BN into Conv.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
ModelProto: the optimized onnx model
|
||||
@ -38,20 +31,29 @@ def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx.
|
||||
|
||||
# Main Process
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler")
|
||||
parser.add_argument('in_file', help='input ONNX')
|
||||
parser.add_argument('out_file', help="ouput ONNX FILE")
|
||||
parser.add_argument('--log', default='i', type=str, help="set log level")
|
||||
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
|
||||
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Optimize a Pytorch generated model for Kneron compiler"
|
||||
)
|
||||
parser.add_argument("in_file", help="input ONNX")
|
||||
parser.add_argument("out_file", help="ouput ONNX FILE")
|
||||
parser.add_argument("--log", default="i", type=str, help="set log level")
|
||||
parser.add_argument(
|
||||
"--no-bn-fusion",
|
||||
dest="disable_fuse_bn",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="set if you have met errors which related to inferenced shape "
|
||||
"mismatch. This option will prevent fusing BatchNormalization "
|
||||
"into Conv.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.log == 'w':
|
||||
if args.log == "w":
|
||||
logging.basicConfig(level=logging.WARN)
|
||||
elif args.log == 'd':
|
||||
elif args.log == "d":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
elif args.log == 'e':
|
||||
elif args.log == "e":
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -60,7 +62,7 @@ if __name__ == "__main__":
|
||||
# When the filename is too short.
|
||||
logging.error("Invalid input file: {}".format(args.in_file))
|
||||
exit(1)
|
||||
elif args.in_file[-4:] == 'onnx':
|
||||
elif args.in_file[-4:] == "onnx":
|
||||
onnx_in = args.in_file
|
||||
else:
|
||||
# When the file is not an onnx file.
|
||||
|
||||
@ -8,7 +8,8 @@ import onnx.utils
|
||||
from tensorflow.python.platform import gfile
|
||||
from tools import combo, eliminating, replacing, other
|
||||
|
||||
def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
|
||||
|
||||
def tf2onnx_flow(pb_path: str, test_mode=False) -> onnx.ModelProto:
|
||||
"""Convert frozen graph pb file into onnx
|
||||
|
||||
Args:
|
||||
@ -21,34 +22,45 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
|
||||
Returns:
|
||||
onnx.ModelProto: converted onnx
|
||||
"""
|
||||
TF2ONNX_VERSION = int(tf2onnx.version.version.replace('.', ''))
|
||||
TF2ONNX_VERSION = int(tf2onnx.version.version.replace(".", ""))
|
||||
|
||||
if 160 <= TF2ONNX_VERSION:
|
||||
from tf2onnx import tf_loader
|
||||
else:
|
||||
from tf2onnx import loader as tf_loader
|
||||
|
||||
if pb_path[-3:] == '.pb':
|
||||
model_name = pb_path.split('/')[-1][:-3]
|
||||
|
||||
# always reset tensorflow session at begin
|
||||
if pb_path[-3:] == ".pb":
|
||||
model_name = pb_path.split("/")[-1][:-3]
|
||||
|
||||
# always reset tensorflow session at begin
|
||||
tf.reset_default_graph()
|
||||
|
||||
|
||||
with tf.Session() as sess:
|
||||
with gfile.FastGFile(pb_path, 'rb') as f:
|
||||
with gfile.FastGFile(pb_path, "rb") as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
sess.graph.as_default()
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
tf.import_graph_def(graph_def, name="")
|
||||
|
||||
if 160 <= int(tf2onnx.version.version.replace('.', '')):
|
||||
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions = tf2onnx.tf_utils.tflist_to_onnx(
|
||||
sess.graph,
|
||||
{})
|
||||
if 160 <= int(tf2onnx.version.version.replace(".", "")):
|
||||
(
|
||||
onnx_nodes,
|
||||
op_cnt,
|
||||
attr_cnt,
|
||||
output_shapes,
|
||||
dtypes,
|
||||
functions,
|
||||
) = tf2onnx.tf_utils.tflist_to_onnx(sess.graph, {})
|
||||
else:
|
||||
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tf2onnx.tfonnx.tflist_to_onnx(
|
||||
sess.graph.get_operations(),
|
||||
{})
|
||||
(
|
||||
onnx_nodes,
|
||||
op_cnt,
|
||||
attr_cnt,
|
||||
output_shapes,
|
||||
dtypes,
|
||||
) = tf2onnx.tfonnx.tflist_to_onnx(
|
||||
sess.graph.get_operations(), {}
|
||||
)
|
||||
|
||||
for n in onnx_nodes:
|
||||
if len(n.output) == 0:
|
||||
@ -59,12 +71,12 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
|
||||
nodes_outputs = set()
|
||||
|
||||
for n in onnx_nodes:
|
||||
if n.op_type == 'Placeholder':
|
||||
if n.op_type == "Placeholder":
|
||||
continue
|
||||
for input in n.input:
|
||||
nodes_inputs.add(input)
|
||||
for output in n.output:
|
||||
nodes_outputs.add(output)
|
||||
nodes_outputs.add(output)
|
||||
|
||||
graph_input_names = set()
|
||||
for input_name in nodes_inputs:
|
||||
@ -76,35 +88,43 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
|
||||
if n.input and n.input[0] not in nodes_outputs:
|
||||
continue
|
||||
if len(n.output) == 0:
|
||||
n.output.append(n.name + ':0')
|
||||
n.output.append(n.name + ":0")
|
||||
graph_output_names.add(n.output[0])
|
||||
else:
|
||||
output_name = n.output[0]
|
||||
if (output_name not in nodes_inputs) and (0 < len(n.input)):
|
||||
if (output_name not in nodes_inputs) and (
|
||||
0 < len(n.input)
|
||||
):
|
||||
graph_output_names.add(output_name)
|
||||
|
||||
logging.info('Model Inputs: %s', str(list(graph_input_names)))
|
||||
logging.info('Model Outputs: %s', str(list(graph_output_names)))
|
||||
logging.info("Model Inputs: %s", str(list(graph_input_names)))
|
||||
logging.info("Model Outputs: %s", str(list(graph_output_names)))
|
||||
|
||||
graph_def, inputs, outputs = tf_loader.from_graphdef(model_path=pb_path,
|
||||
input_names=list(graph_input_names),
|
||||
output_names=list(graph_output_names))
|
||||
graph_def, inputs, outputs = tf_loader.from_graphdef(
|
||||
model_path=pb_path,
|
||||
input_names=list(graph_input_names),
|
||||
output_names=list(graph_output_names),
|
||||
)
|
||||
|
||||
with tf.Graph().as_default() as tf_graph:
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
tf.import_graph_def(graph_def, name="")
|
||||
|
||||
if 160 <= TF2ONNX_VERSION:
|
||||
with tf_loader.tf_session(graph=tf_graph):
|
||||
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
|
||||
input_names=inputs,
|
||||
output_names=outputs,
|
||||
opset=11)
|
||||
onnx_graph = tf2onnx.tfonnx.process_tf_graph(
|
||||
tf_graph=tf_graph,
|
||||
input_names=inputs,
|
||||
output_names=outputs,
|
||||
opset=11,
|
||||
)
|
||||
else:
|
||||
with tf.Session(graph=tf_graph):
|
||||
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
|
||||
input_names=inputs,
|
||||
output_names=outputs,
|
||||
opset=11)
|
||||
onnx_graph = tf2onnx.tfonnx.process_tf_graph(
|
||||
tf_graph=tf_graph,
|
||||
input_names=inputs,
|
||||
output_names=outputs,
|
||||
opset=11,
|
||||
)
|
||||
|
||||
# Optimize with tf2onnx.optimizer
|
||||
onnx_graph = tf2onnx.optimizer.optimize_graph(onnx_graph)
|
||||
@ -115,7 +135,9 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
|
||||
model_proto = other.polish_model(model_proto)
|
||||
|
||||
else:
|
||||
raise Exception('expect .pb file as input, but got "' + str(pb_path) + '"')
|
||||
raise Exception(
|
||||
'expect .pb file as input, but got "' + str(pb_path) + '"'
|
||||
)
|
||||
|
||||
# rename
|
||||
m = model_proto
|
||||
@ -133,15 +155,26 @@ def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
|
||||
return m
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Convert tensorflow pb file to onnx file and optimized onnx file. Or just optimize tensorflow onnx file.')
|
||||
parser.add_argument('in_file', help='input file')
|
||||
parser.add_argument('out_file', help='output optimized model file')
|
||||
parser.add_argument('-t', '--test_mode', default=False, help='test mode will not eliminate shape changes after input')
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert tensorflow pb file to onnx file and optimized "
|
||||
"onnx file. Or just optimize tensorflow onnx file."
|
||||
)
|
||||
parser.add_argument("in_file", help="input file")
|
||||
parser.add_argument("out_file", help="output optimized model file")
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--test_mode",
|
||||
default=False,
|
||||
help="test mode will not eliminate shape changes after input",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.INFO)
|
||||
logging.basicConfig(
|
||||
stream=sys.stdout,
|
||||
format="[%(asctime)s] %(levelname)s: %(message)s",
|
||||
level=logging.INFO,
|
||||
)
|
||||
m = tf2onnx_flow(args.in_file, args.test_mode)
|
||||
onnx.save(m, args.out_file)
|
||||
logging.info('Save Optimized ONNX: %s', args.out_file)
|
||||
logging.info("Save Optimized ONNX: %s", args.out_file)
|
||||
|
||||
@ -6,6 +6,7 @@ import onnxruntime
|
||||
|
||||
from tools import helper
|
||||
|
||||
|
||||
def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
|
||||
# Setup onnx session and get meta data
|
||||
onnx_session = onnxruntime.InferenceSession(onnx_file, None)
|
||||
@ -21,21 +22,32 @@ def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
|
||||
tflite_session.allocate_tensors()
|
||||
tflite_inputs = tflite_session.get_input_details()
|
||||
tflite_outputs = tflite_session.get_output_details()
|
||||
tflite_input_shape = tflite_inputs[0]['shape']
|
||||
tflite_input_shape = tflite_inputs[0]["shape"]
|
||||
# Compare input shape
|
||||
assert(len(onnx_input_shape) == len(tflite_input_shape)), "TFLite and ONNX shape unmatch."
|
||||
assert(onnx_input_shape == [tflite_input_shape[0], tflite_input_shape[3], tflite_input_shape[1], tflite_input_shape[2]]), "TFLite and ONNX shape unmatch."
|
||||
assert len(onnx_input_shape) == len(
|
||||
tflite_input_shape
|
||||
), "TFLite and ONNX shape unmatch."
|
||||
assert onnx_input_shape == [
|
||||
tflite_input_shape[0],
|
||||
tflite_input_shape[3],
|
||||
tflite_input_shape[1],
|
||||
tflite_input_shape[2],
|
||||
], "TFLite and ONNX shape unmatch."
|
||||
# Generate random number and run
|
||||
tflite_results = []
|
||||
onnx_results = []
|
||||
for _ in range(total_times):
|
||||
# Generate input
|
||||
tflite_input_data = np.array(np.random.random_sample(tflite_input_shape), dtype=np.float32)
|
||||
tflite_input_data = np.array(
|
||||
np.random.random_sample(tflite_input_shape), dtype=np.float32
|
||||
)
|
||||
onnx_input_data = np.transpose(tflite_input_data, [0, 3, 1, 2])
|
||||
# Run tflite
|
||||
tflite_session.set_tensor(tflite_inputs[0]['index'], tflite_input_data)
|
||||
tflite_session.set_tensor(tflite_inputs[0]["index"], tflite_input_data)
|
||||
tflite_session.invoke()
|
||||
tflite_results.append(tflite_session.get_tensor(tflite_outputs[0]['index']))
|
||||
tflite_results.append(
|
||||
tflite_session.get_tensor(tflite_outputs[0]["index"])
|
||||
)
|
||||
# Run onnx
|
||||
onnx_input_dict = {onnx_inputs[0].name: onnx_input_data}
|
||||
onnx_results.append(onnx_session.run([], onnx_input_dict)[0])
|
||||
@ -43,26 +55,31 @@ def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
|
||||
return tflite_results, onnx_results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# Argument parser.
|
||||
parser = argparse.ArgumentParser(description="Compare a TFLite model and an ONNX model to check if they have the same output.")
|
||||
parser.add_argument('tflite_file', help='input tflite file')
|
||||
parser.add_argument('onnx_file', help='input ONNX file')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compare a TFLite model and an ONNX model to check "
|
||||
"if they have the same output."
|
||||
)
|
||||
parser.add_argument("tflite_file", help="input tflite file")
|
||||
parser.add_argument("onnx_file", help="input ONNX file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
results_a, results_b = compare_tflite_and_onnx(args.tflite_file, args.onnx_file, total_times=10)
|
||||
results_a, results_b = compare_tflite_and_onnx(
|
||||
args.tflite_file, args.onnx_file, total_times=10
|
||||
)
|
||||
ra_flat = helper.flatten_with_depth(results_a, 0)
|
||||
rb_flat = helper.flatten_with_depth(results_b, 0)
|
||||
shape_a = [item[1] for item in ra_flat]
|
||||
shape_b = [item[1] for item in rb_flat]
|
||||
assert shape_a == shape_b, 'two results data shape doesn\'t match'
|
||||
assert shape_a == shape_b, "two results data shape doesn't match"
|
||||
ra_raw = [item[0] for item in ra_flat]
|
||||
rb_raw = [item[0] for item in rb_flat]
|
||||
|
||||
try:
|
||||
np.testing.assert_almost_equal(ra_raw, rb_raw, 8)
|
||||
print('Two models have the same behaviour.')
|
||||
print("Two models have the same behaviour.")
|
||||
except Exception as mismatch:
|
||||
print(mismatch)
|
||||
exit(1)
|
||||
exit(1)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
import onnx.utils
|
||||
|
||||
try:
|
||||
from onnx import optimizer
|
||||
except ImportError:
|
||||
@ -15,16 +15,19 @@ from . import eliminating
|
||||
from . import fusing
|
||||
from . import constant_folding
|
||||
from . import removing_transpose
|
||||
from . import modhelper
|
||||
from .common_pattern import torch_pattern_match, tf_pattern_match
|
||||
from .helper import logger
|
||||
|
||||
def preprocess(model_proto, disable_fuse_bn=False, duplicate_shared_weights=True):
|
||||
|
||||
def preprocess(
|
||||
model_proto, disable_fuse_bn=False, duplicate_shared_weights=True
|
||||
):
|
||||
"""The most common used functions before other processing.
|
||||
|
||||
Args:
|
||||
model_proto: the original model input
|
||||
duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True.
|
||||
duplicate_shared_weights(bool, optional): duplicate shared weights.
|
||||
Defaults to True.
|
||||
|
||||
Return:
|
||||
the new model after preprocessing
|
||||
@ -65,22 +68,28 @@ def preprocess(model_proto, disable_fuse_bn=False, duplicate_shared_weights=True
|
||||
replacing.replace_initializer_with_Constant(model_proto.graph)
|
||||
other.topological_sort(model_proto.graph)
|
||||
m = other.polish_model(model_proto)
|
||||
passes = ['extract_constant_to_initializer',
|
||||
'eliminate_nop_dropout',
|
||||
'eliminate_deadend',
|
||||
'fuse_matmul_add_bias_into_gemm',
|
||||
'fuse_pad_into_conv']
|
||||
passes = [
|
||||
"extract_constant_to_initializer",
|
||||
"eliminate_nop_dropout",
|
||||
"eliminate_deadend",
|
||||
"fuse_matmul_add_bias_into_gemm",
|
||||
"fuse_pad_into_conv",
|
||||
]
|
||||
if not disable_fuse_bn:
|
||||
passes.append('fuse_bn_into_conv')
|
||||
passes.append("fuse_bn_into_conv")
|
||||
m = optimizer.optimize(m, passes)
|
||||
g = m.graph
|
||||
# Add name again since onnx optimizer higher than 1.7 may remove node names.
|
||||
# Add name again since onnx optimizer higher than 1.7 may remove node names
|
||||
other.add_name_to_node(g)
|
||||
if duplicate_shared_weights:
|
||||
replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=True)
|
||||
replacing.replace_initializer_with_Constant(
|
||||
g, duplicate_shared_weights=True
|
||||
)
|
||||
other.duplicate_param_shared_constant(g)
|
||||
else:
|
||||
replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=False)
|
||||
replacing.replace_initializer_with_Constant(
|
||||
g, duplicate_shared_weights=False
|
||||
)
|
||||
other.topological_sort(g)
|
||||
m = other.polish_model(m)
|
||||
g = m.graph
|
||||
@ -161,12 +170,12 @@ def pytorch_constant_folding(m):
|
||||
other.topological_sort(m.graph)
|
||||
while len(m.graph.value_info) != 0:
|
||||
m.graph.value_info.pop()
|
||||
|
||||
|
||||
m = other.inference_shapes(m)
|
||||
replacing.replace_shape_with_constant(m.graph)
|
||||
other.topological_sort(m.graph)
|
||||
m = torch_pattern_match(m)
|
||||
m = optimizer.optimize(m, ['eliminate_deadend'])
|
||||
m = optimizer.optimize(m, ["eliminate_deadend"])
|
||||
return m
|
||||
|
||||
|
||||
@ -206,7 +215,7 @@ def tensorflow_optimization(m):
|
||||
replacing.replace_shape_with_constant(m.graph)
|
||||
other.topological_sort(m.graph)
|
||||
m = tf_pattern_match(m)
|
||||
m = optimizer.optimize(m, ['eliminate_deadend'])
|
||||
m = optimizer.optimize(m, ["eliminate_deadend"])
|
||||
|
||||
eliminating.eliminate_consecutive_reshape(m.graph)
|
||||
eliminating.eliminate_Squeeze_before_Reshape(m.graph)
|
||||
@ -253,6 +262,6 @@ def postprocess(m):
|
||||
m = other.polish_model(m)
|
||||
|
||||
other.add_output_to_value_info(m.graph)
|
||||
m = optimizer.optimize(m, ['eliminate_deadend'])
|
||||
m.producer_name = 'kneron_formatter'
|
||||
m = optimizer.optimize(m, ["eliminate_deadend"])
|
||||
m.producer_name = "kneron_formatter"
|
||||
return m
|
||||
|
||||
@ -3,19 +3,20 @@ import numpy as np
|
||||
import onnx.helper
|
||||
import onnx.utils
|
||||
|
||||
from . import modhelper
|
||||
from . import helper
|
||||
from . import other
|
||||
|
||||
|
||||
def torch_pattern_match(m):
|
||||
# Create a map from optype to the nodes.
|
||||
optype2node = defaultdict(list)
|
||||
for node in m.graph.node:
|
||||
optype2node[node.op_type].append(node)
|
||||
for matmul_node in optype2node['MatMul']:
|
||||
for matmul_node in optype2node["MatMul"]:
|
||||
pattern_matmul_mul_add(m.graph, matmul_node)
|
||||
for resize_node in optype2node['Resize']:
|
||||
# torch nn.UpsamplingBilinear2d will be given us 4 input: "X, roi, scales, sizes"
|
||||
for resize_node in optype2node["Resize"]:
|
||||
# torch nn.UpsamplingBilinear2d will be given us 4 input:
|
||||
# "X, roi, scales, sizes"
|
||||
if len(resize_node.input) != 4:
|
||||
continue
|
||||
make_UpsamplingBilinear2d_value_info(m.graph, resize_node.name)
|
||||
@ -24,15 +25,17 @@ def torch_pattern_match(m):
|
||||
m = other.polish_model(m)
|
||||
return m
|
||||
|
||||
|
||||
def tf_pattern_match(m):
|
||||
# Create a map from optype to the nodes.
|
||||
optype2node = defaultdict(list)
|
||||
for node in m.graph.node:
|
||||
optype2node[node.op_type].append(node)
|
||||
for matmul_node in optype2node['MatMul']:
|
||||
for matmul_node in optype2node["MatMul"]:
|
||||
pattern_matmul_mul_add(m.graph, matmul_node)
|
||||
for resize_node in optype2node['Resize']:
|
||||
# In tensorflow2onnx, ReizeXXX will be given us 4 input: "X, roi, scales, sizes"
|
||||
for resize_node in optype2node["Resize"]:
|
||||
# In tensorflow2onnx, ReizeXXX will be given us 4 input:
|
||||
# "X, roi, scales, sizes"
|
||||
# and node output name will be given the "node name + :0"
|
||||
if len(resize_node.input) != 4:
|
||||
continue
|
||||
@ -42,24 +45,25 @@ def tf_pattern_match(m):
|
||||
m = other.polish_model(m)
|
||||
return m
|
||||
|
||||
|
||||
def pattern_matmul_mul_add(g, matmul_node):
|
||||
# Check node match - Mul node
|
||||
next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0])
|
||||
if len(next_nodes) != 1:
|
||||
return
|
||||
if next_nodes[0].op_type != 'Mul':
|
||||
if next_nodes[0].op_type != "Mul":
|
||||
return
|
||||
mul_node = next_nodes[0]
|
||||
# Check node match - Add node
|
||||
next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0])
|
||||
if len(next_nodes) != 1:
|
||||
return
|
||||
if next_nodes[0].op_type != 'Add':
|
||||
if next_nodes[0].op_type != "Add":
|
||||
return
|
||||
add_node = next_nodes[0]
|
||||
# Check Mul weight
|
||||
mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1])
|
||||
if mul_weight_node.op_type != 'Constant':
|
||||
if mul_weight_node.op_type != "Constant":
|
||||
return
|
||||
weight_size, mul_weight = helper.constant_to_list(mul_weight_node)
|
||||
for i in mul_weight:
|
||||
@ -68,15 +72,19 @@ def pattern_matmul_mul_add(g, matmul_node):
|
||||
channel = weight_size[0]
|
||||
# Check Add weight
|
||||
add_weight_node = helper.find_node_by_output_name(g, add_node.input[1])
|
||||
if add_weight_node.op_type != 'Constant':
|
||||
if add_weight_node.op_type != "Constant":
|
||||
return
|
||||
# Check MatMul weight to see if it need weight broadcast
|
||||
matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1])
|
||||
matmul_weight_node = helper.find_node_by_output_name(
|
||||
g, matmul_node.input[1]
|
||||
)
|
||||
matmul_weight = helper.constant_to_numpy(matmul_weight_node)
|
||||
if matmul_weight.shape[1] == 1:
|
||||
# Weight broadcast
|
||||
new_matmul_weight = np.tile(matmul_weight, channel)
|
||||
new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight)
|
||||
new_matmul_weight_node = helper.numpy_to_constant(
|
||||
matmul_weight_node.name, new_matmul_weight
|
||||
)
|
||||
g.node.remove(matmul_weight_node)
|
||||
g.node.extend([new_matmul_weight_node])
|
||||
value = helper.find_value_by_name(g, matmul_weight_node.output[0])
|
||||
@ -93,14 +101,14 @@ def pattern_matmul_mul_add(g, matmul_node):
|
||||
g.value_info.remove(value)
|
||||
# Fuse Matmul and Add
|
||||
gemm_node = onnx.helper.make_node(
|
||||
'Gemm',
|
||||
"Gemm",
|
||||
[matmul_node.input[0], matmul_node.input[1], add_node.input[1]],
|
||||
[add_node.output[0]],
|
||||
name = matmul_node.name,
|
||||
alpha = 1.0,
|
||||
beta = 1.0,
|
||||
transA = 0,
|
||||
transB = 0
|
||||
name=matmul_node.name,
|
||||
alpha=1.0,
|
||||
beta=1.0,
|
||||
transA=0,
|
||||
transB=0,
|
||||
)
|
||||
g.node.extend([gemm_node])
|
||||
# Clean up
|
||||
@ -111,6 +119,7 @@ def pattern_matmul_mul_add(g, matmul_node):
|
||||
g.value_info.remove(value)
|
||||
other.topological_sort(g)
|
||||
|
||||
|
||||
def make_UpsamplingBilinear2d_value_info(g, resize_node_name):
|
||||
resize_node = helper.find_node_by_node_name(g, resize_node_name)
|
||||
|
||||
@ -124,34 +133,45 @@ def make_UpsamplingBilinear2d_value_info(g, resize_node_name):
|
||||
new_output_value_info = onnx.helper.make_tensor_value_info(
|
||||
resize_node.output[0],
|
||||
onnx.helper.TensorProto.FLOAT,
|
||||
shape_data.tolist()
|
||||
shape_data.tolist(),
|
||||
)
|
||||
|
||||
g.value_info.extend([new_output_value_info])
|
||||
|
||||
|
||||
def polish_RESIZE_input_param_node(g, resize_node_name):
|
||||
resize_node = helper.find_node_by_node_name(g, resize_node_name)
|
||||
|
||||
shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3])
|
||||
shape_data = helper.constant_to_numpy(shape_data_node).astype(int)
|
||||
|
||||
# handle 0 batch size which is invalid
|
||||
|
||||
# handle 0 batch size which is invalid
|
||||
if shape_data[0] == 0:
|
||||
shape_data[0] = 1
|
||||
|
||||
pre_node_output_value_info = helper.find_value_by_name(g, resize_node.input[0])
|
||||
ori_shape = np.array([pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value,
|
||||
pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value,
|
||||
pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value,
|
||||
pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value])
|
||||
|
||||
resize_node.input.remove(resize_node.input[3])
|
||||
|
||||
pre_node_output_value_info = helper.find_value_by_name(
|
||||
g, resize_node.input[0]
|
||||
)
|
||||
ori_shape = np.array(
|
||||
[
|
||||
pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value,
|
||||
pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value,
|
||||
pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value,
|
||||
pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value,
|
||||
]
|
||||
)
|
||||
|
||||
resize_scales = np.array(shape_data/ori_shape).astype(float)
|
||||
resize_scale_node = helper.list_to_constant('resize_scales_node_' + resize_node.name, resize_scales.shape, resize_scales, data_type=onnx.helper.TensorProto.FLOAT)
|
||||
resize_node.input.remove(resize_node.input[3])
|
||||
|
||||
resize_scales = np.array(shape_data / ori_shape).astype(float)
|
||||
resize_scale_node = helper.list_to_constant(
|
||||
"resize_scales_node_" + resize_node.name,
|
||||
resize_scales.shape,
|
||||
resize_scales,
|
||||
data_type=onnx.helper.TensorProto.FLOAT,
|
||||
)
|
||||
|
||||
resize_node.input[2] = resize_scale_node.name
|
||||
g.node.extend([resize_scale_node])
|
||||
|
||||
|
||||
other.topological_sort(g)
|
||||
|
||||
@ -5,15 +5,14 @@ import logging
|
||||
import traceback
|
||||
|
||||
from . import helper
|
||||
from .general_graph import Graph, Node
|
||||
from .other import topological_sort
|
||||
from .replacing import replace_shape_with_constant
|
||||
from .helper import logger
|
||||
|
||||
|
||||
def are_all_inputs_Constant_with_one_child(g, node):
|
||||
for input_name in node.input:
|
||||
input_node = helper.find_node_by_output_name(g, input_name)
|
||||
if input_node is None or input_node.op_type != 'Constant':
|
||||
if input_node is None or input_node.op_type != "Constant":
|
||||
return False
|
||||
relative_outputs = helper.find_nodes_by_input_name(g, input_name)
|
||||
if len(relative_outputs) > 1:
|
||||
@ -28,7 +27,7 @@ def constant_folding(g):
|
||||
:return: If any node is folded, return True. Otherwise, return False.
|
||||
"""
|
||||
keep_folding = True # Keep the while loop
|
||||
folded = False # Return value
|
||||
folded = False # Return value
|
||||
try:
|
||||
# Before constant folding, duplicate the constant nodes.
|
||||
duplicate_constant_node(g)
|
||||
@ -38,37 +37,47 @@ def constant_folding(g):
|
||||
# Check if the node is foldable
|
||||
if node.op_type not in constant_folding_nodes.keys():
|
||||
continue
|
||||
# Check if the parents of the node are all single follower constant node.
|
||||
# Check if parents of the node are all
|
||||
# single follower constant node.
|
||||
if not are_all_inputs_Constant_with_one_child(g, node):
|
||||
continue
|
||||
# Constant folding for the specific node
|
||||
if constant_folding_nodes[node.op_type](g, node):
|
||||
logging.debug("Constant nodes and %s %s are folded.",
|
||||
node.op_type, node.name)
|
||||
logging.debug(
|
||||
"Constant nodes and %s %s are folded.",
|
||||
node.op_type,
|
||||
node.name,
|
||||
)
|
||||
folded = True
|
||||
keep_folding = True
|
||||
else:
|
||||
logging.debug(
|
||||
"Constant nodes and %s %s are skipped.", node.op_type, node.name)
|
||||
except Exception as e:
|
||||
"Constant nodes and %s %s are skipped.",
|
||||
node.op_type,
|
||||
node.name,
|
||||
)
|
||||
except Exception:
|
||||
logger.error("An exception is raised while constant folding.")
|
||||
logger.error(traceback.format_exc())
|
||||
return folded
|
||||
|
||||
|
||||
|
||||
def duplicate_constant_node(g):
|
||||
""" Duplicate the constant node if its following nodes contain constant folding
|
||||
nodes. Create and link the new constant nodes to the constant folding nodes.
|
||||
"""
|
||||
Duplicate the constant node if its following nodes contain
|
||||
constant folding nodes. Create and link the new constant nodes
|
||||
to the constant folding nodes.
|
||||
"""
|
||||
for node in g.node:
|
||||
# Find a valid constant node
|
||||
if node.op_type != 'Constant':
|
||||
if node.op_type != "Constant":
|
||||
continue
|
||||
output_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
if output_val_info is None:
|
||||
print("Cannot inference the shape of Const node output: " +
|
||||
node.output[0])
|
||||
print(
|
||||
"Cannot inference the shape of Const node output: "
|
||||
+ node.output[0]
|
||||
)
|
||||
exit(1)
|
||||
data_shape = helper.get_shape_from_value_info(output_val_info)
|
||||
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
@ -78,30 +87,37 @@ def duplicate_constant_node(g):
|
||||
continue
|
||||
|
||||
# Check if its following nodes are foldable
|
||||
foldable_output_nodes = list(filter(lambda n: n.op_type in
|
||||
constant_folding_nodes.keys(), output_nodes))
|
||||
foldable_output_nodes = list(
|
||||
filter(
|
||||
lambda n: n.op_type in constant_folding_nodes.keys(),
|
||||
output_nodes,
|
||||
)
|
||||
)
|
||||
if not foldable_output_nodes:
|
||||
continue
|
||||
|
||||
# Duplicate the node needed by foldable nodes
|
||||
for i in range(len(foldable_output_nodes)):
|
||||
logging.debug("Found constant %s and %s %s are availble for folding. Duplicate constant.",
|
||||
node.name, foldable_output_nodes[i].op_type, foldable_output_nodes[i].name)
|
||||
output_name = node.output[0] + '_dup_' + str(i)
|
||||
logging.debug(
|
||||
f"Found constant {node.name} and "
|
||||
f"{foldable_output_nodes[i].op_type} "
|
||||
f"{foldable_output_nodes[i].name} are availble for folding. "
|
||||
"Duplicate constant.",
|
||||
)
|
||||
output_name = node.output[0] + "_dup_" + str(i)
|
||||
new_constant_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
"Constant",
|
||||
[],
|
||||
[output_name],
|
||||
name=output_name,
|
||||
value=node.attribute[0].t
|
||||
value=node.attribute[0].t,
|
||||
)
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
output_name,
|
||||
node.attribute[0].t.data_type,
|
||||
data_shape
|
||||
output_name, node.attribute[0].t.data_type, data_shape
|
||||
)
|
||||
input_ind = list(foldable_output_nodes[i].input).index(
|
||||
node.output[0])
|
||||
node.output[0]
|
||||
)
|
||||
foldable_output_nodes[i].input[input_ind] = output_name
|
||||
|
||||
g.node.extend([new_constant_node])
|
||||
@ -116,6 +132,7 @@ def duplicate_constant_node(g):
|
||||
|
||||
return
|
||||
|
||||
|
||||
def slice_constant_folding(g, node):
|
||||
op_version = helper.get_current_opset_version()
|
||||
# only support opset 9 & 11
|
||||
@ -124,9 +141,9 @@ def slice_constant_folding(g, node):
|
||||
elif op_version == 9:
|
||||
return slice_constant_folding_Opset_9(g, node)
|
||||
|
||||
|
||||
def slice_constant_folding_Opset_11(g, node):
|
||||
""" Fold constant and slice nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and slice nodes to a single constant node."""
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_shape, data_list = helper.constant_to_list(pre_node)
|
||||
|
||||
@ -136,20 +153,26 @@ def slice_constant_folding_Opset_11(g, node):
|
||||
ends_node = helper.find_node_by_output_name(g, node.input[2])
|
||||
_, ends = helper.constant_to_list(ends_node)
|
||||
|
||||
|
||||
axes_node = None if len(node.input) <= 3 else helper.find_node_by_output_name(g, node.input[3])
|
||||
axes_node = (
|
||||
None
|
||||
if len(node.input) <= 3
|
||||
else helper.find_node_by_output_name(g, node.input[3])
|
||||
)
|
||||
if not axes_node:
|
||||
axes = list(range(len(helper.get_shape(data_list))))
|
||||
else:
|
||||
_, axes = helper.constant_to_list(axes_node)
|
||||
|
||||
steps_node = None if len(node.input) <= 4 else helper.find_node_by_output_name(g, node.input[4])
|
||||
steps_node = (
|
||||
None
|
||||
if len(node.input) <= 4
|
||||
else helper.find_node_by_output_name(g, node.input[4])
|
||||
)
|
||||
if not steps_node:
|
||||
steps = [1]*len(helper.get_shape(data_list))
|
||||
steps = [1] * len(helper.get_shape(data_list))
|
||||
else:
|
||||
_, steps = helper.constant_to_list(steps_node)
|
||||
|
||||
|
||||
data_list = list(map(int, data_list))
|
||||
starts = list(map(int, starts))
|
||||
ends = list(map(int, ends))
|
||||
@ -160,10 +183,15 @@ def slice_constant_folding_Opset_11(g, node):
|
||||
|
||||
new_data = None
|
||||
for idx, _ in enumerate(axes):
|
||||
new_data = np.apply_along_axis( lambda x: x[starts[idx] : ends[idx] : steps[idx]], idx, data_list )
|
||||
new_data = np.apply_along_axis(
|
||||
lambda x: x[starts[idx]:ends[idx]:steps[idx]], idx, data_list
|
||||
)
|
||||
|
||||
new_node = helper.list_to_constant(node.output[0], helper.get_shape(
|
||||
new_data), helper.flatten_to_list(new_data))
|
||||
new_node = helper.list_to_constant(
|
||||
node.output[0],
|
||||
helper.get_shape(new_data),
|
||||
helper.flatten_to_list(new_data),
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
value_info = helper.find_value_by_name(g, pre_node.output[0])
|
||||
if value_info is not None:
|
||||
@ -173,16 +201,16 @@ def slice_constant_folding_Opset_11(g, node):
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def slice_constant_folding_Opset_9(g, node):
|
||||
""" Fold constant and slice nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and slice nodes to a single constant node."""
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_shape, data_list = helper.constant_to_list(pre_node)
|
||||
|
||||
data_list = np.reshape(data_list, pre_shape)
|
||||
axes = helper.get_attribute_by_name(node, 'axes')
|
||||
ends = list(helper.get_attribute_by_name(node, 'ends').ints)
|
||||
starts = list(helper.get_attribute_by_name(node, 'starts').ints)
|
||||
axes = helper.get_attribute_by_name(node, "axes")
|
||||
ends = list(helper.get_attribute_by_name(node, "ends").ints)
|
||||
starts = list(helper.get_attribute_by_name(node, "starts").ints)
|
||||
|
||||
if not axes:
|
||||
axes = list(range(len(helper.get_shape(data_list))))
|
||||
@ -190,8 +218,11 @@ def slice_constant_folding_Opset_9(g, node):
|
||||
axes = list(axes.ints)
|
||||
|
||||
new_data = helper.slice_data(data_list, starts, ends, axes)
|
||||
new_node = helper.list_to_constant(node.output[0], helper.get_shape(
|
||||
new_data), helper.flatten_to_list(new_data))
|
||||
new_node = helper.list_to_constant(
|
||||
node.output[0],
|
||||
helper.get_shape(new_data),
|
||||
helper.flatten_to_list(new_data),
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
value_info = helper.find_value_by_name(g, pre_node.output[0])
|
||||
if value_info is not None:
|
||||
@ -201,9 +232,9 @@ def slice_constant_folding_Opset_9(g, node):
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def cast_constant_folding(g, node):
|
||||
""" Fold constant and cast node to a single constant node.
|
||||
"""
|
||||
"""Fold constant and cast node to a single constant node."""
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
shape, data = helper.constant_to_list(pre_node)
|
||||
data_type = node.attribute[0].i
|
||||
@ -212,28 +243,24 @@ def cast_constant_folding(g, node):
|
||||
elif data_type == onnx.helper.TensorProto.FLOAT:
|
||||
data = list(map(float, data))
|
||||
else:
|
||||
raise RuntimeError('data type not supported')
|
||||
raise RuntimeError("data type not supported")
|
||||
|
||||
if shape == 1:
|
||||
tensor = onnx.helper.make_tensor(
|
||||
name=pre_node.attribute[0].name,
|
||||
data_type=data_type,
|
||||
dims=[],
|
||||
vals=data
|
||||
vals=data,
|
||||
)
|
||||
else:
|
||||
tensor = onnx.helper.make_tensor(
|
||||
name=pre_node.attribute[0].name,
|
||||
data_type=data_type,
|
||||
dims=shape,
|
||||
vals=helper.flatten_to_list(data)
|
||||
vals=helper.flatten_to_list(data),
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=tensor
|
||||
"Constant", [], [node.output[0]], name=node.output[0], value=tensor
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
|
||||
@ -250,15 +277,14 @@ def cast_constant_folding(g, node):
|
||||
|
||||
|
||||
def reduceprod_constant_folding(g, node):
|
||||
""" Fold constant and reduceprod nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and reduceprod nodes to a single constant node."""
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
shape, data_set = helper.constant_to_list(pre_node)
|
||||
tensor = pre_node.attribute[0].t
|
||||
|
||||
data_set = np.reshape(data_set, shape)
|
||||
for att in node.attribute:
|
||||
if att.name == 'axes':
|
||||
if att.name == "axes":
|
||||
axes = list(att.ints)
|
||||
else:
|
||||
keepdims = int(att.i)
|
||||
@ -270,14 +296,10 @@ def reduceprod_constant_folding(g, node):
|
||||
name=node.output[0],
|
||||
data_type=tensor.data_type,
|
||||
dims=new_shape,
|
||||
vals=new_flat_data
|
||||
vals=new_flat_data,
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
|
||||
)
|
||||
|
||||
g.node.extend([new_node])
|
||||
@ -294,8 +316,7 @@ def reduceprod_constant_folding(g, node):
|
||||
|
||||
|
||||
def reshape_constant_input_folding(g, node):
|
||||
""" Fold constant and reshape nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and reshape nodes to a single constant node."""
|
||||
pre_data_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_shape_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
|
||||
@ -307,14 +328,10 @@ def reshape_constant_input_folding(g, node):
|
||||
name=node.output[0],
|
||||
data_type=pre_data_node.attribute[0].t.data_type,
|
||||
dims=new_data.shape,
|
||||
vals=helper.flatten_to_list(new_data)
|
||||
vals=helper.flatten_to_list(new_data),
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
|
||||
@ -332,8 +349,7 @@ def reshape_constant_input_folding(g, node):
|
||||
|
||||
|
||||
def concat_constant_folding(g, node):
|
||||
""" Fold constant and concat nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and concat nodes to a single constant node."""
|
||||
node_to_del = []
|
||||
valid_inputs = True
|
||||
for input_name in node.input:
|
||||
@ -342,7 +358,7 @@ def concat_constant_folding(g, node):
|
||||
if len(input_node_output) > 1:
|
||||
valid_inputs = False
|
||||
break
|
||||
if input_node.op_type != 'Constant':
|
||||
if input_node.op_type != "Constant":
|
||||
valid_inputs = False
|
||||
break
|
||||
|
||||
@ -370,7 +386,7 @@ def concat_constant_folding(g, node):
|
||||
node.output[0],
|
||||
helper.get_shape(concat_data),
|
||||
helper.flatten_to_list(concat_data),
|
||||
data_type=node_data_type
|
||||
data_type=node_data_type,
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
node_to_del.append(node)
|
||||
@ -388,8 +404,7 @@ def concat_constant_folding(g, node):
|
||||
|
||||
|
||||
def transpose_constant_folding(g, node):
|
||||
"""Fold constant and transpose nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and transpose nodes to a single constant node."""
|
||||
node_to_del = []
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
shape, data = helper.constant_to_list(pre_node)
|
||||
@ -402,7 +417,7 @@ def transpose_constant_folding(g, node):
|
||||
node.output[0],
|
||||
new_shape,
|
||||
new_data.flatten().tolist(),
|
||||
data_type=pre_node.attribute[0].t.data_type
|
||||
data_type=pre_node.attribute[0].t.data_type,
|
||||
)
|
||||
|
||||
g.node.extend([new_node])
|
||||
@ -415,9 +430,7 @@ def transpose_constant_folding(g, node):
|
||||
g.value_info.remove(next_val_info)
|
||||
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
node.output[0],
|
||||
pre_node.attribute[0].t.data_type,
|
||||
new_shape
|
||||
node.output[0], pre_node.attribute[0].t.data_type, new_shape
|
||||
)
|
||||
g.value_info.extend([new_val_info])
|
||||
|
||||
@ -430,8 +443,7 @@ def transpose_constant_folding(g, node):
|
||||
|
||||
|
||||
def unsqueeze_constant_folding(g, node):
|
||||
"""Fold constant and unsqueeze nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and unsqueeze nodes to a single constant node."""
|
||||
node_to_del = []
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
shape, data = helper.constant_to_list(pre_node)
|
||||
@ -449,7 +461,7 @@ def unsqueeze_constant_folding(g, node):
|
||||
node.output[0],
|
||||
new_shape,
|
||||
np_data.flatten().tolist(),
|
||||
data_type=pre_node.attribute[0].t.data_type
|
||||
data_type=pre_node.attribute[0].t.data_type,
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
node_to_del.extend([node, pre_node])
|
||||
@ -464,9 +476,7 @@ def unsqueeze_constant_folding(g, node):
|
||||
g.value_info.remove(next_val_info)
|
||||
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
node.output[0],
|
||||
pre_node.attribute[0].t.data_type,
|
||||
new_shape
|
||||
node.output[0], pre_node.attribute[0].t.data_type, new_shape
|
||||
)
|
||||
g.value_info.extend([new_val_info])
|
||||
|
||||
@ -478,8 +488,7 @@ def unsqueeze_constant_folding(g, node):
|
||||
|
||||
|
||||
def gather_constant_folding(g, node):
|
||||
"""Fold constant and gather nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and gather nodes to a single constant node."""
|
||||
node_to_del = []
|
||||
|
||||
pre_data_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
@ -502,7 +511,7 @@ def gather_constant_folding(g, node):
|
||||
node.output[0],
|
||||
new_shape,
|
||||
new_data.flatten().tolist(),
|
||||
data_type=pre_data_node.attribute[0].t.data_type
|
||||
data_type=pre_data_node.attribute[0].t.data_type,
|
||||
)
|
||||
|
||||
node_to_del.extend([node, pre_data_node, pre_indices_node])
|
||||
@ -512,9 +521,7 @@ def gather_constant_folding(g, node):
|
||||
val_info_2 = helper.find_value_by_name(g, node.input[1])
|
||||
val_info_3 = helper.find_value_by_name(g, node.output[0])
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
new_node.output[0],
|
||||
pre_data_node.attribute[0].t.data_type,
|
||||
new_shape
|
||||
new_node.output[0], pre_data_node.attribute[0].t.data_type, new_shape
|
||||
)
|
||||
|
||||
if val_info_1 is not None:
|
||||
@ -533,8 +540,7 @@ def gather_constant_folding(g, node):
|
||||
|
||||
|
||||
def add_constant_folding(g, node):
|
||||
"""Fold constant and add nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and add nodes to a single constant node."""
|
||||
node_to_del = []
|
||||
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
|
||||
@ -547,14 +553,14 @@ def add_constant_folding(g, node):
|
||||
np_data2 = np.reshape(data2, shape2)
|
||||
try:
|
||||
new_data = np.add(np_data1, np_data2)
|
||||
except:
|
||||
raise RuntimeError('can\'t broadcast and add two data sets')
|
||||
except Exception:
|
||||
raise RuntimeError("can't broadcast and add two data sets")
|
||||
|
||||
new_node = helper.list_to_constant(
|
||||
node.output[0],
|
||||
new_data.shape,
|
||||
new_data.flatten().tolist(),
|
||||
data_type=pre_node_1.attribute[0].t.data_type
|
||||
data_type=pre_node_1.attribute[0].t.data_type,
|
||||
)
|
||||
|
||||
g.node.extend([new_node])
|
||||
@ -571,8 +577,7 @@ def add_constant_folding(g, node):
|
||||
|
||||
|
||||
def sqrt_constant_folding(g, node):
|
||||
""" Fold constant and sqrt nodes to a single node.
|
||||
"""
|
||||
"""Fold constant and sqrt nodes to a single node."""
|
||||
node_to_del = []
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
shape, data = helper.constant_to_list(pre_node)
|
||||
@ -582,17 +587,13 @@ def sqrt_constant_folding(g, node):
|
||||
data_type = output_val_info.type.tensor_type.elem_type
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0]+'_data',
|
||||
name=node.output[0] + "_data",
|
||||
data_type=data_type,
|
||||
dims=shape,
|
||||
vals=np_data.flatten().tolist()
|
||||
vals=np_data.flatten().tolist(),
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
|
||||
)
|
||||
|
||||
g.value_info.remove(input_val_info)
|
||||
@ -607,13 +608,12 @@ def sqrt_constant_folding(g, node):
|
||||
|
||||
|
||||
def reciprocal_constant_folding(g, node):
|
||||
""" Fold constant and reciprocal nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and reciprocal nodes to a single constant node."""
|
||||
node_to_del = []
|
||||
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
shape, data = helper.constant_to_list(pre_node)
|
||||
data = list(map(lambda x: x if abs(x) > 1.e-8 else 1.e-8, data))
|
||||
data = list(map(lambda x: x if abs(x) > 1.0e-8 else 1.0e-8, data))
|
||||
np_data = np.reshape(data, shape)
|
||||
np_data = np.reciprocal(np_data)
|
||||
|
||||
@ -622,17 +622,13 @@ def reciprocal_constant_folding(g, node):
|
||||
data_type = output_val_info.type.tensor_type.elem_type
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0]+'_data',
|
||||
name=node.output[0] + "_data",
|
||||
data_type=data_type,
|
||||
dims=shape,
|
||||
vals=np_data.flatten().tolist()
|
||||
vals=np_data.flatten().tolist(),
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
|
||||
)
|
||||
|
||||
node_to_del.extend([node, pre_node])
|
||||
@ -648,8 +644,7 @@ def reciprocal_constant_folding(g, node):
|
||||
|
||||
|
||||
def mul_constant_folding(g, node):
|
||||
""" Fold constant and mul nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and mul nodes to a single constant node."""
|
||||
node_to_del = []
|
||||
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
|
||||
@ -666,8 +661,8 @@ def mul_constant_folding(g, node):
|
||||
|
||||
try:
|
||||
new_data = np.multiply(np_data1, np_data2)
|
||||
except:
|
||||
raise RuntimeError('can not broadcast and multiply two data sets')
|
||||
except Exception:
|
||||
raise RuntimeError("can not broadcast and multiply two data sets")
|
||||
|
||||
# Special shape for single element.
|
||||
if shape1 == 1 and shape2 == 1:
|
||||
@ -676,17 +671,13 @@ def mul_constant_folding(g, node):
|
||||
new_shape = new_data.shape
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0]+'_data',
|
||||
name=node.output[0] + "_data",
|
||||
data_type=pre_node_1.attribute[0].t.data_type,
|
||||
dims=new_shape,
|
||||
vals=new_data.flatten().tolist()
|
||||
vals=new_data.flatten().tolist(),
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
|
||||
)
|
||||
|
||||
node_to_del.extend([node, pre_node_1, pre_node_2])
|
||||
@ -703,8 +694,7 @@ def mul_constant_folding(g, node):
|
||||
|
||||
|
||||
def div_constant_folding(g, node):
|
||||
""" Fold constant and mul nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and mul nodes to a single constant node."""
|
||||
node_to_del = []
|
||||
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
|
||||
@ -721,8 +711,8 @@ def div_constant_folding(g, node):
|
||||
|
||||
try:
|
||||
new_data = np.divide(np_data1, np_data2)
|
||||
except:
|
||||
raise RuntimeError('can not broadcast and multiply two data sets')
|
||||
except Exception:
|
||||
raise RuntimeError("can not broadcast and multiply two data sets")
|
||||
|
||||
# Special shape for single element.
|
||||
if shape1 == 1 and shape2 == 1:
|
||||
@ -732,20 +722,16 @@ def div_constant_folding(g, node):
|
||||
|
||||
# Check data type if it is int
|
||||
if pre_node_1.attribute[0].t.data_type == 7:
|
||||
new_data = new_data.astype('int64')
|
||||
new_data = new_data.astype("int64")
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0]+'_data',
|
||||
name=node.output[0] + "_data",
|
||||
data_type=pre_node_1.attribute[0].t.data_type,
|
||||
dims=new_shape,
|
||||
vals=new_data.flatten().tolist()
|
||||
vals=new_data.flatten().tolist(),
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
|
||||
)
|
||||
|
||||
node_to_del.extend([node, pre_node_1, pre_node_2])
|
||||
@ -762,8 +748,7 @@ def div_constant_folding(g, node):
|
||||
|
||||
|
||||
def sub_constant_folding(g, node):
|
||||
""" Fold constant and sub nodes to a single node.
|
||||
"""
|
||||
"""Fold constant and sub nodes to a single node."""
|
||||
node_to_del = []
|
||||
pre_node_1 = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_node_2 = helper.find_node_by_output_name(g, node.input[1])
|
||||
@ -781,17 +766,13 @@ def sub_constant_folding(g, node):
|
||||
new_shape = new_data.shape
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0]+'_data',
|
||||
name=node.output[0] + "_data",
|
||||
data_type=pre_node_1.attribute[0].t.data_type,
|
||||
dims=new_shape,
|
||||
vals=helper.flatten_to_list(new_data)
|
||||
vals=helper.flatten_to_list(new_data),
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
|
||||
)
|
||||
|
||||
g.node.extend([new_node])
|
||||
@ -815,17 +796,13 @@ def neg_constant_folding(g, node):
|
||||
new_data_list = [-num for num in data_list]
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=pre_node.name+'_neg_tensor',
|
||||
name=pre_node.name + "_neg_tensor",
|
||||
data_type=pre_node.attribute[0].t.data_type,
|
||||
dims=shape,
|
||||
vals=new_data_list
|
||||
vals=new_data_list,
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
|
||||
)
|
||||
|
||||
g.node.extend([new_node])
|
||||
@ -851,17 +828,13 @@ def floor_constant_folding(g, node):
|
||||
new_shape = shape
|
||||
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=node.output[0]+'_data',
|
||||
name=node.output[0] + "_data",
|
||||
data_type=pre_node.attribute[0].t.data_type,
|
||||
dims=new_shape,
|
||||
vals=helper.flatten_to_list(new_data)
|
||||
vals=helper.flatten_to_list(new_data),
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[node.output[0]],
|
||||
name=node.output[0],
|
||||
value=new_tensor
|
||||
"Constant", [], [node.output[0]], name=node.output[0], value=new_tensor
|
||||
)
|
||||
|
||||
g.node.extend([new_node])
|
||||
@ -877,8 +850,7 @@ def floor_constant_folding(g, node):
|
||||
|
||||
|
||||
def bn_constant_folding(g, node):
|
||||
""" Fold constant and mul nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and mul nodes to a single constant node."""
|
||||
# Prepare data
|
||||
node_to_del = []
|
||||
input_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
@ -900,17 +872,22 @@ def bn_constant_folding(g, node):
|
||||
mean_data = helper.constant_to_numpy(mean_node)
|
||||
var_data = helper.constant_to_numpy(var_node)
|
||||
|
||||
epsilon = helper.get_var_attribute_by_name(node, 'epsilon', 'float')
|
||||
epsilon = helper.get_var_attribute_by_name(node, "epsilon", "float")
|
||||
if epsilon is None:
|
||||
epsilon = 0.00001
|
||||
|
||||
# Calculate new node
|
||||
new_data = scale_data * (input_data - mean_data) / np.sqrt(var_data + epsilon) + bias_data
|
||||
new_data = (
|
||||
scale_data * (input_data - mean_data) / np.sqrt(var_data + epsilon)
|
||||
+ bias_data
|
||||
)
|
||||
|
||||
new_node = helper.numpy_to_constant(node.output[0], new_data)
|
||||
|
||||
# Reconnect the graph
|
||||
node_to_del.extend([node, input_node, scale_node, bias_node, mean_node, var_node])
|
||||
node_to_del.extend(
|
||||
[node, input_node, scale_node, bias_node, mean_node, var_node]
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
|
||||
for value in input_value_info:
|
||||
@ -925,8 +902,7 @@ def bn_constant_folding(g, node):
|
||||
|
||||
|
||||
def DequantizeLinear_constant_folding(g, node):
|
||||
""" Fold constant and mul nodes to a single constant node.
|
||||
"""
|
||||
"""Fold constant and mul nodes to a single constant node."""
|
||||
# Prepare data
|
||||
node_to_del = []
|
||||
x_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
@ -951,7 +927,9 @@ def DequantizeLinear_constant_folding(g, node):
|
||||
x_zero_point_data = np.array([0.0])
|
||||
|
||||
# Calculate new node
|
||||
new_data = (x_data.astype(np.float32) - x_zero_point_data.astype(np.float32)) * x_scale_data
|
||||
new_data = (
|
||||
x_data.astype(np.float32) - x_zero_point_data.astype(np.float32)
|
||||
) * x_scale_data
|
||||
|
||||
new_node = helper.numpy_to_constant(node.output[0], new_data)
|
||||
|
||||
@ -974,22 +952,22 @@ def DequantizeLinear_constant_folding(g, node):
|
||||
|
||||
# Available constant folding names to function map.
|
||||
constant_folding_nodes = {
|
||||
'Add': add_constant_folding,
|
||||
'BatchNormalization': bn_constant_folding,
|
||||
'Cast': cast_constant_folding,
|
||||
'Concat': concat_constant_folding,
|
||||
'DequantizeLinear': DequantizeLinear_constant_folding,
|
||||
'Div': div_constant_folding,
|
||||
'Floor': floor_constant_folding,
|
||||
'Gather': gather_constant_folding,
|
||||
'Mul': mul_constant_folding,
|
||||
'Reciprocal': reciprocal_constant_folding,
|
||||
'ReduceProd': reduceprod_constant_folding,
|
||||
'Reshape': reshape_constant_input_folding,
|
||||
'Slice': slice_constant_folding,
|
||||
'Sqrt': sqrt_constant_folding,
|
||||
'Transpose': transpose_constant_folding,
|
||||
'Unsqueeze': unsqueeze_constant_folding,
|
||||
'Sub': sub_constant_folding,
|
||||
'Neg': neg_constant_folding
|
||||
"Add": add_constant_folding,
|
||||
"BatchNormalization": bn_constant_folding,
|
||||
"Cast": cast_constant_folding,
|
||||
"Concat": concat_constant_folding,
|
||||
"DequantizeLinear": DequantizeLinear_constant_folding,
|
||||
"Div": div_constant_folding,
|
||||
"Floor": floor_constant_folding,
|
||||
"Gather": gather_constant_folding,
|
||||
"Mul": mul_constant_folding,
|
||||
"Reciprocal": reciprocal_constant_folding,
|
||||
"ReduceProd": reduceprod_constant_folding,
|
||||
"Reshape": reshape_constant_input_folding,
|
||||
"Slice": slice_constant_folding,
|
||||
"Sqrt": sqrt_constant_folding,
|
||||
"Transpose": transpose_constant_folding,
|
||||
"Unsqueeze": unsqueeze_constant_folding,
|
||||
"Sub": sub_constant_folding,
|
||||
"Neg": neg_constant_folding,
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ from . import helper
|
||||
from . import modhelper
|
||||
from .general_graph import Graph
|
||||
|
||||
|
||||
def eliminate_Identify_and_Dropout(g):
|
||||
"""
|
||||
Eliminate Identify layers
|
||||
@ -15,31 +16,46 @@ def eliminate_Identify_and_Dropout(g):
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Identity' and node.op_type != 'Dropout':
|
||||
if node.op_type != "Identity" and node.op_type != "Dropout":
|
||||
continue
|
||||
# If this node is the last node, leave it to `eliminate_useless_last node`
|
||||
# If this node is the last, leave it to `eliminate_useless_last node`
|
||||
if helper.find_output_by_name(g, node.output[0]) is not None:
|
||||
continue
|
||||
# Replace the parents in all the following nodes
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(
|
||||
g, node.output[0]
|
||||
)
|
||||
for following_node in following_nodes:
|
||||
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
||||
modhelper.replace_node_input(
|
||||
following_node, node.output[0], node.input[0]
|
||||
)
|
||||
# Delete value info
|
||||
value_between = helper.find_value_by_name(g, node.output[0])
|
||||
try:
|
||||
g.value_info.remove(value_between)
|
||||
except:
|
||||
except Exception:
|
||||
print("No value info to delete while eliminating identity layers.")
|
||||
# Node is waiting for elimination
|
||||
node_to_remove.append(node)
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
|
||||
# Remove last useless nodes
|
||||
def remove_useless_last_nodes(g):
|
||||
"""Remove useless nodes from the tail of the graph
|
||||
"""
|
||||
USELESS = ["Reshape", "Identity", "Transpose", "Flatten", "Dropout", "Mystery", "Constant", "Squeeze", "Unsqueeze", 'Softmax']
|
||||
"""Remove useless nodes from the tail of the graph"""
|
||||
USELESS = [
|
||||
"Reshape",
|
||||
"Identity",
|
||||
"Transpose",
|
||||
"Flatten",
|
||||
"Dropout",
|
||||
"Mystery",
|
||||
"Constant",
|
||||
"Squeeze",
|
||||
"Unsqueeze",
|
||||
"Softmax",
|
||||
]
|
||||
graph = Graph(g)
|
||||
todo = collections.deque()
|
||||
for node in graph.output_nodes:
|
||||
@ -54,19 +70,30 @@ def remove_useless_last_nodes(g):
|
||||
if cur_node.proto.op_type not in USELESS:
|
||||
continue
|
||||
# Find the output
|
||||
cur_node_output = helper.find_output_by_name(g, cur_node.proto.output[0])
|
||||
cur_node_output = helper.find_output_by_name(
|
||||
g, cur_node.proto.output[0]
|
||||
)
|
||||
for cur_input in cur_node.parents:
|
||||
cur_input.children.remove(cur_node)
|
||||
if len(cur_input.children) == 0:
|
||||
todo.append(cur_input)
|
||||
if cur_node_output is not None:
|
||||
cur_input_output = helper.find_value_by_name(g, cur_input.proto.output[0])
|
||||
cur_input_output_in_output = helper.find_output_by_name(g, cur_input.proto.output[0])
|
||||
if cur_input_output is not None and cur_input_output_in_output is None:
|
||||
cur_input_output = helper.find_value_by_name(
|
||||
g, cur_input.proto.output[0]
|
||||
)
|
||||
cur_input_output_in_output = helper.find_output_by_name(
|
||||
g, cur_input.proto.output[0]
|
||||
)
|
||||
if (
|
||||
cur_input_output is not None
|
||||
and cur_input_output_in_output is None
|
||||
):
|
||||
g.output.extend([cur_input_output])
|
||||
node_to_remove.append(cur_node.proto)
|
||||
try:
|
||||
g.value_info.remove(helper.find_value_by_name(g, cur_node.proto.output[0]))
|
||||
g.value_info.remove(
|
||||
helper.find_value_by_name(g, cur_node.proto.output[0])
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
if cur_node_output is not None:
|
||||
@ -76,10 +103,12 @@ def remove_useless_last_nodes(g):
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
|
||||
######################################
|
||||
# TF only optimization passes #
|
||||
######################################
|
||||
|
||||
|
||||
def eliminate_shape_changing_after_input(g):
|
||||
"""
|
||||
Eliminate the Reshape node after input and reshape the input
|
||||
@ -87,7 +116,14 @@ def eliminate_shape_changing_after_input(g):
|
||||
:param g: the onnx graph
|
||||
"""
|
||||
node_to_remove = []
|
||||
REMOVE_LIST = ["Reshape", "Transpose", "Flatten", "Dropout", "Squeeze", "Unsqueeze"]
|
||||
REMOVE_LIST = [
|
||||
"Reshape",
|
||||
"Transpose",
|
||||
"Flatten",
|
||||
"Dropout",
|
||||
"Squeeze",
|
||||
"Unsqueeze",
|
||||
]
|
||||
for node in g.node:
|
||||
# Find an input and the shape node
|
||||
if node.op_type not in REMOVE_LIST:
|
||||
@ -105,9 +141,9 @@ def eliminate_shape_changing_after_input(g):
|
||||
# Remove Weight if any.
|
||||
output_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
|
||||
if node.op_type == 'Reshape':
|
||||
if node.op_type == "Reshape":
|
||||
shape_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
if shape_node.op_type != 'Constant':
|
||||
if shape_node.op_type != "Constant":
|
||||
continue
|
||||
|
||||
# manuelly set the input shape
|
||||
@ -117,25 +153,29 @@ def eliminate_shape_changing_after_input(g):
|
||||
_, new_shape = helper.constant_to_list(shape_node)
|
||||
for i in range(len(new_shape)):
|
||||
if new_shape[i] == -1:
|
||||
dim = int(old_size//np.prod(new_shape)*(-1))
|
||||
dim = int(old_size // np.prod(new_shape) * (-1))
|
||||
new_shape[i] = dim
|
||||
new_input = onnx.helper.make_tensor_value_info(
|
||||
output_val_info.name,
|
||||
output_val_info.type.tensor_type.elem_type,
|
||||
new_shape
|
||||
new_shape,
|
||||
)
|
||||
|
||||
node_to_remove.append(node)
|
||||
|
||||
shape_outputs = helper.find_nodes_by_input_name(g, shape_node.output[0])
|
||||
shape_outputs = helper.find_nodes_by_input_name(
|
||||
g, shape_node.output[0]
|
||||
)
|
||||
if len(shape_outputs) == 1:
|
||||
node_to_remove.append(shape_node)
|
||||
g.value_info.remove(helper.find_value_by_name(g, shape_node.output[0]))
|
||||
|
||||
g.value_info.remove(
|
||||
helper.find_value_by_name(g, shape_node.output[0])
|
||||
)
|
||||
|
||||
g.input.remove(old_input)
|
||||
g.input.extend([new_input])
|
||||
g.value_info.remove(output_val_info)
|
||||
elif node.op_type == 'Transpose':
|
||||
elif node.op_type == "Transpose":
|
||||
permutation = list(node.attribute[0].ints)
|
||||
pre_shape = helper.get_shape_from_value_info(old_input)
|
||||
new_shape = [pre_shape[i] for i in permutation]
|
||||
@ -143,7 +183,7 @@ def eliminate_shape_changing_after_input(g):
|
||||
new_input = onnx.helper.make_tensor_value_info(
|
||||
output_val_info.name,
|
||||
output_val_info.type.tensor_type.elem_type,
|
||||
new_shape
|
||||
new_shape,
|
||||
)
|
||||
|
||||
node_to_remove.append(node)
|
||||
@ -151,7 +191,7 @@ def eliminate_shape_changing_after_input(g):
|
||||
g.input.remove(old_input)
|
||||
g.input.extend([new_input])
|
||||
g.value_info.remove(output_val_info)
|
||||
elif node.op_type == 'Flatten':
|
||||
elif node.op_type == "Flatten":
|
||||
axis = node.attribute[0].int
|
||||
pre_shape = helper.get_shape_from_value_info(old_input)
|
||||
dim_1, dim_2 = 1, 1
|
||||
@ -166,7 +206,7 @@ def eliminate_shape_changing_after_input(g):
|
||||
new_input = onnx.helper.make_tensor_value_info(
|
||||
output_val_info.name,
|
||||
output_val_info.type.tensor_type.elem_type,
|
||||
new_shape
|
||||
new_shape,
|
||||
)
|
||||
|
||||
node_to_remove.append(node)
|
||||
@ -174,18 +214,18 @@ def eliminate_shape_changing_after_input(g):
|
||||
g.input.remove(old_input)
|
||||
g.input.extend([new_input])
|
||||
g.value_info.remove(output_val_info)
|
||||
elif node.op_type == 'Dropout':
|
||||
elif node.op_type == "Dropout":
|
||||
g.input.remove(old_input)
|
||||
g.input.extend([output_val_info])
|
||||
g.value_info.remove(output_val_info)
|
||||
|
||||
|
||||
node_to_remove.append(node)
|
||||
elif node.op_type == 'Squeeze':
|
||||
elif node.op_type == "Squeeze":
|
||||
axis = list(node.attribute[0].ints)
|
||||
pre_shape = helper.get_shape_from_value_info(old_input)
|
||||
for pos in sorted(axis)[::-1]:
|
||||
if pre_shape[pos] != 1:
|
||||
raise RuntimeError('invalid axis for squeeze')
|
||||
raise RuntimeError("invalid axis for squeeze")
|
||||
else:
|
||||
pre_shape.pop(pos)
|
||||
new_shape = pre_shape
|
||||
@ -193,7 +233,7 @@ def eliminate_shape_changing_after_input(g):
|
||||
new_input = onnx.helper.make_tensor_value_info(
|
||||
output_val_info.name,
|
||||
output_val_info.type.tensor_type.elem_type,
|
||||
new_shape
|
||||
new_shape,
|
||||
)
|
||||
|
||||
node_to_remove.append(node)
|
||||
@ -201,7 +241,7 @@ def eliminate_shape_changing_after_input(g):
|
||||
g.input.remove(old_input)
|
||||
g.input.extend([new_input])
|
||||
g.value_info.remove(output_val_info)
|
||||
elif node.op_type == 'Unsqueeze':
|
||||
elif node.op_type == "Unsqueeze":
|
||||
axis = list(node.attribute[0].ints)
|
||||
pre_shape = helper.get_shape_from_value_info(old_input)
|
||||
new_shape = pre_shape
|
||||
@ -210,7 +250,7 @@ def eliminate_shape_changing_after_input(g):
|
||||
new_input = onnx.helper.make_tensor_value_info(
|
||||
output_val_info.name,
|
||||
output_val_info.type.tensor_type.elem_type,
|
||||
new_shape
|
||||
new_shape,
|
||||
)
|
||||
node_to_remove.append(node)
|
||||
|
||||
@ -222,7 +262,7 @@ def eliminate_shape_changing_after_input(g):
|
||||
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
|
||||
other.topological_sort(g)
|
||||
|
||||
|
||||
@ -231,15 +271,13 @@ def eliminate_Reshape_Cast(g):
|
||||
|
||||
:param g: the onnx graph
|
||||
"""
|
||||
#Find all reshape layers
|
||||
node_to_remove = []
|
||||
# Find all reshape layers
|
||||
for node in g.node:
|
||||
if node.op_type != 'Reshape':
|
||||
if node.op_type != "Reshape":
|
||||
continue
|
||||
prev_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
if prev_node.op_type != 'Cast':
|
||||
if prev_node.op_type != "Cast":
|
||||
continue
|
||||
# Now we find the cast weight pattern. Cast the weight, delete the cast.
|
||||
reshape_node = node
|
||||
cast_node = prev_node
|
||||
weight_node = helper.find_node_by_output_name(g, cast_node.input[0])
|
||||
@ -248,10 +286,12 @@ def eliminate_Reshape_Cast(g):
|
||||
weight_node.attribute[0].t.data_type = 7
|
||||
if weight_node.attribute[0].t.raw_data:
|
||||
raw_data = weight_node.attribute[0].t.raw_data
|
||||
int_data = [i[0] for i in struct.iter_unpack('i', raw_data)]
|
||||
raw_data = struct.pack('q' * len(int_data), *int_data)
|
||||
elif len(weight_node.attribute[0].t.int64_data) > 0\
|
||||
or len(weight_node.attribute[0].t.int32_data) > 0:
|
||||
int_data = [i[0] for i in struct.iter_unpack("i", raw_data)]
|
||||
raw_data = struct.pack("q" * len(int_data), *int_data)
|
||||
elif (
|
||||
len(weight_node.attribute[0].t.int64_data) > 0
|
||||
or len(weight_node.attribute[0].t.int32_data) > 0
|
||||
):
|
||||
# It's already int. Do nothing
|
||||
pass
|
||||
else:
|
||||
@ -264,6 +304,7 @@ def eliminate_Reshape_Cast(g):
|
||||
g.value_info.remove(origin_weight_out)
|
||||
g.node.remove(cast_node)
|
||||
|
||||
|
||||
def eliminate_Cast_after_input(g):
|
||||
"""Eliminate the cast layer right after the input
|
||||
|
||||
@ -271,7 +312,7 @@ def eliminate_Cast_after_input(g):
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Cast':
|
||||
if node.op_type != "Cast":
|
||||
continue
|
||||
old_input = helper.find_input_by_name(g, node.input[0])
|
||||
if old_input is None:
|
||||
@ -279,9 +320,7 @@ def eliminate_Cast_after_input(g):
|
||||
next_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
shape = helper.get_shape_from_value_info(next_val_info)
|
||||
new_val_info = onnx.helper.make_tensor_value_info(
|
||||
next_val_info.name,
|
||||
node.attribute[0].i,
|
||||
shape
|
||||
next_val_info.name, node.attribute[0].i, shape
|
||||
)
|
||||
# Delete old value_info
|
||||
g.input.remove(old_input)
|
||||
@ -293,6 +332,7 @@ def eliminate_Cast_after_input(g):
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
|
||||
def eliminate_consecutive_Cast(g):
|
||||
"""If two cast is next to each other, remove the first cast
|
||||
|
||||
@ -300,10 +340,10 @@ def eliminate_consecutive_Cast(g):
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Cast':
|
||||
if node.op_type != "Cast":
|
||||
continue
|
||||
first_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
if first_node is None or first_node.op_type != 'Cast':
|
||||
if first_node is None or first_node.op_type != "Cast":
|
||||
continue
|
||||
# Here we have two consecutive Cast Node
|
||||
# Reset the input of the later node
|
||||
@ -315,6 +355,7 @@ def eliminate_consecutive_Cast(g):
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
|
||||
def eliminate_Squeeze_before_Reshape(g):
|
||||
"""If Squeeze and Reshape is next to each other, remove the first node
|
||||
|
||||
@ -322,12 +363,12 @@ def eliminate_Squeeze_before_Reshape(g):
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Reshape':
|
||||
if node.op_type != "Reshape":
|
||||
continue
|
||||
first_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
if not first_node:
|
||||
continue
|
||||
if first_node.op_type != 'Squeeze':
|
||||
if first_node.op_type != "Squeeze":
|
||||
continue
|
||||
# Here we have two consecutive Cast Node
|
||||
# Reset the input of the later node
|
||||
@ -339,9 +380,9 @@ def eliminate_Squeeze_before_Reshape(g):
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
|
||||
def eliminate_no_children_input(g):
|
||||
"""Eliminate inputs with no children at all.
|
||||
"""
|
||||
"""Eliminate inputs with no children at all."""
|
||||
# Create a set of input names
|
||||
input_names = set([i.name for i in g.input])
|
||||
# If a name is used in any node, remove this name from the set.
|
||||
@ -353,31 +394,33 @@ def eliminate_no_children_input(g):
|
||||
info = helper.find_input_by_name(g, i)
|
||||
g.input.remove(info)
|
||||
|
||||
|
||||
def eliminate_consecutive_reshape(g):
|
||||
"""Replace consecutive reshape nodes by a single node.
|
||||
"""
|
||||
"""Replace consecutive reshape nodes by a single node."""
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Reshape':
|
||||
if node.op_type != "Reshape":
|
||||
continue
|
||||
pre_data_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
pre_shape_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
if not pre_data_node or not pre_shape_node:
|
||||
continue
|
||||
if pre_shape_node.op_type != 'Constant':
|
||||
if pre_shape_node.op_type != "Constant":
|
||||
continue
|
||||
if pre_data_node.op_type != 'Reshape':
|
||||
if pre_data_node.op_type != "Reshape":
|
||||
continue
|
||||
|
||||
pre_pre_shape_node = helper.find_node_by_output_name(g, pre_data_node.input[1])
|
||||
if pre_pre_shape_node.op_type != 'Constant':
|
||||
|
||||
pre_pre_shape_node = helper.find_node_by_output_name(
|
||||
g, pre_data_node.input[1]
|
||||
)
|
||||
if pre_pre_shape_node.op_type != "Constant":
|
||||
continue
|
||||
|
||||
new_reshape_node = onnx.helper.make_node(
|
||||
'Reshape',
|
||||
"Reshape",
|
||||
[pre_data_node.input[0], node.input[1]],
|
||||
[node.output[0]],
|
||||
name = node.output[0]
|
||||
name=node.output[0],
|
||||
)
|
||||
|
||||
g.node.extend([new_reshape_node])
|
||||
@ -394,6 +437,7 @@ def eliminate_consecutive_reshape(g):
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
|
||||
def eliminate_single_input_Concat(g):
|
||||
"""
|
||||
Eliminate single input Concat layers
|
||||
@ -402,12 +446,12 @@ def eliminate_single_input_Concat(g):
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Concat':
|
||||
if node.op_type != "Concat":
|
||||
continue
|
||||
# If this node has more than 1 input, continue.
|
||||
if len(node.input) > 1:
|
||||
continue
|
||||
# If this node is the output node, set its previous node as output nodes.
|
||||
# If this node is output node, set its previous node as output nodes.
|
||||
if helper.find_output_by_name(g, node.output[0]) is not None:
|
||||
todel_output = helper.find_output_by_name(g, node.output[0])
|
||||
the_input_value = helper.find_value_by_name(g, node.input[0])
|
||||
@ -416,20 +460,25 @@ def eliminate_single_input_Concat(g):
|
||||
node_to_remove.append(node)
|
||||
continue
|
||||
# Replace the parents in all the following nodes
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(
|
||||
g, node.output[0]
|
||||
)
|
||||
for following_node in following_nodes:
|
||||
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
||||
modhelper.replace_node_input(
|
||||
following_node, node.output[0], node.input[0]
|
||||
)
|
||||
# Delete value info
|
||||
value_between = helper.find_value_by_name(g, node.output[0])
|
||||
try:
|
||||
g.value_info.remove(value_between)
|
||||
except:
|
||||
except Exception:
|
||||
print("No value info to delete while eliminating identity layers.")
|
||||
# Node is waiting for elimination
|
||||
node_to_remove.append(node)
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
|
||||
def eliminate_nop_Maxpool_and_AveragePool(g):
|
||||
"""
|
||||
Eliminate do nothing MaxPool and AveragePool layers.
|
||||
@ -439,7 +488,7 @@ def eliminate_nop_Maxpool_and_AveragePool(g):
|
||||
"""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'MaxPool' and node.op_type != 'AveragePool':
|
||||
if node.op_type != "MaxPool" and node.op_type != "AveragePool":
|
||||
continue
|
||||
# If this node is actually working, continue.
|
||||
kernel = helper.get_list_attribute_by_name(node, "kernel_shape", "int")
|
||||
@ -447,7 +496,7 @@ def eliminate_nop_Maxpool_and_AveragePool(g):
|
||||
strides = helper.get_list_attribute_by_name(node, "strides", "int")
|
||||
if kernel != [1, 1] or pads != [0, 0, 0, 0] or strides != [1, 1]:
|
||||
continue
|
||||
# If this node is the output node, set its previous node as output nodes.
|
||||
# If this node is the output, set its previous node as output nodes.
|
||||
if helper.find_output_by_name(g, node.output[0]) is not None:
|
||||
todel_output = helper.find_output_by_name(g, node.output[0])
|
||||
the_input_value = helper.find_value_by_name(g, node.input[0])
|
||||
@ -456,14 +505,18 @@ def eliminate_nop_Maxpool_and_AveragePool(g):
|
||||
node_to_remove.append(node)
|
||||
continue
|
||||
# Replace the parents in all the following nodes
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(
|
||||
g, node.output[0]
|
||||
)
|
||||
for following_node in following_nodes:
|
||||
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
||||
modhelper.replace_node_input(
|
||||
following_node, node.output[0], node.input[0]
|
||||
)
|
||||
# Delete value info
|
||||
value_between = helper.find_value_by_name(g, node.output[0])
|
||||
try:
|
||||
g.value_info.remove(value_between)
|
||||
except:
|
||||
except Exception:
|
||||
print("No value info to delete while eliminating identity layers.")
|
||||
# Node is waiting for elimination
|
||||
node_to_remove.append(node)
|
||||
@ -474,20 +527,20 @@ def eliminate_nop_Maxpool_and_AveragePool(g):
|
||||
def eliminate_trivial_maxpool(g):
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'MaxPool':
|
||||
if node.op_type != "MaxPool":
|
||||
continue
|
||||
pads = None
|
||||
strides = None
|
||||
dilation = None
|
||||
kernel_shape = None
|
||||
for att in node.attribute:
|
||||
if att.name == 'pads':
|
||||
if att.name == "pads":
|
||||
pads = list(att.ints)
|
||||
elif att.name == 'strides':
|
||||
elif att.name == "strides":
|
||||
strides = list(att.ints)
|
||||
elif att.name == 'kernel_shape':
|
||||
elif att.name == "kernel_shape":
|
||||
kernel_shape = list(att.ints)
|
||||
elif att.name == 'dilation':
|
||||
elif att.name == "dilation":
|
||||
dilation = list(att.ints)
|
||||
else:
|
||||
pass
|
||||
@ -504,7 +557,7 @@ def eliminate_trivial_maxpool(g):
|
||||
|
||||
next_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
|
||||
if next_nodes[0] == None:
|
||||
if next_nodes[0] is None:
|
||||
output_value = helper.find_output_by_name(g, node.output[0])
|
||||
if not output_value:
|
||||
continue
|
||||
@ -512,18 +565,21 @@ def eliminate_trivial_maxpool(g):
|
||||
pre_val_info = helper.find_value_by_name(g, node.input[0])
|
||||
g.output.extend([pre_val_info])
|
||||
g.output.remove(output_value)
|
||||
|
||||
|
||||
for next_node in next_nodes:
|
||||
modhelper.replace_node_input(next_node, node.output[0], node.input[0])
|
||||
|
||||
modhelper.replace_node_input(
|
||||
next_node, node.output[0], node.input[0]
|
||||
)
|
||||
|
||||
next_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
g.value_info.remove(next_val_info)
|
||||
|
||||
while node_to_del:
|
||||
g.node.remove(node_to_del.pop())
|
||||
|
||||
|
||||
other.topological_sort(g)
|
||||
|
||||
|
||||
def eliminate_empty_value_infos(g):
|
||||
to_remove = []
|
||||
for value_info in g.value_info:
|
||||
@ -532,10 +588,11 @@ def eliminate_empty_value_infos(g):
|
||||
for value_info in to_remove:
|
||||
g.value_info.remove(value_info)
|
||||
|
||||
|
||||
def eliminate_nop_pads(g):
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Pad':
|
||||
if node.op_type != "Pad":
|
||||
continue
|
||||
# Check if the Pad is empty or not
|
||||
pads_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
@ -546,11 +603,7 @@ def eliminate_nop_pads(g):
|
||||
all_zero = False
|
||||
if not all_zero:
|
||||
continue
|
||||
# Check if it has the constant_value_node
|
||||
constant_value_node = None
|
||||
if len(node.input) > 2:
|
||||
constant_value_node = helper.find_node_by_output_name(g, node.input[2])
|
||||
# If this node is the output node, set its previous node as output nodes.
|
||||
# If this node is the output, set its previous node as output nodes.
|
||||
if helper.find_output_by_name(g, node.output[0]) is not None:
|
||||
todel_output = helper.find_output_by_name(g, node.output[0])
|
||||
g.output.remove(todel_output)
|
||||
@ -559,38 +612,44 @@ def eliminate_nop_pads(g):
|
||||
if the_input_value is not None:
|
||||
g.output.extend([the_input_value])
|
||||
# Replace the parents in all the following nodes
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(
|
||||
g, node.output[0]
|
||||
)
|
||||
for following_node in following_nodes:
|
||||
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
||||
modhelper.replace_node_input(
|
||||
following_node, node.output[0], node.input[0]
|
||||
)
|
||||
# Delete value info
|
||||
value_between = helper.find_value_by_name(g, node.output[0])
|
||||
try:
|
||||
g.value_info.remove(value_between)
|
||||
except:
|
||||
helper.logger.info("No value info to delete while eliminating identity layers.")
|
||||
except Exception:
|
||||
helper.logger.info(
|
||||
"No value info to delete while eliminating identity layers."
|
||||
)
|
||||
# Node is waiting for elimination
|
||||
node_to_remove.append(node)
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
|
||||
def eliminate_trivial_elementwise_calculation(g):
|
||||
"""Eliminate Add, Sub, Mul, Sub nodes which do nothing.
|
||||
"""
|
||||
"""Eliminate Add, Sub, Mul, Sub nodes which do nothing."""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
weight_node = None
|
||||
if node.op_type == 'Add' or node.op_type == 'Sub':
|
||||
if node.op_type == "Add" or node.op_type == "Sub":
|
||||
# For add and sub, check if the weights are 0s.
|
||||
weight_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
if weight_node is None or weight_node.op_type != 'Constant':
|
||||
if weight_node is None or weight_node.op_type != "Constant":
|
||||
continue
|
||||
weight_np = helper.constant_to_numpy(weight_node)
|
||||
if np.any(weight_np):
|
||||
continue
|
||||
elif node.op_type == 'Mul' or node.op_type == 'Div':
|
||||
elif node.op_type == "Mul" or node.op_type == "Div":
|
||||
# For Mul and Div, check if the weights are 1s.
|
||||
weight_node = helper.find_node_by_output_name(g, node.input[1])
|
||||
if weight_node is None or weight_node.op_type != 'Constant':
|
||||
if weight_node is None or weight_node.op_type != "Constant":
|
||||
continue
|
||||
weight_np = helper.constant_to_numpy(weight_node)
|
||||
weight_np = weight_np - 1
|
||||
@ -605,9 +664,13 @@ def eliminate_trivial_elementwise_calculation(g):
|
||||
if output_value_info is not None:
|
||||
g.value_info.remove(output_value_info)
|
||||
# Replace next node input if any.
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(
|
||||
g, node.output[0]
|
||||
)
|
||||
for following_node in following_nodes:
|
||||
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
||||
modhelper.replace_node_input(
|
||||
following_node, node.output[0], node.input[0]
|
||||
)
|
||||
todel_output = helper.find_output_by_name(g, node.output[0])
|
||||
if todel_output is not None:
|
||||
g.output.remove(todel_output)
|
||||
@ -616,38 +679,53 @@ def eliminate_trivial_elementwise_calculation(g):
|
||||
the_input_value = helper.find_value_by_name(g, node.input[0])
|
||||
g.output.extend([the_input_value])
|
||||
# Delete the constant node if it is not used by other nodes
|
||||
constant_following_nodes = helper.find_following_nodes_by_input_value_name(g, weight_node.output[0])
|
||||
constant_following_nodes = (
|
||||
helper.find_following_nodes_by_input_value_name(
|
||||
g, weight_node.output[0]
|
||||
)
|
||||
)
|
||||
if len(constant_following_nodes) == 1:
|
||||
node_to_remove.append(weight_node)
|
||||
output_value_info = helper.find_value_by_name(g, weight_node.output[0])
|
||||
output_value_info = helper.find_value_by_name(
|
||||
g, weight_node.output[0]
|
||||
)
|
||||
if output_value_info is not None:
|
||||
g.value_info.remove(output_value_info)
|
||||
for node in node_to_remove:
|
||||
g.node.remove(node)
|
||||
|
||||
|
||||
def eliminate_nop_cast(g):
|
||||
"""Eliminate do nothing Cast nodes.
|
||||
"""
|
||||
"""Eliminate do nothing Cast nodes."""
|
||||
node_to_remove = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Cast':
|
||||
if node.op_type != "Cast":
|
||||
continue
|
||||
# Get input value_info
|
||||
input_value = helper.find_value_by_name(g, node.input[0])
|
||||
if input_value is None:
|
||||
helper.logger.debug(f"Cannot find the input value_info for Cast node {node.name}. Skip elimination check.")
|
||||
helper.logger.debug(
|
||||
f"Cannot find the input value_info for Cast node {node.name}. "
|
||||
"Skip elimination check."
|
||||
)
|
||||
continue
|
||||
# Get output value_info
|
||||
output_value = helper.find_value_by_name(g, node.output[0])
|
||||
if output_value is None:
|
||||
output_value = helper.find_output_by_name(g, node.output[0])
|
||||
if output_value is None:
|
||||
helper.logger.debug(f"Cannot find the output value_info for Cast node {node.name}. Skip elimination check.")
|
||||
helper.logger.debug(
|
||||
f"Cannot find the output value_info for Cast node {node.name}."
|
||||
" Skip elimination check."
|
||||
)
|
||||
continue
|
||||
# Compare the type.
|
||||
if input_value.type.tensor_type.elem_type != output_value.type.tensor_type.elem_type:
|
||||
if (
|
||||
input_value.type.tensor_type.elem_type
|
||||
!= output_value.type.tensor_type.elem_type
|
||||
):
|
||||
continue
|
||||
# If this node is the output node, set its previous node as output nodes.
|
||||
# If this node is the output, set its previous node as output nodes.
|
||||
if helper.find_output_by_name(g, node.output[0]) is not None:
|
||||
todel_output = helper.find_output_by_name(g, node.output[0])
|
||||
g.output.remove(todel_output)
|
||||
@ -656,9 +734,13 @@ def eliminate_nop_cast(g):
|
||||
if the_input_value is not None:
|
||||
g.output.extend([the_input_value])
|
||||
# Replace the parents in all the following nodes
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(
|
||||
g, node.output[0]
|
||||
)
|
||||
for following_node in following_nodes:
|
||||
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
||||
modhelper.replace_node_input(
|
||||
following_node, node.output[0], node.input[0]
|
||||
)
|
||||
# Delete value info
|
||||
value_between = helper.find_value_by_name(g, node.output[0])
|
||||
if value_between is not None:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,9 +1,11 @@
|
||||
from collections import deque
|
||||
|
||||
|
||||
class Node:
|
||||
"""A Node which maps a node proto. It has pointers to its parents and
|
||||
children.
|
||||
"""
|
||||
|
||||
def __init__(self, onnx_node):
|
||||
"""Initialize a node. This initialization only set up the mapping to
|
||||
node proto. The pointers should be set up by outside.
|
||||
@ -17,12 +19,12 @@ class Node:
|
||||
self.name = onnx_node.name
|
||||
self.proto = onnx_node
|
||||
|
||||
|
||||
class Graph:
|
||||
"""A graph which is constructed from the onnx proto.
|
||||
"""
|
||||
"""A graph which is constructed from the onnx proto."""
|
||||
|
||||
def __init__(self, onnx_graph):
|
||||
"""Construct the graph from onnx.
|
||||
"""
|
||||
"""Construct the graph from onnx."""
|
||||
self.input_nodes = []
|
||||
self.output_nodes = []
|
||||
self.name2node = {}
|
||||
@ -51,9 +53,9 @@ class Graph:
|
||||
for value in onnx_graph.value_info:
|
||||
node = self.output2node[value.name]
|
||||
node.output_value = value
|
||||
|
||||
def get_sorted_node_list(self):
|
||||
"""Return a node list in topological order.
|
||||
"""
|
||||
"""Return a node list in topological order."""
|
||||
visited = set()
|
||||
todo = deque()
|
||||
result = []
|
||||
|
||||
@ -6,21 +6,26 @@ import struct
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
__ONNX_VERSION__ = -1
|
||||
__ONNX_VERSION__ = -1
|
||||
|
||||
logger = logging.getLogger("optimizer_scripts")
|
||||
|
||||
|
||||
def setup_current_opset_version(m):
|
||||
global __ONNX_VERSION__
|
||||
__ONNX_VERSION__ = m.opset_import[0].version
|
||||
if __ONNX_VERSION__ not in [11]:
|
||||
raise RuntimeError('Only support opset 11, but got ' + str(__ONNX_VERSION__))
|
||||
raise RuntimeError(
|
||||
"Only support opset 11, but got " + str(__ONNX_VERSION__)
|
||||
)
|
||||
|
||||
|
||||
def get_current_opset_version():
|
||||
if __ONNX_VERSION__ == -1:
|
||||
raise RuntimeError('do setup_current_opset_version first please')
|
||||
raise RuntimeError("do setup_current_opset_version first please")
|
||||
return __ONNX_VERSION__
|
||||
|
||||
|
||||
def find_nodes_by_input_name(g, name):
|
||||
nodes = []
|
||||
for node in g.node:
|
||||
@ -28,6 +33,7 @@ def find_nodes_by_input_name(g, name):
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
|
||||
def find_node_by_output_name(g, name):
|
||||
"""
|
||||
Find a node in the graph by its output name
|
||||
@ -41,6 +47,7 @@ def find_node_by_output_name(g, name):
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def find_node_by_node_name(g, name):
|
||||
"""
|
||||
Find a node in the graph by its output name
|
||||
@ -54,6 +61,7 @@ def find_node_by_node_name(g, name):
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def find_following_nodes_by_input_value_name(g, name):
|
||||
""" Find the following nodes of a specific value.
|
||||
|
||||
@ -63,6 +71,7 @@ def find_following_nodes_by_input_value_name(g, name):
|
||||
"""
|
||||
return find_nodes_by_input_name(g, name)
|
||||
|
||||
|
||||
def find_value_by_name(g, name):
|
||||
"""
|
||||
Find a value_info in the graph by name
|
||||
@ -76,6 +85,7 @@ def find_value_by_name(g, name):
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def find_output_by_name(g, name):
|
||||
"""
|
||||
Find a value_info in the graph by name
|
||||
@ -89,6 +99,7 @@ def find_output_by_name(g, name):
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def find_input_by_name(g, name):
|
||||
"""
|
||||
Find a input in the graph by name
|
||||
@ -102,6 +113,7 @@ def find_input_by_name(g, name):
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def list_to_constant(name, shape, data, data_type=None):
|
||||
"""Generate a constant node using the given infomation.
|
||||
|
||||
@ -119,18 +131,9 @@ def list_to_constant(name, shape, data, data_type=None):
|
||||
data_type = onnx.helper.TensorProto.INT64
|
||||
else:
|
||||
data_type = onnx.helper.TensorProto.FLOAT
|
||||
tensor = onnx.helper.make_tensor(
|
||||
name,
|
||||
data_type,
|
||||
shape,
|
||||
data
|
||||
)
|
||||
tensor = onnx.helper.make_tensor(name, data_type, shape, data)
|
||||
new_w_node = onnx.helper.make_node(
|
||||
"Constant",
|
||||
[],
|
||||
[name],
|
||||
name = name,
|
||||
value = tensor
|
||||
"Constant", [], [name], name=name, value=tensor
|
||||
)
|
||||
return new_w_node
|
||||
|
||||
@ -151,18 +154,9 @@ def scaler_to_constant(name, data, data_type=None):
|
||||
else:
|
||||
logger.error("Cannot create scaler constant with a list.")
|
||||
exit(1)
|
||||
tensor = onnx.helper.make_tensor(
|
||||
name,
|
||||
data_type,
|
||||
None,
|
||||
[data]
|
||||
)
|
||||
tensor = onnx.helper.make_tensor(name, data_type, None, [data])
|
||||
new_w_node = onnx.helper.make_node(
|
||||
"Constant",
|
||||
[],
|
||||
[name],
|
||||
name = name,
|
||||
value = tensor
|
||||
"Constant", [], [name], name=name, value=tensor
|
||||
)
|
||||
return new_w_node
|
||||
|
||||
@ -170,6 +164,7 @@ def scaler_to_constant(name, data, data_type=None):
|
||||
def numpy_to_constant(name, np_array):
|
||||
return list_to_constant(name, np_array.shape, np_array.flatten().tolist())
|
||||
|
||||
|
||||
def constant_to_list(node):
|
||||
"""Generate a list from the constant node
|
||||
|
||||
@ -184,27 +179,27 @@ def constant_to_list(node):
|
||||
if len(tensor.int32_data) != 0:
|
||||
data = list(tensor.int32_data)
|
||||
else:
|
||||
data = [i[0] for i in struct.iter_unpack('i', tensor.raw_data)]
|
||||
data = [i[0] for i in struct.iter_unpack("i", tensor.raw_data)]
|
||||
elif tensor.data_type == onnx.helper.TensorProto.INT64:
|
||||
if len(tensor.int64_data) != 0:
|
||||
data = list(tensor.int64_data)
|
||||
else:
|
||||
data = [i[0] for i in struct.iter_unpack('q', tensor.raw_data)]
|
||||
data = [i[0] for i in struct.iter_unpack("q", tensor.raw_data)]
|
||||
elif tensor.data_type == onnx.helper.TensorProto.INT8:
|
||||
if len(tensor.int32_data) != 0:
|
||||
data = list(tensor.int32_data)
|
||||
else:
|
||||
data = [i[0] for i in struct.iter_unpack('b', tensor.raw_data)]
|
||||
data = [i[0] for i in struct.iter_unpack("b", tensor.raw_data)]
|
||||
elif tensor.data_type == onnx.helper.TensorProto.FLOAT:
|
||||
if len(tensor.float_data) != 0:
|
||||
data = list(tensor.float_data)
|
||||
else:
|
||||
data = [i[0] for i in struct.iter_unpack('f', tensor.raw_data)]
|
||||
data = [i[0] for i in struct.iter_unpack("f", tensor.raw_data)]
|
||||
elif tensor.data_type == onnx.helper.TensorProto.DOUBLE:
|
||||
if len(tensor.double_data) != 0:
|
||||
data = list(tensor.double_data)
|
||||
else:
|
||||
data = [i[0] for i in struct.iter_unpack('d', tensor.raw_data)]
|
||||
data = [i[0] for i in struct.iter_unpack("d", tensor.raw_data)]
|
||||
else:
|
||||
print("Not supported data type {}".format(tensor.data_type))
|
||||
raise RuntimeError
|
||||
@ -214,6 +209,7 @@ def constant_to_list(node):
|
||||
shape = list(tensor.dims)
|
||||
return shape, data
|
||||
|
||||
|
||||
def constant_to_numpy(node):
|
||||
"""Generate a numpy array from the constant node
|
||||
|
||||
@ -223,6 +219,7 @@ def constant_to_numpy(node):
|
||||
shape, data = constant_to_list(node)
|
||||
return np.array(data).reshape(shape)
|
||||
|
||||
|
||||
def all_constant_input(node):
|
||||
"""Find the inputs of the given node. If the inputs of this node are all\\
|
||||
constant nodes, return True. Otherwise, return False.
|
||||
@ -234,24 +231,26 @@ def all_constant_input(node):
|
||||
return False
|
||||
isConstant = True
|
||||
for parent in node.parents:
|
||||
if parent.proto is None or parent.proto.op_type != 'Constant':
|
||||
if parent.proto is None or parent.proto.op_type != "Constant":
|
||||
isConstant = False
|
||||
break
|
||||
return isConstant
|
||||
|
||||
|
||||
def get_padding(size, kernel_size, strides):
|
||||
""" Calculate the padding array for same padding in the Tensorflow fashion.\\
|
||||
""" Calculate the padding array for same padding in the Tensorflow fashion.\\
|
||||
See https://www.tensorflow.org/api_guides/python/nn#Convolution for more.
|
||||
"""
|
||||
if size[0] % strides[0] == 0:
|
||||
pad_h = max(kernel_size[0] - strides[0], 0)
|
||||
else:
|
||||
pad_h = max(kernel_size[0] - (size[0] % strides[0]), 0)
|
||||
if size[1] % strides[1] == 0:
|
||||
pad_w = max(kernel_size[1] - strides[1], 0)
|
||||
else:
|
||||
pad_w = max(kernel_size[1] - (size[1] % strides[1]), 0)
|
||||
return [pad_h//2, pad_w//2, pad_h-pad_h//2, pad_w-pad_w//2]
|
||||
if size[0] % strides[0] == 0:
|
||||
pad_h = max(kernel_size[0] - strides[0], 0)
|
||||
else:
|
||||
pad_h = max(kernel_size[0] - (size[0] % strides[0]), 0)
|
||||
if size[1] % strides[1] == 0:
|
||||
pad_w = max(kernel_size[1] - strides[1], 0)
|
||||
else:
|
||||
pad_w = max(kernel_size[1] - (size[1] % strides[1]), 0)
|
||||
return [pad_h // 2, pad_w // 2, pad_h - pad_h // 2, pad_w - pad_w // 2]
|
||||
|
||||
|
||||
def get_shape_from_value_info(value):
|
||||
"""Get shape from a value info.
|
||||
@ -261,12 +260,13 @@ def get_shape_from_value_info(value):
|
||||
"""
|
||||
return [d.dim_value for d in value.type.tensor_type.shape.dim]
|
||||
|
||||
|
||||
def find_size_shape_from_value(value):
|
||||
'''
|
||||
"""
|
||||
Find the size of data within the value_info object.
|
||||
:param value: value_info
|
||||
:return: int size and list shape of the data in the value_info
|
||||
'''
|
||||
"""
|
||||
if not value:
|
||||
return None, None
|
||||
if not value.type.tensor_type.shape.dim:
|
||||
@ -292,6 +292,7 @@ def get_attribute_by_name(node, attr_name):
|
||||
return attr
|
||||
return None
|
||||
|
||||
|
||||
def get_list_attribute_by_name(node, attr_name: str, attr_type: str):
|
||||
"""Get list attribute with specific name in the given node proto.
|
||||
|
||||
@ -317,12 +318,13 @@ def get_list_attribute_by_name(node, attr_name: str, attr_type: str):
|
||||
print("Warning: undefined type for list attribute extraction")
|
||||
return None
|
||||
|
||||
|
||||
def get_var_attribute_by_name(node, attr_name: str, attr_type: str):
|
||||
"""Get variable attribute with specific name in the given node proto.
|
||||
|
||||
:param node: the node proto.\\
|
||||
:param attr_name: str for the name of the target.\\
|
||||
:param attr_type: str which should be "float", "int", "string" or "tensor".\\
|
||||
:param node: the node proto.
|
||||
:param attr_name: str for the name of the target.
|
||||
:param attr_type: str which should be "float", "int", "string" or "tensor".
|
||||
:return: if found, return the variable. Else, return None.
|
||||
"""
|
||||
attr_proto = get_attribute_by_name(node, attr_name)
|
||||
@ -333,7 +335,7 @@ def get_var_attribute_by_name(node, attr_name: str, attr_type: str):
|
||||
elif attr_type == "float":
|
||||
return attr_proto.f
|
||||
elif attr_type == "string":
|
||||
if type(attr_proto.s) == type(b'abc'):
|
||||
if isinstance(attr_proto.s, bytes):
|
||||
return attr_proto.s.decode("utf-8")
|
||||
else:
|
||||
return attr_proto.s
|
||||
@ -343,22 +345,25 @@ def get_var_attribute_by_name(node, attr_name: str, attr_type: str):
|
||||
print("Warning: undefined type for variable attribute extraction")
|
||||
return None
|
||||
|
||||
|
||||
def flatten_with_depth(data, depth):
|
||||
output = []
|
||||
if type(data) not in [type(np.array([1])), type([1])]:
|
||||
return [[data, 0]]
|
||||
for item in data:
|
||||
if type(item) not in [type(np.array([1])), type([1])]:
|
||||
output.append([item, depth+1])
|
||||
output.append([item, depth + 1])
|
||||
else:
|
||||
output += flatten_with_depth(item, depth+1)
|
||||
output += flatten_with_depth(item, depth + 1)
|
||||
return output
|
||||
|
||||
|
||||
def flatten_to_list(data):
|
||||
flatten_depth = flatten_with_depth(data, 0)
|
||||
flat_data = [item[0] for item in flatten_depth]
|
||||
return flat_data
|
||||
|
||||
|
||||
def get_shape(data):
|
||||
shape = []
|
||||
if type(data) not in [type(np.array([1])), type([1])]:
|
||||
@ -378,7 +383,7 @@ def slice_data(data, starts, ends, axes):
|
||||
starts_updated = []
|
||||
ends_updated = []
|
||||
for i in range(len(starts)):
|
||||
start_updated = min(starts[i], shape[i]-1) % shape[i]
|
||||
start_updated = min(starts[i], shape[i] - 1) % shape[i]
|
||||
starts_updated.append(start_updated)
|
||||
for j in range(len(starts)):
|
||||
if ends[j] >= shape[j]:
|
||||
@ -393,19 +398,21 @@ def slice_data(data, starts, ends, axes):
|
||||
index_slices.append(list(range(shape[i])))
|
||||
else:
|
||||
axe_ind = axes.index(i)
|
||||
index_slices.append(list(range(starts_updated[axe_ind], ends_updated[axe_ind])))
|
||||
index_slices.append(
|
||||
list(range(starts_updated[axe_ind], ends_updated[axe_ind]))
|
||||
)
|
||||
|
||||
indices = [1]
|
||||
for i in range(len(shape)-1, -1, -1):
|
||||
step = np.prod(shape[i+1:])
|
||||
for i in range(len(shape) - 1, -1, -1):
|
||||
step = np.prod(shape[i + 1:])
|
||||
temp_pos = indices
|
||||
new_indices = []
|
||||
for n in index_slices[i]:
|
||||
for pos in temp_pos:
|
||||
new_indices.append(int(n*step+pos))
|
||||
new_indices.append(int(n * step + pos))
|
||||
indices = new_indices
|
||||
|
||||
sliced_data = [flat_data[k-1] for k in indices]
|
||||
sliced_data = [flat_data[k - 1] for k in indices]
|
||||
|
||||
# reshape to correct shape.
|
||||
new_shape = []
|
||||
@ -414,48 +421,51 @@ def slice_data(data, starts, ends, axes):
|
||||
new_shape.append(shape[i])
|
||||
else:
|
||||
axe_ind = axes.index(i)
|
||||
new_shape.append(ends_updated[axe_ind]-starts_updated[axe_ind])
|
||||
new_shape.append(ends_updated[axe_ind] - starts_updated[axe_ind])
|
||||
if any([dim < 1 for dim in new_shape]):
|
||||
raise RuntimeError('Invalid starts ends.')
|
||||
|
||||
raise RuntimeError("Invalid starts ends.")
|
||||
|
||||
sliced_data = np.reshape(sliced_data, new_shape)
|
||||
|
||||
return sliced_data
|
||||
|
||||
|
||||
def concatenate(data_sets, axis):
|
||||
# check shapes
|
||||
shapes = []
|
||||
shapes_ = []
|
||||
for data_set in data_sets:
|
||||
shape = get_shape(data_set)
|
||||
shapes.append(list(shape))
|
||||
shape.pop(axis)
|
||||
shapes_.append(shape)
|
||||
shape = get_shape(data_set)
|
||||
shapes.append(list(shape))
|
||||
shape.pop(axis)
|
||||
shapes_.append(shape)
|
||||
if not all([s == shapes_[0] for s in shapes_]):
|
||||
raise RuntimeError('data sets shapes do not match')
|
||||
|
||||
raise RuntimeError("data sets shapes do not match")
|
||||
|
||||
new_dim = sum([s[axis] for s in shapes])
|
||||
new_shape = list(shapes[0])
|
||||
new_shape[axis] = new_dim
|
||||
|
||||
flat_data_sets = []
|
||||
for data_set in data_sets:
|
||||
flat_data_sets.append(flatten_to_list(data_set))
|
||||
|
||||
flat_data_sets.append(flatten_to_list(data_set))
|
||||
|
||||
sub_block_size = 1
|
||||
for i in range(axis+1, len(shapes[0])):
|
||||
sub_block_size *= shapes[0][i]
|
||||
|
||||
for i in range(axis + 1, len(shapes[0])):
|
||||
sub_block_size *= shapes[0][i]
|
||||
|
||||
split_num = 1
|
||||
for i in range(axis):
|
||||
split_num *= shapes[0][i]
|
||||
split_num *= shapes[0][i]
|
||||
|
||||
total_flat_data = []
|
||||
for i in range(split_num):
|
||||
for j in range(len(shapes)):
|
||||
block_size = sub_block_size*shapes[j][axis]
|
||||
total_flat_data.extend(flat_data_sets[j][i*block_size:(i+1)*block_size])
|
||||
|
||||
for j in range(len(shapes)):
|
||||
block_size = sub_block_size * shapes[j][axis]
|
||||
total_flat_data.extend(
|
||||
flat_data_sets[j][i * block_size:(i + 1) * block_size]
|
||||
)
|
||||
|
||||
new_data = np.reshape(total_flat_data, new_shape)
|
||||
|
||||
return new_data
|
||||
@ -464,158 +474,169 @@ def concatenate(data_sets, axis):
|
||||
def broadcast_data_sets(data_set_1, data_set_2):
|
||||
shape1 = get_shape(data_set_1)
|
||||
shape2 = get_shape(data_set_2)
|
||||
|
||||
|
||||
# compare shapes and get broadcasted shape
|
||||
list_a, list_b = (shape1, shape2) if len(shape1) > len(shape2) else (shape2, shape1)
|
||||
list_a, list_b = (
|
||||
(shape1, shape2) if len(shape1) > len(shape2) else (shape2, shape1)
|
||||
)
|
||||
while len(list_a) > len(list_b):
|
||||
list_b.insert(0, 0)
|
||||
broadcasted_shape = []
|
||||
for i in range(len(list_a)):
|
||||
if list_b[i] == 0:
|
||||
broadcasted_shape.append(list_a[i])
|
||||
elif list_b[i] == 1:
|
||||
broadcasted_shape.append(list_a[i])
|
||||
elif list_a[i] == 1:
|
||||
broadcasted_shape.append(list_b[i])
|
||||
elif list_a[i] == list_b[i]:
|
||||
broadcasted_shape.append(list_a[i])
|
||||
else:
|
||||
raise RuntimeError('Can not broadcast two data sets')
|
||||
if list_b[i] == 0:
|
||||
broadcasted_shape.append(list_a[i])
|
||||
elif list_b[i] == 1:
|
||||
broadcasted_shape.append(list_a[i])
|
||||
elif list_a[i] == 1:
|
||||
broadcasted_shape.append(list_b[i])
|
||||
elif list_a[i] == list_b[i]:
|
||||
broadcasted_shape.append(list_a[i])
|
||||
else:
|
||||
raise RuntimeError("Can not broadcast two data sets")
|
||||
|
||||
# prepare data for broadcasting.
|
||||
shape1 = list(map(lambda x:x if x != 0 else 1, shape1))
|
||||
shape2 = list(map(lambda x:x if x != 0 else 1, shape2))
|
||||
shape1 = list(map(lambda x: x if x != 0 else 1, shape1))
|
||||
shape2 = list(map(lambda x: x if x != 0 else 1, shape2))
|
||||
data_1 = np.reshape(data_set_1, shape1)
|
||||
data_2 = np.reshape(data_set_2, shape2)
|
||||
|
||||
for i in range(len(shape1)):
|
||||
if shape1[i] != broadcasted_shape[i]:
|
||||
new_data_total = [list(data_1) for _ in range(broadcasted_shape[i])]
|
||||
data_1 = concatenate(new_data_total, axis=i)
|
||||
if shape1[i] != broadcasted_shape[i]:
|
||||
new_data_total = [
|
||||
list(data_1) for _ in range(broadcasted_shape[i])
|
||||
]
|
||||
data_1 = concatenate(new_data_total, axis=i)
|
||||
for i in range(len(shape2)):
|
||||
if shape2[i] != broadcasted_shape[i]:
|
||||
new_data_total = [list(data_2) for _ in range(broadcasted_shape[i])]
|
||||
data_2 = concatenate(new_data_total, axis=i)
|
||||
if shape2[i] != broadcasted_shape[i]:
|
||||
new_data_total = [
|
||||
list(data_2) for _ in range(broadcasted_shape[i])
|
||||
]
|
||||
data_2 = concatenate(new_data_total, axis=i)
|
||||
|
||||
return data_1, data_2
|
||||
|
||||
|
||||
def add(data_set_1, data_set_2):
|
||||
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2)
|
||||
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(
|
||||
data_set_1, data_set_2
|
||||
)
|
||||
|
||||
flat_data_1 = flatten_to_list(broadcasted_data_1)
|
||||
flat_data_2 = flatten_to_list(broadcasted_data_2)
|
||||
shape = get_shape(broadcasted_data_1)
|
||||
res = []
|
||||
for i in range(len(flat_data_1)):
|
||||
res.append(flat_data_1[i]+flat_data_2[i])
|
||||
|
||||
res = np.reshape(res, shape)
|
||||
flat_data_1 = flatten_to_list(broadcasted_data_1)
|
||||
flat_data_2 = flatten_to_list(broadcasted_data_2)
|
||||
shape = get_shape(broadcasted_data_1)
|
||||
res = []
|
||||
for i in range(len(flat_data_1)):
|
||||
res.append(flat_data_1[i] + flat_data_2[i])
|
||||
|
||||
return res
|
||||
res = np.reshape(res, shape)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def reduceprod(data_set, axis, keepdims=1):
|
||||
flat_data = flatten_to_list(data_set)
|
||||
old_shape = get_shape(data_set)
|
||||
flat_data = flatten_to_list(data_set)
|
||||
old_shape = get_shape(data_set)
|
||||
|
||||
temp_shape = old_shape
|
||||
temp_flat_data = flat_data
|
||||
for ax in axis:
|
||||
split_num = 1
|
||||
step = 1
|
||||
for i in range(ax):
|
||||
split_num *= temp_shape[i]
|
||||
for i in range(ax+1, len(temp_shape)):
|
||||
step *= temp_shape[i]
|
||||
|
||||
block_size = len(temp_flat_data)//split_num
|
||||
new_flat_data = []
|
||||
for j in range(split_num):
|
||||
block_data = temp_flat_data[j*block_size:(j+1)*block_size]
|
||||
reduced_block_data = []
|
||||
for k in range(step):
|
||||
val = block_data[k]
|
||||
for l in range(1, block_size//step):
|
||||
val *= block_data[k+l*step]
|
||||
reduced_block_data.append(val)
|
||||
new_flat_data.extend(reduced_block_data)
|
||||
temp_flat_data = new_flat_data
|
||||
temp_shape[ax] = 1
|
||||
|
||||
new_flat_data = temp_flat_data
|
||||
new_shape = temp_shape
|
||||
if not keepdims:
|
||||
axis = sorted(list(axis))
|
||||
for pos in axis[::-1]:
|
||||
new_shape.pop(pos)
|
||||
|
||||
return np.reshape(new_flat_data, new_shape)
|
||||
temp_shape = old_shape
|
||||
temp_flat_data = flat_data
|
||||
for ax in axis:
|
||||
split_num = 1
|
||||
step = 1
|
||||
for i in range(ax):
|
||||
split_num *= temp_shape[i]
|
||||
for i in range(ax + 1, len(temp_shape)):
|
||||
step *= temp_shape[i]
|
||||
|
||||
block_size = len(temp_flat_data) // split_num
|
||||
new_flat_data = []
|
||||
for j in range(split_num):
|
||||
block_data = temp_flat_data[j * block_size:(j + 1) * block_size]
|
||||
reduced_block_data = []
|
||||
for k in range(step):
|
||||
val = block_data[k]
|
||||
for li in range(1, block_size // step):
|
||||
val *= block_data[k + li * step]
|
||||
reduced_block_data.append(val)
|
||||
new_flat_data.extend(reduced_block_data)
|
||||
temp_flat_data = new_flat_data
|
||||
temp_shape[ax] = 1
|
||||
|
||||
new_flat_data = temp_flat_data
|
||||
new_shape = temp_shape
|
||||
if not keepdims:
|
||||
axis = sorted(list(axis))
|
||||
for pos in axis[::-1]:
|
||||
new_shape.pop(pos)
|
||||
|
||||
return np.reshape(new_flat_data, new_shape)
|
||||
|
||||
|
||||
def transpose(data_set, permutation):
|
||||
# find series of local swaps
|
||||
data_set = list(data_set)
|
||||
perm = list(permutation)
|
||||
shape = get_shape(data_set)
|
||||
flat_data = flatten_to_list(data_set)
|
||||
assert set(perm) == set(range(len(shape))), 'invalid permutation'
|
||||
# find series of local swaps
|
||||
data_set = list(data_set)
|
||||
perm = list(permutation)
|
||||
shape = get_shape(data_set)
|
||||
flat_data = flatten_to_list(data_set)
|
||||
assert set(perm) == set(range(len(shape))), "invalid permutation"
|
||||
|
||||
new_shape = [shape[i] for i in perm]
|
||||
swaps = []
|
||||
bubbled = True
|
||||
while bubbled:
|
||||
bubbled = False
|
||||
for i in range(len(new_shape)-1):
|
||||
if perm[i] > perm[i+1]:
|
||||
swaps.append([i, i+1])
|
||||
p_1, p_2 = perm[i], perm[i+1]
|
||||
perm[i], perm[i+1] = p_2, p_1
|
||||
bubbled = True
|
||||
|
||||
# apply local swaps
|
||||
current_shape = list(shape)
|
||||
temp_flat_data = flat_data
|
||||
new_shape = [shape[i] for i in perm]
|
||||
swaps = []
|
||||
bubbled = True
|
||||
while bubbled:
|
||||
bubbled = False
|
||||
for i in range(len(new_shape) - 1):
|
||||
if perm[i] > perm[i + 1]:
|
||||
swaps.append([i, i + 1])
|
||||
p_1, p_2 = perm[i], perm[i + 1]
|
||||
perm[i], perm[i + 1] = p_2, p_1
|
||||
bubbled = True
|
||||
|
||||
for swap in swaps[::-1]:
|
||||
ind_1, ind_2 = swap[0], swap[1]
|
||||
dim_1 = current_shape[ind_1]
|
||||
dim_2 = current_shape[ind_2]
|
||||
split_num = 1
|
||||
block_size = 1
|
||||
# apply local swaps
|
||||
current_shape = list(shape)
|
||||
temp_flat_data = flat_data
|
||||
|
||||
for i in range(ind_1):
|
||||
split_num *= current_shape[i]
|
||||
for i in range(ind_2+1, len(current_shape)):
|
||||
block_size *= current_shape[i]
|
||||
for swap in swaps[::-1]:
|
||||
ind_1, ind_2 = swap[0], swap[1]
|
||||
dim_1 = current_shape[ind_1]
|
||||
dim_2 = current_shape[ind_2]
|
||||
split_num = 1
|
||||
block_size = 1
|
||||
|
||||
data_blocks = np.reshape(temp_flat_data, [-1, block_size])
|
||||
flat_data_1 = []
|
||||
for k in range(split_num):
|
||||
block = []
|
||||
for m in range(dim_2):
|
||||
for n in range(dim_1):
|
||||
block_pos = k*dim_1*dim_2 + n*dim_2+m
|
||||
block.extend(data_blocks[block_pos])
|
||||
flat_data_1.extend(block)
|
||||
for i in range(ind_1):
|
||||
split_num *= current_shape[i]
|
||||
for i in range(ind_2 + 1, len(current_shape)):
|
||||
block_size *= current_shape[i]
|
||||
|
||||
temp_flat_data = flat_data_1
|
||||
current_shape[ind_1] = dim_2
|
||||
current_shape[ind_2] = dim_1
|
||||
data_blocks = np.reshape(temp_flat_data, [-1, block_size])
|
||||
flat_data_1 = []
|
||||
for k in range(split_num):
|
||||
block = []
|
||||
for m in range(dim_2):
|
||||
for n in range(dim_1):
|
||||
block_pos = k * dim_1 * dim_2 + n * dim_2 + m
|
||||
block.extend(data_blocks[block_pos])
|
||||
flat_data_1.extend(block)
|
||||
|
||||
temp_flat_data = flat_data_1
|
||||
current_shape[ind_1] = dim_2
|
||||
current_shape[ind_2] = dim_1
|
||||
|
||||
return np.reshape(temp_flat_data, current_shape)
|
||||
|
||||
return np.reshape(temp_flat_data, current_shape)
|
||||
|
||||
def subtract(data_set_1, data_set_2):
|
||||
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2)
|
||||
broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(
|
||||
data_set_1, data_set_2
|
||||
)
|
||||
|
||||
shape = get_shape(broadcasted_data_1)
|
||||
flat_data_1 = flatten_to_list(broadcasted_data_1)
|
||||
flat_data_2 = flatten_to_list(broadcasted_data_2)
|
||||
|
||||
substracted_data = [flat_data_1[i] - flat_data_2[i] for i in range(len(flat_data_1))]
|
||||
substracted_data = [
|
||||
flat_data_1[i] - flat_data_2[i] for i in range(len(flat_data_1))
|
||||
]
|
||||
|
||||
new_data = np.reshape(substracted_data, shape)
|
||||
|
||||
return new_data
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""This module contains helper functions that do graph modifications.
|
||||
"""
|
||||
This module contains helper functions that do graph modifications.
|
||||
"""
|
||||
|
||||
import onnx
|
||||
from . import helper
|
||||
|
||||
|
||||
@ -10,9 +10,10 @@ def replace_node_input(node, old_input, new_input):
|
||||
if input_name == old_input:
|
||||
node.input[i] = new_input
|
||||
|
||||
|
||||
def delete_nodes(g, node_list):
|
||||
node_to_delete = []
|
||||
#Find target nodes
|
||||
# Find target nodes
|
||||
for node in g.node:
|
||||
if node.name not in node_list:
|
||||
continue
|
||||
@ -23,16 +24,28 @@ def delete_nodes(g, node_list):
|
||||
for node in node_to_delete:
|
||||
# Check the node whether if it is valid to delete
|
||||
if len(node.input) == 0:
|
||||
print("Deleting an Constant node. Please make sure you also delete all its following nodes")
|
||||
print(
|
||||
"Deleting an Constant node. "
|
||||
"Please make sure you also delete all its following nodes"
|
||||
)
|
||||
elif len(node.input) > 1:
|
||||
print("Warning: Node {} has more than one input. This script cannot delete merge nodes.".format(node.name))
|
||||
print(
|
||||
f"Warning: Node {node.name} has more than one input. "
|
||||
"This script cannot delete merge nodes."
|
||||
)
|
||||
# Connect the nodes around the target node.
|
||||
# Set the following node input as the previous node output.
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
||||
following_nodes = helper.find_following_nodes_by_input_value_name(
|
||||
g, node.output[0]
|
||||
)
|
||||
if len(node.input) == 0:
|
||||
for following_node in following_nodes:
|
||||
following_node.input.remove(node.output[0])
|
||||
elif len(following_nodes) > 0 and len(node.input) == 1 and helper.find_input_by_name(g, node.input[0]) is not None:
|
||||
elif (
|
||||
len(following_nodes) > 0
|
||||
and len(node.input) == 1
|
||||
and helper.find_input_by_name(g, node.input[0]) is not None
|
||||
):
|
||||
# The node input is an input
|
||||
new_input = helper.find_value_by_name(g, node.output[0])
|
||||
g.input.append(new_input)
|
||||
@ -40,9 +53,11 @@ def delete_nodes(g, node_list):
|
||||
g.value_info.remove(new_input)
|
||||
elif len(following_nodes) > 0:
|
||||
for following_node in following_nodes:
|
||||
replace_node_input(following_node, node.output[0], node.input[0])
|
||||
replace_node_input(
|
||||
following_node, node.output[0], node.input[0]
|
||||
)
|
||||
else:
|
||||
# If the node is the output, replace the output with the previous input.
|
||||
# If the node is the output, replace it with previous input.
|
||||
value = helper.find_value_by_name(g, node.input[0])
|
||||
output_values = []
|
||||
while len(g.output):
|
||||
@ -56,6 +71,7 @@ def delete_nodes(g, node_list):
|
||||
# Remove the node and value info.
|
||||
g.node.remove(node)
|
||||
|
||||
|
||||
def delete_input(g, target_list):
|
||||
for name in target_list:
|
||||
input_value = helper.find_input_by_name(g, name)
|
||||
@ -64,6 +80,7 @@ def delete_input(g, target_list):
|
||||
continue
|
||||
g.input.remove(input_value)
|
||||
|
||||
|
||||
def delete_output(g, target_list):
|
||||
for name in target_list:
|
||||
output_value = helper.find_output_by_name(g, name)
|
||||
@ -72,6 +89,7 @@ def delete_output(g, target_list):
|
||||
continue
|
||||
g.output.remove(output_value)
|
||||
|
||||
|
||||
def delete_value_with_name_if_exists(g, name):
|
||||
value = helper.find_value_by_name(g, name)
|
||||
if value is not None:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,317 +1,368 @@
|
||||
from . import helper
|
||||
from . import other
|
||||
from . import modhelper
|
||||
from . import fusing
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnx.utils
|
||||
|
||||
def eliminate_transposes(m):
|
||||
g = m.graph
|
||||
keep_eliminating = True
|
||||
while keep_eliminating:
|
||||
while swap_transpose_with_single_next_node(g):
|
||||
pass
|
||||
splitted = split_transpose_for_multiple_next_nodes(g)
|
||||
annihilated = annihilate_transposes(g)
|
||||
multiple_trans_swapped = swap_multiple_transposes_with_node(g)
|
||||
keep_eliminating = splitted or annihilated or multiple_trans_swapped
|
||||
|
||||
if keep_eliminating:
|
||||
m = other.polish_model(m)
|
||||
g = m.graph
|
||||
|
||||
return m
|
||||
def eliminate_transposes(m):
|
||||
g = m.graph
|
||||
keep_eliminating = True
|
||||
while keep_eliminating:
|
||||
while swap_transpose_with_single_next_node(g):
|
||||
pass
|
||||
splitted = split_transpose_for_multiple_next_nodes(g)
|
||||
annihilated = annihilate_transposes(g)
|
||||
multiple_trans_swapped = swap_multiple_transposes_with_node(g)
|
||||
keep_eliminating = splitted or annihilated or multiple_trans_swapped
|
||||
|
||||
if keep_eliminating:
|
||||
m = other.polish_model(m)
|
||||
g = m.graph
|
||||
|
||||
return m
|
||||
|
||||
|
||||
def swap_transpose_with_single_next_node(g):
|
||||
swapped = False
|
||||
passable_nodes = set(['Relu', 'Neg', 'LeakyRelu', 'Sqrt', 'Reciprocal', 'Add', 'Mul', 'Tanh'])
|
||||
for node in g.node:
|
||||
trans_node = node
|
||||
# Check for transpose node
|
||||
if trans_node.op_type != 'Transpose':
|
||||
continue
|
||||
next_nodes = helper.find_nodes_by_input_name(g, trans_node.output[0])
|
||||
if len(next_nodes) != 1:
|
||||
continue
|
||||
next_node = next_nodes[0]
|
||||
# Check if the next node is the type can be swapped
|
||||
if next_node.op_type not in passable_nodes:
|
||||
continue
|
||||
swapped = False
|
||||
passable_nodes = set(
|
||||
[
|
||||
"Relu",
|
||||
"Neg",
|
||||
"LeakyRelu",
|
||||
"Sqrt",
|
||||
"Reciprocal",
|
||||
"Add",
|
||||
"Mul",
|
||||
"Tanh",
|
||||
]
|
||||
)
|
||||
for node in g.node:
|
||||
trans_node = node
|
||||
# Check for transpose node
|
||||
if trans_node.op_type != "Transpose":
|
||||
continue
|
||||
next_nodes = helper.find_nodes_by_input_name(g, trans_node.output[0])
|
||||
if len(next_nodes) != 1:
|
||||
continue
|
||||
next_node = next_nodes[0]
|
||||
# Check if the next node is the type can be swapped
|
||||
if next_node.op_type not in passable_nodes:
|
||||
continue
|
||||
|
||||
input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in next_node.input]
|
||||
input_nodes = [
|
||||
helper.find_node_by_output_name(g, input_name)
|
||||
for input_name in next_node.input
|
||||
]
|
||||
|
||||
# Check if the node has nonconstant input other than the Transpose node itself
|
||||
nonconstant_input = False
|
||||
for input_node in input_nodes:
|
||||
if input_node == None:
|
||||
nonconstant_input = True
|
||||
break
|
||||
if input_node.name == trans_node.name:
|
||||
continue
|
||||
elif input_node.op_type == 'Constant':
|
||||
continue
|
||||
else:
|
||||
nonconstant_input = True
|
||||
break
|
||||
if nonconstant_input:
|
||||
continue
|
||||
# Check if the node has nonconstant input
|
||||
# other than the Transpose node itself
|
||||
nonconstant_input = False
|
||||
for input_node in input_nodes:
|
||||
if input_node is None:
|
||||
nonconstant_input = True
|
||||
break
|
||||
if input_node.name == trans_node.name:
|
||||
continue
|
||||
elif input_node.op_type == "Constant":
|
||||
continue
|
||||
else:
|
||||
nonconstant_input = True
|
||||
break
|
||||
if nonconstant_input:
|
||||
continue
|
||||
|
||||
for input_node in input_nodes:
|
||||
if input_node.name == trans_node.name:
|
||||
# if the input is just the transpose node
|
||||
next_value_info = helper.find_value_by_name(g, next_node.output[0])
|
||||
mid_value_info = helper.find_value_by_name(g, trans_node.output[0])
|
||||
for input_node in input_nodes:
|
||||
if input_node.name == trans_node.name:
|
||||
# if the input is just the transpose node
|
||||
next_value_info = helper.find_value_by_name(
|
||||
g, next_node.output[0]
|
||||
)
|
||||
mid_value_info = helper.find_value_by_name(
|
||||
g, trans_node.output[0]
|
||||
)
|
||||
|
||||
output_nodes = helper.find_nodes_by_input_name(g, next_node.output[0])
|
||||
for out_node in output_nodes:
|
||||
modhelper.replace_node_input(out_node, next_node.output[0], trans_node.name)
|
||||
output_nodes = helper.find_nodes_by_input_name(
|
||||
g, next_node.output[0]
|
||||
)
|
||||
for out_node in output_nodes:
|
||||
modhelper.replace_node_input(
|
||||
out_node, next_node.output[0], trans_node.name
|
||||
)
|
||||
|
||||
next_node.input[0] = trans_node.input[0]
|
||||
next_node.output[0] = next_node.name
|
||||
trans_node.input[0] = next_node.name
|
||||
trans_node.output[0] = trans_node.name
|
||||
next_node.input[0] = trans_node.input[0]
|
||||
next_node.output[0] = next_node.name
|
||||
trans_node.input[0] = next_node.name
|
||||
trans_node.output[0] = trans_node.name
|
||||
|
||||
if next_value_info:
|
||||
next_value_info.name = trans_node.name
|
||||
if mid_value_info:
|
||||
g.value_info.remove(mid_value_info)
|
||||
else:
|
||||
# if the input is a constant node
|
||||
old_tensor = input_node.attribute[0].t
|
||||
old_shape, data = helper.constant_to_list(input_node)
|
||||
# If the constant node is a scaler, no action is needed
|
||||
if type(old_shape) == int:
|
||||
old_shape = [old_shape]
|
||||
permutation = list(trans_node.attribute[0].ints)
|
||||
while len(old_shape) < len(permutation):
|
||||
old_shape.insert(0, 1)
|
||||
np_data = np.reshape(data, old_shape)
|
||||
reverse_perm = []
|
||||
for i in range(len(permutation)):
|
||||
reverse_perm.append(permutation.index(i))
|
||||
np_data = np.transpose(np_data, reverse_perm)
|
||||
new_shape = np_data.shape
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=old_tensor.name,
|
||||
data_type=old_tensor.data_type,
|
||||
dims=new_shape,
|
||||
vals=np_data.flatten().tolist()
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
'Constant',
|
||||
[],
|
||||
[input_node.output[0]],
|
||||
name=input_node.name,
|
||||
value=new_tensor
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
if next_value_info:
|
||||
next_value_info.name = trans_node.name
|
||||
if mid_value_info:
|
||||
g.value_info.remove(mid_value_info)
|
||||
else:
|
||||
# if the input is a constant node
|
||||
old_tensor = input_node.attribute[0].t
|
||||
old_shape, data = helper.constant_to_list(input_node)
|
||||
# If the constant node is a scaler, no action is needed
|
||||
if type(old_shape) == int:
|
||||
old_shape = [old_shape]
|
||||
permutation = list(trans_node.attribute[0].ints)
|
||||
while len(old_shape) < len(permutation):
|
||||
old_shape.insert(0, 1)
|
||||
np_data = np.reshape(data, old_shape)
|
||||
reverse_perm = []
|
||||
for i in range(len(permutation)):
|
||||
reverse_perm.append(permutation.index(i))
|
||||
np_data = np.transpose(np_data, reverse_perm)
|
||||
new_shape = np_data.shape
|
||||
new_tensor = onnx.helper.make_tensor(
|
||||
name=old_tensor.name,
|
||||
data_type=old_tensor.data_type,
|
||||
dims=new_shape,
|
||||
vals=np_data.flatten().tolist(),
|
||||
)
|
||||
new_node = onnx.helper.make_node(
|
||||
"Constant",
|
||||
[],
|
||||
[input_node.output[0]],
|
||||
name=input_node.name,
|
||||
value=new_tensor,
|
||||
)
|
||||
g.node.extend([new_node])
|
||||
|
||||
g.value_info.remove(helper.find_value_by_name(g, input_node.output[0]))
|
||||
g.node.remove(input_node)
|
||||
g.value_info.remove(
|
||||
helper.find_value_by_name(g, input_node.output[0])
|
||||
)
|
||||
g.node.remove(input_node)
|
||||
|
||||
swapped = True
|
||||
swapped = True
|
||||
|
||||
other.topological_sort(g)
|
||||
return swapped
|
||||
other.topological_sort(g)
|
||||
return swapped
|
||||
|
||||
|
||||
def swap_multiple_transposes_with_node(g):
|
||||
# here only consider same input transposes
|
||||
swapped = False
|
||||
passable_nodes = set(['Add', 'Mul'])
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
if node.op_type not in passable_nodes:
|
||||
continue
|
||||
input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in node.input]
|
||||
if any([input_node == None for input_node in input_nodes]):
|
||||
continue
|
||||
if any([input_node.op_type != 'Transpose' for input_node in input_nodes]):
|
||||
continue
|
||||
# here only consider same input transposes
|
||||
swapped = False
|
||||
passable_nodes = set(["Add", "Mul"])
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
if node.op_type not in passable_nodes:
|
||||
continue
|
||||
input_nodes = [
|
||||
helper.find_node_by_output_name(g, input_name)
|
||||
for input_name in node.input
|
||||
]
|
||||
if any([input_node is None for input_node in input_nodes]):
|
||||
continue
|
||||
if any(
|
||||
[input_node.op_type != "Transpose" for input_node in input_nodes]
|
||||
):
|
||||
continue
|
||||
|
||||
permutation = list(input_nodes[0].attribute[0].ints)
|
||||
if any([list(input_node.attribute[0].ints) != permutation for input_node in input_nodes]):
|
||||
continue
|
||||
|
||||
for input_name in node.input:
|
||||
input_node = helper.find_node_by_output_name(g, input_name)
|
||||
modhelper.replace_node_input(node, input_name, input_node.input[0])
|
||||
permutation = list(input_nodes[0].attribute[0].ints)
|
||||
if any(
|
||||
[
|
||||
list(input_node.attribute[0].ints) != permutation
|
||||
for input_node in input_nodes
|
||||
]
|
||||
):
|
||||
continue
|
||||
|
||||
node_to_del.extend(input_nodes)
|
||||
for input_node in input_nodes:
|
||||
input_val_info = helper.find_value_by_name(g, input_node.output[0])
|
||||
if input_val_info is not None:
|
||||
g.value_info.remove(input_val_info)
|
||||
output_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
if output_val_info is not None:
|
||||
g.value_info.remove(output_val_info)
|
||||
for input_name in node.input:
|
||||
input_node = helper.find_node_by_output_name(g, input_name)
|
||||
modhelper.replace_node_input(node, input_name, input_node.input[0])
|
||||
|
||||
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
for i in range(len(output_nodes)):
|
||||
new_trans_node_name = node.name+'_trans_'+str(i)
|
||||
new_trans_node = onnx.helper.make_node(
|
||||
'Transpose',
|
||||
[node.output[0]],
|
||||
[new_trans_node_name],
|
||||
name=new_trans_node_name,
|
||||
perm=permutation
|
||||
)
|
||||
modhelper.replace_node_input(output_nodes[i], node.output[0], new_trans_node_name)
|
||||
|
||||
g.node.extend([new_trans_node])
|
||||
|
||||
swapped = True
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
other.topological_sort(g)
|
||||
return swapped
|
||||
node_to_del.extend(input_nodes)
|
||||
for input_node in input_nodes:
|
||||
input_val_info = helper.find_value_by_name(g, input_node.output[0])
|
||||
if input_val_info is not None:
|
||||
g.value_info.remove(input_val_info)
|
||||
output_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
if output_val_info is not None:
|
||||
g.value_info.remove(output_val_info)
|
||||
|
||||
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
for i in range(len(output_nodes)):
|
||||
new_trans_node_name = node.name + "_trans_" + str(i)
|
||||
new_trans_node = onnx.helper.make_node(
|
||||
"Transpose",
|
||||
[node.output[0]],
|
||||
[new_trans_node_name],
|
||||
name=new_trans_node_name,
|
||||
perm=permutation,
|
||||
)
|
||||
modhelper.replace_node_input(
|
||||
output_nodes[i], node.output[0], new_trans_node_name
|
||||
)
|
||||
|
||||
g.node.extend([new_trans_node])
|
||||
|
||||
swapped = True
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
other.topological_sort(g)
|
||||
return swapped
|
||||
|
||||
|
||||
def annihilate_transposes(g):
|
||||
node_to_del = []
|
||||
annihilated = False
|
||||
for node in g.node:
|
||||
if node.op_type != 'Transpose':
|
||||
continue
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
if not pre_node or pre_node.op_type != 'Transpose':
|
||||
continue
|
||||
nodes_from_top_transpose = helper.find_nodes_by_input_name(g, pre_node.output[0])
|
||||
if len(nodes_from_top_transpose) > 1:
|
||||
continue
|
||||
|
||||
perm_1 = list(pre_node.attribute[0].ints)
|
||||
perm_2 = list(node.attribute[0].ints)
|
||||
if perm_1 != perm_2:
|
||||
continue
|
||||
node_to_del = []
|
||||
annihilated = False
|
||||
for node in g.node:
|
||||
if node.op_type != "Transpose":
|
||||
continue
|
||||
pre_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
if not pre_node or pre_node.op_type != "Transpose":
|
||||
continue
|
||||
nodes_from_top_transpose = helper.find_nodes_by_input_name(
|
||||
g, pre_node.output[0]
|
||||
)
|
||||
if len(nodes_from_top_transpose) > 1:
|
||||
continue
|
||||
|
||||
out_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
for out_node in out_nodes:
|
||||
modhelper.replace_node_input(out_node, node.output[0], pre_node.input[0])
|
||||
|
||||
node_to_del.extend([node, pre_node])
|
||||
mid_value_info = helper.find_value_by_name(g, pre_node.output[0])
|
||||
out_value_info = helper.find_value_by_name(g, node.output[0])
|
||||
g.value_info.remove(mid_value_info)
|
||||
g.value_info.remove(out_value_info)
|
||||
perm_1 = list(pre_node.attribute[0].ints)
|
||||
perm_2 = list(node.attribute[0].ints)
|
||||
if perm_1 != perm_2:
|
||||
continue
|
||||
|
||||
annihilated = True
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return annihilated
|
||||
out_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
for out_node in out_nodes:
|
||||
modhelper.replace_node_input(
|
||||
out_node, node.output[0], pre_node.input[0]
|
||||
)
|
||||
|
||||
node_to_del.extend([node, pre_node])
|
||||
mid_value_info = helper.find_value_by_name(g, pre_node.output[0])
|
||||
out_value_info = helper.find_value_by_name(g, node.output[0])
|
||||
g.value_info.remove(mid_value_info)
|
||||
g.value_info.remove(out_value_info)
|
||||
|
||||
annihilated = True
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
return annihilated
|
||||
|
||||
|
||||
def split_transpose_for_multiple_next_nodes(g):
|
||||
splitted = False
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Transpose':
|
||||
continue
|
||||
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
if len(output_nodes) < 2:
|
||||
continue
|
||||
for i in range(len(output_nodes)):
|
||||
output_node = output_nodes[i]
|
||||
new_trans_node_name = node.name + '_' + str(i)
|
||||
new_trans_node = onnx.helper.make_node(
|
||||
'Transpose',
|
||||
[node.input[0]],
|
||||
[new_trans_node_name],
|
||||
name=new_trans_node_name,
|
||||
perm=list(node.attribute[0].ints)
|
||||
)
|
||||
modhelper.replace_node_input(output_node, node.output[0], new_trans_node.output[0])
|
||||
g.node.extend([new_trans_node])
|
||||
|
||||
node_to_del.append(node)
|
||||
val_info = helper.find_value_by_name(g, node.output[0])
|
||||
g.value_info.remove(val_info)
|
||||
splitted = False
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
if node.op_type != "Transpose":
|
||||
continue
|
||||
output_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
if len(output_nodes) < 2:
|
||||
continue
|
||||
for i in range(len(output_nodes)):
|
||||
output_node = output_nodes[i]
|
||||
new_trans_node_name = node.name + "_" + str(i)
|
||||
new_trans_node = onnx.helper.make_node(
|
||||
"Transpose",
|
||||
[node.input[0]],
|
||||
[new_trans_node_name],
|
||||
name=new_trans_node_name,
|
||||
perm=list(node.attribute[0].ints),
|
||||
)
|
||||
modhelper.replace_node_input(
|
||||
output_node, node.output[0], new_trans_node.output[0]
|
||||
)
|
||||
g.node.extend([new_trans_node])
|
||||
|
||||
node_to_del.append(node)
|
||||
val_info = helper.find_value_by_name(g, node.output[0])
|
||||
g.value_info.remove(val_info)
|
||||
|
||||
splitted = True
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
other.topological_sort(g)
|
||||
return splitted
|
||||
|
||||
splitted = True
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
other.topological_sort(g)
|
||||
return splitted
|
||||
|
||||
def remove_trivial_transpose(g):
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
if node.op_type != 'Transpose':
|
||||
continue
|
||||
permutation = list(node.attribute[0].ints)
|
||||
if permutation != list(range(len(permutation))):
|
||||
continue
|
||||
|
||||
next_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
if not next_nodes:
|
||||
input_val_info = helper.find_value_by_name(g, node.input[0])
|
||||
out_val_info = helper.find_output_by_name(g, node.output[0])
|
||||
if not input_val_info:
|
||||
input_val_info = helper.find_input_by_name(g, node.input[0])
|
||||
g.output.remove(out_val_info)
|
||||
g.output.extend([input_val_info])
|
||||
else:
|
||||
out_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
for next_node in next_nodes:
|
||||
modhelper.replace_node_input(next_node, node.output[0], node.input[0])
|
||||
g.value_info.remove(out_val_info)
|
||||
|
||||
node_to_del.append(node)
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
other.topological_sort(g)
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
if node.op_type != "Transpose":
|
||||
continue
|
||||
permutation = list(node.attribute[0].ints)
|
||||
if permutation != list(range(len(permutation))):
|
||||
continue
|
||||
|
||||
next_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
||||
if not next_nodes:
|
||||
input_val_info = helper.find_value_by_name(g, node.input[0])
|
||||
out_val_info = helper.find_output_by_name(g, node.output[0])
|
||||
if not input_val_info:
|
||||
input_val_info = helper.find_input_by_name(g, node.input[0])
|
||||
g.output.remove(out_val_info)
|
||||
g.output.extend([input_val_info])
|
||||
else:
|
||||
out_val_info = helper.find_value_by_name(g, node.output[0])
|
||||
for next_node in next_nodes:
|
||||
modhelper.replace_node_input(
|
||||
next_node, node.output[0], node.input[0]
|
||||
)
|
||||
g.value_info.remove(out_val_info)
|
||||
|
||||
node_to_del.append(node)
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
other.topological_sort(g)
|
||||
|
||||
|
||||
def fuse_Transpose_into_Gemm_weight(g):
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
# Check pattern
|
||||
if node.op_type != 'Gemm':
|
||||
continue
|
||||
prev_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
if prev_node is None or prev_node.op_type != 'Flatten':
|
||||
continue
|
||||
transpose_node = helper.find_node_by_output_name(g, prev_node.input[0])
|
||||
if transpose_node.op_type != 'Transpose':
|
||||
continue
|
||||
# Check attribute
|
||||
perm = helper.get_list_attribute_by_name(transpose_node, 'perm', 'int')
|
||||
if perm != [0, 2, 3, 1]:
|
||||
continue
|
||||
transB = helper.get_var_attribute_by_name(node, 'transB', 'int')
|
||||
if transB is not None and transB == 1:
|
||||
continue
|
||||
# Get the original weight
|
||||
origin_weight = helper.find_node_by_output_name(g, node.input[1])
|
||||
origin_np = helper.constant_to_numpy(origin_weight)
|
||||
# Calculate a new weight
|
||||
shape = helper.get_shape_from_value_info(helper.find_value_by_name(g, prev_node.input[0]))
|
||||
shape.append(-1)
|
||||
new_np = np.reshape(origin_np, shape)
|
||||
new_np = np.transpose(new_np, [0, 3, 1, 2, 4])
|
||||
new_np = np.reshape(new_np, [-1, new_np.shape[-1]])
|
||||
new_weight = helper.numpy_to_constant(origin_weight.output[0], new_np)
|
||||
# Replace and eliminate
|
||||
prev_node.input[0] = transpose_node.input[0]
|
||||
node_to_del.append(transpose_node)
|
||||
node_to_del.append(origin_weight)
|
||||
g.value_info.remove(helper.find_value_by_name(g, transpose_node.output[0]))
|
||||
g.node.extend([new_weight])
|
||||
node_to_del = []
|
||||
for node in g.node:
|
||||
# Check pattern
|
||||
if node.op_type != "Gemm":
|
||||
continue
|
||||
prev_node = helper.find_node_by_output_name(g, node.input[0])
|
||||
if prev_node is None or prev_node.op_type != "Flatten":
|
||||
continue
|
||||
transpose_node = helper.find_node_by_output_name(g, prev_node.input[0])
|
||||
if transpose_node.op_type != "Transpose":
|
||||
continue
|
||||
# Check attribute
|
||||
perm = helper.get_list_attribute_by_name(transpose_node, "perm", "int")
|
||||
if perm != [0, 2, 3, 1]:
|
||||
continue
|
||||
transB = helper.get_var_attribute_by_name(node, "transB", "int")
|
||||
if transB is not None and transB == 1:
|
||||
continue
|
||||
# Get the original weight
|
||||
origin_weight = helper.find_node_by_output_name(g, node.input[1])
|
||||
origin_np = helper.constant_to_numpy(origin_weight)
|
||||
# Calculate a new weight
|
||||
shape = helper.get_shape_from_value_info(
|
||||
helper.find_value_by_name(g, prev_node.input[0])
|
||||
)
|
||||
shape.append(-1)
|
||||
new_np = np.reshape(origin_np, shape)
|
||||
new_np = np.transpose(new_np, [0, 3, 1, 2, 4])
|
||||
new_np = np.reshape(new_np, [-1, new_np.shape[-1]])
|
||||
new_weight = helper.numpy_to_constant(origin_weight.output[0], new_np)
|
||||
# Replace and eliminate
|
||||
prev_node.input[0] = transpose_node.input[0]
|
||||
node_to_del.append(transpose_node)
|
||||
node_to_del.append(origin_weight)
|
||||
g.value_info.remove(
|
||||
helper.find_value_by_name(g, transpose_node.output[0])
|
||||
)
|
||||
g.node.extend([new_weight])
|
||||
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
while node_to_del:
|
||||
node = node_to_del.pop()
|
||||
g.node.remove(node)
|
||||
|
||||
other.topological_sort(g)
|
||||
other.topological_sort(g)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,11 +1,10 @@
|
||||
"""Special operations on model.
|
||||
"""
|
||||
import logging
|
||||
import onnx.helper
|
||||
import numpy as np
|
||||
from . import helper
|
||||
from . import other
|
||||
from . import modhelper
|
||||
|
||||
|
||||
def change_first_conv_from_bgr_to_rgb(m):
|
||||
"""For input channel format BGR model, use this function to change the first
|
||||
@ -16,12 +15,14 @@ def change_first_conv_from_bgr_to_rgb(m):
|
||||
# Check for first node.
|
||||
g = m.graph
|
||||
input_name = g.input[0].name
|
||||
first_nodes = helper.find_following_nodes_by_input_value_name(g, input_name)
|
||||
first_nodes = helper.find_following_nodes_by_input_value_name(
|
||||
g, input_name
|
||||
)
|
||||
if len(first_nodes) > 1:
|
||||
return False
|
||||
first_node = first_nodes[0]
|
||||
# Now we have the first node. Check this first node.
|
||||
if first_node.op_type != 'Conv':
|
||||
if first_node.op_type != "Conv":
|
||||
return False
|
||||
weight_value = helper.find_value_by_name(g, first_node.input[1])
|
||||
weight_shape = helper.get_shape_from_value_info(weight_value)
|
||||
@ -41,10 +42,12 @@ def change_first_conv_from_bgr_to_rgb(m):
|
||||
other.topological_sort(g)
|
||||
return True
|
||||
|
||||
|
||||
def change_input_from_bgr_to_rgb(m):
|
||||
"""For input channel format BGR model, use this function to modify the model
|
||||
to accepct RGB image.If the first node is a non-group Conv. Modify weight to
|
||||
adapt the input into RGB. Otherwise create a new node.
|
||||
"""
|
||||
For input channel format BGR model, use this function to modify the model
|
||||
to accepct RGB image.If the first node is a non-group Conv.
|
||||
Modify weight to adapt the input into RGB. Otherwise create a new node.
|
||||
|
||||
:param m: the model proto
|
||||
"""
|
||||
@ -61,34 +64,33 @@ def change_input_from_bgr_to_rgb(m):
|
||||
return
|
||||
# Otherwise, create a special conv node and replace the input
|
||||
# Construct weight
|
||||
weight_np = np.zeros((3, 3, 3, 3)).astype('float32')
|
||||
weight_np = np.zeros((3, 3, 3, 3)).astype("float32")
|
||||
weight_np[0, 2, 1, 1] = 1.0
|
||||
weight_np[1, 1, 1, 1] = 1.0
|
||||
weight_np[2, 0, 1, 1] = 1.0
|
||||
new_weight = helper.numpy_to_constant("bgr_shuffle_weight", weight_np)
|
||||
# Construct Conv
|
||||
new_conv = onnx.helper.make_node(
|
||||
'Conv',
|
||||
['rgb_input', "bgr_shuffle_weight"],
|
||||
"Conv",
|
||||
["rgb_input", "bgr_shuffle_weight"],
|
||||
[g.input[0].name],
|
||||
name='bgr_shuffle',
|
||||
name="bgr_shuffle",
|
||||
dilations=[1, 1],
|
||||
kernel_shape=[3, 3],
|
||||
pads=[1, 1, 1, 1],
|
||||
strides=[1, 1]
|
||||
strides=[1, 1],
|
||||
)
|
||||
# Connect the graph
|
||||
old_input_value = g.input.pop()
|
||||
new_input_value = onnx.helper.make_tensor_value_info(
|
||||
'rgb_input',
|
||||
old_input_value.type.tensor_type.elem_type,
|
||||
input_shape
|
||||
"rgb_input", old_input_value.type.tensor_type.elem_type, input_shape
|
||||
)
|
||||
g.input.extend([new_input_value])
|
||||
g.node.extend([new_weight, new_conv])
|
||||
# topological sort
|
||||
other.topological_sort(g)
|
||||
|
||||
|
||||
def add_0_5_to_normalized_input(m):
|
||||
"""For normalized input between -0.5 ~ 0.5, add 0.5 to the input to keep it
|
||||
between 0 ~ 1.
|
||||
@ -105,41 +107,37 @@ def add_0_5_to_normalized_input(m):
|
||||
return
|
||||
# Construct weight
|
||||
ch = input_shape[1]
|
||||
weight_np = np.zeros((ch, ch, 3, 3)).astype('float32')
|
||||
weight_np = np.zeros((ch, ch, 3, 3)).astype("float32")
|
||||
for i in range(ch):
|
||||
weight_np[i, i, 1, 1] = 1.0
|
||||
new_weight = helper.numpy_to_constant("input_norm_weight", weight_np)
|
||||
# Construct bias
|
||||
bias_np = np.array([0.5] * ch).astype('float32')
|
||||
bias_np = np.array([0.5] * ch).astype("float32")
|
||||
new_bias = helper.numpy_to_constant("input_norm_bias", bias_np)
|
||||
# Construct Conv
|
||||
new_conv = onnx.helper.make_node(
|
||||
'Conv',
|
||||
['origin_input', "input_norm_weight", "input_norm_bias"],
|
||||
"Conv",
|
||||
["origin_input", "input_norm_weight", "input_norm_bias"],
|
||||
[g.input[0].name],
|
||||
name='input_norm',
|
||||
name="input_norm",
|
||||
dilations=[1, 1],
|
||||
kernel_shape=[3, 3],
|
||||
pads=[1, 1, 1, 1],
|
||||
strides=[1, 1]
|
||||
strides=[1, 1],
|
||||
)
|
||||
# Construct value_infos
|
||||
old_input_value = g.input.pop()
|
||||
weight_value = onnx.helper.make_tensor_value_info(
|
||||
'input_norm_weight',
|
||||
"input_norm_weight",
|
||||
old_input_value.type.tensor_type.elem_type,
|
||||
[3, 3, 3, 3]
|
||||
[3, 3, 3, 3],
|
||||
)
|
||||
bias_value = onnx.helper.make_tensor_value_info(
|
||||
'input_norm_bias',
|
||||
old_input_value.type.tensor_type.elem_type,
|
||||
[3]
|
||||
"input_norm_bias", old_input_value.type.tensor_type.elem_type, [3]
|
||||
)
|
||||
# Connect the graph
|
||||
new_input_value = onnx.helper.make_tensor_value_info(
|
||||
'origin_input',
|
||||
old_input_value.type.tensor_type.elem_type,
|
||||
input_shape
|
||||
"origin_input", old_input_value.type.tensor_type.elem_type, input_shape
|
||||
)
|
||||
g.input.extend([new_input_value])
|
||||
g.node.extend([new_weight, new_bias, new_conv])
|
||||
@ -147,9 +145,9 @@ def add_0_5_to_normalized_input(m):
|
||||
# topological sort
|
||||
other.topological_sort(g)
|
||||
|
||||
|
||||
def add_rgb2yynn_node(m):
|
||||
"""Add a conv layer which can convert rgb to yynn input.
|
||||
"""
|
||||
"""Add a conv layer which can convert rgb to yynn input."""
|
||||
g = m.graph
|
||||
if len(g.input) > 1:
|
||||
print("This model has multiple inputs. Cannot change to rgb input.")
|
||||
@ -159,37 +157,32 @@ def add_rgb2yynn_node(m):
|
||||
print("The input shape is not BCHW. Cannot normalize input.")
|
||||
return
|
||||
# Construct weight
|
||||
ch = input_shape[1]
|
||||
weight_np = np.zeros((3, 3, 4, 4)).astype('float32')
|
||||
weight_np[1, 1, :3, :2] = np.array([[[[0.299],
|
||||
[0.587],
|
||||
[0.114]]]])
|
||||
weight_np[1, 1, 3, 2:] = 1.
|
||||
weight_np = np.zeros((3, 3, 4, 4)).astype("float32")
|
||||
weight_np[1, 1, :3, :2] = np.array([[[[0.299], [0.587], [0.114]]]])
|
||||
weight_np[1, 1, 3, 2:] = 1.0
|
||||
weight_np = np.transpose(weight_np, (3, 2, 0, 1))
|
||||
new_weight = helper.numpy_to_constant("input_rgb2yynn_weight", weight_np)
|
||||
# Construct conv node
|
||||
new_conv = onnx.helper.make_node(
|
||||
'Conv',
|
||||
['new_input', "input_rgb2yynn_weight"],
|
||||
"Conv",
|
||||
["new_input", "input_rgb2yynn_weight"],
|
||||
[g.input[0].name],
|
||||
name='input_rgba2yynn',
|
||||
name="input_rgba2yynn",
|
||||
dilations=[1, 1],
|
||||
kernel_shape=[3, 3],
|
||||
pads=[1, 1, 1, 1],
|
||||
strides=[1, 1]
|
||||
strides=[1, 1],
|
||||
)
|
||||
# Construct value_infos
|
||||
old_input_value = g.input.pop()
|
||||
weight_value = onnx.helper.make_tensor_value_info(
|
||||
'input_rgb2yynn_weight',
|
||||
"input_rgb2yynn_weight",
|
||||
old_input_value.type.tensor_type.elem_type,
|
||||
[4, 4, 3, 3]
|
||||
[4, 4, 3, 3],
|
||||
)
|
||||
# Connect the graph
|
||||
new_input_value = onnx.helper.make_tensor_value_info(
|
||||
'new_input',
|
||||
old_input_value.type.tensor_type.elem_type,
|
||||
input_shape
|
||||
"new_input", old_input_value.type.tensor_type.elem_type, input_shape
|
||||
)
|
||||
g.input.extend([new_input_value])
|
||||
g.node.extend([new_weight, new_conv])
|
||||
@ -197,6 +190,7 @@ def add_rgb2yynn_node(m):
|
||||
# topological sort
|
||||
other.topological_sort(g)
|
||||
|
||||
|
||||
def swap_MatMul_inputs(g, original_matmul_node):
|
||||
# Create Transpose nodes
|
||||
input_a_value = helper.find_value_by_name(g, original_matmul_node.input[0])
|
||||
@ -206,11 +200,12 @@ def swap_MatMul_inputs(g, original_matmul_node):
|
||||
else:
|
||||
perm = [0, 2, 1]
|
||||
new_input_b_node = onnx.helper.make_node(
|
||||
'Transpose',
|
||||
inputs = [input_a_value.name],
|
||||
outputs = [input_a_value.name + '_transposed'],
|
||||
name = f"{input_a_value.name}_transposed_for_{original_matmul_node.name}",
|
||||
perm = perm
|
||||
"Transpose",
|
||||
inputs=[input_a_value.name],
|
||||
outputs=[input_a_value.name + "_transposed"],
|
||||
name=f"{input_a_value.name}_transposed_for_"
|
||||
f"{original_matmul_node.name}",
|
||||
perm=perm,
|
||||
)
|
||||
input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1])
|
||||
input_b_shape = helper.get_shape_from_value_info(input_b_value)
|
||||
@ -219,18 +214,19 @@ def swap_MatMul_inputs(g, original_matmul_node):
|
||||
else:
|
||||
perm = [0, 1, 3, 2]
|
||||
new_input_a_node = onnx.helper.make_node(
|
||||
'Transpose',
|
||||
inputs = [input_b_value.name],
|
||||
outputs = [input_b_value.name + '_transposed'],
|
||||
name = f'{input_b_value.name}_transposed_for_{original_matmul_node.name}',
|
||||
perm = perm
|
||||
"Transpose",
|
||||
inputs=[input_b_value.name],
|
||||
outputs=[input_b_value.name + "_transposed"],
|
||||
name=f"{input_b_value.name}_transposed_for_"
|
||||
f"{original_matmul_node.name}",
|
||||
perm=perm,
|
||||
)
|
||||
# Create new MatMul node
|
||||
new_matmul_node = onnx.helper.make_node(
|
||||
'MatMul',
|
||||
inputs = [new_input_a_node.output[0], new_input_b_node.output[0]],
|
||||
outputs = [original_matmul_node.output[0] + '_transposed'],
|
||||
name = original_matmul_node.name + '_transposed'
|
||||
"MatMul",
|
||||
inputs=[new_input_a_node.output[0], new_input_b_node.output[0]],
|
||||
outputs=[original_matmul_node.output[0] + "_transposed"],
|
||||
name=original_matmul_node.name + "_transposed",
|
||||
)
|
||||
# Create final Transpose node
|
||||
output_value = helper.find_value_by_name(g, original_matmul_node.output[0])
|
||||
@ -240,17 +236,25 @@ def swap_MatMul_inputs(g, original_matmul_node):
|
||||
else:
|
||||
perm = [0, 1, 3, 2]
|
||||
new_final_transpose_node = onnx.helper.make_node(
|
||||
'Transpose',
|
||||
inputs = [new_matmul_node.output[0]],
|
||||
outputs = [original_matmul_node.output[0]],
|
||||
name = original_matmul_node.name + '_final_transpose',
|
||||
perm = perm
|
||||
"Transpose",
|
||||
inputs=[new_matmul_node.output[0]],
|
||||
outputs=[original_matmul_node.output[0]],
|
||||
name=original_matmul_node.name + "_final_transpose",
|
||||
perm=perm,
|
||||
)
|
||||
# Add new nodes
|
||||
g.node.extend([new_input_a_node, new_input_b_node, new_matmul_node, new_final_transpose_node])
|
||||
g.node.extend(
|
||||
[
|
||||
new_input_a_node,
|
||||
new_input_b_node,
|
||||
new_matmul_node,
|
||||
new_final_transpose_node,
|
||||
]
|
||||
)
|
||||
# Delete original nodes
|
||||
g.node.remove(original_matmul_node)
|
||||
|
||||
|
||||
def split_MatMul_batch_then_concat(g, original_matmul_node):
|
||||
new_nodes = []
|
||||
final_concat_inputs = []
|
||||
@ -265,49 +269,85 @@ def split_MatMul_batch_then_concat(g, original_matmul_node):
|
||||
batch_count = input_a_shape[1]
|
||||
for i in range(batch_count):
|
||||
# Create Split nodes for input A
|
||||
starts_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_starts", (1, ), [i])
|
||||
ends_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_ends", (1, ), [i+1])
|
||||
axes_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_axes", (1, ), [len(input_a_shape) - 3])
|
||||
starts_node = helper.list_to_constant(
|
||||
f"{input_a_value.name}_sliced_{i}_starts", (1,), [i]
|
||||
)
|
||||
ends_node = helper.list_to_constant(
|
||||
f"{input_a_value.name}_sliced_{i}_ends", (1,), [i + 1]
|
||||
)
|
||||
axes_node = helper.list_to_constant(
|
||||
f"{input_a_value.name}_sliced_{i}_axes",
|
||||
(1,),
|
||||
[len(input_a_shape) - 3],
|
||||
)
|
||||
new_sliced_a_node = onnx.helper.make_node(
|
||||
'Slice',
|
||||
inputs = [input_a_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]],
|
||||
outputs = [f"{input_a_value.name}_sliced_{i}"],
|
||||
name = f"{input_a_value.name}_sliced_{i}_for_{original_matmul_node.name}"
|
||||
"Slice",
|
||||
inputs=[
|
||||
input_a_value.name,
|
||||
starts_node.output[0],
|
||||
ends_node.output[0],
|
||||
axes_node.output[0],
|
||||
],
|
||||
outputs=[f"{input_a_value.name}_sliced_{i}"],
|
||||
name=f"{input_a_value.name}_sliced_{i}_for_"
|
||||
f"{original_matmul_node.name}",
|
||||
)
|
||||
new_nodes.extend(
|
||||
[starts_node, ends_node, axes_node, new_sliced_a_node]
|
||||
)
|
||||
new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_a_node])
|
||||
# Create Split nodes for input B
|
||||
starts_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_starts", (1, ), [i])
|
||||
ends_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_ends", (1, ), [i+1])
|
||||
axes_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_axes", (1, ), [len(input_b_shape) - 3])
|
||||
new_sliced_b_node = onnx.helper.make_node(
|
||||
'Slice',
|
||||
inputs = [input_b_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]],
|
||||
outputs = [f"{input_b_value.name}_sliced_{i}"],
|
||||
name = f"{input_b_value.name}_sliced_{i}_for_{original_matmul_node.name}"
|
||||
starts_node = helper.list_to_constant(
|
||||
f"{input_b_value.name}_sliced_{i}_starts", (1,), [i]
|
||||
)
|
||||
ends_node = helper.list_to_constant(
|
||||
f"{input_b_value.name}_sliced_{i}_ends", (1,), [i + 1]
|
||||
)
|
||||
axes_node = helper.list_to_constant(
|
||||
f"{input_b_value.name}_sliced_{i}_axes",
|
||||
(1,),
|
||||
[len(input_b_shape) - 3],
|
||||
)
|
||||
new_sliced_b_node = onnx.helper.make_node(
|
||||
"Slice",
|
||||
inputs=[
|
||||
input_b_value.name,
|
||||
starts_node.output[0],
|
||||
ends_node.output[0],
|
||||
axes_node.output[0],
|
||||
],
|
||||
outputs=[f"{input_b_value.name}_sliced_{i}"],
|
||||
name=f"{input_b_value.name}_sliced_{i}_for_"
|
||||
f"{original_matmul_node.name}",
|
||||
)
|
||||
new_nodes.extend(
|
||||
[starts_node, ends_node, axes_node, new_sliced_b_node]
|
||||
)
|
||||
new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_b_node])
|
||||
# Create MatMul nodes
|
||||
new_matmul_node = onnx.helper.make_node(
|
||||
'MatMul',
|
||||
inputs = [new_sliced_a_node.output[0], new_sliced_b_node.output[0]],
|
||||
outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"],
|
||||
name = f"{original_matmul_node.name}_sliced_{i}"
|
||||
"MatMul",
|
||||
inputs=[new_sliced_a_node.output[0], new_sliced_b_node.output[0]],
|
||||
outputs=[f"{original_matmul_node.output[0]}_sliced_{i}"],
|
||||
name=f"{original_matmul_node.name}_sliced_{i}",
|
||||
)
|
||||
new_nodes.append(new_matmul_node)
|
||||
final_concat_inputs.append(new_matmul_node.output[0])
|
||||
# Create Concat nodes
|
||||
output_value = helper.find_value_by_name(g, original_matmul_node.output[0])
|
||||
if output_value is None:
|
||||
output_value = helper.find_output_by_name(g, original_matmul_node.output[0])
|
||||
output_value = helper.find_output_by_name(
|
||||
g, original_matmul_node.output[0]
|
||||
)
|
||||
if output_value is None:
|
||||
helper.logger.error(f"Cannot find value_info for {original_matmul_node.output[0]}")
|
||||
helper.logger.error(
|
||||
f"Cannot find value_info for {original_matmul_node.output[0]}"
|
||||
)
|
||||
output_shape = helper.get_shape_from_value_info(output_value)
|
||||
new_concat_node = onnx.helper.make_node(
|
||||
"Concat",
|
||||
inputs = final_concat_inputs,
|
||||
outputs = [original_matmul_node.output[0]],
|
||||
name = f"{original_matmul_node.name}_final_concat",
|
||||
axis = len(output_shape) - 3
|
||||
inputs=final_concat_inputs,
|
||||
outputs=[original_matmul_node.output[0]],
|
||||
name=f"{original_matmul_node.name}_final_concat",
|
||||
axis=len(output_shape) - 3,
|
||||
)
|
||||
new_nodes.append(new_concat_node)
|
||||
# Add new nodes
|
||||
@ -320,7 +360,9 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
|
||||
new_nodes = []
|
||||
final_concat_inputs = []
|
||||
# Get the batch count
|
||||
input_b_node = helper.find_node_by_output_name(g, original_matmul_node.input[1])
|
||||
input_b_node = helper.find_node_by_output_name(
|
||||
g, original_matmul_node.input[1]
|
||||
)
|
||||
input_b_np = helper.constant_to_numpy(input_b_node)
|
||||
if len(input_b_np.shape) == 3:
|
||||
batch_count = input_b_np.shape[0]
|
||||
@ -329,17 +371,19 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
|
||||
for i in range(batch_count):
|
||||
# Create new constant node
|
||||
if len(input_b_np.shape) == 3:
|
||||
new_np = input_b_np[i:i+1, ...]
|
||||
new_np = input_b_np[i:i + 1, ...]
|
||||
else:
|
||||
new_np = input_b_np[:, i:i+1, ...]
|
||||
new_weight = helper.numpy_to_constant(f"{input_b_node.name}_sliced_{i}", new_np)
|
||||
new_np = input_b_np[:, i:i + 1, ...]
|
||||
new_weight = helper.numpy_to_constant(
|
||||
f"{input_b_node.name}_sliced_{i}", new_np
|
||||
)
|
||||
new_nodes.append(new_weight)
|
||||
# Create MatMul nodes
|
||||
new_matmul_node = onnx.helper.make_node(
|
||||
'MatMul',
|
||||
inputs = [original_matmul_node.input[0], new_weight.output[0]],
|
||||
outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"],
|
||||
name = f"{original_matmul_node.name}_sliced_{i}"
|
||||
"MatMul",
|
||||
inputs=[original_matmul_node.input[0], new_weight.output[0]],
|
||||
outputs=[f"{original_matmul_node.output[0]}_sliced_{i}"],
|
||||
name=f"{original_matmul_node.name}_sliced_{i}",
|
||||
)
|
||||
new_nodes.append(new_matmul_node)
|
||||
final_concat_inputs.append(new_matmul_node.output[0])
|
||||
@ -348,10 +392,10 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
|
||||
output_shape = helper.get_shape_from_value_info(output_value)
|
||||
new_concat_node = onnx.helper.make_node(
|
||||
"Concat",
|
||||
inputs = final_concat_inputs,
|
||||
outputs = [original_matmul_node.output[0]],
|
||||
name = f"{original_matmul_node.name}_final_concat",
|
||||
axis = len(output_shape) - 3
|
||||
inputs=final_concat_inputs,
|
||||
outputs=[original_matmul_node.output[0]],
|
||||
name=f"{original_matmul_node.name}_final_concat",
|
||||
axis=len(output_shape) - 3,
|
||||
)
|
||||
new_nodes.append(new_concat_node)
|
||||
# Add new nodes
|
||||
@ -367,7 +411,7 @@ def split_MatMul_Constant_input_then_concat(g, original_matmul_node):
|
||||
|
||||
def special_MatMul_process(g):
|
||||
for node in g.node:
|
||||
if node.op_type != 'MatMul':
|
||||
if node.op_type != "MatMul":
|
||||
continue
|
||||
input_a_name = node.input[0]
|
||||
input_a_value = helper.find_value_by_name(g, input_a_name)
|
||||
@ -383,19 +427,30 @@ def special_MatMul_process(g):
|
||||
continue
|
||||
# Too many dimensions or too few dimensions. Not supported. Skip
|
||||
if len(input_a_shape) > 4 or len(input_b_shape) > 4:
|
||||
helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have too many dimensions.")
|
||||
helper.logger.warning(
|
||||
f"Cannot optimize MatMul {node.name}: "
|
||||
"inputs have too many dimensions."
|
||||
)
|
||||
continue
|
||||
if len(input_a_shape) < 2 or len(input_b_shape) < 2:
|
||||
helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have two few dimensions.")
|
||||
helper.logger.warning(
|
||||
f"Cannot optimize MatMul {node.name}: "
|
||||
"inputs have two few dimensions."
|
||||
)
|
||||
continue
|
||||
# For 4 dimension, check the first dimension (should be 1) and treated as 3 dimensions.
|
||||
# For 4 dimension, check the first dimension (should be 1)
|
||||
# and treated as 3 dimensions.
|
||||
extra_dim = None
|
||||
if len(input_a_shape) == 4:
|
||||
extra_dim = input_a_shape[0]
|
||||
input_a_shape = input_a_shape[1:]
|
||||
if len(input_b_shape) == 4:
|
||||
if input_b_shape[0] != extra_dim:
|
||||
helper.logger.warning(f"Cannot optimize MatMul {node.name}: input dimension batch sizes does not match ({extra_dim} vs {input_b_shape[0]}).")
|
||||
helper.logger.warning(
|
||||
f"Cannot optimize MatMul {node.name}: "
|
||||
"input dimension batch sizes does not match "
|
||||
f"({extra_dim} vs {input_b_shape[0]})."
|
||||
)
|
||||
continue
|
||||
input_b_shape = input_b_shape[1:]
|
||||
# Check input B dimension
|
||||
@ -404,20 +459,31 @@ def special_MatMul_process(g):
|
||||
continue
|
||||
# If B is B x W x V, but B is a constant.
|
||||
input_b_node = helper.find_node_by_output_name(g, input_b_name)
|
||||
if input_b_node is not None and input_b_node.op_type == 'Constant':
|
||||
if input_b_node is not None and input_b_node.op_type == "Constant":
|
||||
# Constant input
|
||||
helper.logger.debug(f"Optimizing MatMul node {node.name}: split constant input.")
|
||||
helper.logger.debug(
|
||||
f"Optimizing MatMul node {node.name}: split constant input."
|
||||
)
|
||||
split_MatMul_Constant_input_then_concat(g, node)
|
||||
# If B is B x W x V and A is 1 x H x W, do the swap.
|
||||
elif len(input_a_shape) == 2 or (input_a_shape[0] == 1 and (extra_dim is None or extra_dim == 1)):
|
||||
helper.logger.debug(f"Optimizing MatMul node {node.name}: swap input.")
|
||||
elif len(input_a_shape) == 2 or (
|
||||
input_a_shape[0] == 1 and (extra_dim is None or extra_dim == 1)
|
||||
):
|
||||
helper.logger.debug(
|
||||
f"Optimizing MatMul node {node.name}: swap input."
|
||||
)
|
||||
swap_MatMul_inputs(g, node)
|
||||
# If B is B x W x V and A is B x H x W, do the split.
|
||||
elif input_b_shape[0] == input_a_shape[0]:
|
||||
helper.logger.debug(f"Optimizing MatMul node {node.name}: split input batch.")
|
||||
helper.logger.debug(
|
||||
f"Optimizing MatMul node {node.name}: split input batch."
|
||||
)
|
||||
split_MatMul_batch_then_concat(g, node)
|
||||
# Other cases are not supported: If B is B x W x V but A is X x H x W.
|
||||
else:
|
||||
helper.logger.warning(f"Cannot optimize MatMul {node.name}: unknown reason. Might be shape mismatch.")
|
||||
helper.logger.warning(
|
||||
f"Cannot optimize MatMul {node.name}: "
|
||||
"unknown reason. Might be shape mismatch."
|
||||
)
|
||||
continue
|
||||
other.topological_sort(g)
|
||||
other.topological_sort(g)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user