42 lines
1.7 KiB
Python

import os
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
import json
def load_data(data_dir, batch_size, input_size, worker):
transform_train_list = [
transforms.RandomResizedCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0, 0, 0], [1/255.0, 1/255.0, 1/255.0]),
transforms.Normalize([0.5*256, 0.5*256, 0.5*256], [256.0, 256.0, 256.0])
]
transform_val_list = [
transforms.Resize(input_size),
transforms.ToTensor(),
transforms.Normalize([0, 0, 0], [1/255.0, 1/255.0, 1/255.0]),
transforms.Normalize([0.5*256, 0.5*256, 0.5*256], [256.0, 256.0, 256.0])
]
data_transforms = {
'train': transforms.Compose(transform_train_list),
'val': transforms.Compose(transform_val_list)
}
print("Initializing Datasets and Dataloaders...")
# Create training and validation datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# Create training and validation dataloaders
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=worker, pin_memory=True) for x in ['train','val']}
print('-------------Label mapping to Idx:--------------')
class_id = image_datasets['train'].class_to_idx
class_id = dict([(value, key) for key, value in class_id.items()])
print(class_id)
print('------------------------------------------------')
with open("./eval_utils/class_id.json", "w") as outfile:
json.dump(class_id, outfile)
return dataloaders