13 lines
435 B
Python

import torch
import os
def save_model(network, model_name, snapshot_path, epoch_label, device):
save_filename = model_name + '_%s.pth'% epoch_label
save_path = os.path.join(snapshot_path,save_filename)
if not os.path.isdir(snapshot_path):
os.makedirs(snapshot_path)
print('saving model ', save_path)
torch.save(network.cpu().state_dict(), save_path)
network = network.to(device)
return network