88 lines
4.0 KiB
Python
88 lines
4.0 KiB
Python
import argparse
|
|
import os
|
|
import sys
|
|
from datetime import date
|
|
|
|
import torch
|
|
from load_data import load_data
|
|
from loss_functions import load_loss_functions
|
|
from load_optimizer import load_optimizer
|
|
from load_lr_scheduler import load_lr_scheduler
|
|
from train_model import train_model
|
|
from load_model import initialize_model
|
|
from save_model import save_model
|
|
|
|
def makedirs(path):
|
|
# Intended behavior: try to create the directory,
|
|
# pass if the directory exists already, fails otherwise.
|
|
try:
|
|
os.makedirs(path)
|
|
except OSError:
|
|
if not os.path.isdir(path):
|
|
raise
|
|
|
|
def check_args(parsed_args):
|
|
""" Function to check for inherent contradictions within parsed arguments.
|
|
Args
|
|
parsed_args: parser.parse_args()
|
|
Returns
|
|
parsed_args
|
|
"""
|
|
if parsed_args.gpu >= 0 and torch.cuda.is_available() == False:
|
|
raise ValueError("No gpu is available")
|
|
return parsed_args
|
|
|
|
|
|
def parse_args(args):
|
|
"""
|
|
Parse the arguments.
|
|
"""
|
|
today = str(date.today())
|
|
|
|
parser = argparse.ArgumentParser(description='Simple training script for training a image classification network.')
|
|
parser.add_argument('data_dir', type=str, help='Path to your dataset')
|
|
parser.add_argument('--model-name', type=str, help='Name of your model', default='model_ft' )
|
|
parser.add_argument('--model-def-path', type=str, help='Path to pretrained model definition', default=None )
|
|
parser.add_argument('--lr', type=float, help='Learning rate', default=5e-3)
|
|
parser.add_argument('--backbone', help='Backbone model.', default='resnet18', type=str)
|
|
parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi). (-1 for cpu)',type=int,default=-1)
|
|
parser.add_argument('--workers', help='The number of dataloader workers',type=int, default=1)
|
|
parser.add_argument('--epochs', help='Number of epochs to train.', type=int, default=100)
|
|
parser.add_argument('--freeze-backbone', help='Freeze training of backbone layers.', type=int, default=0)
|
|
parser.add_argument('--batch-size', help='Size of the batches.', default=128, type=int)
|
|
parser.add_argument('--snapshot', help='Path to the pretrained models.')
|
|
parser.add_argument('--snapshot-path', help='Path to store snapshots of models during training (defaults to \'snapshots\')', default='./snapshots/{}'.format(today))
|
|
parser.add_argument('--optimizer', help='Choose an optimizer from SGD, ASGD and ADAM', type=str, default='SGD')
|
|
parser.add_argument('--loss', help='Choose a loss function', type=str, default='cross_entropy')
|
|
parser.add_argument('--early-stop', help='Choose if early stopping', type=int, default=1)
|
|
parser.add_argument('--patience', help='Choose patience for early stopping',type=int, default=7)
|
|
|
|
print(vars(parser.parse_args(args)))
|
|
return check_args(parser.parse_args(args))
|
|
|
|
|
|
|
|
def main(args=None):
|
|
# parse arguments
|
|
if args is None:
|
|
args = sys.argv[1:]
|
|
|
|
args = parse_args(args)
|
|
device = "cuda:"+str(args.gpu) if args.gpu >= 0 else "cpu"
|
|
num_classes = len([f for f in os.listdir(os.path.join(args.data_dir, 'train')) if not f.startswith('.')])
|
|
model_ft, input_size = initialize_model(args.backbone, num_classes, args.freeze_backbone, model_def_path = args.model_def_path, use_pretrained=args.snapshot)
|
|
dataloaders_dict = load_data(args.data_dir, args.batch_size, input_size, args.workers)
|
|
optimizer_ft = load_optimizer(model_ft, lr=args.lr, freeze_backbone = args.freeze_backbone, op_type=args.optimizer)
|
|
lr_scheduler_ft = load_lr_scheduler(optimizer_ft)
|
|
criterion = load_loss_functions(loss_func = args.loss)
|
|
|
|
# Train
|
|
model_ft,_ = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, lr_scheduler_ft, device, args.snapshot_path, model_name = args.model_name, num_epochs=args.epochs,early_stop = args.early_stop, patience = args.patience)
|
|
|
|
save_model(model_ft, args.model_name, args.snapshot_path, 'best', device)
|
|
return model_ft
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|