STDC/tools/optimizer_scripts/pytorch2onnx.py
2022-04-12 14:26:54 +08:00

94 lines
2.7 KiB
Python

import onnx
import onnx.utils
import sys
import logging
import argparse
from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow
# Debug use
# logging.basicConfig(level=logging.DEBUG)
######################################
# Generate a prototype onnx #
######################################
parser = argparse.ArgumentParser(
description="Optimize a Pytorch generated model for Kneron compiler"
)
parser.add_argument("in_file", help="input ONNX or PTH FILE")
parser.add_argument("out_file", help="ouput ONNX FILE")
parser.add_argument(
"--input-size",
dest="input_size",
nargs=3,
help="if you using pth, please use this argument to set up the input "
"size of the model. It should be in 'CH H W' format, "
"e.g. '--input-size 3 256 512'.",
)
parser.add_argument(
"--no-bn-fusion",
dest="disable_fuse_bn",
action="store_true",
default=False,
help="set if you have met errors which related to inferenced shape "
"mismatch. This option will prevent fusing BatchNormalization "
"into Conv.",
)
args = parser.parse_args()
if len(args.in_file) <= 4:
# When the filename is too short.
logging.error("Invalid input file: {}".format(args.in_file))
exit(1)
elif args.in_file[-4:] == ".pth":
# Pytorch pth case
logging.warning("Converting from pth to onnx is not recommended.")
onnx_in = args.out_file
# Import pytorch libraries
from torch.autograd import Variable
import torch
import torch.onnx
# import torchvision
# Standard ImageNet input - 3 channels, 224x224.
# Values don't matter as we care about network structure.
# But they can also be real inputs.
if args.input_size is None:
logging.error("'--input-size' is required for the pth input file.")
exit(1)
dummy_input = Variable(
torch.randn(
1,
int(args.input_size[0]),
int(args.input_size[1]),
int(args.input_size[2]),
)
)
# Obtain your model, it can be also constructed in your script explicitly.
model = torch.load(sys.argv[1], map_location="cpu")
# model = torchvision.models.resnet34(pretrained=True)
# Invoke export.
# torch.save(model, "resnet34.pth")
torch.onnx.export(model, dummy_input, args.out_file, opset_version=11)
elif args.in_file[-4:] == "onnx":
onnx_in = args.in_file
else:
# When the file is neither an onnx or a pytorch pth.
logging.error("Invalid input file: {}".format(args.in_file))
exit(1)
onnx_out = args.out_file
######################################
# Optimize onnx #
######################################
m = onnx.load(onnx_in)
m = torch_exported_onnx_flow(m, args.disable_fuse_bn)
onnx.save(m, onnx_out)