83 lines
2.3 KiB
Python
83 lines
2.3 KiB
Python
import onnx
|
|
import onnx.utils
|
|
|
|
import logging
|
|
import argparse
|
|
|
|
from .tools import combo
|
|
|
|
|
|
# Define general pytorch exported onnx optimize process
|
|
def torch_exported_onnx_flow(
|
|
m: onnx.ModelProto, disable_fuse_bn=False
|
|
) -> onnx.ModelProto:
|
|
"""Optimize the Pytorch exported onnx.
|
|
|
|
Args:
|
|
m (ModelProto): the input onnx model
|
|
disable_fuse_bn (bool, optional): do not fuse BN into Conv.
|
|
Defaults to False.
|
|
|
|
Returns:
|
|
ModelProto: the optimized onnx model
|
|
"""
|
|
m = combo.preprocess(m, disable_fuse_bn)
|
|
m = combo.pytorch_constant_folding(m)
|
|
m = combo.common_optimization(m)
|
|
m = combo.postprocess(m)
|
|
|
|
return m
|
|
|
|
|
|
# Main Process
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Optimize a Pytorch generated model for Kneron compiler"
|
|
)
|
|
parser.add_argument("in_file", help="input ONNX")
|
|
parser.add_argument("out_file", help="ouput ONNX FILE")
|
|
parser.add_argument("--log", default="i", type=str, help="set log level")
|
|
parser.add_argument(
|
|
"--no-bn-fusion",
|
|
dest="disable_fuse_bn",
|
|
action="store_true",
|
|
default=False,
|
|
help="set if you have met errors which related to inferenced shape "
|
|
"mismatch. This option will prevent fusing BatchNormalization "
|
|
"into Conv.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.log == "w":
|
|
logging.basicConfig(level=logging.WARN)
|
|
elif args.log == "d":
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
elif args.log == "e":
|
|
logging.basicConfig(level=logging.ERROR)
|
|
else:
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
if len(args.in_file) <= 4:
|
|
# When the filename is too short.
|
|
logging.error("Invalid input file: {}".format(args.in_file))
|
|
exit(1)
|
|
elif args.in_file[-4:] == "onnx":
|
|
onnx_in = args.in_file
|
|
else:
|
|
# When the file is not an onnx file.
|
|
logging.error("Invalid input file: {}".format(args.in_file))
|
|
exit(1)
|
|
|
|
onnx_out = args.out_file
|
|
|
|
######################################
|
|
# Optimize onnx #
|
|
######################################
|
|
|
|
m = onnx.load(onnx_in)
|
|
|
|
m = torch_exported_onnx_flow(m, args.disable_fuse_bn)
|
|
|
|
onnx.save(m, onnx_out)
|