import torch import torch.nn as nn def load_loss_functions(loss_func = 'cross_entropy'): if loss_func == 'cross_entropy': criterion = nn.CrossEntropyLoss() else: print("Invalid loss function name, exiting...") exit() return criterion