import torch import os def save_model(network, model_name, snapshot_path, epoch_label, device): save_filename = model_name + '_%s.pth'% epoch_label save_path = os.path.join(snapshot_path,save_filename) if not os.path.isdir(snapshot_path): os.makedirs(snapshot_path) print('saving model ', save_path) torch.save(network.cpu().state_dict(), save_path) network = network.to(device) return network