Skip to content

Commit 3395cce

Browse files
authored
Update biRNN.ipynb
1 parent 3974236 commit 3395cce

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

biRNN.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
" return iterator.get_next()\n",
4848
" return input_fn\n",
4949
"\n",
50-
"padding_info = ({'image':[28,28,1],'label':[]})\n",
50+
"padding_info = ({'image':[784,],'label':[]})\n",
5151
"test_input_fn = input_fn_maker('mnist_tfrecord/test/', 'mnist_tfrecord/data_info.csv',batch_size = 512,\n",
5252
" padding = padding_info)\n",
5353
"train_input_fn = input_fn_maker('mnist_tfrecord/train/', 'mnist_tfrecord/data_info.csv', shuffle=True, batch_size = 128,\n",
@@ -63,10 +63,11 @@
6363
"outputs": [],
6464
"source": [
6565
"def model_fn(features, mode):\n",
66+
" features['image'] = tf.reshape(features['image'],[-1,28,28,1])\n",
6667
" # shape: [None,28,28,1]\n",
6768
" # create RNN cells:\n",
68-
" rnn_fcells = [tf.nn.rnn_cell.GRUCell(dim) for dim in [128,256]]\n",
69-
" rnn_bcells = [tf.nn.rnn_cell.GRUCell(dim) for dim in [128,256]]\n",
69+
" rnn_fcells = [tf.nn.rnn_cell.GRUCell(dim,kernel_initializer=tf.orthogonal_initializer) for dim in [128,256]]\n",
70+
" rnn_bcells = [tf.nn.rnn_cell.GRUCell(dim,kernel_initializer=tf.orthogonal_initializer) for dim in [128,256]]\n",
7071
" # stack cells for multi-layers RNN\n",
7172
" multi_rnn_fcell = tf.nn.rnn_cell.MultiRNNCell(rnn_fcells)\n",
7273
" multi_rnn_bcell = tf.nn.rnn_cell.MultiRNNCell(rnn_bcells)\n",

0 commit comments

Comments
 (0)