42 lines
1.7 KiB
Python
42 lines
1.7 KiB
Python
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision
|
|
from torchvision import datasets, models, transforms
|
|
import json
|
|
|
|
def load_data(data_dir, batch_size, input_size, worker):
|
|
transform_train_list = [
|
|
transforms.RandomResizedCrop(input_size),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0, 0, 0], [1/255.0, 1/255.0, 1/255.0]),
|
|
transforms.Normalize([0.5*256, 0.5*256, 0.5*256], [256.0, 256.0, 256.0])
|
|
]
|
|
transform_val_list = [
|
|
transforms.Resize(input_size),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0, 0, 0], [1/255.0, 1/255.0, 1/255.0]),
|
|
transforms.Normalize([0.5*256, 0.5*256, 0.5*256], [256.0, 256.0, 256.0])
|
|
]
|
|
|
|
data_transforms = {
|
|
'train': transforms.Compose(transform_train_list),
|
|
'val': transforms.Compose(transform_val_list)
|
|
}
|
|
|
|
print("Initializing Datasets and Dataloaders...")
|
|
# Create training and validation datasets
|
|
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
|
|
# Create training and validation dataloaders
|
|
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=worker, pin_memory=True) for x in ['train','val']}
|
|
print('-------------Label mapping to Idx:--------------')
|
|
class_id = image_datasets['train'].class_to_idx
|
|
class_id = dict([(value, key) for key, value in class_id.items()])
|
|
print(class_id)
|
|
print('------------------------------------------------')
|
|
with open("./eval_utils/class_id.json", "w") as outfile:
|
|
json.dump(class_id, outfile)
|
|
|
|
return dataloaders
|