import torch import torch.nn as nn import torch.optim as optim import numpy as np import torchvision from torchvision import datasets, models, transforms def load_optimizer(model_ft, lr=0.001, momentum=0.9, freeze_backbone = True, op_type='SGD'): params_to_update = model_ft.parameters() print("Params to learn:") if freeze_backbone: params_to_update = [] for name,param in model_ft.named_parameters(): if param.requires_grad == True: params_to_update.append(param) print("\t",name) else: for name,param in model_ft.named_parameters(): if param.requires_grad == True: print("\t",name) if op_type == 'SGD': optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum=momentum) elif op_type == 'ASGD': optimizer_ft = optim.ASGD(params_to_update, lr=lr) elif op_type == 'ADAM': optim.Adam(params_to_update, lr=lr) else: print("Invalid optimizer name, exiting...") exit() return optimizer_ft