40 lines
1.4 KiB
Python
40 lines
1.4 KiB
Python
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision
|
|
from torchvision import datasets, models, transforms
|
|
import numpy as np
|
|
from load_model import initialize_model
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import scipy.io
|
|
import torch.onnx
|
|
|
|
|
|
def main(args=None):
|
|
parser = argparse.ArgumentParser(description='converter.')
|
|
parser.add_argument('--save-path', type=str, help='Path to the onnx model.', default=None)
|
|
parser.add_argument('--backbone', help='Backbone model.', default='resnet18', type=str)
|
|
parser.add_argument('--num_classes', help='the number of classes.', type = int, default=0)
|
|
parser.add_argument('--model-def-path', type=str, help='Path to pretrained model definition', default=None )
|
|
parser.add_argument('--snapshot', help='Path to the pretrained models.')
|
|
print(vars(parser.parse_args()))
|
|
args = parser.parse_args()
|
|
|
|
model_structure, input_size = initialize_model(args.backbone, args.num_classes, False, args.model_def_path)
|
|
|
|
model_structure.load_state_dict(torch.load(args.snapshot))
|
|
model = model_structure.eval()
|
|
|
|
dummy_input = torch.randn(1, 3, input_size[0],input_size[1])
|
|
save_path = args.save_path
|
|
if args.save_path is None:
|
|
save_path = args.backbone+'.onnx'
|
|
torch.onnx.export(model, dummy_input, save_path, keep_initializers_as_inputs=True, opset_version=11)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|