Skip to content

Commit

Permalink
fix black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tejuafonja committed Mar 16, 2021
1 parent 8575736 commit 448ae01
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions tutorials/torch/mnist_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from datasets import MNISTDataset

from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
from cleverhans.torch.attacks.projected_gradient_descent import (
projected_gradient_descent,
)


FLAGS = flags.FLAGS


Expand Down Expand Up @@ -79,12 +81,13 @@ def ld_mnist():
test_transforms = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()]
)
train_dataset = torchvision.datasets.MNIST(
root="/tmp/data", train=True, transform=train_transforms, download=True
)
test_dataset = torchvision.datasets.MNIST(
root="/tmp/data", train=False, transform=test_transforms, download=True

# Load MNIST dataset
train_dataset = MNISTDataset(root="/tmp/data", transform=train_transforms)
test_dataset = MNISTDataset(
root="/tmp/data", train=False, transform=test_transforms
)

train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=128, shuffle=True, num_workers=2
)
Expand Down

0 comments on commit 448ae01

Please sign in to comment.