Skip to content

Commit 7685620

Browse files
add all scan version
1 parent a744d2d commit 7685620

File tree

1 file changed

+111
-22
lines changed

1 file changed

+111
-22
lines changed

rnn_class/renet.py

Lines changed: 111 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def relu(a):
3030

3131
def y2indicator(y):
3232
N = len(y)
33-
ind = np.zeros((N, 10))
33+
ind = np.zeros((N, 10), dtype='int32')
3434
for i in xrange(N):
3535
ind[i, y[i]] = 1
3636
return ind
@@ -133,11 +133,62 @@ def output(self, x, go_backwards=False):
133133
sequences=x,
134134
outputs_info=[self.H0],
135135
n_steps=x.shape[0],
136-
go_backwards=go_backwards
136+
go_backwards=go_backwards,
137+
# non_sequences=self.params,
138+
# strict=True,
137139
)
138140
return h
139141

140142

143+
def renet_layer_lr_noscan(X, rnn1, rnn2, w, h, wp, hp):
144+
list_of_images = []
145+
for i in xrange(h/hp):
146+
# x = X[:,i*hp:(i*hp + hp),:].dimshuffle((2, 0, 1)).flatten().reshape((w/wp, X.shape[0]*wp*hp))
147+
h_tm1 = rnn1.H0
148+
hr_tm1 = rnn2.H0
149+
h1 = []
150+
h2 = []
151+
for j in xrange(w/wp):
152+
x = X[:,i*hp:(i*hp + hp),j*wp:(j*wp + wp)].flatten()
153+
h_t = rnn1.recurrence(x, h_tm1)
154+
h1.append(h_t)
155+
h_tm1 = h_t
156+
157+
jr = w/wp - j - 1
158+
xr = X[:,i*hp:(i*hp + hp),jr*wp:(jr*wp + wp)].flatten()
159+
hr_t = rnn2.recurrence(x, hr_tm1)
160+
h2.append(hr_t)
161+
hr_tm1 = hr_t
162+
img = T.concatenate([h1, h2])
163+
list_of_images.append(img)
164+
return T.stacklists(list_of_images).dimshuffle((1, 0, 2))
165+
166+
167+
def renet_layer_lr_allscan(X, rnn1, rnn2, w, h, wp, hp):
168+
# list_of_images = []
169+
C = X.shape[0]
170+
X = X.dimshuffle((1, 0, 2)).reshape((h/hp, hp*C*w)) # split the rows for the first scan
171+
def rnn_pass(x):
172+
x = x.reshape((hp, C, w)).dimshuffle((2, 1, 0)).reshape((w/wp, C*wp*hp))
173+
h1 = rnn1.output(x)
174+
h2 = rnn2.output(x, go_backwards=True)
175+
img = T.concatenate([h1.T, h2.T])
176+
# list_of_images.append(img)
177+
return img
178+
179+
results, _ = theano.scan(
180+
fn=rnn_pass,
181+
sequences=X,
182+
outputs_info=None,
183+
n_steps=h/hp,
184+
)
185+
return results.dimshuffle((1, 0, 2))
186+
# return T.stacklists(list_of_images).dimshuffle((1, 0, 2))
187+
188+
189+
def renet_layer_ud_allscan(X, rnn1, rnn2, w, h, wp, hp):
190+
return renet_layer_lr_allscan(X.dimshuffle((0, 2, 1)), rnn1, rnn2, w, h, wp, hp)
191+
141192

142193
# expect the input image to be K x width x height
143194
# def renet_layer_lr(X, Wx1, Wh1, Bh1, H01, Wx2, Wh2, Bh2, H02, w, h, wp, hp):
@@ -156,7 +207,7 @@ def renet_layer_lr(X, rnn1, rnn2, w, h, wp, hp):
156207
# lefts = []
157208
# rights = []
158209
for i in xrange(h/hp):
159-
x = X[:,i*hp:(i*hp + hp),:].dimshuffle((1, 0, 2)).flatten().reshape((w/wp, X.shape[0]*wp*hp))
210+
x = X[:,i*hp:(i*hp + hp),:].dimshuffle((2, 0, 1)).flatten().reshape((w/wp, X.shape[0]*wp*hp))
160211
# reshape the row into a 2-D matrix to be fed into scan
161212
# h1, _ = theano.scan(
162213
# fn=recurrence1,
@@ -224,6 +275,7 @@ def getKaggleMNIST():
224275
# MNIST data:
225276
# column 0 is labels
226277
# column 1-785 is data, with values 0 .. 255
278+
# total size of CSV: (42000, 1, 28, 28)
227279
train = pd.read_csv('../large_files/train.csv').as_matrix()
228280
train = shuffle(train)
229281

@@ -239,7 +291,7 @@ def getKaggleMNIST():
239291

240292
def getMNIST():
241293
# data shape: train (50000, 784), test (10000, 784)
242-
# already scaled from 0..1
294+
# already scaled from 0..1 and converted to float32
243295
datadir = '../large_files/'
244296
if not os.path.exists(datadir):
245297
datadir = ''
@@ -262,11 +314,21 @@ def getMNIST():
262314
Ytrain_ind = y2indicator(Ytrain)
263315
Ytest_ind = y2indicator(Ytest)
264316

265-
return Xtrain.reshape(50000, 1, 28, 28), Ytrain, Ytrain_ind, Xtest.reshape(10000, 1, 28, 28), Ytest, Ytest_ind
317+
Xtrain, Ytrain = shuffle(Xtrain, Ytrain)
318+
Xtest, Ytest = shuffle(Xtest, Ytest)
266319

320+
# try to take a smaller sample
321+
Xtrain = Xtrain[0:30000]
322+
Ytrain = Ytrain[0:30000]
323+
Xtest = Xtest[0:1000]
324+
Ytest = Ytest[0:1000]
267325

268-
def main(ReUnit=GRU, getData=getMNIST):
326+
return Xtrain.reshape(len(Xtrain), 1, 28, 28), Ytrain, Ytrain_ind, Xtest.reshape(len(Xtest), 1, 28, 28), Ytest, Ytest_ind
327+
328+
329+
def main(ReUnit=RNNUnit, getData=getMNIST):
269330
t0 = datetime.now()
331+
print "Start time:", t0
270332

271333
Xtrain, Ytrain, Ytrain_ind, Xtest, Ytest, Ytest_ind = getData()
272334

@@ -282,25 +344,27 @@ def main(ReUnit=GRU, getData=getMNIST):
282344
M = 300
283345
K = 10
284346

285-
batch_sz = 100
347+
batch_sz = 1
286348
n_batches = N / batch_sz
287349

288-
M1 = 2 # num feature maps
350+
M1 = 256 # num feature maps
289351
rnn1 = ReUnit('1', 2, 2, C, M1)
290352
rnn2 = ReUnit('2', 2, 2, C, M1)
291353

292-
M2 = 2 # num feature maps
354+
M2 = 256 # num feature maps
293355
rnn3 = ReUnit('3', 1, 1, 2*M1, M2)
294356
rnn4 = ReUnit('4', 1, 1, 2*M1, M2)
295357

296-
M3 = 2
358+
M3 = 64
297359
rnn5 = ReUnit('5', 2, 2, 2*M2, M3)
298360
rnn6 = ReUnit('6', 2, 2, 2*M2, M3)
299361

300-
M4 = 2
362+
M4 = 64
301363
rnn7 = ReUnit('7', 1, 1, 2*M3, M4)
302364
rnn8 = ReUnit('8', 1, 1, 2*M3, M4)
303365

366+
print "Finished creating rnn objects, elapsed time:", (datetime.now() - t0)
367+
304368

305369
# vanilla ANN weights
306370
W9_init = np.random.randn(2*M4*7*7, M) / np.sqrt(2*M4*7*7 + M)
@@ -311,7 +375,7 @@ def main(ReUnit=GRU, getData=getMNIST):
311375

312376
# step 2: define theano variables and expressions
313377
X = T.tensor4('X', dtype='float32')
314-
# x = T.tensor3('x', dtype='float32')
378+
x = T.tensor3('x', dtype='float32')
315379
Y = T.matrix('T')
316380

317381
W9 = theano.shared(W9_init.astype(np.float32), 'W9')
@@ -322,6 +386,8 @@ def main(ReUnit=GRU, getData=getMNIST):
322386
for rnn in (rnn1, rnn2, rnn3, rnn4, rnn5, rnn6, rnn7, rnn8):
323387
params += rnn.params
324388

389+
390+
print "Finished creating all shared vars, elapsed time:", (datetime.now() - t0)
325391
# momentum changes
326392
# dW1 = theano.shared(np.zeros(W1_init.shape, dtype=np.float32), 'dW1')
327393
# db1 = theano.shared(np.zeros(b1_init.shape, dtype=np.float32), 'db1')
@@ -332,24 +398,45 @@ def main(ReUnit=GRU, getData=getMNIST):
332398
# dW4 = theano.shared(np.zeros(W4_init.shape, dtype=np.float32), 'dW4')
333399
# db4 = theano.shared(np.zeros(b4_init.shape, dtype=np.float32), 'db4')
334400

401+
# Z_tmp = renet_layer_lr_allscan(x, rnn1, rnn2, 28, 28, 2, 2)
402+
# # Z_tmp = renet_layer_lr_noscan(x, rnn1, rnn2, 28, 28, 2, 2)
403+
# tmp_op = theano.function(
404+
# inputs=[x],
405+
# outputs=Z_tmp,
406+
# )
407+
# print "Xtrain[0].shape:", Xtrain[0].shape
408+
# out = tmp_op(Xtrain[0])
409+
# print "Z_tmp.shape:", out.shape
410+
# exit()
411+
335412
def forward(x):
413+
# x = args[0]
336414
# forward pass
337-
Z1 = renet_layer_lr(x, rnn1, rnn2, 28, 28, 2, 2)
338-
Z2 = renet_layer_ud(Z1, rnn3, rnn4, 14, 14, 1, 1)
339-
Z3 = renet_layer_lr(Z2, rnn5, rnn6, 14, 14, 2, 2)
340-
Z4 = renet_layer_ud(Z3, rnn7, rnn8, 7, 7, 1, 1)
415+
Z1 = renet_layer_lr_allscan(x, rnn1, rnn2, 28, 28, 2, 2)
416+
Z2 = renet_layer_ud_allscan(Z1, rnn3, rnn4, 14, 14, 1, 1)
417+
Z3 = renet_layer_lr_allscan(Z2, rnn5, rnn6, 14, 14, 2, 2)
418+
Z4 = renet_layer_ud_allscan(Z3, rnn7, rnn8, 7, 7, 1, 1)
341419
Z5 = relu(Z4.flatten().dot(W9) + b9)
342420
pY = T.nnet.softmax( Z5.dot(W10) + b10)
343421
return pY
344422

345-
batch_forward_out3, _ = theano.scan(
346-
fn=forward,
347-
sequences=X,
348-
# outputs_info=[self.H0],
349-
n_steps=X.shape[0]
350-
)
423+
if True: #batch_sz > 1:
424+
batch_forward_out3, _ = theano.scan(
425+
fn=forward,
426+
sequences=X,
427+
# outputs_info=[self.H0],
428+
n_steps=X.shape[0],
429+
# non_sequences=params,
430+
# strict=True,
431+
)
432+
else:
433+
batch_forward_out3 = forward(X[0])
434+
435+
print "Finished creating output scan, elapsed time:", (datetime.now() - t0)
351436
batch_forward_out = batch_forward_out3.flatten(ndim=2) # the output will be (N, 1, 10)
352437

438+
print "Finished reshaping output, elapsed time:", (datetime.now() - t0)
439+
353440
## TMP: just test the first/second layer ##
354441
# tmp_op = theano.function(
355442
# inputs=[X],
@@ -404,6 +491,8 @@ def forward(x):
404491
# step 3: training expressions and functions
405492
updates = [(param, param - lr*T.grad(cost, param)) for param in params]
406493

494+
print "Finished creating update expressions, elapsed time:", (datetime.now() - t0)
495+
407496
# update weight changes
408497
# update_dW1 = mu*dW1 - lr*T.grad(cost, W1)
409498
# update_db1 = mu*db1 - lr*T.grad(cost, b1)

0 commit comments

Comments
 (0)