import argparse import onnx from .toolchain import ModelConfig, SUPPORTED_PLATFORMS from . import onnx_optimizer if __name__ == "__main__": parser = argparse.ArgumentParser(description="Optimize ONNX model and run IP Evaluator") parser.add_argument("platform", type=str, choices=SUPPORTED_PLATFORMS, help="Target hardware platform.") parser.add_argument("path", type=str, help="Path to the ONNX/BIE model file.") parser.add_argument("-e", "--evaluator-only", help="Evaluator only, skip optimization step.", action="store_true") parser.add_argument("-E", "--evaluator-report-path", type=str, default="", help="Path to the directory to save the evaluator report.") parser.add_argument("-o", "--optimizer-only", help="Optimizer only, skip evaluator step.", action="store_true") parser.add_argument("-O", "--optimized-path", type=str, default="", help="Path to save the optimized ONNX model.") parser.add_argument("--deep-search", action="store_true", help="Use deep search for optimization, which may take longer but can yield better performance.") args = parser.parse_args() file_path = args.path if not file_path.lower().endswith(('.onnx', '.bie')): raise ValueError("Input file must be an ONNX or BIE file.") m = None # Optimization step if not (file_path.lower().endswith('.bie') or args.evaluator_only): m = onnx.load(file_path) m = onnx_optimizer.onnx2onnx_flow(m) if args.optimized_path: onnx.save(m, args.optimized_path) print(f"Optimized model saved to {args.optimized_path}") else: optimized_path = file_path.rsplit('.', 1)[0] + '.opt.onnx' onnx.save(m, optimized_path) print(f"Optimized model saved to {optimized_path}") elif file_path.lower().endswith('.onnx'): m = onnx.load(file_path) else: print("Skipping optimization step as per user request.") if args.optimizer_only: print("Skipping evaluator step as per user request.") exit(0) if file_path.lower().endswith('.bie'): km = ModelConfig(32770, "0001", args.platform, bie_path=file_path) else: km = ModelConfig(32770, "0001", args.platform, onnx_model=m) compiler_tiling = "deep_search" if args.deep_search else "default" if args.evaluator_report_path: eval_result = km.evaluate(output_dir=args.evaluator_report_path, compiler_tiling=compiler_tiling) else: print("No evaluator report path provided, using /data1/kneron_flow.") eval_result = km.evaluate(compiler_tiling=compiler_tiling) print("\nNpu performance evaluation result:\n" + str(eval_result))