2026-01-28 06:16:04 +00:00

88 lines
3.3 KiB
Python

import argparse
import os
import shutil
import subprocess
# Constants
CONVERTER_BASE = "/workspace/libs/ONNX_Convertor"
# Parse arguments
parser = argparse.ArgumentParser(description="Convert models from differenct platforms. For more detailed tunning, please use the converters under /workspace/libs/ONNX_Convertor.")
parser.add_argument(
'platform',
type=str,
help="Source platform: pytorch, caffe, keras, tf, tflite, onnx")
parser.add_argument(
'input_model',
type=str,
help="Relative path of the input model")
parser.add_argument(
'output_model',
type=str,
help="Relative path of model output.")
parser.add_argument(
'--no-bn-fusion',
dest='disable_fuse_bn',
action='store_true',
default=False,
help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.")
parser.add_argument(
'-w',
'--weight',
type=str,
nargs=1,
help="Additional weight input for some platfrom.")
args = parser.parse_args()
# onnx2onnx caller
def call_optimizer(input_model, output_model, no_bn_fusion=False):
command = ['python', CONVERTER_BASE + '/optimizer_scripts/onnx2onnx.py',
input_model,
'-o', output_model,
'-t']
if no_bn_fusion:
command.append('--no-bn-fusion')
subprocess.run(command, check=True)
# Run converter
if args.platform == 'keras':
subprocess.run(['python', CONVERTER_BASE + '/keras-onnx/generate_onnx.py',
'-o', args.output_model,
'-O2', '--duplicate-shared-weights',
args.input_model], check=True)
call_optimizer(args.output_model, args.output_model, args.disable_fuse_bn)
elif args.platform == 'caffe':
if args.weight is None or len(args.weight) == 0:
print("Caffe model need both model structure and weight. Please supply weight with -w.")
quit(1)
subprocess.run(['python', CONVERTER_BASE + '/caffe-onnx/generate_onnx.py',
'-o', args.output_model,
'-n', args.input_model,
'-w', args.weight[0]], check=True)
call_optimizer(args.output_model, args.output_model, args.disable_fuse_bn)
elif args.platform == 'pytorch':
command = ['python', CONVERTER_BASE + '/optimizer_scripts/pytorch_exported_onnx_preprocess.py',
args.input_model, args.output_model]
if args.disable_fuse_bn:
command.append('--no-bn-fusion')
subprocess.run(command, check=True)
call_optimizer(args.output_model, args.output_model, args.disable_fuse_bn)
elif args.platform == 'tf':
subprocess.run(['python', CONVERTER_BASE + '/optimizer_scripts/tensorflow2onnx.py',
args.input_model, args.output_model], check=True)
call_optimizer(args.output_model, args.output_model, args.disable_fuse_bn)
elif args.platform == 'tflite':
subprocess.run(['python', CONVERTER_BASE + '/tflite-onnx/onnx_tflite/tflite2onnx.py',
'-tflite', args.input_model,
'-save_path', args.output_model,
'-release_mode', 'True'], check=True)
call_optimizer(args.output_model, args.output_model, args.disable_fuse_bn)
elif args.platform == 'onnx':
call_optimizer(args.input_model, args.output_model, args.disable_fuse_bn)