139 lines
5.5 KiB
Python

import sys
import os
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
def initialize_weights(model_ft, pretrained=''):
state_dict = torch.load(pretrained) # load checkpoint
state_dict = intersect_dicts(state_dict, model_ft.state_dict()) # intersect
model_ft.load_state_dict(state_dict, strict=False) # load
print('Transferred %g/%g items from %s' % (len(state_dict), len(model_ft.state_dict()), pretrained)) # report
def initialize_model(model_name, num_classes, feature_extract, model_def_path=None, use_pretrained=None):
# Initialize these variables which will be set in this if statement. Each of these
# variables is model specific.
model_ft = None
input_size = 0
current_path=os.getcwd()
if model_name == 'FP_classifier':
if num_classes != 2:
print("Number of classes should be two, exiting...")
exit()
if model_def_path == None:
model_def_path = './models/FP_classifier/'
sys.path.append(model_def_path)
from Mobilenet_v2_small import mobile_net_v2
if use_pretrained:
model_ft = mobile_net_v2(num_classes)
model_ft.load_state_dict(torch.load(use_pretrained))
set_parameter_requires_grad(model_ft, feature_extract)
if feature_extract:
for param in model_ft.model.classifier[1].parameters():
param.requires_grad = True
else:
model_ft = mobile_net_v2(num_classes)
input_size = (56,32)
elif model_name == 'mobilenetv2':
""" Mobilenetv2
"""
if model_def_path == None:
model_def_path = './models/MobileNetV2/'
sys.path.append(model_def_path)
from Mobilenet_v2 import mobilenet_v2
if use_pretrained is not None and len(use_pretrained)>0:
model_ft = mobilenet_v2(num_classes)
initialize_weights(model_ft, use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
if feature_extract:
for param in model_ft.model.classifier[1].parameters():
param.requires_grad = True
else:
model_ft = mobilenet_v2(num_classes)
input_size = (224,224)
elif model_name == 'resnet18':
""" ResNet18
"""
if model_def_path == None:
model_def_path = './models/ResNet18/'
sys.path.append(model_def_path)
from ResNet18 import resnet18
if use_pretrained is not None and len(use_pretrained)>0:
model_ft = resnet18(num_classes)
initialize_weights(model_ft, use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
if feature_extract:
for param in model_ft.model.fc.parameters():
param.requires_grad = True
else:
model_ft = resnet18(num_classes)
input_size = (224,224)
elif model_name == 'resnet50':
""" ResNet50
"""
if model_def_path == None:
model_def_path = './models/ResNet50/'
sys.path.append(model_def_path)
from ResNet50 import resnet50
if use_pretrained is not None and len(use_pretrained)>0:
model_ft = resnet50(num_classes)
initialize_weights(model_ft, use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
if feature_extract:
for param in model_ft.model.fc.parameters():
param.requires_grad = True
else:
model_ft = resnet50(num_classes)
input_size = (224,224)
elif model_name in [ 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7']:
""" EfficientNet
"""
if model_def_path == None:
model_def_path = './models/EfficientNet/'
sys.path.append(sys.path.append(model_def_path))
from EfficientNet_520 import EfficientNet
if use_pretrained is not None and len(use_pretrained)>0:
model_ft = EfficientNet.from_name(model_name)
model_ft.set_swish(memory_efficient=False)
model_ft.load_state_dict(torch.load(use_pretrained) )
set_parameter_requires_grad(model_ft, feature_extract)
if imagenet != 0:
num_ftrs = model_ft._fc.in_features
model_ft._fc = nn.Linear(num_ftrs, num_classes, bias=True)
else:
model_ft = EfficientNet.from_name(model_name,num_classes=num_classes)
input_size = (224,224)
else:
print("Invalid model name, exiting...")
exit()
return model_ft, input_size
if __name__ == '__main__':
model_ft, input_size = initialize_model('resnet18', 1000, False, model_def_path=None, use_pretrained='ResNet18.pth')
print(model_ft)
#from save_model import save_model
#save_model(model_ft, 'mobilenetv2', 'exp/', 0, 'cpu')