STDC/tools/optimizer_scripts/tensorflow2onnx.py

148 lines
5.5 KiB
Python

import tensorflow as tf
import tf2onnx
import argparse
import logging
import sys
import onnx
import onnx.utils
from tensorflow.python.platform import gfile
from tools import combo, eliminating, replacing, other
def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
"""Convert frozen graph pb file into onnx
Args:
pb_path (str): input pb file path
test_mode (bool, optional): test mode. Defaults to False.
Raises:
Exception: invalid input file
Returns:
onnx.ModelProto: converted onnx
"""
TF2ONNX_VERSION = int(tf2onnx.version.version.replace('.', ''))
if 160 <= TF2ONNX_VERSION:
from tf2onnx import tf_loader
else:
from tf2onnx import loader as tf_loader
if pb_path[-3:] == '.pb':
model_name = pb_path.split('/')[-1][:-3]
# always reset tensorflow session at begin
tf.reset_default_graph()
with tf.Session() as sess:
with gfile.FastGFile(pb_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
if 160 <= int(tf2onnx.version.version.replace('.', '')):
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions = tf2onnx.tf_utils.tflist_to_onnx(
sess.graph,
{})
else:
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tf2onnx.tfonnx.tflist_to_onnx(
sess.graph.get_operations(),
{})
for n in onnx_nodes:
if len(n.output) == 0:
onnx_nodes.remove(n)
# find inputs and outputs of graph
nodes_inputs = set()
nodes_outputs = set()
for n in onnx_nodes:
if n.op_type == 'Placeholder':
continue
for input in n.input:
nodes_inputs.add(input)
for output in n.output:
nodes_outputs.add(output)
graph_input_names = set()
for input_name in nodes_inputs:
if input_name not in nodes_outputs:
graph_input_names.add(input_name)
graph_output_names = set()
for n in onnx_nodes:
if n.input and n.input[0] not in nodes_outputs:
continue
if len(n.output) == 0:
n.output.append(n.name + ':0')
graph_output_names.add(n.output[0])
else:
output_name = n.output[0]
if (output_name not in nodes_inputs) and (0 < len(n.input)):
graph_output_names.add(output_name)
logging.info('Model Inputs: %s', str(list(graph_input_names)))
logging.info('Model Outputs: %s', str(list(graph_output_names)))
graph_def, inputs, outputs = tf_loader.from_graphdef(model_path=pb_path,
input_names=list(graph_input_names),
output_names=list(graph_output_names))
with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(graph_def, name='')
if 160 <= TF2ONNX_VERSION:
with tf_loader.tf_session(graph=tf_graph):
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
input_names=inputs,
output_names=outputs,
opset=11)
else:
with tf.Session(graph=tf_graph):
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
input_names=inputs,
output_names=outputs,
opset=11)
# Optimize with tf2onnx.optimizer
onnx_graph = tf2onnx.optimizer.optimize_graph(onnx_graph)
model_proto = onnx_graph.make_model(model_name)
# Make tf2onnx output compatible with the spec. of other.polish_model
replacing.replace_initializer_with_Constant(model_proto.graph)
model_proto = other.polish_model(model_proto)
else:
raise Exception('expect .pb file as input, but got "' + str(pb_path) + '"')
# rename
m = model_proto
m = combo.preprocess(m)
m = combo.common_optimization(m)
m = combo.tensorflow_optimization(m)
m = combo.postprocess(m)
if not test_mode:
g = m.graph
eliminating.eliminate_shape_changing_after_input(g)
m = other.polish_model(m)
return m
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert tensorflow pb file to onnx file and optimized onnx file. Or just optimize tensorflow onnx file.')
parser.add_argument('in_file', help='input file')
parser.add_argument('out_file', help='output optimized model file')
parser.add_argument('-t', '--test_mode', default=False, help='test mode will not eliminate shape changes after input')
args = parser.parse_args()
logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.INFO)
m = tf2onnx_flow(args.in_file, args.test_mode)
onnx.save(m, args.out_file)
logging.info('Save Optimized ONNX: %s', args.out_file)