diff --git a/tutorials/torch/mnist_tutorial.py b/tutorials/torch/mnist_tutorial.py index 9db4bd1d9..559b5a4ae 100644 --- a/tutorials/torch/mnist_tutorial.py +++ b/tutorials/torch/mnist_tutorial.py @@ -5,12 +5,14 @@ import torch.nn as nn import torch.nn.functional as F import torchvision +from datasets import MNISTDataset from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method from cleverhans.torch.attacks.projected_gradient_descent import ( projected_gradient_descent, ) + FLAGS = flags.FLAGS @@ -79,12 +81,13 @@ def ld_mnist(): test_transforms = torchvision.transforms.Compose( [torchvision.transforms.ToTensor()] ) - train_dataset = torchvision.datasets.MNIST( - root="/tmp/data", train=True, transform=train_transforms, download=True - ) - test_dataset = torchvision.datasets.MNIST( - root="/tmp/data", train=False, transform=test_transforms, download=True + + # Load MNIST dataset + train_dataset = MNISTDataset(root="/tmp/data", transform=train_transforms) + test_dataset = MNISTDataset( + root="/tmp/data", train=False, transform=test_transforms ) + train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=128, shuffle=True, num_workers=2 )