68 lines
2.9 KiB
Python
68 lines
2.9 KiB
Python
import argparse
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import onnx
|
|
import onnxruntime
|
|
|
|
from tools import helper
|
|
|
|
def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
|
|
# Setup onnx session and get meta data
|
|
onnx_session = onnxruntime.InferenceSession(onnx_file, None)
|
|
onnx_outputs = onnx_session.get_outputs()
|
|
assert len(onnx_outputs) == 1, "The onnx model has more than one output"
|
|
onnx_model = onnx.load(onnx_file)
|
|
onnx_graph = onnx_model.graph
|
|
onnx_inputs = onnx_graph.input
|
|
assert len(onnx_inputs) == 1, "The onnx model has more than one input"
|
|
_, onnx_input_shape = helper.find_size_shape_from_value(onnx_inputs[0])
|
|
# Setup TFLite sessio and get meta data
|
|
tflite_session = tf.lite.Interpreter(model_path=tflite_file)
|
|
tflite_session.allocate_tensors()
|
|
tflite_inputs = tflite_session.get_input_details()
|
|
tflite_outputs = tflite_session.get_output_details()
|
|
tflite_input_shape = tflite_inputs[0]['shape']
|
|
# Compare input shape
|
|
assert(len(onnx_input_shape) == len(tflite_input_shape)), "TFLite and ONNX shape unmatch."
|
|
assert(onnx_input_shape == [tflite_input_shape[0], tflite_input_shape[3], tflite_input_shape[1], tflite_input_shape[2]]), "TFLite and ONNX shape unmatch."
|
|
# Generate random number and run
|
|
tflite_results = []
|
|
onnx_results = []
|
|
for _ in range(total_times):
|
|
# Generate input
|
|
tflite_input_data = np.array(np.random.random_sample(tflite_input_shape), dtype=np.float32)
|
|
onnx_input_data = np.transpose(tflite_input_data, [0, 3, 1, 2])
|
|
# Run tflite
|
|
tflite_session.set_tensor(tflite_inputs[0]['index'], tflite_input_data)
|
|
tflite_session.invoke()
|
|
tflite_results.append(tflite_session.get_tensor(tflite_outputs[0]['index']))
|
|
# Run onnx
|
|
onnx_input_dict = {onnx_inputs[0].name: onnx_input_data}
|
|
onnx_results.append(onnx_session.run([], onnx_input_dict)[0])
|
|
|
|
return tflite_results, onnx_results
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Argument parser.
|
|
parser = argparse.ArgumentParser(description="Compare a TFLite model and an ONNX model to check if they have the same output.")
|
|
parser.add_argument('tflite_file', help='input tflite file')
|
|
parser.add_argument('onnx_file', help='input ONNX file')
|
|
|
|
args = parser.parse_args()
|
|
|
|
results_a, results_b = compare_tflite_and_onnx(args.tflite_file, args.onnx_file, total_times=10)
|
|
ra_flat = helper.flatten_with_depth(results_a, 0)
|
|
rb_flat = helper.flatten_with_depth(results_b, 0)
|
|
shape_a = [item[1] for item in ra_flat]
|
|
shape_b = [item[1] for item in rb_flat]
|
|
assert shape_a == shape_b, 'two results data shape doesn\'t match'
|
|
ra_raw = [item[0] for item in ra_flat]
|
|
rb_raw = [item[0] for item in rb_flat]
|
|
|
|
try:
|
|
np.testing.assert_almost_equal(ra_raw, rb_raw, 8)
|
|
print('Two models have the same behaviour.')
|
|
except Exception as mismatch:
|
|
print(mismatch)
|
|
exit(1) |