86 lines
3.1 KiB
Python
86 lines
3.1 KiB
Python
import argparse
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import onnx
|
|
import onnxruntime
|
|
|
|
from tools import helper
|
|
|
|
|
|
def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
|
|
# Setup onnx session and get meta data
|
|
onnx_session = onnxruntime.InferenceSession(onnx_file, None)
|
|
onnx_outputs = onnx_session.get_outputs()
|
|
assert len(onnx_outputs) == 1, "The onnx model has more than one output"
|
|
onnx_model = onnx.load(onnx_file)
|
|
onnx_graph = onnx_model.graph
|
|
onnx_inputs = onnx_graph.input
|
|
assert len(onnx_inputs) == 1, "The onnx model has more than one input"
|
|
_, onnx_input_shape = helper.find_size_shape_from_value(onnx_inputs[0])
|
|
# Setup TFLite sessio and get meta data
|
|
tflite_session = tf.lite.Interpreter(model_path=tflite_file)
|
|
tflite_session.allocate_tensors()
|
|
tflite_inputs = tflite_session.get_input_details()
|
|
tflite_outputs = tflite_session.get_output_details()
|
|
tflite_input_shape = tflite_inputs[0]["shape"]
|
|
# Compare input shape
|
|
assert len(onnx_input_shape) == len(
|
|
tflite_input_shape
|
|
), "TFLite and ONNX shape unmatch."
|
|
assert onnx_input_shape == [
|
|
tflite_input_shape[0],
|
|
tflite_input_shape[3],
|
|
tflite_input_shape[1],
|
|
tflite_input_shape[2],
|
|
], "TFLite and ONNX shape unmatch."
|
|
# Generate random number and run
|
|
tflite_results = []
|
|
onnx_results = []
|
|
for _ in range(total_times):
|
|
# Generate input
|
|
tflite_input_data = np.array(
|
|
np.random.random_sample(tflite_input_shape), dtype=np.float32
|
|
)
|
|
onnx_input_data = np.transpose(tflite_input_data, [0, 3, 1, 2])
|
|
# Run tflite
|
|
tflite_session.set_tensor(tflite_inputs[0]["index"], tflite_input_data)
|
|
tflite_session.invoke()
|
|
tflite_results.append(
|
|
tflite_session.get_tensor(tflite_outputs[0]["index"])
|
|
)
|
|
# Run onnx
|
|
onnx_input_dict = {onnx_inputs[0].name: onnx_input_data}
|
|
onnx_results.append(onnx_session.run([], onnx_input_dict)[0])
|
|
|
|
return tflite_results, onnx_results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Argument parser.
|
|
parser = argparse.ArgumentParser(
|
|
description="Compare a TFLite model and an ONNX model to check "
|
|
"if they have the same output."
|
|
)
|
|
parser.add_argument("tflite_file", help="input tflite file")
|
|
parser.add_argument("onnx_file", help="input ONNX file")
|
|
|
|
args = parser.parse_args()
|
|
|
|
results_a, results_b = compare_tflite_and_onnx(
|
|
args.tflite_file, args.onnx_file, total_times=10
|
|
)
|
|
ra_flat = helper.flatten_with_depth(results_a, 0)
|
|
rb_flat = helper.flatten_with_depth(results_b, 0)
|
|
shape_a = [item[1] for item in ra_flat]
|
|
shape_b = [item[1] for item in rb_flat]
|
|
assert shape_a == shape_b, "two results data shape doesn't match"
|
|
ra_raw = [item[0] for item in ra_flat]
|
|
rb_raw = [item[0] for item in rb_flat]
|
|
|
|
try:
|
|
np.testing.assert_almost_equal(ra_raw, rb_raw, 8)
|
|
print("Two models have the same behaviour.")
|
|
except Exception as mismatch:
|
|
print(mismatch)
|
|
exit(1)
|