20
20
from scipy .io import loadmat
21
21
from sklearn .utils import shuffle
22
22
23
- from benchmark import get_data , y2indicator , error_rate
23
+ from benchmark import get_data , error_rate
24
24
25
25
26
26
def convpool (X , W , b ):
@@ -61,12 +61,10 @@ def main():
61
61
# print len(Ytrain)
62
62
del train
63
63
Xtrain , Ytrain = shuffle (Xtrain , Ytrain )
64
- Ytrain_ind = y2indicator (Ytrain )
65
64
66
65
Xtest = rearrange (test ['X' ])
67
66
Ytest = test ['y' ].flatten () - 1
68
67
del test
69
- Ytest_ind = y2indicator (Ytest )
70
68
71
69
# gradient descent params
72
70
max_iter = 6
@@ -81,7 +79,6 @@ def main():
81
79
Ytrain = Ytrain [:73000 ]
82
80
Xtest = Xtest [:26000 ,]
83
81
Ytest = Ytest [:26000 ]
84
- Ytest_ind = Ytest_ind [:26000 ,]
85
82
# print "Xtest.shape:", Xtest.shape
86
83
# print "Ytest.shape:", Ytest.shape
87
84
@@ -108,7 +105,7 @@ def main():
108
105
# define variables and expressions
109
106
# using None as the first shape element takes up too much RAM unfortunately
110
107
X = tf .placeholder (tf .float32 , shape = (batch_sz , 32 , 32 , 3 ), name = 'X' )
111
- T = tf .placeholder (tf .float32 , shape = (batch_sz , K ), name = 'T' )
108
+ T = tf .placeholder (tf .int32 , shape = (batch_sz ,), name = 'T' )
112
109
W1 = tf .Variable (W1_init .astype (np .float32 ))
113
110
b1 = tf .Variable (b1_init .astype (np .float32 ))
114
111
W2 = tf .Variable (W2_init .astype (np .float32 ))
@@ -126,7 +123,7 @@ def main():
126
123
Yish = tf .matmul (Z3 , W4 ) + b4
127
124
128
125
cost = tf .reduce_sum (
129
- tf .nn .softmax_cross_entropy_with_logits (
126
+ tf .nn .sparse_softmax_cross_entropy_with_logits (
130
127
logits = Yish ,
131
128
labels = T
132
129
)
@@ -139,14 +136,16 @@ def main():
139
136
140
137
t0 = datetime .now ()
141
138
LL = []
139
+ W1_val = None
140
+ W2_val = None
142
141
init = tf .global_variables_initializer ()
143
142
with tf .Session () as session :
144
143
session .run (init )
145
144
146
145
for i in range (max_iter ):
147
146
for j in range (n_batches ):
148
147
Xbatch = Xtrain [j * batch_sz :(j * batch_sz + batch_sz ),]
149
- Ybatch = Ytrain_ind [j * batch_sz :(j * batch_sz + batch_sz ),]
148
+ Ybatch = Ytrain [j * batch_sz :(j * batch_sz + batch_sz ),]
150
149
151
150
if len (Xbatch ) == batch_sz :
152
151
session .run (train_op , feed_dict = {X : Xbatch , T : Ybatch })
@@ -157,17 +156,59 @@ def main():
157
156
prediction = np .zeros (len (Xtest ))
158
157
for k in range (len (Xtest ) // batch_sz ):
159
158
Xtestbatch = Xtest [k * batch_sz :(k * batch_sz + batch_sz ),]
160
- Ytestbatch = Ytest_ind [k * batch_sz :(k * batch_sz + batch_sz ),]
159
+ Ytestbatch = Ytest [k * batch_sz :(k * batch_sz + batch_sz ),]
161
160
test_cost += session .run (cost , feed_dict = {X : Xtestbatch , T : Ytestbatch })
162
161
prediction [k * batch_sz :(k * batch_sz + batch_sz )] = session .run (
163
162
predict_op , feed_dict = {X : Xtestbatch })
164
163
err = error_rate (prediction , Ytest )
165
164
print ("Cost / err at iteration i=%d, j=%d: %.3f / %.3f" % (i , j , test_cost , err ))
166
165
LL .append (test_cost )
166
+
167
+ W1_val = W1 .eval ()
168
+ W2_val = W2 .eval ()
167
169
print ("Elapsed time:" , (datetime .now () - t0 ))
168
170
plt .plot (LL )
169
171
plt .show ()
170
172
171
173
174
+ W1_val = W1_val .transpose (3 , 2 , 0 , 1 )
175
+ W2_val = W2_val .transpose (3 , 2 , 0 , 1 )
176
+
177
+
178
+ # visualize W1 (20, 3, 5, 5)
179
+ # W1_val = W1.get_value()
180
+ grid = np .zeros ((8 * 5 , 8 * 5 ))
181
+ m = 0
182
+ n = 0
183
+ for i in range (20 ):
184
+ for j in range (3 ):
185
+ filt = W1_val [i ,j ]
186
+ grid [m * 5 :(m + 1 )* 5 ,n * 5 :(n + 1 )* 5 ] = filt
187
+ m += 1
188
+ if m >= 8 :
189
+ m = 0
190
+ n += 1
191
+ plt .imshow (grid , cmap = 'gray' )
192
+ plt .title ("W1" )
193
+ plt .show ()
194
+
195
+ # visualize W2 (50, 20, 5, 5)
196
+ # W2_val = W2.get_value()
197
+ grid = np .zeros ((32 * 5 , 32 * 5 ))
198
+ m = 0
199
+ n = 0
200
+ for i in range (50 ):
201
+ for j in range (20 ):
202
+ filt = W2_val [i ,j ]
203
+ grid [m * 5 :(m + 1 )* 5 ,n * 5 :(n + 1 )* 5 ] = filt
204
+ m += 1
205
+ if m >= 32 :
206
+ m = 0
207
+ n += 1
208
+ plt .imshow (grid , cmap = 'gray' )
209
+ plt .title ("W2" )
210
+ plt .show ()
211
+
212
+
172
213
if __name__ == '__main__' :
173
214
main ()
0 commit comments