Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add fashion mnist and move mnists to s3 (#7635)
Browse files Browse the repository at this point in the history
* add fashion mnist and move mnists to s3

* refactor
  • Loading branch information
szha authored and piiswrong committed Aug 28, 2017
1 parent aceef5a commit e845cec
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 16 deletions.
68 changes: 52 additions & 16 deletions python/mxnet/gluon/data/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def __init__(self, root, train, transform):
self._data = None
self._label = None

if not os.path.isdir(self._root):
os.makedirs(self._root)
self._get_data()

def __getitem__(self, idx):
Expand Down Expand Up @@ -70,24 +72,29 @@ class MNIST(_DownloadedDataset):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
def __init__(self, root='~/.mxnet/datasets/', train=True,
def __init__(self, root='~/.mxnet/datasets/mnist', train=True,
transform=None):
self._base_url = 'https://apache-mxnet.s3.amazonaws.com/gluon/dataset/mnist/'
self._train_data = ('train-images-idx3-ubyte.gz',
'6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d')
self._train_label = ('train-labels-idx1-ubyte.gz',
'2a80914081dc54586dbdf242f9805a6b8d2a15fc')
self._test_data = ('t10k-images-idx3-ubyte.gz',
'c3a25af1f52dad7f726cce8cacb138654b760d48')
self._test_label = ('t10k-labels-idx1-ubyte.gz',
'763e7fa3757d93b0cdec073cef058b2004252c17')
super(MNIST, self).__init__(root, train, transform)

def _get_data(self):
if not os.path.isdir(self._root):
os.makedirs(self._root)
url = 'http://data.mxnet.io/data/mnist/'
if self._train:
data_file = download(url+'train-images-idx3-ubyte.gz', self._root,
sha1_hash='6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d')
label_file = download(url+'train-labels-idx1-ubyte.gz', self._root,
sha1_hash='2a80914081dc54586dbdf242f9805a6b8d2a15fc')
data, label = self._train_data, self._train_label
else:
data_file = download(url+'t10k-images-idx3-ubyte.gz', self._root,
sha1_hash='c3a25af1f52dad7f726cce8cacb138654b760d48')
label_file = download(url+'t10k-labels-idx1-ubyte.gz', self._root,
sha1_hash='763e7fa3757d93b0cdec073cef058b2004252c17')
data, label = self._test_data, self._test_label

data_file = download(self._base_url + data[0], self._root,
sha1_hash=data[1])
label_file = download(self._base_url + label[0], self._root,
sha1_hash=label[1])

with gzip.open(label_file, 'rb') as fin:
struct.unpack(">II", fin.read(8))
Expand All @@ -102,6 +109,38 @@ def _get_data(self):
self._label = label


class FashionMNIST(MNIST):
"""A dataset of Zalando's article images consisting of fashion products,
a drop-in replacement of the original MNIST dataset from
`https://github.com/zalandoresearch/fashion-mnist`_.
Each sample is an image (in 3D NDArray) with shape (28, 28, 1).
Parameters
----------
root : str
Path to temp folder for storing data.
train : bool
Whether to load the training or testing set.
transform : function
A user defined callback that transforms each instance. For example::
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
def __init__(self, root='~/.mxnet/datasets/fashion-mnist', train=True,
transform=None):
self._base_url = 'https://apache-mxnet.s3.amazonaws.com/gluon/dataset/fashion-mnist/'
self._train_data = ('train-images-idx3-ubyte.gz',
'0cf37b0d40ed5169c6b3aba31069a9770ac9043d')
self._train_label = ('train-labels-idx1-ubyte.gz',
'236021d52f1e40852b06a4c3008d8de8aef1e40b')
self._test_data = ('t10k-images-idx3-ubyte.gz',
'626ed6a7c06dd17c0eec72fa3be1740f146a2863')
self._test_label = ('t10k-labels-idx1-ubyte.gz',
'17f9ab60e7257a1620f4ad76bbbaf857c3920701')
super(FashionMNIST, self).__init__(root, train, transform)


class CIFAR10(_DownloadedDataset):
"""CIFAR10 image classification dataset from `https://www.cs.toronto.edu/~kriz/cifar.html`_.
Expand All @@ -118,7 +157,7 @@ class CIFAR10(_DownloadedDataset):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
def __init__(self, root='~/.mxnet/datasets/', train=True,
def __init__(self, root='~/.mxnet/datasets/cifar10', train=True,
transform=None):
self._file_hashes = {'data_batch_1.bin': 'aadd24acce27caa71bf4b10992e9e7b2d74c2540',
'data_batch_2.bin': 'c0ba65cce70568cd57b4e03e9ac8d2a5367c1795',
Expand All @@ -136,9 +175,6 @@ def _read_batch(self, filename):
data[:, 0].astype(np.int32)

def _get_data(self):
if not os.path.isdir(self._root):
os.makedirs(self._root)

file_paths = [(name, os.path.join(self._root, 'cifar-10-batches-bin/', name))
for name in self._file_hashes]
if any(not os.path.exists(path) or not check_sha1(path, self._file_hashes[name])
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test_sampler():

def test_datasets():
assert len(gluon.data.vision.MNIST(root='data')) == 60000
assert len(gluon.data.vision.FashionMNIST(root='data')) == 60000
assert len(gluon.data.vision.CIFAR10(root='data', train=False)) == 10000

def test_image_folder_dataset():
Expand Down

0 comments on commit e845cec

Please sign in to comment.