kneron_model_converter/ktc/opt_and_eval.py
2026-01-28 06:16:04 +00:00

53 lines
2.6 KiB
Python

import argparse
import onnx
from .toolchain import ModelConfig, SUPPORTED_PLATFORMS
from . import onnx_optimizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Optimize ONNX model and run IP Evaluator")
parser.add_argument("platform", type=str, choices=SUPPORTED_PLATFORMS, help="Target hardware platform.")
parser.add_argument("path", type=str, help="Path to the ONNX/BIE model file.")
parser.add_argument("-e", "--evaluator-only", help="Evaluator only, skip optimization step.", action="store_true")
parser.add_argument("-E", "--evaluator-report-path", type=str, default="", help="Path to the directory to save the evaluator report.")
parser.add_argument("-o", "--optimizer-only", help="Optimizer only, skip evaluator step.", action="store_true")
parser.add_argument("-O", "--optimized-path", type=str, default="", help="Path to save the optimized ONNX model.")
parser.add_argument("--deep-search", action="store_true", help="Use deep search for optimization, which may take longer but can yield better performance.")
args = parser.parse_args()
file_path = args.path
if not file_path.lower().endswith(('.onnx', '.bie')):
raise ValueError("Input file must be an ONNX or BIE file.")
m = None
# Optimization step
if not (file_path.lower().endswith('.bie') or args.evaluator_only):
m = onnx.load(file_path)
m = onnx_optimizer.onnx2onnx_flow(m)
if args.optimized_path:
onnx.save(m, args.optimized_path)
print(f"Optimized model saved to {args.optimized_path}")
else:
optimized_path = file_path.rsplit('.', 1)[0] + '.opt.onnx'
onnx.save(m, optimized_path)
print(f"Optimized model saved to {optimized_path}")
elif file_path.lower().endswith('.onnx'):
m = onnx.load(file_path)
else:
print("Skipping optimization step as per user request.")
if args.optimizer_only:
print("Skipping evaluator step as per user request.")
exit(0)
if file_path.lower().endswith('.bie'):
km = ModelConfig(32770, "0001", args.platform, bie_path=file_path)
else:
km = ModelConfig(32770, "0001", args.platform, onnx_model=m)
compiler_tiling = "deep_search" if args.deep_search else "default"
if args.evaluator_report_path:
eval_result = km.evaluate(output_dir=args.evaluator_report_path, compiler_tiling=compiler_tiling)
else:
print("No evaluator report path provided, using /data1/kneron_flow.")
eval_result = km.evaluate(compiler_tiling=compiler_tiling)
print("\nNpu performance evaluation result:\n" + str(eval_result))