import numpy as np import torch import os class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" def __init__(self, model_name = 'model_ft', patience=7, verbose=False, delta=0, path='./snapshots/'): """ Args: patience (int): How long to wait after last time validation loss improved. Default: 7 verbose (bool): If True, prints a message for each validation loss improvement. Default: False delta (float): Minimum change in the monitored quantity to qualify as an improvement. Default: 0 path (str): Path for the checkpoint to be saved to. Default: 'checkpoint.pt' """ self.model_name = model_name self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = np.Inf self.delta = delta self.path = path def __call__(self, val_loss, model, epoch_label): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model, epoch_label) elif score < self.best_score + self.delta: self.counter += 1 print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model, epoch_label) self.counter = 0 def save_checkpoint(self, val_loss, model, epoch_label): '''Saves model when validation loss decrease.''' if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') save_filename = self.model_name + '_%s.pth'% epoch_label save_path = os.path.join(self.path,save_filename) if not os.path.isdir(self.path): os.makedirs(self.path) torch.save(model.state_dict(), save_path) self.val_loss_min = val_loss