2026-01-28 06:16:04 +00:00

39 lines
1.3 KiB
Python

import argparse
import logging
from kneronnxopt.onnx_vs_onnx import onnx_vs_onnx
if __name__ == "__main__":
# Argument parser.
parser = argparse.ArgumentParser(
description="Compare two ONNX models to check if they have the same output."
)
parser.add_argument("in_file_a", help="input ONNX file a")
parser.add_argument("in_file_b", help="input ONNX file b")
parser.add_argument(
"--count", type=int, default=10, help="total times to run inference"
)
parser.add_argument(
"--decimal", type=int, default=4, help="decimal places to compare"
)
parser.add_argument("--log", default="INFO", help="log level (default: INFO)")
args = parser.parse_args()
if args.log == "DEBUG":
logging.basicConfig(level=logging.DEBUG)
elif args.log == "INFO":
logging.basicConfig(level=logging.INFO)
elif args.log == "WARNING":
logging.basicConfig(level=logging.WARNING)
elif args.log == "ERROR":
logging.basicConfig(level=logging.ERROR)
elif args.log == "CRITICAL":
logging.basicConfig(level=logging.CRITICAL)
else:
logging.basicConfig(level=logging.INFO)
if not onnx_vs_onnx(
args.in_file_a, args.in_file_b, total_times=args.count, decimal=args.decimal
):
exit(1)