11 lines
271 B
Python
11 lines
271 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
def load_loss_functions(loss_func = 'cross_entropy'):
|
|
if loss_func == 'cross_entropy':
|
|
criterion = nn.CrossEntropyLoss()
|
|
else:
|
|
print("Invalid loss function name, exiting...")
|
|
exit()
|
|
return criterion
|