Yolov5s/ai_training/classification/load_optimizer.py

34 lines
1.0 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
def load_optimizer(model_ft, lr=0.001, momentum=0.9, freeze_backbone = True, op_type='SGD'):
params_to_update = model_ft.parameters()
print("Params to learn:")
if freeze_backbone:
params_to_update = []
for name,param in model_ft.named_parameters():
if param.requires_grad == True:
params_to_update.append(param)
print("\t",name)
else:
for name,param in model_ft.named_parameters():
if param.requires_grad == True:
print("\t",name)
if op_type == 'SGD':
optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum=momentum)
elif op_type == 'ASGD':
optimizer_ft = optim.ASGD(params_to_update, lr=lr)
elif op_type == 'ADAM':
optim.Adam(params_to_update, lr=lr)
else:
print("Invalid optimizer name, exiting...")
exit()
return optimizer_ft