1
- """Functions for downloading and reading MNIST data."""
2
- from __future__ import print_function
3
- import gzip
4
- import os
5
- import urllib
6
- import numpy
7
- SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
8
- def maybe_download (filename , work_directory ):
9
- """Download the data from Yann's website, unless it's already here."""
10
- if not os .path .exists (work_directory ):
11
- os .mkdir (work_directory )
12
- filepath = os .path .join (work_directory , filename )
13
- if not os .path .exists (filepath ):
14
- filepath , _ = urllib .urlretrieve (SOURCE_URL + filename , filepath )
15
- statinfo = os .stat (filepath )
16
- print ('Succesfully downloaded' , filename , statinfo .st_size , 'bytes.' )
17
- return filepath
18
- def _read32 (bytestream ):
19
- dt = numpy .dtype (numpy .uint32 ).newbyteorder ('>' )
20
- return numpy .frombuffer (bytestream .read (4 ), dtype = dt )
21
- def extract_images (filename ):
22
- """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
23
- print ('Extracting' , filename )
24
- with gzip .open (filename ) as bytestream :
25
- magic = _read32 (bytestream )
26
- if magic != 2051 :
27
- raise ValueError (
28
- 'Invalid magic number %d in MNIST image file: %s' %
29
- (magic , filename ))
30
- num_images = _read32 (bytestream )
31
- rows = _read32 (bytestream )
32
- cols = _read32 (bytestream )
33
- buf = bytestream .read (rows * cols * num_images )
34
- data = numpy .frombuffer (buf , dtype = numpy .uint8 )
35
- data = data .reshape (num_images , rows , cols , 1 )
36
- return data
37
- def dense_to_one_hot (labels_dense , num_classes = 10 ):
38
- """Convert class labels from scalars to one-hot vectors."""
39
- num_labels = labels_dense .shape [0 ]
40
- index_offset = numpy .arange (num_labels ) * num_classes
41
- labels_one_hot = numpy .zeros ((num_labels , num_classes ))
42
- labels_one_hot .flat [index_offset + labels_dense .ravel ()] = 1
43
- return labels_one_hot
44
- def extract_labels (filename , one_hot = False ):
45
- """Extract the labels into a 1D uint8 numpy array [index]."""
46
- print ('Extracting' , filename )
47
- with gzip .open (filename ) as bytestream :
48
- magic = _read32 (bytestream )
49
- if magic != 2049 :
50
- raise ValueError (
51
- 'Invalid magic number %d in MNIST label file: %s' %
52
- (magic , filename ))
53
- num_items = _read32 (bytestream )
54
- buf = bytestream .read (num_items )
55
- labels = numpy .frombuffer (buf , dtype = numpy .uint8 )
56
- if one_hot :
57
- return dense_to_one_hot (labels )
58
- return labels
59
- class DataSet (object ):
60
- def __init__ (self , images , labels , fake_data = False ):
61
- if fake_data :
62
- self ._num_examples = 10000
63
- else :
64
- assert images .shape [0 ] == labels .shape [0 ], (
65
- "images.shape: %s labels.shape: %s" % (images .shape ,
66
- labels .shape ))
67
- self ._num_examples = images .shape [0 ]
68
- # Convert shape from [num examples, rows, columns, depth]
69
- # to [num examples, rows*columns] (assuming depth == 1)
70
- assert images .shape [3 ] == 1
71
- images = images .reshape (images .shape [0 ],
72
- images .shape [1 ] * images .shape [2 ])
73
- # Convert from [0, 255] -> [0.0, 1.0].
74
- images = images .astype (numpy .float32 )
75
- images = numpy .multiply (images , 1.0 / 255.0 )
76
- self ._images = images
77
- self ._labels = labels
78
- self ._epochs_completed = 0
79
- self ._index_in_epoch = 0
80
- @property
81
- def images (self ):
82
- return self ._images
83
- @property
84
- def labels (self ):
85
- return self ._labels
86
- @property
87
- def num_examples (self ):
88
- return self ._num_examples
89
- @property
90
- def epochs_completed (self ):
91
- return self ._epochs_completed
92
- def next_batch (self , batch_size , fake_data = False ):
93
- """Return the next `batch_size` examples from this data set."""
94
- if fake_data :
95
- fake_image = [1.0 for _ in xrange (784 )]
96
- fake_label = 0
97
- return [fake_image for _ in xrange (batch_size )], [
98
- fake_label for _ in xrange (batch_size )]
99
- start = self ._index_in_epoch
100
- self ._index_in_epoch += batch_size
101
- if self ._index_in_epoch > self ._num_examples :
102
- # Finished epoch
103
- self ._epochs_completed += 1
104
- # Shuffle the data
105
- perm = numpy .arange (self ._num_examples )
106
- numpy .random .shuffle (perm )
107
- self ._images = self ._images [perm ]
108
- self ._labels = self ._labels [perm ]
109
- # Start next epoch
110
- start = 0
111
- self ._index_in_epoch = batch_size
112
- assert batch_size <= self ._num_examples
113
- end = self ._index_in_epoch
114
- return self ._images [start :end ], self ._labels [start :end ]
115
- def read_data_sets (train_dir , fake_data = False , one_hot = False ):
116
- class DataSets (object ):
117
- pass
118
- data_sets = DataSets ()
119
- if fake_data :
120
- data_sets .train = DataSet ([], [], fake_data = True )
121
- data_sets .validation = DataSet ([], [], fake_data = True )
122
- data_sets .test = DataSet ([], [], fake_data = True )
123
- return data_sets
124
- TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
125
- TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
126
- TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
127
- TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
128
- VALIDATION_SIZE = 5000
129
- local_file = maybe_download (TRAIN_IMAGES , train_dir )
130
- train_images = extract_images (local_file )
131
- local_file = maybe_download (TRAIN_LABELS , train_dir )
132
- train_labels = extract_labels (local_file , one_hot = one_hot )
133
- local_file = maybe_download (TEST_IMAGES , train_dir )
134
- test_images = extract_images (local_file )
135
- local_file = maybe_download (TEST_LABELS , train_dir )
136
- test_labels = extract_labels (local_file , one_hot = one_hot )
137
- validation_images = train_images [:VALIDATION_SIZE ]
138
- validation_labels = train_labels [:VALIDATION_SIZE ]
139
- train_images = train_images [VALIDATION_SIZE :]
140
- train_labels = train_labels [VALIDATION_SIZE :]
141
- data_sets .train = DataSet (train_images , train_labels )
142
- data_sets .validation = DataSet (validation_images , validation_labels )
143
- data_sets .test = DataSet (test_images , test_labels )
1
+ """Functions for downloading and reading MNIST data."""
2
+
3
+ import gzip
4
+ import os
5
+ import urllib . request , urllib . parse , urllib . error
6
+ import numpy
7
+ SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
8
+ def maybe_download (filename , work_directory ):
9
+ """Download the data from Yann's website, unless it's already here."""
10
+ if not os .path .exists (work_directory ):
11
+ os .mkdir (work_directory )
12
+ filepath = os .path .join (work_directory , filename )
13
+ if not os .path .exists (filepath ):
14
+ filepath , _ = urllib .request . urlretrieve (SOURCE_URL + filename , filepath )
15
+ statinfo = os .stat (filepath )
16
+ print ('Succesfully downloaded' , filename , statinfo .st_size , 'bytes.' )
17
+ return filepath
18
+ def _read32 (bytestream ):
19
+ dt = numpy .dtype (numpy .uint32 ).newbyteorder ('>' )
20
+ return numpy .frombuffer (bytestream .read (4 ), dtype = dt )[ 0 ]
21
+ def extract_images (filename ):
22
+ """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
23
+ print ('Extracting' , filename )
24
+ with gzip .open (filename ) as bytestream :
25
+ magic = _read32 (bytestream )
26
+ if magic != 2051 :
27
+ raise ValueError (
28
+ 'Invalid magic number %d in MNIST image file: %s' %
29
+ (magic , filename ))
30
+ num_images = _read32 (bytestream )
31
+ rows = _read32 (bytestream )
32
+ cols = _read32 (bytestream )
33
+ buf = bytestream .read (rows * cols * num_images )
34
+ data = numpy .frombuffer (buf , dtype = numpy .uint8 )
35
+ data = data .reshape (num_images , rows , cols , 1 )
36
+ return data
37
+ def dense_to_one_hot (labels_dense , num_classes = 10 ):
38
+ """Convert class labels from scalars to one-hot vectors."""
39
+ num_labels = labels_dense .shape [0 ]
40
+ index_offset = numpy .arange (num_labels ) * num_classes
41
+ labels_one_hot = numpy .zeros ((num_labels , num_classes ))
42
+ labels_one_hot .flat [index_offset + labels_dense .ravel ()] = 1
43
+ return labels_one_hot
44
+ def extract_labels (filename , one_hot = False ):
45
+ """Extract the labels into a 1D uint8 numpy array [index]."""
46
+ print ('Extracting' , filename )
47
+ with gzip .open (filename ) as bytestream :
48
+ magic = _read32 (bytestream )
49
+ if magic != 2049 :
50
+ raise ValueError (
51
+ 'Invalid magic number %d in MNIST label file: %s' %
52
+ (magic , filename ))
53
+ num_items = _read32 (bytestream )
54
+ buf = bytestream .read (num_items )
55
+ labels = numpy .frombuffer (buf , dtype = numpy .uint8 )
56
+ if one_hot :
57
+ return dense_to_one_hot (labels )
58
+ return labels
59
+ class DataSet (object ):
60
+ def __init__ (self , images , labels , fake_data = False ):
61
+ if fake_data :
62
+ self ._num_examples = 10000
63
+ else :
64
+ assert images .shape [0 ] == labels .shape [0 ], (
65
+ "images.shape: %s labels.shape: %s" % (images .shape ,
66
+ labels .shape ))
67
+ self ._num_examples = images .shape [0 ]
68
+ # Convert shape from [num examples, rows, columns, depth]
69
+ # to [num examples, rows*columns] (assuming depth == 1)
70
+ assert images .shape [3 ] == 1
71
+ images = images .reshape (images .shape [0 ],
72
+ images .shape [1 ] * images .shape [2 ])
73
+ # Convert from [0, 255] -> [0.0, 1.0].
74
+ images = images .astype (numpy .float32 )
75
+ images = numpy .multiply (images , 1.0 / 255.0 )
76
+ self ._images = images
77
+ self ._labels = labels
78
+ self ._epochs_completed = 0
79
+ self ._index_in_epoch = 0
80
+ @property
81
+ def images (self ):
82
+ return self ._images
83
+ @property
84
+ def labels (self ):
85
+ return self ._labels
86
+ @property
87
+ def num_examples (self ):
88
+ return self ._num_examples
89
+ @property
90
+ def epochs_completed (self ):
91
+ return self ._epochs_completed
92
+ def next_batch (self , batch_size , fake_data = False ):
93
+ """Return the next `batch_size` examples from this data set."""
94
+ if fake_data :
95
+ fake_image = [1.0 for _ in range (784 )]
96
+ fake_label = 0
97
+ return [fake_image for _ in range (batch_size )], [
98
+ fake_label for _ in range (batch_size )]
99
+ start = self ._index_in_epoch
100
+ self ._index_in_epoch += batch_size
101
+ if self ._index_in_epoch > self ._num_examples :
102
+ # Finished epoch
103
+ self ._epochs_completed += 1
104
+ # Shuffle the data
105
+ perm = numpy .arange (self ._num_examples )
106
+ numpy .random .shuffle (perm )
107
+ self ._images = self ._images [perm ]
108
+ self ._labels = self ._labels [perm ]
109
+ # Start next epoch
110
+ start = 0
111
+ self ._index_in_epoch = batch_size
112
+ assert batch_size <= self ._num_examples
113
+ end = self ._index_in_epoch
114
+ return self ._images [start :end ], self ._labels [start :end ]
115
+ def read_data_sets (train_dir , fake_data = False , one_hot = False ):
116
+ class DataSets (object ):
117
+ pass
118
+ data_sets = DataSets ()
119
+ if fake_data :
120
+ data_sets .train = DataSet ([], [], fake_data = True )
121
+ data_sets .validation = DataSet ([], [], fake_data = True )
122
+ data_sets .test = DataSet ([], [], fake_data = True )
123
+ return data_sets
124
+ TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
125
+ TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
126
+ TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
127
+ TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
128
+ VALIDATION_SIZE = 5000
129
+ local_file = maybe_download (TRAIN_IMAGES , train_dir )
130
+ train_images = extract_images (local_file )
131
+ local_file = maybe_download (TRAIN_LABELS , train_dir )
132
+ train_labels = extract_labels (local_file , one_hot = one_hot )
133
+ local_file = maybe_download (TEST_IMAGES , train_dir )
134
+ test_images = extract_images (local_file )
135
+ local_file = maybe_download (TEST_LABELS , train_dir )
136
+ test_labels = extract_labels (local_file , one_hot = one_hot )
137
+ validation_images = train_images [:VALIDATION_SIZE ]
138
+ validation_labels = train_labels [:VALIDATION_SIZE ]
139
+ train_images = train_images [VALIDATION_SIZE :]
140
+ train_labels = train_labels [VALIDATION_SIZE :]
141
+ data_sets .train = DataSet (train_images , train_labels )
142
+ data_sets .validation = DataSet (validation_images , validation_labels )
143
+ data_sets .test = DataSet (test_images , test_labels )
144
144
return data_sets
0 commit comments