Skip to content

Commit a744d2d

Browse files
add batches and downcast
1 parent f8a7327 commit a744d2d

File tree

1 file changed

+56
-124
lines changed

1 file changed

+56
-124
lines changed

rnn_class/renet.py

Lines changed: 56 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -279,61 +279,25 @@ def main(ReUnit=GRU, getData=getMNIST):
279279

280280
N = Xtrain.shape[0]
281281
C = Xtrain.shape[1]
282-
M = 4096
282+
M = 300
283283
K = 10
284284

285-
# New
286-
wp, hp = 2, 2
287-
288-
M1 = 256 # num feature maps
289-
# Wx1_shape = (M1, Xtrain.shape[1]*wp*hp)
290-
# Wx1_init = init_filter(Wx1_shape)
291-
# Wh1_init = init_filter( (M1,M1) )
292-
# bh1_init = np.zeros((M1,), dtype=np.float32)
293-
# H01_init = init_filter( (M1,) )
294-
# Wx2_init = init_filter(Wx1_shape)
295-
# Wh2_init = init_filter( (M1,M1) )
296-
# bh2_init = np.zeros((M1,), dtype=np.float32)
297-
# H02_init = init_filter( (M1,) )
285+
batch_sz = 100
286+
n_batches = N / batch_sz
287+
288+
M1 = 2 # num feature maps
298289
rnn1 = ReUnit('1', 2, 2, C, M1)
299290
rnn2 = ReUnit('2', 2, 2, C, M1)
300291

301-
M2 = 256 # num feature maps
302-
# Wx3_shape = (M2, 2*M1*1*1)
303-
# Wx3_init = init_filter(Wx3_shape)
304-
# Wh3_init = init_filter( (M2,M2) )
305-
# bh3_init = np.zeros((M2,), dtype=np.float32)
306-
# H03_init = init_filter( (M2,) )
307-
# Wx4_init = init_filter(Wx3_shape)
308-
# Wh4_init = init_filter( (M2,M2) )
309-
# bh4_init = np.zeros((M2,), dtype=np.float32)
310-
# H04_init = init_filter( (M2,) )
292+
M2 = 2 # num feature maps
311293
rnn3 = ReUnit('3', 1, 1, 2*M1, M2)
312294
rnn4 = ReUnit('4', 1, 1, 2*M1, M2)
313295

314-
M3 = 64
315-
# Wx5_shape = (M3, 2*M2*wp*hp)
316-
# Wx5_init = init_filter(Wx5_shape)
317-
# Wh5_init = init_filter( (M3,M3) )
318-
# bh5_init = np.zeros((M3,), dtype=np.float32)
319-
# H05_init = init_filter( (M3,) )
320-
# Wx6_init = init_filter(Wx5_shape)
321-
# Wh6_init = init_filter( (M3,M3) )
322-
# bh6_init = np.zeros((M3,), dtype=np.float32)
323-
# H06_init = init_filter( (M3,) )
296+
M3 = 2
324297
rnn5 = ReUnit('5', 2, 2, 2*M2, M3)
325298
rnn6 = ReUnit('6', 2, 2, 2*M2, M3)
326299

327-
M4 = 64
328-
# Wx7_shape = (M4, 2*M3*1*1)
329-
# Wx7_init = init_filter(Wx7_shape)
330-
# Wh7_init = init_filter( (M4,M4) )
331-
# bh7_init = np.zeros((M4,), dtype=np.float32)
332-
# H07_init = init_filter( (M4,) )
333-
# Wx8_init = init_filter(Wx7_shape)
334-
# Wh8_init = init_filter( (M4,M4) )
335-
# bh8_init = np.zeros((M4,), dtype=np.float32)
336-
# H08_init = init_filter( (M4,) )
300+
M4 = 2
337301
rnn7 = ReUnit('7', 1, 1, 2*M3, M4)
338302
rnn8 = ReUnit('8', 1, 1, 2*M3, M4)
339303

@@ -346,59 +310,15 @@ def main(ReUnit=GRU, getData=getMNIST):
346310

347311

348312
# step 2: define theano variables and expressions
349-
X = T.tensor3('X', dtype='float32')
313+
X = T.tensor4('X', dtype='float32')
314+
# x = T.tensor3('x', dtype='float32')
350315
Y = T.matrix('T')
351-
# Wx1 = theano.shared(Wx1_init, 'Wx1')
352-
# Wh1 = theano.shared(Wh1_init, 'Wh1')
353-
# bh1 = theano.shared(bh1_init, 'bh1')
354-
# H01 = theano.shared(H01_init, 'H01')
355-
# Wx2 = theano.shared(Wx2_init, 'Wx2')
356-
# Wh2 = theano.shared(Wh2_init, 'Wh2')
357-
# bh2 = theano.shared(bh2_init, 'bh2')
358-
# H02 = theano.shared(H02_init, 'H02')
359-
360-
# Wx3 = theano.shared(Wx3_init, 'Wx3')
361-
# Wh3 = theano.shared(Wh3_init, 'Wh3')
362-
# bh3 = theano.shared(bh3_init, 'bh3')
363-
# H03 = theano.shared(H03_init, 'H03')
364-
# Wx4 = theano.shared(Wx4_init, 'Wx4')
365-
# Wh4 = theano.shared(Wh4_init, 'Wh4')
366-
# bh4 = theano.shared(bh4_init, 'bh4')
367-
# H04 = theano.shared(H04_init, 'H04')
368-
369-
# Wx5 = theano.shared(Wx5_init, 'Wx5')
370-
# Wh5 = theano.shared(Wh5_init, 'Wh5')
371-
# bh5 = theano.shared(bh5_init, 'bh5')
372-
# H05 = theano.shared(H05_init, 'H05')
373-
# Wx6 = theano.shared(Wx6_init, 'Wx6')
374-
# Wh6 = theano.shared(Wh6_init, 'Wh6')
375-
# bh6 = theano.shared(bh6_init, 'bh6')
376-
# H06 = theano.shared(H06_init, 'H06')
377-
378-
# Wx7 = theano.shared(Wx7_init, 'Wx7')
379-
# Wh7 = theano.shared(Wh7_init, 'Wh7')
380-
# bh7 = theano.shared(bh7_init, 'bh7')
381-
# H07 = theano.shared(H07_init, 'H07')
382-
# Wx8 = theano.shared(Wx8_init, 'Wx8')
383-
# Wh8 = theano.shared(Wh8_init, 'Wh8')
384-
# bh8 = theano.shared(bh8_init, 'bh8')
385-
# H08 = theano.shared(H08_init, 'H08')
386316

387317
W9 = theano.shared(W9_init.astype(np.float32), 'W9')
388318
b9 = theano.shared(b9_init, 'b9')
389319
W10 = theano.shared(W10_init.astype(np.float32), 'W10')
390320
b10 = theano.shared(b10_init, 'b10')
391-
params = [
392-
# Wx1, Wh1, bh1, H01,
393-
# Wx2, Wh2, bh2, H02,
394-
# Wx3, Wh3, bh3, H03,
395-
# Wx4, Wh4, bh4, H04,
396-
# Wx5, Wh5, bh5, H05,
397-
# Wx6, Wh6, bh6, H06,
398-
# Wx7, Wh7, bh7, H07,
399-
# Wx8, Wh8, bh8, H08,
400-
W9, b9, W10, b10,
401-
]
321+
params = [W9, b9, W10, b10]
402322
for rnn in (rnn1, rnn2, rnn3, rnn4, rnn5, rnn6, rnn7, rnn8):
403323
params += rnn.params
404324

@@ -412,9 +332,23 @@ def main(ReUnit=GRU, getData=getMNIST):
412332
# dW4 = theano.shared(np.zeros(W4_init.shape, dtype=np.float32), 'dW4')
413333
# db4 = theano.shared(np.zeros(b4_init.shape, dtype=np.float32), 'db4')
414334

415-
# forward pass
416-
# Z1 = renet_layer_lr(X, Wx1, Wh1, bh1, H01, Wx2, Wh2, bh2, H02, 28, 28, wp, hp)
417-
Z1 = renet_layer_lr(X, rnn1, rnn2, 28, 28, wp, hp)
335+
def forward(x):
336+
# forward pass
337+
Z1 = renet_layer_lr(x, rnn1, rnn2, 28, 28, 2, 2)
338+
Z2 = renet_layer_ud(Z1, rnn3, rnn4, 14, 14, 1, 1)
339+
Z3 = renet_layer_lr(Z2, rnn5, rnn6, 14, 14, 2, 2)
340+
Z4 = renet_layer_ud(Z3, rnn7, rnn8, 7, 7, 1, 1)
341+
Z5 = relu(Z4.flatten().dot(W9) + b9)
342+
pY = T.nnet.softmax( Z5.dot(W10) + b10)
343+
return pY
344+
345+
batch_forward_out3, _ = theano.scan(
346+
fn=forward,
347+
sequences=X,
348+
# outputs_info=[self.H0],
349+
n_steps=X.shape[0]
350+
)
351+
batch_forward_out = batch_forward_out3.flatten(ndim=2) # the output will be (N, 1, 10)
418352

419353
## TMP: just test the first/second layer ##
420354
# tmp_op = theano.function(
@@ -426,8 +360,7 @@ def main(ReUnit=GRU, getData=getMNIST):
426360
# print "Z1.shape:", out.shape
427361
# exit()
428362

429-
# Z2 = renet_layer_ud(Z1, Wx3, Wh3, bh3, H03, Wx4, Wh4, bh4, H04, 14, 14, 1, 1)
430-
Z2 = renet_layer_ud(Z1, rnn3, rnn4, 14, 14, 1, 1)
363+
431364

432365
# tmp_op2 = theano.function(
433366
# inputs=[X],
@@ -437,9 +370,7 @@ def main(ReUnit=GRU, getData=getMNIST):
437370
# print "Z2.shape:", out.shape
438371
# exit()
439372

440-
441-
# Z3 = renet_layer_lr(Z2, Wx5, Wh5, bh5, H05, Wx6, Wh6, bh6, H06, 14, 14, wp, hp)
442-
Z3 = renet_layer_lr(Z2, rnn5, rnn6, 14, 14, wp, hp)
373+
443374

444375
# tmp_op3 = theano.function(
445376
# inputs=[X],
@@ -449,11 +380,7 @@ def main(ReUnit=GRU, getData=getMNIST):
449380
# print "Z3.shape:", out.shape
450381
# exit()
451382

452-
# Z4 = renet_layer_ud(Z3, Wx7, Wh7, bh7, H07, Wx8, Wh8, bh8, H08, 7, 7, 1, 1)
453-
Z4 = renet_layer_ud(Z3, rnn7, rnn8, 7, 7, 1, 1)
454-
455-
Z5 = relu(Z4.flatten().dot(W9) + b9)
456-
pY = T.nnet.softmax( Z5.dot(W10) + b10)
383+
457384

458385
# tmp_op4 = theano.function(
459386
# inputs=[X],
@@ -463,21 +390,18 @@ def main(ReUnit=GRU, getData=getMNIST):
463390
# print "Z4.shape:", out.shape
464391
# exit()
465392

393+
# tmp_op_out = theano.function(inputs=[X], outputs=batch_forward_out)
394+
# out = tmp_op_out(Xtest[0:50,])
395+
# print "out.shape:", out.shape
396+
# exit()
397+
466398
# define the cost function and prediction
467399
# params = (W1, b1, W2, b2, W3, b3, W4, b4)
468400
reg_cost = reg*np.sum((param*param).sum() for param in params)
469-
cost = -(Y * T.log(pY)).sum() + reg_cost
470-
prediction = T.argmax(pY, axis=1)
401+
cost = -(Y * T.log(batch_forward_out)).sum() + reg_cost
402+
prediction = T.argmax(batch_forward_out, axis=1)
471403

472404
# step 3: training expressions and functions
473-
# update_W1 = W1 + mu*dW1 - lr*T.grad(cost, W1)
474-
# update_b1 = b1 + mu*db1 - lr*T.grad(cost, b1)
475-
# update_W2 = W2 + mu*dW2 - lr*T.grad(cost, W2)
476-
# update_b2 = b2 + mu*db2 - lr*T.grad(cost, b2)
477-
# update_W3 = W3 + mu*dW3 - lr*T.grad(cost, W3)
478-
# update_b3 = b3 + mu*db3 - lr*T.grad(cost, b3)
479-
# update_W4 = W4 + mu*dW4 - lr*T.grad(cost, W4)
480-
# update_b4 = b4 + mu*db4 - lr*T.grad(cost, b4)
481405
updates = [(param, param - lr*T.grad(cost, param)) for param in params]
482406

483407
# update weight changes
@@ -493,34 +417,42 @@ def main(ReUnit=GRU, getData=getMNIST):
493417
train = theano.function(
494418
inputs=[X, Y],
495419
updates=updates,
420+
allow_input_downcast=True,
496421
)
497422

498423
# create another function for this because we want it over the whole dataset
499424
get_prediction = theano.function(
500425
inputs=[X, Y],
501426
outputs=[cost, prediction],
427+
allow_input_downcast=True,
502428
)
503429

504430
print "Setup elapsed time:", (datetime.now() - t0)
431+
432+
# test it
433+
# print get_prediction(Xtest, Ytest_ind)
434+
# exit()
435+
505436
t0 = datetime.now()
506437
LL = []
507438
t1 = t0
508439
for i in xrange(max_iter):
509440
print "i:", i
510-
for j in xrange(N):
441+
for j in xrange(n_batches):
511442
# print "j:", j
512-
Xbatch = Xtrain[j,:]
513-
Ybatch = Ytrain_ind[j:j+1,:]
443+
Xbatch = Xtrain[j*batch_sz:(j*batch_sz + batch_sz),:]
444+
Ybatch = Ytrain_ind[j*batch_sz:(j*batch_sz + batch_sz),:]
514445

515446
train(Xbatch, Ybatch)
516447
if j % print_period == 0:
517-
cost_val = 0
518-
prediction_val = np.zeros(len(Ytest))
519-
for k in xrange(len(Ytest)):
520-
c, p = get_prediction(Xtest[k], Ytest_ind[k:k+1,:])
521-
cost_val += c
522-
prediction_val[k] = p[0]
523-
# print "pred:", p[0], type(p[0]), "target:", Ytest[k], type(Ytest[k])
448+
cost_val, prediction_val = get_prediction(Xtest, Ytest_ind)
449+
# cost_val = 0
450+
# prediction_val = np.zeros(len(Ytest))
451+
# for k in xrange(len(Ytest)):
452+
# c, p = get_prediction(Xtest[k], Ytest_ind[k:k+1,:])
453+
# cost_val += c
454+
# prediction_val[k] = p[0]
455+
# # print "pred:", p[0], type(p[0]), "target:", Ytest[k], type(Ytest[k])
524456
err = error_rate(prediction_val, Ytest)
525457
print "Cost / err at iteration i=%d, j=%d: %.3f / %.2f" % (i, j, cost_val / len(Ytest), err)
526458
t2 = datetime.now()

0 commit comments

Comments
 (0)