139 lines
5.5 KiB
Python
139 lines
5.5 KiB
Python
import sys
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision
|
|
from torchvision import datasets, models, transforms
|
|
|
|
def set_parameter_requires_grad(model, feature_extracting):
|
|
if feature_extracting:
|
|
for param in model.parameters():
|
|
param.requires_grad = False
|
|
|
|
def intersect_dicts(da, db, exclude=()):
|
|
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
|
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
|
|
|
|
def initialize_weights(model_ft, pretrained=''):
|
|
state_dict = torch.load(pretrained) # load checkpoint
|
|
state_dict = intersect_dicts(state_dict, model_ft.state_dict()) # intersect
|
|
model_ft.load_state_dict(state_dict, strict=False) # load
|
|
print('Transferred %g/%g items from %s' % (len(state_dict), len(model_ft.state_dict()), pretrained)) # report
|
|
|
|
def initialize_model(model_name, num_classes, feature_extract, model_def_path=None, use_pretrained=None):
|
|
# Initialize these variables which will be set in this if statement. Each of these
|
|
# variables is model specific.
|
|
model_ft = None
|
|
input_size = 0
|
|
current_path=os.getcwd()
|
|
|
|
if model_name == 'FP_classifier':
|
|
if num_classes != 2:
|
|
print("Number of classes should be two, exiting...")
|
|
exit()
|
|
if model_def_path == None:
|
|
model_def_path = './models/FP_classifier/'
|
|
sys.path.append(model_def_path)
|
|
from Mobilenet_v2_small import mobile_net_v2
|
|
if use_pretrained:
|
|
model_ft = mobile_net_v2(num_classes)
|
|
model_ft.load_state_dict(torch.load(use_pretrained))
|
|
set_parameter_requires_grad(model_ft, feature_extract)
|
|
if feature_extract:
|
|
for param in model_ft.model.classifier[1].parameters():
|
|
param.requires_grad = True
|
|
else:
|
|
model_ft = mobile_net_v2(num_classes)
|
|
|
|
|
|
input_size = (56,32)
|
|
|
|
elif model_name == 'mobilenetv2':
|
|
""" Mobilenetv2
|
|
"""
|
|
if model_def_path == None:
|
|
model_def_path = './models/MobileNetV2/'
|
|
sys.path.append(model_def_path)
|
|
from Mobilenet_v2 import mobilenet_v2
|
|
if use_pretrained is not None and len(use_pretrained)>0:
|
|
model_ft = mobilenet_v2(num_classes)
|
|
initialize_weights(model_ft, use_pretrained)
|
|
set_parameter_requires_grad(model_ft, feature_extract)
|
|
if feature_extract:
|
|
for param in model_ft.model.classifier[1].parameters():
|
|
param.requires_grad = True
|
|
else:
|
|
model_ft = mobilenet_v2(num_classes)
|
|
|
|
input_size = (224,224)
|
|
|
|
elif model_name == 'resnet18':
|
|
""" ResNet18
|
|
"""
|
|
if model_def_path == None:
|
|
model_def_path = './models/ResNet18/'
|
|
sys.path.append(model_def_path)
|
|
from ResNet18 import resnet18
|
|
if use_pretrained is not None and len(use_pretrained)>0:
|
|
model_ft = resnet18(num_classes)
|
|
initialize_weights(model_ft, use_pretrained)
|
|
set_parameter_requires_grad(model_ft, feature_extract)
|
|
if feature_extract:
|
|
for param in model_ft.model.fc.parameters():
|
|
param.requires_grad = True
|
|
else:
|
|
model_ft = resnet18(num_classes)
|
|
|
|
input_size = (224,224)
|
|
elif model_name == 'resnet50':
|
|
""" ResNet50
|
|
"""
|
|
if model_def_path == None:
|
|
model_def_path = './models/ResNet50/'
|
|
sys.path.append(model_def_path)
|
|
from ResNet50 import resnet50
|
|
if use_pretrained is not None and len(use_pretrained)>0:
|
|
model_ft = resnet50(num_classes)
|
|
initialize_weights(model_ft, use_pretrained)
|
|
set_parameter_requires_grad(model_ft, feature_extract)
|
|
if feature_extract:
|
|
for param in model_ft.model.fc.parameters():
|
|
param.requires_grad = True
|
|
else:
|
|
model_ft = resnet50(num_classes)
|
|
|
|
input_size = (224,224)
|
|
|
|
elif model_name in [ 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7']:
|
|
""" EfficientNet
|
|
"""
|
|
if model_def_path == None:
|
|
model_def_path = './models/EfficientNet/'
|
|
sys.path.append(sys.path.append(model_def_path))
|
|
from EfficientNet_520 import EfficientNet
|
|
if use_pretrained is not None and len(use_pretrained)>0:
|
|
|
|
model_ft = EfficientNet.from_name(model_name)
|
|
model_ft.set_swish(memory_efficient=False)
|
|
model_ft.load_state_dict(torch.load(use_pretrained) )
|
|
set_parameter_requires_grad(model_ft, feature_extract)
|
|
if imagenet != 0:
|
|
num_ftrs = model_ft._fc.in_features
|
|
model_ft._fc = nn.Linear(num_ftrs, num_classes, bias=True)
|
|
|
|
else:
|
|
model_ft = EfficientNet.from_name(model_name,num_classes=num_classes)
|
|
|
|
input_size = (224,224)
|
|
|
|
else:
|
|
print("Invalid model name, exiting...")
|
|
exit()
|
|
|
|
return model_ft, input_size
|
|
|
|
if __name__ == '__main__':
|
|
model_ft, input_size = initialize_model('resnet18', 1000, False, model_def_path=None, use_pretrained='ResNet18.pth')
|
|
print(model_ft)
|
|
#from save_model import save_model
|
|
#save_model(model_ft, 'mobilenetv2', 'exp/', 0, 'cpu') |