2026-03-11 16:13:59 +08:00

86 lines
3.1 KiB
Python

import argparse
import numpy as np
import tensorflow as tf
import onnx
import onnxruntime
from tools import helper
def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10):
# Setup onnx session and get meta data
onnx_session = onnxruntime.InferenceSession(onnx_file, None)
onnx_outputs = onnx_session.get_outputs()
assert len(onnx_outputs) == 1, "The onnx model has more than one output"
onnx_model = onnx.load(onnx_file)
onnx_graph = onnx_model.graph
onnx_inputs = onnx_graph.input
assert len(onnx_inputs) == 1, "The onnx model has more than one input"
_, onnx_input_shape = helper.find_size_shape_from_value(onnx_inputs[0])
# Setup TFLite sessio and get meta data
tflite_session = tf.lite.Interpreter(model_path=tflite_file)
tflite_session.allocate_tensors()
tflite_inputs = tflite_session.get_input_details()
tflite_outputs = tflite_session.get_output_details()
tflite_input_shape = tflite_inputs[0]["shape"]
# Compare input shape
assert len(onnx_input_shape) == len(
tflite_input_shape
), "TFLite and ONNX shape unmatch."
assert onnx_input_shape == [
tflite_input_shape[0],
tflite_input_shape[3],
tflite_input_shape[1],
tflite_input_shape[2],
], "TFLite and ONNX shape unmatch."
# Generate random number and run
tflite_results = []
onnx_results = []
for _ in range(total_times):
# Generate input
tflite_input_data = np.array(
np.random.random_sample(tflite_input_shape), dtype=np.float32
)
onnx_input_data = np.transpose(tflite_input_data, [0, 3, 1, 2])
# Run tflite
tflite_session.set_tensor(tflite_inputs[0]["index"], tflite_input_data)
tflite_session.invoke()
tflite_results.append(
tflite_session.get_tensor(tflite_outputs[0]["index"])
)
# Run onnx
onnx_input_dict = {onnx_inputs[0].name: onnx_input_data}
onnx_results.append(onnx_session.run([], onnx_input_dict)[0])
return tflite_results, onnx_results
if __name__ == "__main__":
# Argument parser.
parser = argparse.ArgumentParser(
description="Compare a TFLite model and an ONNX model to check "
"if they have the same output."
)
parser.add_argument("tflite_file", help="input tflite file")
parser.add_argument("onnx_file", help="input ONNX file")
args = parser.parse_args()
results_a, results_b = compare_tflite_and_onnx(
args.tflite_file, args.onnx_file, total_times=10
)
ra_flat = helper.flatten_with_depth(results_a, 0)
rb_flat = helper.flatten_with_depth(results_b, 0)
shape_a = [item[1] for item in ra_flat]
shape_b = [item[1] for item in rb_flat]
assert shape_a == shape_b, "two results data shape doesn't match"
ra_raw = [item[0] for item in ra_flat]
rb_raw = [item[0] for item in rb_flat]
try:
np.testing.assert_almost_equal(ra_raw, rb_raw, 8)
print("Two models have the same behaviour.")
except Exception as mismatch:
print(mismatch)
exit(1)