STDC/tools/optimizer_scripts/tflite_vs_onnx.py

68 lines
2.9 KiB
Python

import argparse
import numpy as np
import tensorflow as tf
import onnx
import onnxruntime
from tools import helper
def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
# Setup onnx session and get meta data
onnx_session = onnxruntime.InferenceSession(onnx_file, None)
onnx_outputs = onnx_session.get_outputs()
assert len(onnx_outputs) == 1, "The onnx model has more than one output"
onnx_model = onnx.load(onnx_file)
onnx_graph = onnx_model.graph
onnx_inputs = onnx_graph.input
assert len(onnx_inputs) == 1, "The onnx model has more than one input"
_, onnx_input_shape = helper.find_size_shape_from_value(onnx_inputs[0])
# Setup TFLite sessio and get meta data
tflite_session = tf.lite.Interpreter(model_path=tflite_file)
tflite_session.allocate_tensors()
tflite_inputs = tflite_session.get_input_details()
tflite_outputs = tflite_session.get_output_details()
tflite_input_shape = tflite_inputs[0]['shape']
# Compare input shape
assert(len(onnx_input_shape) == len(tflite_input_shape)), "TFLite and ONNX shape unmatch."
assert(onnx_input_shape == [tflite_input_shape[0], tflite_input_shape[3], tflite_input_shape[1], tflite_input_shape[2]]), "TFLite and ONNX shape unmatch."
# Generate random number and run
tflite_results = []
onnx_results = []
for _ in range(total_times):
# Generate input
tflite_input_data = np.array(np.random.random_sample(tflite_input_shape), dtype=np.float32)
onnx_input_data = np.transpose(tflite_input_data, [0, 3, 1, 2])
# Run tflite
tflite_session.set_tensor(tflite_inputs[0]['index'], tflite_input_data)
tflite_session.invoke()
tflite_results.append(tflite_session.get_tensor(tflite_outputs[0]['index']))
# Run onnx
onnx_input_dict = {onnx_inputs[0].name: onnx_input_data}
onnx_results.append(onnx_session.run([], onnx_input_dict)[0])
return tflite_results, onnx_results
if __name__ == '__main__':
# Argument parser.
parser = argparse.ArgumentParser(description="Compare a TFLite model and an ONNX model to check if they have the same output.")
parser.add_argument('tflite_file', help='input tflite file')
parser.add_argument('onnx_file', help='input ONNX file')
args = parser.parse_args()
results_a, results_b = compare_tflite_and_onnx(args.tflite_file, args.onnx_file, total_times=10)
ra_flat = helper.flatten_with_depth(results_a, 0)
rb_flat = helper.flatten_with_depth(results_b, 0)
shape_a = [item[1] for item in ra_flat]
shape_b = [item[1] for item in rb_flat]
assert shape_a == shape_b, 'two results data shape doesn\'t match'
ra_raw = [item[0] for item in ra_flat]
rb_raw = [item[0] for item in rb_flat]
try:
np.testing.assert_almost_equal(ra_raw, rb_raw, 8)
print('Two models have the same behaviour.')
except Exception as mismatch:
print(mismatch)
exit(1)