148 lines
5.5 KiB
Python
148 lines
5.5 KiB
Python
import tensorflow as tf
|
|
import tf2onnx
|
|
import argparse
|
|
import logging
|
|
import sys
|
|
import onnx
|
|
import onnx.utils
|
|
from tensorflow.python.platform import gfile
|
|
from tools import combo, eliminating, replacing, other
|
|
|
|
def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
|
|
"""Convert frozen graph pb file into onnx
|
|
|
|
Args:
|
|
pb_path (str): input pb file path
|
|
test_mode (bool, optional): test mode. Defaults to False.
|
|
|
|
Raises:
|
|
Exception: invalid input file
|
|
|
|
Returns:
|
|
onnx.ModelProto: converted onnx
|
|
"""
|
|
TF2ONNX_VERSION = int(tf2onnx.version.version.replace('.', ''))
|
|
|
|
if 160 <= TF2ONNX_VERSION:
|
|
from tf2onnx import tf_loader
|
|
else:
|
|
from tf2onnx import loader as tf_loader
|
|
|
|
if pb_path[-3:] == '.pb':
|
|
model_name = pb_path.split('/')[-1][:-3]
|
|
|
|
# always reset tensorflow session at begin
|
|
tf.reset_default_graph()
|
|
|
|
with tf.Session() as sess:
|
|
with gfile.FastGFile(pb_path, 'rb') as f:
|
|
graph_def = tf.GraphDef()
|
|
graph_def.ParseFromString(f.read())
|
|
sess.graph.as_default()
|
|
tf.import_graph_def(graph_def, name='')
|
|
|
|
if 160 <= int(tf2onnx.version.version.replace('.', '')):
|
|
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions = tf2onnx.tf_utils.tflist_to_onnx(
|
|
sess.graph,
|
|
{})
|
|
else:
|
|
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tf2onnx.tfonnx.tflist_to_onnx(
|
|
sess.graph.get_operations(),
|
|
{})
|
|
|
|
for n in onnx_nodes:
|
|
if len(n.output) == 0:
|
|
onnx_nodes.remove(n)
|
|
|
|
# find inputs and outputs of graph
|
|
nodes_inputs = set()
|
|
nodes_outputs = set()
|
|
|
|
for n in onnx_nodes:
|
|
if n.op_type == 'Placeholder':
|
|
continue
|
|
for input in n.input:
|
|
nodes_inputs.add(input)
|
|
for output in n.output:
|
|
nodes_outputs.add(output)
|
|
|
|
graph_input_names = set()
|
|
for input_name in nodes_inputs:
|
|
if input_name not in nodes_outputs:
|
|
graph_input_names.add(input_name)
|
|
|
|
graph_output_names = set()
|
|
for n in onnx_nodes:
|
|
if n.input and n.input[0] not in nodes_outputs:
|
|
continue
|
|
if len(n.output) == 0:
|
|
n.output.append(n.name + ':0')
|
|
graph_output_names.add(n.output[0])
|
|
else:
|
|
output_name = n.output[0]
|
|
if (output_name not in nodes_inputs) and (0 < len(n.input)):
|
|
graph_output_names.add(output_name)
|
|
|
|
logging.info('Model Inputs: %s', str(list(graph_input_names)))
|
|
logging.info('Model Outputs: %s', str(list(graph_output_names)))
|
|
|
|
graph_def, inputs, outputs = tf_loader.from_graphdef(model_path=pb_path,
|
|
input_names=list(graph_input_names),
|
|
output_names=list(graph_output_names))
|
|
|
|
with tf.Graph().as_default() as tf_graph:
|
|
tf.import_graph_def(graph_def, name='')
|
|
|
|
if 160 <= TF2ONNX_VERSION:
|
|
with tf_loader.tf_session(graph=tf_graph):
|
|
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
|
|
input_names=inputs,
|
|
output_names=outputs,
|
|
opset=11)
|
|
else:
|
|
with tf.Session(graph=tf_graph):
|
|
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
|
|
input_names=inputs,
|
|
output_names=outputs,
|
|
opset=11)
|
|
|
|
# Optimize with tf2onnx.optimizer
|
|
onnx_graph = tf2onnx.optimizer.optimize_graph(onnx_graph)
|
|
model_proto = onnx_graph.make_model(model_name)
|
|
|
|
# Make tf2onnx output compatible with the spec. of other.polish_model
|
|
replacing.replace_initializer_with_Constant(model_proto.graph)
|
|
model_proto = other.polish_model(model_proto)
|
|
|
|
else:
|
|
raise Exception('expect .pb file as input, but got "' + str(pb_path) + '"')
|
|
|
|
# rename
|
|
m = model_proto
|
|
|
|
m = combo.preprocess(m)
|
|
m = combo.common_optimization(m)
|
|
m = combo.tensorflow_optimization(m)
|
|
m = combo.postprocess(m)
|
|
|
|
if not test_mode:
|
|
g = m.graph
|
|
eliminating.eliminate_shape_changing_after_input(g)
|
|
|
|
m = other.polish_model(m)
|
|
return m
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Convert tensorflow pb file to onnx file and optimized onnx file. Or just optimize tensorflow onnx file.')
|
|
parser.add_argument('in_file', help='input file')
|
|
parser.add_argument('out_file', help='output optimized model file')
|
|
parser.add_argument('-t', '--test_mode', default=False, help='test mode will not eliminate shape changes after input')
|
|
|
|
args = parser.parse_args()
|
|
logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.INFO)
|
|
m = tf2onnx_flow(args.in_file, args.test_mode)
|
|
onnx.save(m, args.out_file)
|
|
logging.info('Save Optimized ONNX: %s', args.out_file)
|