import sys import os import torch import torch.nn as nn import torchvision from torchvision import datasets, models, transforms def set_parameter_requires_grad(model, feature_extracting): if feature_extracting: for param in model.parameters(): param.requires_grad = False def intersect_dicts(da, db, exclude=()): # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} def initialize_weights(model_ft, pretrained=''): state_dict = torch.load(pretrained) # load checkpoint state_dict = intersect_dicts(state_dict, model_ft.state_dict()) # intersect model_ft.load_state_dict(state_dict, strict=False) # load print('Transferred %g/%g items from %s' % (len(state_dict), len(model_ft.state_dict()), pretrained)) # report def initialize_model(model_name, num_classes, feature_extract, model_def_path=None, use_pretrained=None): # Initialize these variables which will be set in this if statement. Each of these # variables is model specific. model_ft = None input_size = 0 current_path=os.getcwd() if model_name == 'FP_classifier': if num_classes != 2: print("Number of classes should be two, exiting...") exit() if model_def_path == None: model_def_path = './models/FP_classifier/' sys.path.append(model_def_path) from Mobilenet_v2_small import mobile_net_v2 if use_pretrained: model_ft = mobile_net_v2(num_classes) model_ft.load_state_dict(torch.load(use_pretrained)) set_parameter_requires_grad(model_ft, feature_extract) if feature_extract: for param in model_ft.model.classifier[1].parameters(): param.requires_grad = True else: model_ft = mobile_net_v2(num_classes) input_size = (56,32) elif model_name == 'mobilenetv2': """ Mobilenetv2 """ if model_def_path == None: model_def_path = './models/MobileNetV2/' sys.path.append(model_def_path) from Mobilenet_v2 import mobilenet_v2 if use_pretrained is not None and len(use_pretrained)>0: model_ft = mobilenet_v2(num_classes) initialize_weights(model_ft, use_pretrained) set_parameter_requires_grad(model_ft, feature_extract) if feature_extract: for param in model_ft.model.classifier[1].parameters(): param.requires_grad = True else: model_ft = mobilenet_v2(num_classes) input_size = (224,224) elif model_name == 'resnet18': """ ResNet18 """ if model_def_path == None: model_def_path = './models/ResNet18/' sys.path.append(model_def_path) from ResNet18 import resnet18 if use_pretrained is not None and len(use_pretrained)>0: model_ft = resnet18(num_classes) initialize_weights(model_ft, use_pretrained) set_parameter_requires_grad(model_ft, feature_extract) if feature_extract: for param in model_ft.model.fc.parameters(): param.requires_grad = True else: model_ft = resnet18(num_classes) input_size = (224,224) elif model_name == 'resnet50': """ ResNet50 """ if model_def_path == None: model_def_path = './models/ResNet50/' sys.path.append(model_def_path) from ResNet50 import resnet50 if use_pretrained is not None and len(use_pretrained)>0: model_ft = resnet50(num_classes) initialize_weights(model_ft, use_pretrained) set_parameter_requires_grad(model_ft, feature_extract) if feature_extract: for param in model_ft.model.fc.parameters(): param.requires_grad = True else: model_ft = resnet50(num_classes) input_size = (224,224) elif model_name in [ 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7']: """ EfficientNet """ if model_def_path == None: model_def_path = './models/EfficientNet/' sys.path.append(sys.path.append(model_def_path)) from EfficientNet_520 import EfficientNet if use_pretrained is not None and len(use_pretrained)>0: model_ft = EfficientNet.from_name(model_name) model_ft.set_swish(memory_efficient=False) model_ft.load_state_dict(torch.load(use_pretrained) ) set_parameter_requires_grad(model_ft, feature_extract) if imagenet != 0: num_ftrs = model_ft._fc.in_features model_ft._fc = nn.Linear(num_ftrs, num_classes, bias=True) else: model_ft = EfficientNet.from_name(model_name,num_classes=num_classes) input_size = (224,224) else: print("Invalid model name, exiting...") exit() return model_ft, input_size if __name__ == '__main__': model_ft, input_size = initialize_model('resnet18', 1000, False, model_def_path=None, use_pretrained='ResNet18.pth') print(model_ft) #from save_model import save_model #save_model(model_ft, 'mobilenetv2', 'exp/', 0, 'cpu')