Yolov5s/ai_training/classification/loss_functions.py

11 lines
271 B
Python

import torch
import torch.nn as nn
def load_loss_functions(loss_func = 'cross_entropy'):
if loss_func == 'cross_entropy':
criterion = nn.CrossEntropyLoss()
else:
print("Invalid loss function name, exiting...")
exit()
return criterion