Skip to content

Commit 2d55b9d

Browse files
authored
Merge pull request #10 from pytorch/cifarfix
making cifar data loader also return PIL Image
2 parents 98b9aa5 + 05bcb18 commit 2d55b9d

File tree

3 files changed

+183
-6
lines changed

3 files changed

+183
-6
lines changed

test/sanity_checks.ipynb

Lines changed: 168 additions & 5 deletions
Large diffs are not rendered by default.

torchvision/datasets/cifar.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def __getitem__(self, index):
7777
img, target = self.train_data[index], self.train_labels[index]
7878
else:
7979
img, target = self.test_data[index], self.test_labels[index]
80+
81+
# doing this so that it is consistent with all other datasets
82+
# to return a PIL Image
83+
img = Image.fromarray(np.transpose(img, (1,2,0)))
8084

8185
if self.transform is not None:
8286
img = self.transform(img)

torchvision/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1+
import torch
2+
import math
13

24
def make_grid(tensor, nrow=8, padding=2):
35
"""
46
Given a 4D mini-batch Tensor of shape (B x C x H x W),
7+
or a list of images all of the same size,
58
makes a grid of images
69
"""
7-
import math
10+
tensorlist = None
11+
if isinstance(tensor, list):
12+
tensorlist = tensor
13+
numImages = len(tensorlist)
14+
size = torch.Size(torch.Size([long(numImages)]) + tensorlist[0].size())
15+
tensor = tensorlist[0].new(size)
16+
for i in range(numImages):
17+
tensor[i].copy_(tensorlist[i])
818
if tensor.dim() == 3: # single image
919
return tensor
1020
# make the mini-batch of images into a grid

0 commit comments

Comments
 (0)