81 lines
2.4 KiB
Python
81 lines
2.4 KiB
Python
import onnx
|
|
import onnx.utils
|
|
try:
|
|
from onnx import optimizer
|
|
except ImportError:
|
|
import onnxoptimizer as optimizer
|
|
import sys
|
|
import numpy as np
|
|
import struct
|
|
import logging
|
|
import argparse
|
|
|
|
from .tools import eliminating
|
|
from .tools import fusing
|
|
from .tools import replacing
|
|
from .tools import other
|
|
from .tools import combo
|
|
from .tools import special
|
|
|
|
# Define general pytorch exported onnx optimize process
|
|
def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx.ModelProto:
|
|
"""Optimize the Pytorch exported onnx.
|
|
|
|
Args:
|
|
m (ModelProto): the input onnx model
|
|
disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False.
|
|
|
|
Returns:
|
|
ModelProto: the optimized onnx model
|
|
"""
|
|
m = combo.preprocess(m, disable_fuse_bn)
|
|
m = combo.pytorch_constant_folding(m)
|
|
m = combo.common_optimization(m)
|
|
m = combo.postprocess(m)
|
|
|
|
return m
|
|
|
|
|
|
# Main Process
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler")
|
|
parser.add_argument('in_file', help='input ONNX')
|
|
parser.add_argument('out_file', help="ouput ONNX FILE")
|
|
parser.add_argument('--log', default='i', type=str, help="set log level")
|
|
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
|
|
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.log == 'w':
|
|
logging.basicConfig(level=logging.WARN)
|
|
elif args.log == 'd':
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
elif args.log == 'e':
|
|
logging.basicConfig(level=logging.ERROR)
|
|
else:
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
if len(args.in_file) <= 4:
|
|
# When the filename is too short.
|
|
logging.error("Invalid input file: {}".format(args.in_file))
|
|
exit(1)
|
|
elif args.in_file[-4:] == 'onnx':
|
|
onnx_in = args.in_file
|
|
else:
|
|
# When the file is not an onnx file.
|
|
logging.error("Invalid input file: {}".format(args.in_file))
|
|
exit(1)
|
|
|
|
onnx_out = args.out_file
|
|
|
|
######################################
|
|
# Optimize onnx #
|
|
######################################
|
|
|
|
m = onnx.load(onnx_in)
|
|
|
|
m = torch_exported_onnx_flow(m, args.disable_fuse_bn)
|
|
|
|
onnx.save(m, onnx_out)
|