39 lines
1.3 KiB
Python
39 lines
1.3 KiB
Python
import argparse
|
|
import logging
|
|
|
|
from kneronnxopt.onnx_vs_onnx import onnx_vs_onnx
|
|
|
|
if __name__ == "__main__":
|
|
# Argument parser.
|
|
parser = argparse.ArgumentParser(
|
|
description="Compare two ONNX models to check if they have the same output."
|
|
)
|
|
parser.add_argument("in_file_a", help="input ONNX file a")
|
|
parser.add_argument("in_file_b", help="input ONNX file b")
|
|
parser.add_argument(
|
|
"--count", type=int, default=10, help="total times to run inference"
|
|
)
|
|
parser.add_argument(
|
|
"--decimal", type=int, default=4, help="decimal places to compare"
|
|
)
|
|
parser.add_argument("--log", default="INFO", help="log level (default: INFO)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.log == "DEBUG":
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
elif args.log == "INFO":
|
|
logging.basicConfig(level=logging.INFO)
|
|
elif args.log == "WARNING":
|
|
logging.basicConfig(level=logging.WARNING)
|
|
elif args.log == "ERROR":
|
|
logging.basicConfig(level=logging.ERROR)
|
|
elif args.log == "CRITICAL":
|
|
logging.basicConfig(level=logging.CRITICAL)
|
|
else:
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
if not onnx_vs_onnx(
|
|
args.in_file_a, args.in_file_b, total_times=args.count, decimal=args.decimal
|
|
):
|
|
exit(1) |