1
+ '''
2
+ Python utilities from https://github.com/datapythonista/mnist
3
+ '''
4
+
1
5
import os
2
6
import functools
3
7
import operator
4
8
import gzip
5
9
import struct
6
10
import array
7
- import numpy as np
11
+ import tempfile
12
+
13
+ try :
14
+ from urllib .request import urlretrieve
15
+ except ImportError :
16
+ from urllib import urlretrieve # py2
17
+ try :
18
+ from urllib .parse import urljoin
19
+ except ImportError :
20
+ from urlparse import urljoin
21
+ import numpy
22
+
23
+
24
+ __version__ = '0.2.2'
25
+
26
+
27
+ # `datasets_url` and `temporary_dir` can be set by the user using:
28
+ # >>> mnist.datasets_url = 'http://my.mnist.url'
29
+ # >>> mnist.temporary_dir = lambda: '/tmp/mnist'
30
+ datasets_url = 'http://yann.lecun.com/exdb/mnist/'
31
+ temporary_dir = tempfile .gettempdir
32
+
33
+
34
+ class IdxDecodeError (ValueError ):
35
+ """Raised when an invalid idx file is parsed."""
36
+ pass
37
+
38
+
39
+ def download_file (fname , target_dir = None , force = False ):
40
+ """Download fname from the datasets_url, and save it to target_dir,
41
+ unless the file already exists, and force is False.
42
+
43
+ Parameters
44
+ ----------
45
+ fname : str
46
+ Name of the file to download
47
+
48
+ target_dir : str
49
+ Directory where to store the file
50
+
51
+ force : bool
52
+ Force downloading the file, if it already exists
53
+
54
+ Returns
55
+ -------
56
+ fname : str
57
+ Full path of the downloaded file
58
+ """
59
+ target_dir = target_dir or temporary_dir ()
60
+ target_fname = os .path .join (target_dir , fname )
61
+
62
+ if force or not os .path .isfile (target_fname ):
63
+ url = urljoin (datasets_url , fname )
64
+ urlretrieve (url , target_fname )
65
+
66
+ return target_fname
8
67
9
- #from: https://github.com/datapythonista/mnist
10
68
11
69
def parse_idx (fd ):
70
+ """Parse an IDX file, and return it as a numpy array.
71
+
72
+ Parameters
73
+ ----------
74
+ fd : file
75
+ File descriptor of the IDX file to parse
76
+
77
+ endian : str
78
+ Byte order of the IDX file. See [1] for available options
79
+
80
+ Returns
81
+ -------
82
+ data : numpy.ndarray
83
+ Numpy array with the dimensions and the data in the IDX file
84
+
85
+ 1. https://docs.python.org/3/library/struct.html
86
+ #byte-order-size-and-alignment
87
+ """
12
88
DATA_TYPES = {0x08 : 'B' , # unsigned byte
13
89
0x09 : 'b' , # signed byte
14
90
0x0b : 'h' , # short (2 bytes)
@@ -18,18 +94,21 @@ def parse_idx(fd):
18
94
19
95
header = fd .read (4 )
20
96
if len (header ) != 4 :
21
- raise IdxDecodeError ('Invalid IDX file, file empty or does not contain a full header.' )
97
+ raise IdxDecodeError ('Invalid IDX file, '
98
+ 'file empty or does not contain a full header.' )
22
99
23
100
zeros , data_type , num_dimensions = struct .unpack ('>HBB' , header )
24
101
25
102
if zeros != 0 :
26
- raise IdxDecodeError ('Invalid IDX file, file must start with two zero bytes. '
103
+ raise IdxDecodeError ('Invalid IDX file, '
104
+ 'file must start with two zero bytes. '
27
105
'Found 0x%02x' % zeros )
28
106
29
107
try :
30
108
data_type = DATA_TYPES [data_type ]
31
109
except KeyError :
32
- raise IdxDecodeError ('Unknown data type 0x%02x in IDX file' % data_type )
110
+ raise IdxDecodeError ('Unknown data type '
111
+ '0x%02x in IDX file' % data_type )
33
112
34
113
dimension_sizes = struct .unpack ('>' + 'I' * num_dimensions ,
35
114
fd .read (4 * num_dimensions ))
@@ -40,29 +119,89 @@ def parse_idx(fd):
40
119
expected_items = functools .reduce (operator .mul , dimension_sizes )
41
120
if len (data ) != expected_items :
42
121
raise IdxDecodeError ('IDX file has wrong number of items. '
43
- 'Expected: %d. Found: %d' % (expected_items , len (data )))
122
+ 'Expected: %d. Found: %d' % (expected_items ,
123
+ len (data )))
44
124
45
- return np .array (data ).reshape (dimension_sizes )
125
+ return numpy .array (data ).reshape (dimension_sizes )
46
126
47
127
48
128
def download_and_parse_mnist_file (fname , target_dir = None , force = False ):
49
- fname = 'res/' + fname
129
+ """Download the IDX file named fname from the URL specified in dataset_url
130
+ and return it as a numpy array.
131
+
132
+ Parameters
133
+ ----------
134
+ fname : str
135
+ File name to download and parse
136
+
137
+ target_dir : str
138
+ Directory where to store the file
139
+
140
+ force : bool
141
+ Force downloading the file, if it already exists
142
+
143
+ Returns
144
+ -------
145
+ data : numpy.ndarray
146
+ Numpy array with the dimensions and the data in the IDX file
147
+ """
148
+ fname = download_file (fname , target_dir = target_dir , force = force )
50
149
fopen = gzip .open if os .path .splitext (fname )[1 ] == '.gz' else open
51
150
with fopen (fname , 'rb' ) as fd :
52
151
return parse_idx (fd )
53
152
54
153
55
154
def train_images ():
155
+ """Return train images from Yann LeCun MNIST database as a numpy array.
156
+ Download the file, if not already found in the temporary directory of
157
+ the system.
158
+
159
+ Returns
160
+ -------
161
+ train_images : numpy.ndarray
162
+ Numpy array with the images in the train MNIST database. The first
163
+ dimension indexes each sample, while the other two index rows and
164
+ columns of the image
165
+ """
56
166
return download_and_parse_mnist_file ('train-images-idx3-ubyte.gz' )
57
167
58
168
59
169
def test_images ():
170
+ """Return test images from Yann LeCun MNIST database as a numpy array.
171
+ Download the file, if not already found in the temporary directory of
172
+ the system.
173
+
174
+ Returns
175
+ -------
176
+ test_images : numpy.ndarray
177
+ Numpy array with the images in the train MNIST database. The first
178
+ dimension indexes each sample, while the other two index rows and
179
+ columns of the image
180
+ """
60
181
return download_and_parse_mnist_file ('t10k-images-idx3-ubyte.gz' )
61
182
62
183
63
184
def train_labels ():
185
+ """Return train labels from Yann LeCun MNIST database as a numpy array.
186
+ Download the file, if not already found in the temporary directory of
187
+ the system.
188
+
189
+ Returns
190
+ -------
191
+ train_labels : numpy.ndarray
192
+ Numpy array with the labels 0 to 9 in the train MNIST database.
193
+ """
64
194
return download_and_parse_mnist_file ('train-labels-idx1-ubyte.gz' )
65
195
66
196
67
197
def test_labels ():
198
+ """Return test labels from Yann LeCun MNIST database as a numpy array.
199
+ Download the file, if not already found in the temporary directory of
200
+ the system.
201
+
202
+ Returns
203
+ -------
204
+ test_labels : numpy.ndarray
205
+ Numpy array with the labels 0 to 9 in the train MNIST database.
206
+ """
68
207
return download_and_parse_mnist_file ('t10k-labels-idx1-ubyte.gz' )
0 commit comments