STDC/tools/optimizer_scripts/pytorch2onnx.py

82 lines
2.9 KiB
Python

import onnx
import onnx.utils
try:
from onnx import optimizer
except ImportError:
import onnxoptimizer as optimizer
import sys
import numpy as np
import struct
import logging
import argparse
from tools import eliminating
from tools import fusing
from tools import replacing
from tools import other
from tools import combo
from tools import special
from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow
# Debug use
# logging.basicConfig(level=logging.DEBUG)
######################################
# Generate a prototype onnx #
######################################
parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler")
parser.add_argument('in_file', help='input ONNX or PTH FILE')
parser.add_argument('out_file', help="ouput ONNX FILE")
parser.add_argument('--input-size', dest='input_size', nargs=3,
help='if you using pth, please use this argument to set up the input size of the model. It should be in \'CH H W\' format, e.g. \'--input-size 3 256 512\'.')
parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False,
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
args = parser.parse_args()
if len(args.in_file) <= 4:
# When the filename is too short.
logging.error("Invalid input file: {}".format(args.in_file))
exit(1)
elif args.in_file[-4:] == '.pth':
# Pytorch pth case
logging.warning("Converting from pth to onnx is not recommended.")
onnx_in = args.out_file
# Import pytorch libraries
from torch.autograd import Variable
import torch
import torch.onnx
# import torchvision
# Standard ImageNet input - 3 channels, 224x224.
# Values don't matter as we care about network structure.
# But they can also be real inputs.
if args.input_size is None:
logging.error("\'--input-size\' is required for the pth input file.")
exit(1)
dummy_input = Variable(torch.randn(1, int(args.input_size[0]), int(args.input_size[1]), int(args.input_size[2])))
# Obtain your model, it can be also constructed in your script explicitly.
model = torch.load(sys.argv[1], map_location='cpu')
# model = torchvision.models.resnet34(pretrained=True)
# Invoke export.
# torch.save(model, "resnet34.pth")
torch.onnx.export(model, dummy_input, args.out_file, opset_version=11)
elif args.in_file[-4:] == 'onnx':
onnx_in = args.in_file
else:
# When the file is neither an onnx or a pytorch pth.
logging.error("Invalid input file: {}".format(args.in_file))
exit(1)
onnx_out = args.out_file
######################################
# Optimize onnx #
######################################
m = onnx.load(onnx_in)
m = torch_exported_onnx_flow(m, args.disable_fuse_bn)
onnx.save(m, onnx_out)