13 lines
435 B
Python
13 lines
435 B
Python
import torch
|
|
import os
|
|
|
|
def save_model(network, model_name, snapshot_path, epoch_label, device):
|
|
save_filename = model_name + '_%s.pth'% epoch_label
|
|
save_path = os.path.join(snapshot_path,save_filename)
|
|
if not os.path.isdir(snapshot_path):
|
|
os.makedirs(snapshot_path)
|
|
print('saving model ', save_path)
|
|
torch.save(network.cpu().state_dict(), save_path)
|
|
network = network.to(device)
|
|
return network
|