53 lines
2.6 KiB
Python
53 lines
2.6 KiB
Python
import argparse
|
|
import onnx
|
|
|
|
from .toolchain import ModelConfig, SUPPORTED_PLATFORMS
|
|
from . import onnx_optimizer
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Optimize ONNX model and run IP Evaluator")
|
|
parser.add_argument("platform", type=str, choices=SUPPORTED_PLATFORMS, help="Target hardware platform.")
|
|
parser.add_argument("path", type=str, help="Path to the ONNX/BIE model file.")
|
|
parser.add_argument("-e", "--evaluator-only", help="Evaluator only, skip optimization step.", action="store_true")
|
|
parser.add_argument("-E", "--evaluator-report-path", type=str, default="", help="Path to the directory to save the evaluator report.")
|
|
parser.add_argument("-o", "--optimizer-only", help="Optimizer only, skip evaluator step.", action="store_true")
|
|
parser.add_argument("-O", "--optimized-path", type=str, default="", help="Path to save the optimized ONNX model.")
|
|
parser.add_argument("--deep-search", action="store_true", help="Use deep search for optimization, which may take longer but can yield better performance.")
|
|
args = parser.parse_args()
|
|
|
|
file_path = args.path
|
|
if not file_path.lower().endswith(('.onnx', '.bie')):
|
|
raise ValueError("Input file must be an ONNX or BIE file.")
|
|
|
|
m = None
|
|
# Optimization step
|
|
if not (file_path.lower().endswith('.bie') or args.evaluator_only):
|
|
m = onnx.load(file_path)
|
|
m = onnx_optimizer.onnx2onnx_flow(m)
|
|
if args.optimized_path:
|
|
onnx.save(m, args.optimized_path)
|
|
print(f"Optimized model saved to {args.optimized_path}")
|
|
else:
|
|
optimized_path = file_path.rsplit('.', 1)[0] + '.opt.onnx'
|
|
onnx.save(m, optimized_path)
|
|
print(f"Optimized model saved to {optimized_path}")
|
|
elif file_path.lower().endswith('.onnx'):
|
|
m = onnx.load(file_path)
|
|
else:
|
|
print("Skipping optimization step as per user request.")
|
|
|
|
if args.optimizer_only:
|
|
print("Skipping evaluator step as per user request.")
|
|
exit(0)
|
|
if file_path.lower().endswith('.bie'):
|
|
km = ModelConfig(32770, "0001", args.platform, bie_path=file_path)
|
|
else:
|
|
km = ModelConfig(32770, "0001", args.platform, onnx_model=m)
|
|
compiler_tiling = "deep_search" if args.deep_search else "default"
|
|
if args.evaluator_report_path:
|
|
eval_result = km.evaluate(output_dir=args.evaluator_report_path, compiler_tiling=compiler_tiling)
|
|
else:
|
|
print("No evaluator report path provided, using /data1/kneron_flow.")
|
|
eval_result = km.evaluate(compiler_tiling=compiler_tiling)
|
|
print("\nNpu performance evaluation result:\n" + str(eval_result))
|