34 lines
1.0 KiB
Python
34 lines
1.0 KiB
Python
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
import torchvision
|
|
from torchvision import datasets, models, transforms
|
|
|
|
def load_optimizer(model_ft, lr=0.001, momentum=0.9, freeze_backbone = True, op_type='SGD'):
|
|
params_to_update = model_ft.parameters()
|
|
print("Params to learn:")
|
|
if freeze_backbone:
|
|
params_to_update = []
|
|
for name,param in model_ft.named_parameters():
|
|
if param.requires_grad == True:
|
|
params_to_update.append(param)
|
|
print("\t",name)
|
|
else:
|
|
for name,param in model_ft.named_parameters():
|
|
if param.requires_grad == True:
|
|
print("\t",name)
|
|
|
|
if op_type == 'SGD':
|
|
optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum=momentum)
|
|
elif op_type == 'ASGD':
|
|
optimizer_ft = optim.ASGD(params_to_update, lr=lr)
|
|
elif op_type == 'ADAM':
|
|
optim.Adam(params_to_update, lr=lr)
|
|
else:
|
|
print("Invalid optimizer name, exiting...")
|
|
exit()
|
|
|
|
return optimizer_ft
|