Skip to content

Commit cca77a2

Browse files
committed
✨ add complete mnist file
1 parent 53ffd87 commit cca77a2

File tree

1 file changed

+147
-8
lines changed

1 file changed

+147
-8
lines changed

mnist_parser/__init__.py

+147-8
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,90 @@
1+
'''
2+
Python utilities from https://github.com/datapythonista/mnist
3+
'''
4+
15
import os
26
import functools
37
import operator
48
import gzip
59
import struct
610
import array
7-
import numpy as np
11+
import tempfile
12+
13+
try:
14+
from urllib.request import urlretrieve
15+
except ImportError:
16+
from urllib import urlretrieve # py2
17+
try:
18+
from urllib.parse import urljoin
19+
except ImportError:
20+
from urlparse import urljoin
21+
import numpy
22+
23+
24+
__version__ = '0.2.2'
25+
26+
27+
# `datasets_url` and `temporary_dir` can be set by the user using:
28+
# >>> mnist.datasets_url = 'http://my.mnist.url'
29+
# >>> mnist.temporary_dir = lambda: '/tmp/mnist'
30+
datasets_url = 'http://yann.lecun.com/exdb/mnist/'
31+
temporary_dir = tempfile.gettempdir
32+
33+
34+
class IdxDecodeError(ValueError):
35+
"""Raised when an invalid idx file is parsed."""
36+
pass
37+
38+
39+
def download_file(fname, target_dir=None, force=False):
40+
"""Download fname from the datasets_url, and save it to target_dir,
41+
unless the file already exists, and force is False.
42+
43+
Parameters
44+
----------
45+
fname : str
46+
Name of the file to download
47+
48+
target_dir : str
49+
Directory where to store the file
50+
51+
force : bool
52+
Force downloading the file, if it already exists
53+
54+
Returns
55+
-------
56+
fname : str
57+
Full path of the downloaded file
58+
"""
59+
target_dir = target_dir or temporary_dir()
60+
target_fname = os.path.join(target_dir, fname)
61+
62+
if force or not os.path.isfile(target_fname):
63+
url = urljoin(datasets_url, fname)
64+
urlretrieve(url, target_fname)
65+
66+
return target_fname
867

9-
#from: https://github.com/datapythonista/mnist
1068

1169
def parse_idx(fd):
70+
"""Parse an IDX file, and return it as a numpy array.
71+
72+
Parameters
73+
----------
74+
fd : file
75+
File descriptor of the IDX file to parse
76+
77+
endian : str
78+
Byte order of the IDX file. See [1] for available options
79+
80+
Returns
81+
-------
82+
data : numpy.ndarray
83+
Numpy array with the dimensions and the data in the IDX file
84+
85+
1. https://docs.python.org/3/library/struct.html
86+
#byte-order-size-and-alignment
87+
"""
1288
DATA_TYPES = {0x08: 'B', # unsigned byte
1389
0x09: 'b', # signed byte
1490
0x0b: 'h', # short (2 bytes)
@@ -18,18 +94,21 @@ def parse_idx(fd):
1894

1995
header = fd.read(4)
2096
if len(header) != 4:
21-
raise IdxDecodeError('Invalid IDX file, file empty or does not contain a full header.')
97+
raise IdxDecodeError('Invalid IDX file, '
98+
'file empty or does not contain a full header.')
2299

23100
zeros, data_type, num_dimensions = struct.unpack('>HBB', header)
24101

25102
if zeros != 0:
26-
raise IdxDecodeError('Invalid IDX file, file must start with two zero bytes. '
103+
raise IdxDecodeError('Invalid IDX file, '
104+
'file must start with two zero bytes. '
27105
'Found 0x%02x' % zeros)
28106

29107
try:
30108
data_type = DATA_TYPES[data_type]
31109
except KeyError:
32-
raise IdxDecodeError('Unknown data type 0x%02x in IDX file' % data_type)
110+
raise IdxDecodeError('Unknown data type '
111+
'0x%02x in IDX file' % data_type)
33112

34113
dimension_sizes = struct.unpack('>' + 'I' * num_dimensions,
35114
fd.read(4 * num_dimensions))
@@ -40,29 +119,89 @@ def parse_idx(fd):
40119
expected_items = functools.reduce(operator.mul, dimension_sizes)
41120
if len(data) != expected_items:
42121
raise IdxDecodeError('IDX file has wrong number of items. '
43-
'Expected: %d. Found: %d' % (expected_items, len(data)))
122+
'Expected: %d. Found: %d' % (expected_items,
123+
len(data)))
44124

45-
return np.array(data).reshape(dimension_sizes)
125+
return numpy.array(data).reshape(dimension_sizes)
46126

47127

48128
def download_and_parse_mnist_file(fname, target_dir=None, force=False):
49-
fname = 'res/' + fname
129+
"""Download the IDX file named fname from the URL specified in dataset_url
130+
and return it as a numpy array.
131+
132+
Parameters
133+
----------
134+
fname : str
135+
File name to download and parse
136+
137+
target_dir : str
138+
Directory where to store the file
139+
140+
force : bool
141+
Force downloading the file, if it already exists
142+
143+
Returns
144+
-------
145+
data : numpy.ndarray
146+
Numpy array with the dimensions and the data in the IDX file
147+
"""
148+
fname = download_file(fname, target_dir=target_dir, force=force)
50149
fopen = gzip.open if os.path.splitext(fname)[1] == '.gz' else open
51150
with fopen(fname, 'rb') as fd:
52151
return parse_idx(fd)
53152

54153

55154
def train_images():
155+
"""Return train images from Yann LeCun MNIST database as a numpy array.
156+
Download the file, if not already found in the temporary directory of
157+
the system.
158+
159+
Returns
160+
-------
161+
train_images : numpy.ndarray
162+
Numpy array with the images in the train MNIST database. The first
163+
dimension indexes each sample, while the other two index rows and
164+
columns of the image
165+
"""
56166
return download_and_parse_mnist_file('train-images-idx3-ubyte.gz')
57167

58168

59169
def test_images():
170+
"""Return test images from Yann LeCun MNIST database as a numpy array.
171+
Download the file, if not already found in the temporary directory of
172+
the system.
173+
174+
Returns
175+
-------
176+
test_images : numpy.ndarray
177+
Numpy array with the images in the train MNIST database. The first
178+
dimension indexes each sample, while the other two index rows and
179+
columns of the image
180+
"""
60181
return download_and_parse_mnist_file('t10k-images-idx3-ubyte.gz')
61182

62183

63184
def train_labels():
185+
"""Return train labels from Yann LeCun MNIST database as a numpy array.
186+
Download the file, if not already found in the temporary directory of
187+
the system.
188+
189+
Returns
190+
-------
191+
train_labels : numpy.ndarray
192+
Numpy array with the labels 0 to 9 in the train MNIST database.
193+
"""
64194
return download_and_parse_mnist_file('train-labels-idx1-ubyte.gz')
65195

66196

67197
def test_labels():
198+
"""Return test labels from Yann LeCun MNIST database as a numpy array.
199+
Download the file, if not already found in the temporary directory of
200+
the system.
201+
202+
Returns
203+
-------
204+
test_labels : numpy.ndarray
205+
Numpy array with the labels 0 to 9 in the train MNIST database.
206+
"""
68207
return download_and_parse_mnist_file('t10k-labels-idx1-ubyte.gz')

0 commit comments

Comments
 (0)