-
Notifications
You must be signed in to change notification settings - Fork 2
/
stl10_input.py
128 lines (98 loc) · 3.79 KB
/
stl10_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import print_function
import os
import sys
import tarfile
import matplotlib.pyplot as plt
import numpy as np
from keras.utils.data_utils import get_file
if sys.version_info >= (3, 0, 0):
import urllib.request as urllib # ugly but works
else:
import urllib
print(sys.version_info)
# image shape
HEIGHT = 96
WIDTH = 96
DEPTH = 3
# size of a single image in bytes
SIZE = HEIGHT * WIDTH * DEPTH
# path to the directory with the data
DATA_DIR = 'stl10_binary'
# url of the binary data
DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz'
def read_labels(path_to_labels):
"""
:param path_to_labels: path to the binary file containing labels from the STL-10 dataset
:return: an array containing the labels
"""
with open(path_to_labels, 'rb') as f:
labels = np.fromfile(f, dtype=np.uint8)
return labels
def read_all_images(path_to_data):
"""
:param path_to_data: the file containing the binary images from the STL-10 dataset
:return: an array containing all the images
"""
with open(path_to_data, 'rb') as f:
# read whole file in uint8 chunks
everything = np.fromfile(f, dtype=np.uint8)
# We force the data into 3x96x96 chunks, since the
# images are stored in "column-major order", meaning
# that "the first 96*96 values are the red channel,
# the next 96*96 are green, and the last are blue."
# The -1 is since the size of the pictures depends
# on the input file, and this way numpy determines
# the size on its own.
images = np.reshape(everything, (-1, 3, 96, 96))
# Now transpose the images into a standard image format
# readable by, for example, matplotlib.imshow
# You might want to comment this line or reverse the shuffle
# if you will use a learning algorithm like CNN, since they like
# their channels separated.
images = np.transpose(images, (0, 3, 2, 1))
return images
def read_single_image(image_file):
"""
CAREFUL! - this method uses a file as input instead of the path - so the
position of the reader will be remembered outside of context of this method.
:param image_file: the open file containing the images
:return: a single image
"""
# read a single image, count determines the number of uint8's to read
image = np.fromfile(image_file, dtype=np.uint8, count=SIZE)
# force into image matrix
image = np.reshape(image, (3, 96, 96))
# transpose to standard format
# You might want to comment this line or reverse the shuffle
# if you will use a learning algorithm like CNN, since they like
# their channels separated.
image = np.transpose(image, (2, 1, 0))
return image
def plot_image(image):
"""
:param image: the image to be plotted in a 3-D matrix format
:return: None
"""
plt.imshow(image)
plt.show()
def load_data():
# download data if needed
path = get_file(DATA_DIR, origin=DATA_URL, untar=True)
# test to check if the whole dataset is read correctly
# path to the binary train file with image data
train_data_path = os.path.join(path, 'train_X.bin')
# path to the binary train file with labels
train_label_path = os.path.join(path, 'train_y.bin')
# path to the binary test file with image data
test_data_path = os.path.join(path, 'test_X.bin')
# path to the binary test file with labels
test_label_path = os.path.join(path, 'test_y.bin')
x_train = read_all_images(train_data_path)
print(x_train.shape)
y_train = read_labels(train_label_path)
print(y_train.shape)
x_test = read_all_images(test_data_path)
print(x_test.shape)
y_test = read_labels(test_label_path)
print(y_test.shape)
return (x_train, y_train), (x_test, y_test)