Source code for ares.dataset.cifar_dataset

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10


[docs]def cifar10(batchsize, cifar10_path): '''The function to create cifar10 dataloader.''' transform = transforms.Compose([transforms.ToTensor()]) cifar = CIFAR10(root=cifar10_path, train=False, download=True, transform=transform) test_loader = DataLoader(cifar, batch_size=batchsize, shuffle=False, num_workers=1, pin_memory= False, drop_last= False) test_loader.name = "cifar10" test_loader.batch = batchsize return test_loader