670 lines
26 KiB
Python
670 lines
26 KiB
Python
import collections
|
|
import struct
|
|
import onnx
|
|
import numpy as np
|
|
from . import other
|
|
from . import helper
|
|
from . import modhelper
|
|
from .general_graph import Graph
|
|
|
|
def eliminate_Identify_and_Dropout(g):
|
|
"""
|
|
Eliminate Identify layers
|
|
|
|
:param g: the onnx graph
|
|
"""
|
|
node_to_remove = []
|
|
for node in g.node:
|
|
if node.op_type != 'Identity' and node.op_type != 'Dropout':
|
|
continue
|
|
# If this node is the last node, leave it to `eliminate_useless_last node`
|
|
if helper.find_output_by_name(g, node.output[0]) is not None:
|
|
continue
|
|
# Replace the parents in all the following nodes
|
|
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
|
for following_node in following_nodes:
|
|
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
|
# Delete value info
|
|
value_between = helper.find_value_by_name(g, node.output[0])
|
|
try:
|
|
g.value_info.remove(value_between)
|
|
except:
|
|
print("No value info to delete while eliminating identity layers.")
|
|
# Node is waiting for elimination
|
|
node_to_remove.append(node)
|
|
for node in node_to_remove:
|
|
g.node.remove(node)
|
|
|
|
# Remove last useless nodes
|
|
def remove_useless_last_nodes(g):
|
|
"""Remove useless nodes from the tail of the graph
|
|
"""
|
|
USELESS = ["Reshape", "Identity", "Transpose", "Flatten", "Dropout", "Mystery", "Constant", "Squeeze", "Unsqueeze", 'Softmax']
|
|
graph = Graph(g)
|
|
todo = collections.deque()
|
|
for node in graph.output_nodes:
|
|
if len(node.children) == 0:
|
|
todo.append(node)
|
|
node_to_remove = []
|
|
while todo:
|
|
# BFS find nodes to remove
|
|
cur_node = todo.popleft()
|
|
if cur_node.proto is None:
|
|
continue
|
|
if cur_node.proto.op_type not in USELESS:
|
|
continue
|
|
# Find the output
|
|
cur_node_output = helper.find_output_by_name(g, cur_node.proto.output[0])
|
|
for cur_input in cur_node.parents:
|
|
cur_input.children.remove(cur_node)
|
|
if len(cur_input.children) == 0:
|
|
todo.append(cur_input)
|
|
if cur_node_output is not None:
|
|
cur_input_output = helper.find_value_by_name(g, cur_input.proto.output[0])
|
|
cur_input_output_in_output = helper.find_output_by_name(g, cur_input.proto.output[0])
|
|
if cur_input_output is not None and cur_input_output_in_output is None:
|
|
g.output.extend([cur_input_output])
|
|
node_to_remove.append(cur_node.proto)
|
|
try:
|
|
g.value_info.remove(helper.find_value_by_name(g, cur_node.proto.output[0]))
|
|
except ValueError:
|
|
pass
|
|
if cur_node_output is not None:
|
|
g.output.remove(cur_node_output)
|
|
cur_node.proto = None
|
|
cur_node.parents.clear()
|
|
for node in node_to_remove:
|
|
g.node.remove(node)
|
|
|
|
######################################
|
|
# TF only optimization passes #
|
|
######################################
|
|
|
|
def eliminate_shape_changing_after_input(g):
|
|
"""
|
|
Eliminate the Reshape node after input and reshape the input
|
|
|
|
:param g: the onnx graph
|
|
"""
|
|
node_to_remove = []
|
|
REMOVE_LIST = ["Reshape", "Transpose", "Flatten", "Dropout", "Squeeze", "Unsqueeze"]
|
|
for node in g.node:
|
|
# Find an input and the shape node
|
|
if node.op_type not in REMOVE_LIST:
|
|
continue
|
|
old_input = helper.find_input_by_name(g, node.input[0])
|
|
if old_input is None:
|
|
continue
|
|
# If the input is used by multiple nodes, skip.
|
|
counter = 0
|
|
for tnode in g.node:
|
|
if old_input.name in tnode.input:
|
|
counter += 1
|
|
if counter > 1:
|
|
continue
|
|
# Remove Weight if any.
|
|
output_val_info = helper.find_value_by_name(g, node.output[0])
|
|
|
|
if node.op_type == 'Reshape':
|
|
shape_node = helper.find_node_by_output_name(g, node.input[1])
|
|
if shape_node.op_type != 'Constant':
|
|
continue
|
|
|
|
# manuelly set the input shape
|
|
shape_info = helper.find_value_by_name(g, shape_node.output[0])
|
|
old_size, old_shape = helper.find_size_shape_from_value(shape_info)
|
|
|
|
_, new_shape = helper.constant_to_list(shape_node)
|
|
for i in range(len(new_shape)):
|
|
if new_shape[i] == -1:
|
|
dim = int(old_size//np.prod(new_shape)*(-1))
|
|
new_shape[i] = dim
|
|
new_input = onnx.helper.make_tensor_value_info(
|
|
output_val_info.name,
|
|
output_val_info.type.tensor_type.elem_type,
|
|
new_shape
|
|
)
|
|
|
|
node_to_remove.append(node)
|
|
|
|
shape_outputs = helper.find_nodes_by_input_name(g, shape_node.output[0])
|
|
if len(shape_outputs) == 1:
|
|
node_to_remove.append(shape_node)
|
|
g.value_info.remove(helper.find_value_by_name(g, shape_node.output[0]))
|
|
|
|
g.input.remove(old_input)
|
|
g.input.extend([new_input])
|
|
g.value_info.remove(output_val_info)
|
|
elif node.op_type == 'Transpose':
|
|
permutation = list(node.attribute[0].ints)
|
|
pre_shape = helper.get_shape_from_value_info(old_input)
|
|
new_shape = [pre_shape[i] for i in permutation]
|
|
|
|
new_input = onnx.helper.make_tensor_value_info(
|
|
output_val_info.name,
|
|
output_val_info.type.tensor_type.elem_type,
|
|
new_shape
|
|
)
|
|
|
|
node_to_remove.append(node)
|
|
|
|
g.input.remove(old_input)
|
|
g.input.extend([new_input])
|
|
g.value_info.remove(output_val_info)
|
|
elif node.op_type == 'Flatten':
|
|
axis = node.attribute[0].int
|
|
pre_shape = helper.get_shape_from_value_info(old_input)
|
|
dim_1, dim_2 = 1, 1
|
|
if axis == 0:
|
|
dim_1 = 1
|
|
dim_2 = np.prod(pre_shape)
|
|
else:
|
|
dim_1 = np.prod(pre_shape[:axis]).astype(int)
|
|
dim_2 = np.prod(pre_shape[axis:]).astype(int)
|
|
new_shape = [dim_1, dim_2]
|
|
|
|
new_input = onnx.helper.make_tensor_value_info(
|
|
output_val_info.name,
|
|
output_val_info.type.tensor_type.elem_type,
|
|
new_shape
|
|
)
|
|
|
|
node_to_remove.append(node)
|
|
|
|
g.input.remove(old_input)
|
|
g.input.extend([new_input])
|
|
g.value_info.remove(output_val_info)
|
|
elif node.op_type == 'Dropout':
|
|
g.input.remove(old_input)
|
|
g.input.extend([output_val_info])
|
|
g.value_info.remove(output_val_info)
|
|
|
|
node_to_remove.append(node)
|
|
elif node.op_type == 'Squeeze':
|
|
axis = list(node.attribute[0].ints)
|
|
pre_shape = helper.get_shape_from_value_info(old_input)
|
|
for pos in sorted(axis)[::-1]:
|
|
if pre_shape[pos] != 1:
|
|
raise RuntimeError('invalid axis for squeeze')
|
|
else:
|
|
pre_shape.pop(pos)
|
|
new_shape = pre_shape
|
|
|
|
new_input = onnx.helper.make_tensor_value_info(
|
|
output_val_info.name,
|
|
output_val_info.type.tensor_type.elem_type,
|
|
new_shape
|
|
)
|
|
|
|
node_to_remove.append(node)
|
|
|
|
g.input.remove(old_input)
|
|
g.input.extend([new_input])
|
|
g.value_info.remove(output_val_info)
|
|
elif node.op_type == 'Unsqueeze':
|
|
axis = list(node.attribute[0].ints)
|
|
pre_shape = helper.get_shape_from_value_info(old_input)
|
|
new_shape = pre_shape
|
|
for pos in axis:
|
|
new_shape.insert(pos, 1)
|
|
new_input = onnx.helper.make_tensor_value_info(
|
|
output_val_info.name,
|
|
output_val_info.type.tensor_type.elem_type,
|
|
new_shape
|
|
)
|
|
node_to_remove.append(node)
|
|
|
|
g.input.remove(old_input)
|
|
g.input.extend([new_input])
|
|
g.value_info.remove(output_val_info)
|
|
else:
|
|
pass
|
|
|
|
for node in node_to_remove:
|
|
g.node.remove(node)
|
|
|
|
other.topological_sort(g)
|
|
|
|
|
|
def eliminate_Reshape_Cast(g):
|
|
"""Eliminate the cast layer for shape of Reshape layer
|
|
|
|
:param g: the onnx graph
|
|
"""
|
|
#Find all reshape layers
|
|
node_to_remove = []
|
|
for node in g.node:
|
|
if node.op_type != 'Reshape':
|
|
continue
|
|
prev_node = helper.find_node_by_output_name(g, node.input[1])
|
|
if prev_node.op_type != 'Cast':
|
|
continue
|
|
# Now we find the cast weight pattern. Cast the weight, delete the cast.
|
|
reshape_node = node
|
|
cast_node = prev_node
|
|
weight_node = helper.find_node_by_output_name(g, cast_node.input[0])
|
|
if weight_node is None:
|
|
raise RuntimeError("Unexpected None before Cast-Reshape.")
|
|
weight_node.attribute[0].t.data_type = 7
|
|
if weight_node.attribute[0].t.raw_data:
|
|
raw_data = weight_node.attribute[0].t.raw_data
|
|
int_data = [i[0] for i in struct.iter_unpack('i', raw_data)]
|
|
raw_data = struct.pack('q' * len(int_data), *int_data)
|
|
elif len(weight_node.attribute[0].t.int64_data) > 0\
|
|
or len(weight_node.attribute[0].t.int32_data) > 0:
|
|
# It's already int. Do nothing
|
|
pass
|
|
else:
|
|
raise NotImplementedError()
|
|
# Change Value info
|
|
origin_weight_out = helper.find_value_by_name(g, weight_node.output[0])
|
|
weight_node.output.pop()
|
|
weight_node.output.extend([reshape_node.input[1]])
|
|
# Delete
|
|
g.value_info.remove(origin_weight_out)
|
|
g.node.remove(cast_node)
|
|
|
|
def eliminate_Cast_after_input(g):
|
|
"""Eliminate the cast layer right after the input
|
|
|
|
:param g: the onnx graph
|
|
"""
|
|
node_to_remove = []
|
|
for node in g.node:
|
|
if node.op_type != 'Cast':
|
|
continue
|
|
old_input = helper.find_input_by_name(g, node.input[0])
|
|
if old_input is None:
|
|
continue
|
|
next_val_info = helper.find_value_by_name(g, node.output[0])
|
|
shape = helper.get_shape_from_value_info(next_val_info)
|
|
new_val_info = onnx.helper.make_tensor_value_info(
|
|
next_val_info.name,
|
|
node.attribute[0].i,
|
|
shape
|
|
)
|
|
# Delete old value_info
|
|
g.input.remove(old_input)
|
|
g.value_info.remove(next_val_info)
|
|
# Append nodes to node_to_remove
|
|
node_to_remove.append(node)
|
|
# Add new input
|
|
g.input.extend([new_val_info])
|
|
for node in node_to_remove:
|
|
g.node.remove(node)
|
|
|
|
def eliminate_consecutive_Cast(g):
|
|
"""If two cast is next to each other, remove the first cast
|
|
|
|
:param g: the onnx graph
|
|
"""
|
|
node_to_remove = []
|
|
for node in g.node:
|
|
if node.op_type != 'Cast':
|
|
continue
|
|
first_node = helper.find_node_by_output_name(g, node.input[0])
|
|
if first_node is None or first_node.op_type != 'Cast':
|
|
continue
|
|
# Here we have two consecutive Cast Node
|
|
# Reset the input of the later node
|
|
node.input[0] = first_node.input[0]
|
|
# Remove the first node and its output value info
|
|
node_to_remove.append(first_node)
|
|
first_output = helper.find_value_by_name(g, first_node.output[0])
|
|
g.value_info.remove(first_output)
|
|
for node in node_to_remove:
|
|
g.node.remove(node)
|
|
|
|
def eliminate_Squeeze_before_Reshape(g):
|
|
"""If Squeeze and Reshape is next to each other, remove the first node
|
|
|
|
:param g: the onnx graph
|
|
"""
|
|
node_to_remove = []
|
|
for node in g.node:
|
|
if node.op_type != 'Reshape':
|
|
continue
|
|
first_node = helper.find_node_by_output_name(g, node.input[0])
|
|
if not first_node:
|
|
continue
|
|
if first_node.op_type != 'Squeeze':
|
|
continue
|
|
# Here we have two consecutive Cast Node
|
|
# Reset the input of the later node
|
|
node.input[0] = first_node.input[0]
|
|
# Remove the first node and its output value info
|
|
node_to_remove.append(first_node)
|
|
first_output = helper.find_value_by_name(g, first_node.output[0])
|
|
g.value_info.remove(first_output)
|
|
for node in node_to_remove:
|
|
g.node.remove(node)
|
|
|
|
def eliminate_no_children_input(g):
|
|
"""Eliminate inputs with no children at all.
|
|
"""
|
|
# Create a set of input names
|
|
input_names = set([i.name for i in g.input])
|
|
# If a name is used in any node, remove this name from the set.
|
|
for n in g.node:
|
|
for i in n.input:
|
|
input_names.discard(i)
|
|
# Remove the inputs with the left names.
|
|
for i in input_names:
|
|
info = helper.find_input_by_name(g, i)
|
|
g.input.remove(info)
|
|
|
|
def eliminate_consecutive_reshape(g):
|
|
"""Replace consecutive reshape nodes by a single node.
|
|
"""
|
|
node_to_del = []
|
|
for node in g.node:
|
|
if node.op_type != 'Reshape':
|
|
continue
|
|
pre_data_node = helper.find_node_by_output_name(g, node.input[0])
|
|
pre_shape_node = helper.find_node_by_output_name(g, node.input[1])
|
|
if not pre_data_node or not pre_shape_node:
|
|
continue
|
|
if pre_shape_node.op_type != 'Constant':
|
|
continue
|
|
if pre_data_node.op_type != 'Reshape':
|
|
continue
|
|
|
|
pre_pre_shape_node = helper.find_node_by_output_name(g, pre_data_node.input[1])
|
|
if pre_pre_shape_node.op_type != 'Constant':
|
|
continue
|
|
|
|
new_reshape_node = onnx.helper.make_node(
|
|
'Reshape',
|
|
[pre_data_node.input[0], node.input[1]],
|
|
[node.output[0]],
|
|
name = node.output[0]
|
|
)
|
|
|
|
g.node.extend([new_reshape_node])
|
|
node_to_del.append(node)
|
|
node_to_del.append(pre_data_node)
|
|
node_to_del.append(pre_pre_shape_node)
|
|
|
|
val_info_to_del1 = helper.find_value_by_name(g, node.input[0])
|
|
val_info_to_del2 = helper.find_value_by_name(g, pre_data_node.input[1])
|
|
g.value_info.remove(val_info_to_del1)
|
|
g.value_info.remove(val_info_to_del2)
|
|
|
|
while node_to_del:
|
|
node = node_to_del.pop()
|
|
g.node.remove(node)
|
|
|
|
def eliminate_single_input_Concat(g):
|
|
"""
|
|
Eliminate single input Concat layers
|
|
|
|
:param g: the onnx graph
|
|
"""
|
|
node_to_remove = []
|
|
for node in g.node:
|
|
if node.op_type != 'Concat':
|
|
continue
|
|
# If this node has more than 1 input, continue.
|
|
if len(node.input) > 1:
|
|
continue
|
|
# If this node is the output node, set its previous node as output nodes.
|
|
if helper.find_output_by_name(g, node.output[0]) is not None:
|
|
todel_output = helper.find_output_by_name(g, node.output[0])
|
|
the_input_value = helper.find_value_by_name(g, node.input[0])
|
|
g.output.remove(todel_output)
|
|
g.output.extend([the_input_value])
|
|
node_to_remove.append(node)
|
|
continue
|
|
# Replace the parents in all the following nodes
|
|
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
|
for following_node in following_nodes:
|
|
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
|
# Delete value info
|
|
value_between = helper.find_value_by_name(g, node.output[0])
|
|
try:
|
|
g.value_info.remove(value_between)
|
|
except:
|
|
print("No value info to delete while eliminating identity layers.")
|
|
# Node is waiting for elimination
|
|
node_to_remove.append(node)
|
|
for node in node_to_remove:
|
|
g.node.remove(node)
|
|
|
|
def eliminate_nop_Maxpool_and_AveragePool(g):
|
|
"""
|
|
Eliminate do nothing MaxPool and AveragePool layers.
|
|
Those layers have valid padding, 1x1 kernel and [1,1] strides.
|
|
|
|
:param g: the onnx graph
|
|
"""
|
|
node_to_remove = []
|
|
for node in g.node:
|
|
if node.op_type != 'MaxPool' and node.op_type != 'AveragePool':
|
|
continue
|
|
# If this node is actually working, continue.
|
|
kernel = helper.get_list_attribute_by_name(node, "kernel_shape", "int")
|
|
pads = helper.get_list_attribute_by_name(node, "pads", "int")
|
|
strides = helper.get_list_attribute_by_name(node, "strides", "int")
|
|
if kernel != [1, 1] or pads != [0, 0, 0, 0] or strides != [1, 1]:
|
|
continue
|
|
# If this node is the output node, set its previous node as output nodes.
|
|
if helper.find_output_by_name(g, node.output[0]) is not None:
|
|
todel_output = helper.find_output_by_name(g, node.output[0])
|
|
the_input_value = helper.find_value_by_name(g, node.input[0])
|
|
g.output.remove(todel_output)
|
|
g.output.extend([the_input_value])
|
|
node_to_remove.append(node)
|
|
continue
|
|
# Replace the parents in all the following nodes
|
|
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
|
for following_node in following_nodes:
|
|
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
|
# Delete value info
|
|
value_between = helper.find_value_by_name(g, node.output[0])
|
|
try:
|
|
g.value_info.remove(value_between)
|
|
except:
|
|
print("No value info to delete while eliminating identity layers.")
|
|
# Node is waiting for elimination
|
|
node_to_remove.append(node)
|
|
for node in node_to_remove:
|
|
g.node.remove(node)
|
|
|
|
|
|
def eliminate_trivial_maxpool(g):
|
|
node_to_del = []
|
|
for node in g.node:
|
|
if node.op_type != 'MaxPool':
|
|
continue
|
|
pads = None
|
|
strides = None
|
|
dilation = None
|
|
kernel_shape = None
|
|
for att in node.attribute:
|
|
if att.name == 'pads':
|
|
pads = list(att.ints)
|
|
elif att.name == 'strides':
|
|
strides = list(att.ints)
|
|
elif att.name == 'kernel_shape':
|
|
kernel_shape = list(att.ints)
|
|
elif att.name == 'dilation':
|
|
dilation = list(att.ints)
|
|
else:
|
|
pass
|
|
if pads and any([pad != 0 for pad in pads]):
|
|
continue
|
|
if strides and any([stride != 1 for stride in strides]):
|
|
continue
|
|
if dilation and any([dila != 1 for dila in dilation]):
|
|
continue
|
|
if any([dim != 1 for dim in kernel_shape]):
|
|
continue
|
|
|
|
node_to_del.append(node)
|
|
|
|
next_nodes = helper.find_nodes_by_input_name(g, node.output[0])
|
|
|
|
if next_nodes[0] == None:
|
|
output_value = helper.find_output_by_name(g, node.output[0])
|
|
if not output_value:
|
|
continue
|
|
else:
|
|
pre_val_info = helper.find_value_by_name(g, node.input[0])
|
|
g.output.extend([pre_val_info])
|
|
g.output.remove(output_value)
|
|
|
|
for next_node in next_nodes:
|
|
modhelper.replace_node_input(next_node, node.output[0], node.input[0])
|
|
|
|
next_val_info = helper.find_value_by_name(g, node.output[0])
|
|
g.value_info.remove(next_val_info)
|
|
|
|
while node_to_del:
|
|
g.node.remove(node_to_del.pop())
|
|
|
|
other.topological_sort(g)
|
|
|
|
def eliminate_empty_value_infos(g):
|
|
to_remove = []
|
|
for value_info in g.value_info:
|
|
if len(value_info.type.tensor_type.shape.dim) == 0:
|
|
to_remove.append(value_info)
|
|
for value_info in to_remove:
|
|
g.value_info.remove(value_info)
|
|
|
|
def eliminate_nop_pads(g):
|
|
node_to_remove = []
|
|
for node in g.node:
|
|
if node.op_type != 'Pad':
|
|
continue
|
|
# Check if the Pad is empty or not
|
|
pads_node = helper.find_node_by_output_name(g, node.input[1])
|
|
pads_np = helper.constant_to_numpy(pads_node)
|
|
all_zero = True
|
|
for value in pads_np:
|
|
if value != 0:
|
|
all_zero = False
|
|
if not all_zero:
|
|
continue
|
|
# Check if it has the constant_value_node
|
|
constant_value_node = None
|
|
if len(node.input) > 2:
|
|
constant_value_node = helper.find_node_by_output_name(g, node.input[2])
|
|
# If this node is the output node, set its previous node as output nodes.
|
|
if helper.find_output_by_name(g, node.output[0]) is not None:
|
|
todel_output = helper.find_output_by_name(g, node.output[0])
|
|
g.output.remove(todel_output)
|
|
if helper.find_output_by_name(g, node.input[0]) is None:
|
|
the_input_value = helper.find_value_by_name(g, node.input[0])
|
|
if the_input_value is not None:
|
|
g.output.extend([the_input_value])
|
|
# Replace the parents in all the following nodes
|
|
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
|
for following_node in following_nodes:
|
|
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
|
# Delete value info
|
|
value_between = helper.find_value_by_name(g, node.output[0])
|
|
try:
|
|
g.value_info.remove(value_between)
|
|
except:
|
|
helper.logger.info("No value info to delete while eliminating identity layers.")
|
|
# Node is waiting for elimination
|
|
node_to_remove.append(node)
|
|
for node in node_to_remove:
|
|
g.node.remove(node)
|
|
|
|
def eliminate_trivial_elementwise_calculation(g):
|
|
"""Eliminate Add, Sub, Mul, Sub nodes which do nothing.
|
|
"""
|
|
node_to_remove = []
|
|
for node in g.node:
|
|
weight_node = None
|
|
if node.op_type == 'Add' or node.op_type == 'Sub':
|
|
# For add and sub, check if the weights are 0s.
|
|
weight_node = helper.find_node_by_output_name(g, node.input[1])
|
|
if weight_node is None or weight_node.op_type != 'Constant':
|
|
continue
|
|
weight_np = helper.constant_to_numpy(weight_node)
|
|
if np.any(weight_np):
|
|
continue
|
|
elif node.op_type == 'Mul' or node.op_type == 'Div':
|
|
# For Mul and Div, check if the weights are 1s.
|
|
weight_node = helper.find_node_by_output_name(g, node.input[1])
|
|
if weight_node is None or weight_node.op_type != 'Constant':
|
|
continue
|
|
weight_np = helper.constant_to_numpy(weight_node)
|
|
weight_np = weight_np - 1
|
|
if np.any(weight_np):
|
|
continue
|
|
else:
|
|
# For other nodes, just skip
|
|
continue
|
|
# Remove the node
|
|
node_to_remove.append(node)
|
|
output_value_info = helper.find_value_by_name(g, node.output[0])
|
|
if output_value_info is not None:
|
|
g.value_info.remove(output_value_info)
|
|
# Replace next node input if any.
|
|
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
|
for following_node in following_nodes:
|
|
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
|
todel_output = helper.find_output_by_name(g, node.output[0])
|
|
if todel_output is not None:
|
|
g.output.remove(todel_output)
|
|
previous_output = helper.find_output_by_name(g, node.input[0])
|
|
if previous_output is None:
|
|
the_input_value = helper.find_value_by_name(g, node.input[0])
|
|
g.output.extend([the_input_value])
|
|
# Delete the constant node if it is not used by other nodes
|
|
constant_following_nodes = helper.find_following_nodes_by_input_value_name(g, weight_node.output[0])
|
|
if len(constant_following_nodes) == 1:
|
|
node_to_remove.append(weight_node)
|
|
output_value_info = helper.find_value_by_name(g, weight_node.output[0])
|
|
if output_value_info is not None:
|
|
g.value_info.remove(output_value_info)
|
|
for node in node_to_remove:
|
|
g.node.remove(node)
|
|
|
|
def eliminate_nop_cast(g):
|
|
"""Eliminate do nothing Cast nodes.
|
|
"""
|
|
node_to_remove = []
|
|
for node in g.node:
|
|
if node.op_type != 'Cast':
|
|
continue
|
|
# Get input value_info
|
|
input_value = helper.find_value_by_name(g, node.input[0])
|
|
if input_value is None:
|
|
helper.logger.debug(f"Cannot find the input value_info for Cast node {node.name}. Skip elimination check.")
|
|
continue
|
|
# Get output value_info
|
|
output_value = helper.find_value_by_name(g, node.output[0])
|
|
if output_value is None:
|
|
output_value = helper.find_output_by_name(g, node.output[0])
|
|
if output_value is None:
|
|
helper.logger.debug(f"Cannot find the output value_info for Cast node {node.name}. Skip elimination check.")
|
|
continue
|
|
# Compare the type.
|
|
if input_value.type.tensor_type.elem_type != output_value.type.tensor_type.elem_type:
|
|
continue
|
|
# If this node is the output node, set its previous node as output nodes.
|
|
if helper.find_output_by_name(g, node.output[0]) is not None:
|
|
todel_output = helper.find_output_by_name(g, node.output[0])
|
|
g.output.remove(todel_output)
|
|
if helper.find_output_by_name(g, node.input[0]) is None:
|
|
the_input_value = helper.find_value_by_name(g, node.input[0])
|
|
if the_input_value is not None:
|
|
g.output.extend([the_input_value])
|
|
# Replace the parents in all the following nodes
|
|
following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0])
|
|
for following_node in following_nodes:
|
|
modhelper.replace_node_input(following_node, node.output[0], node.input[0])
|
|
# Delete value info
|
|
value_between = helper.find_value_by_name(g, node.output[0])
|
|
if value_between is not None:
|
|
g.value_info.remove(value_between)
|
|
# Node is waiting for elimination
|
|
node_to_remove.append(node)
|
|
for node in node_to_remove:
|
|
g.node.remove(node)
|