STDC/tools/optimizer_scripts/pytorch_exported_onnx_preprocess.py
2022-04-12 14:26:54 +08:00

83 lines
2.3 KiB
Python

import onnx
import onnx.utils
import logging
import argparse
from .tools import combo
# Define general pytorch exported onnx optimize process
def torch_exported_onnx_flow(
m: onnx.ModelProto, disable_fuse_bn=False
) -> onnx.ModelProto:
"""Optimize the Pytorch exported onnx.
Args:
m (ModelProto): the input onnx model
disable_fuse_bn (bool, optional): do not fuse BN into Conv.
Defaults to False.
Returns:
ModelProto: the optimized onnx model
"""
m = combo.preprocess(m, disable_fuse_bn)
m = combo.pytorch_constant_folding(m)
m = combo.common_optimization(m)
m = combo.postprocess(m)
return m
# Main Process
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Optimize a Pytorch generated model for Kneron compiler"
)
parser.add_argument("in_file", help="input ONNX")
parser.add_argument("out_file", help="ouput ONNX FILE")
parser.add_argument("--log", default="i", type=str, help="set log level")
parser.add_argument(
"--no-bn-fusion",
dest="disable_fuse_bn",
action="store_true",
default=False,
help="set if you have met errors which related to inferenced shape "
"mismatch. This option will prevent fusing BatchNormalization "
"into Conv.",
)
args = parser.parse_args()
if args.log == "w":
logging.basicConfig(level=logging.WARN)
elif args.log == "d":
logging.basicConfig(level=logging.DEBUG)
elif args.log == "e":
logging.basicConfig(level=logging.ERROR)
else:
logging.basicConfig(level=logging.INFO)
if len(args.in_file) <= 4:
# When the filename is too short.
logging.error("Invalid input file: {}".format(args.in_file))
exit(1)
elif args.in_file[-4:] == "onnx":
onnx_in = args.in_file
else:
# When the file is not an onnx file.
logging.error("Invalid input file: {}".format(args.in_file))
exit(1)
onnx_out = args.out_file
######################################
# Optimize onnx #
######################################
m = onnx.load(onnx_in)
m = torch_exported_onnx_flow(m, args.disable_fuse_bn)
onnx.save(m, onnx_out)