40 lines
1.4 KiB
Python

import os
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
import numpy as np
from load_model import initialize_model
import argparse
import os
import sys
import scipy.io
import torch.onnx
def main(args=None):
parser = argparse.ArgumentParser(description='converter.')
parser.add_argument('--save-path', type=str, help='Path to the onnx model.', default=None)
parser.add_argument('--backbone', help='Backbone model.', default='resnet18', type=str)
parser.add_argument('--num_classes', help='the number of classes.', type = int, default=0)
parser.add_argument('--model-def-path', type=str, help='Path to pretrained model definition', default=None )
parser.add_argument('--snapshot', help='Path to the pretrained models.')
print(vars(parser.parse_args()))
args = parser.parse_args()
model_structure, input_size = initialize_model(args.backbone, args.num_classes, False, args.model_def_path)
model_structure.load_state_dict(torch.load(args.snapshot))
model = model_structure.eval()
dummy_input = torch.randn(1, 3, input_size[0],input_size[1])
save_path = args.save_path
if args.save_path is None:
save_path = args.backbone+'.onnx'
torch.onnx.export(model, dummy_input, save_path, keep_initializers_as_inputs=True, opset_version=11)
if __name__ == '__main__':
main()