import os import torch import torch.nn as nn import torchvision from torchvision import datasets, models, transforms import numpy as np from load_model import initialize_model import argparse import os import sys import scipy.io import torch.onnx def main(args=None): parser = argparse.ArgumentParser(description='converter.') parser.add_argument('--save-path', type=str, help='Path to the onnx model.', default=None) parser.add_argument('--backbone', help='Backbone model.', default='resnet18', type=str) parser.add_argument('--num_classes', help='the number of classes.', type = int, default=0) parser.add_argument('--model-def-path', type=str, help='Path to pretrained model definition', default=None ) parser.add_argument('--snapshot', help='Path to the pretrained models.') print(vars(parser.parse_args())) args = parser.parse_args() model_structure, input_size = initialize_model(args.backbone, args.num_classes, False, args.model_def_path) model_structure.load_state_dict(torch.load(args.snapshot)) model = model_structure.eval() dummy_input = torch.randn(1, 3, input_size[0],input_size[1]) save_path = args.save_path if args.save_path is None: save_path = args.backbone+'.onnx' torch.onnx.export(model, dummy_input, save_path, keep_initializers_as_inputs=True, opset_version=11) if __name__ == '__main__': main()