88 lines
3.3 KiB
Python
88 lines
3.3 KiB
Python
import argparse
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
|
|
# Constants
|
|
CONVERTER_BASE = "/workspace/libs/ONNX_Convertor"
|
|
|
|
# Parse arguments
|
|
parser = argparse.ArgumentParser(description="Convert models from differenct platforms. For more detailed tunning, please use the converters under /workspace/libs/ONNX_Convertor.")
|
|
|
|
parser.add_argument(
|
|
'platform',
|
|
type=str,
|
|
help="Source platform: pytorch, caffe, keras, tf, tflite, onnx")
|
|
|
|
parser.add_argument(
|
|
'input_model',
|
|
type=str,
|
|
help="Relative path of the input model")
|
|
|
|
parser.add_argument(
|
|
'output_model',
|
|
type=str,
|
|
help="Relative path of model output.")
|
|
|
|
parser.add_argument(
|
|
'--no-bn-fusion',
|
|
dest='disable_fuse_bn',
|
|
action='store_true',
|
|
default=False,
|
|
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
|
|
|
|
parser.add_argument(
|
|
'-w',
|
|
'--weight',
|
|
type=str,
|
|
nargs=1,
|
|
help="Additional weight input for some platfrom.")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# onnx2onnx caller
|
|
def call_optimizer(input_model, output_model, no_bn_fusion=False):
|
|
command = ['python', CONVERTER_BASE + '/optimizer_scripts/onnx2onnx.py',
|
|
input_model,
|
|
'-o', output_model,
|
|
'-t']
|
|
if no_bn_fusion:
|
|
command.append('--no-bn-fusion')
|
|
subprocess.run(command, check=True)
|
|
|
|
# Run converter
|
|
if args.platform == 'keras':
|
|
subprocess.run(['python', CONVERTER_BASE + '/keras-onnx/generate_onnx.py',
|
|
'-o', args.output_model,
|
|
'-O2', '--duplicate-shared-weights',
|
|
args.input_model], check=True)
|
|
call_optimizer(args.output_model, args.output_model, args.disable_fuse_bn)
|
|
elif args.platform == 'caffe':
|
|
if args.weight is None or len(args.weight) == 0:
|
|
print("Caffe model need both model structure and weight. Please supply weight with -w.")
|
|
quit(1)
|
|
subprocess.run(['python', CONVERTER_BASE + '/caffe-onnx/generate_onnx.py',
|
|
'-o', args.output_model,
|
|
'-n', args.input_model,
|
|
'-w', args.weight[0]], check=True)
|
|
call_optimizer(args.output_model, args.output_model, args.disable_fuse_bn)
|
|
elif args.platform == 'pytorch':
|
|
command = ['python', CONVERTER_BASE + '/optimizer_scripts/pytorch_exported_onnx_preprocess.py',
|
|
args.input_model, args.output_model]
|
|
if args.disable_fuse_bn:
|
|
command.append('--no-bn-fusion')
|
|
subprocess.run(command, check=True)
|
|
call_optimizer(args.output_model, args.output_model, args.disable_fuse_bn)
|
|
elif args.platform == 'tf':
|
|
subprocess.run(['python', CONVERTER_BASE + '/optimizer_scripts/tensorflow2onnx.py',
|
|
args.input_model, args.output_model], check=True)
|
|
call_optimizer(args.output_model, args.output_model, args.disable_fuse_bn)
|
|
elif args.platform == 'tflite':
|
|
subprocess.run(['python', CONVERTER_BASE + '/tflite-onnx/onnx_tflite/tflite2onnx.py',
|
|
'-tflite', args.input_model,
|
|
'-save_path', args.output_model,
|
|
'-release_mode', 'True'], check=True)
|
|
call_optimizer(args.output_model, args.output_model, args.disable_fuse_bn)
|
|
elif args.platform == 'onnx':
|
|
call_optimizer(args.input_model, args.output_model, args.disable_fuse_bn)
|