import os import torch import torch.nn as nn import torchvision from torchvision import datasets, models, transforms import json def load_data(data_dir, batch_size, input_size, worker): transform_train_list = [ transforms.RandomResizedCrop(input_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0, 0, 0], [1/255.0, 1/255.0, 1/255.0]), transforms.Normalize([0.5*256, 0.5*256, 0.5*256], [256.0, 256.0, 256.0]) ] transform_val_list = [ transforms.Resize(input_size), transforms.ToTensor(), transforms.Normalize([0, 0, 0], [1/255.0, 1/255.0, 1/255.0]), transforms.Normalize([0.5*256, 0.5*256, 0.5*256], [256.0, 256.0, 256.0]) ] data_transforms = { 'train': transforms.Compose(transform_train_list), 'val': transforms.Compose(transform_val_list) } print("Initializing Datasets and Dataloaders...") # Create training and validation datasets image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} # Create training and validation dataloaders dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=worker, pin_memory=True) for x in ['train','val']} print('-------------Label mapping to Idx:--------------') class_id = image_datasets['train'].class_to_idx class_id = dict([(value, key) for key, value in class_id.items()]) print(class_id) print('------------------------------------------------') with open("./eval_utils/class_id.json", "w") as outfile: json.dump(class_id, outfile) return dataloaders