Skip to content

Commit b980473

Browse files
Update bidirectional rnn for TF1.0
Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>
1 parent 2f0c1ba commit b980473

File tree

1 file changed

+44
-125
lines changed

1 file changed

+44
-125
lines changed

notebooks/3_NeuralNetworks/bidirectional_rnn.ipynb

Lines changed: 44 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
{
22
"cells": [
33
{
4-
"cell_type": "code",
5-
"execution_count": null,
6-
"metadata": {},
7-
"outputs": [],
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"collapsed": true
7+
},
88
"source": [
99
"'''\n",
1010
"A Bidirectional Reccurent Neural Network (LSTM) implementation example using TensorFlow library.\n",
@@ -18,35 +18,26 @@
1818
},
1919
{
2020
"cell_type": "code",
21-
"execution_count": 1,
22-
"metadata": {},
23-
"outputs": [
24-
{
25-
"name": "stdout",
26-
"output_type": "stream",
27-
"text": [
28-
"Extracting /tmp/data/train-images-idx3-ubyte.gz\n",
29-
"Extracting /tmp/data/train-labels-idx1-ubyte.gz\n",
30-
"Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n",
31-
"Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n"
32-
]
33-
}
34-
],
21+
"execution_count": null,
22+
"metadata": {
23+
"collapsed": false
24+
},
25+
"outputs": [],
3526
"source": [
3627
"import tensorflow as tf\n",
37-
"from tensorflow.models.rnn import rnn, rnn_cell\n",
28+
"from tensorflow.contrib import rnn\n",
3829
"import numpy as np\n",
3930
"\n",
4031
"# Import MINST data\n",
4132
"from tensorflow.examples.tutorials.mnist import input_data\n",
42-
"mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)"
33+
"mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)"
4334
]
4435
},
4536
{
46-
"cell_type": "code",
47-
"execution_count": null,
48-
"metadata": {},
49-
"outputs": [],
37+
"cell_type": "markdown",
38+
"metadata": {
39+
"collapsed": true
40+
},
5041
"source": [
5142
"'''\n",
5243
"To classify images using a bidirectional reccurent neural network, we consider\n",
@@ -58,7 +49,9 @@
5849
{
5950
"cell_type": "code",
6051
"execution_count": 2,
61-
"metadata": {},
52+
"metadata": {
53+
"collapsed": true
54+
},
6255
"outputs": [],
6356
"source": [
6457
"# Parameters\n",
@@ -90,7 +83,9 @@
9083
{
9184
"cell_type": "code",
9285
"execution_count": 3,
93-
"metadata": {},
86+
"metadata": {
87+
"collapsed": false
88+
},
9489
"outputs": [],
9590
"source": [
9691
"def BiRNN(x, weights, biases):\n",
@@ -104,20 +99,20 @@
10499
" # Reshape to (n_steps*batch_size, n_input)\n",
105100
" x = tf.reshape(x, [-1, n_input])\n",
106101
" # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)\n",
107-
" x = tf.split(0, n_steps, x)\n",
102+
" x = tf.split(x, n_steps, 0)\n",
108103
"\n",
109104
" # Define lstm cells with tensorflow\n",
110105
" # Forward direction cell\n",
111-
" lstm_fw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)\n",
106+
" lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)\n",
112107
" # Backward direction cell\n",
113-
" lstm_bw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)\n",
108+
" lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)\n",
114109
"\n",
115110
" # Get lstm cell output\n",
116111
" try:\n",
117-
" outputs, _, _ = rnn.bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,\n",
112+
" outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,\n",
118113
" dtype=tf.float32)\n",
119114
" except Exception: # Old TensorFlow version only returns outputs not states\n",
120-
" outputs = rnn.bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,\n",
115+
" outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,\n",
121116
" dtype=tf.float32)\n",
122117
"\n",
123118
" # Linear activation, using rnn inner loop last output\n",
@@ -126,109 +121,24 @@
126121
"pred = BiRNN(x, weights, biases)\n",
127122
"\n",
128123
"# Define loss and optimizer\n",
129-
"cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))\n",
124+
"cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))\n",
130125
"optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)\n",
131126
"\n",
132127
"# Evaluate model\n",
133128
"correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))\n",
134129
"accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n",
135130
"\n",
136131
"# Initializing the variables\n",
137-
"init = tf.initialize_all_variables()"
132+
"init = tf.global_variables_initializer()"
138133
]
139134
},
140135
{
141136
"cell_type": "code",
142-
"execution_count": 4,
143-
"metadata": {},
144-
"outputs": [
145-
{
146-
"name": "stdout",
147-
"output_type": "stream",
148-
"text": [
149-
"Iter 1280, Minibatch Loss= 1.689740, Training Accuracy= 0.36719\n",
150-
"Iter 2560, Minibatch Loss= 1.477009, Training Accuracy= 0.44531\n",
151-
"Iter 3840, Minibatch Loss= 1.245874, Training Accuracy= 0.53125\n",
152-
"Iter 5120, Minibatch Loss= 0.990923, Training Accuracy= 0.64062\n",
153-
"Iter 6400, Minibatch Loss= 0.752950, Training Accuracy= 0.71875\n",
154-
"Iter 7680, Minibatch Loss= 1.023025, Training Accuracy= 0.61719\n",
155-
"Iter 8960, Minibatch Loss= 0.921414, Training Accuracy= 0.68750\n",
156-
"Iter 10240, Minibatch Loss= 0.719829, Training Accuracy= 0.75000\n",
157-
"Iter 11520, Minibatch Loss= 0.468657, Training Accuracy= 0.86719\n",
158-
"Iter 12800, Minibatch Loss= 0.654315, Training Accuracy= 0.78125\n",
159-
"Iter 14080, Minibatch Loss= 0.595391, Training Accuracy= 0.83594\n",
160-
"Iter 15360, Minibatch Loss= 0.392862, Training Accuracy= 0.83594\n",
161-
"Iter 16640, Minibatch Loss= 0.421122, Training Accuracy= 0.92188\n",
162-
"Iter 17920, Minibatch Loss= 0.311471, Training Accuracy= 0.88281\n",
163-
"Iter 19200, Minibatch Loss= 0.276949, Training Accuracy= 0.92188\n",
164-
"Iter 20480, Minibatch Loss= 0.170499, Training Accuracy= 0.94531\n",
165-
"Iter 21760, Minibatch Loss= 0.419481, Training Accuracy= 0.86719\n",
166-
"Iter 23040, Minibatch Loss= 0.183765, Training Accuracy= 0.92188\n",
167-
"Iter 24320, Minibatch Loss= 0.386232, Training Accuracy= 0.86719\n",
168-
"Iter 25600, Minibatch Loss= 0.335571, Training Accuracy= 0.88281\n",
169-
"Iter 26880, Minibatch Loss= 0.169092, Training Accuracy= 0.92969\n",
170-
"Iter 28160, Minibatch Loss= 0.247623, Training Accuracy= 0.92969\n",
171-
"Iter 29440, Minibatch Loss= 0.242989, Training Accuracy= 0.94531\n",
172-
"Iter 30720, Minibatch Loss= 0.253811, Training Accuracy= 0.92188\n",
173-
"Iter 32000, Minibatch Loss= 0.169660, Training Accuracy= 0.93750\n",
174-
"Iter 33280, Minibatch Loss= 0.291349, Training Accuracy= 0.90625\n",
175-
"Iter 34560, Minibatch Loss= 0.172026, Training Accuracy= 0.95312\n",
176-
"Iter 35840, Minibatch Loss= 0.186019, Training Accuracy= 0.93750\n",
177-
"Iter 37120, Minibatch Loss= 0.298480, Training Accuracy= 0.89062\n",
178-
"Iter 38400, Minibatch Loss= 0.158750, Training Accuracy= 0.92188\n",
179-
"Iter 39680, Minibatch Loss= 0.162706, Training Accuracy= 0.94531\n",
180-
"Iter 40960, Minibatch Loss= 0.339814, Training Accuracy= 0.86719\n",
181-
"Iter 42240, Minibatch Loss= 0.068817, Training Accuracy= 0.99219\n",
182-
"Iter 43520, Minibatch Loss= 0.188742, Training Accuracy= 0.93750\n",
183-
"Iter 44800, Minibatch Loss= 0.176708, Training Accuracy= 0.92969\n",
184-
"Iter 46080, Minibatch Loss= 0.096726, Training Accuracy= 0.96875\n",
185-
"Iter 47360, Minibatch Loss= 0.220973, Training Accuracy= 0.92969\n",
186-
"Iter 48640, Minibatch Loss= 0.226749, Training Accuracy= 0.94531\n",
187-
"Iter 49920, Minibatch Loss= 0.188906, Training Accuracy= 0.94531\n",
188-
"Iter 51200, Minibatch Loss= 0.145194, Training Accuracy= 0.95312\n",
189-
"Iter 52480, Minibatch Loss= 0.168948, Training Accuracy= 0.95312\n",
190-
"Iter 53760, Minibatch Loss= 0.069116, Training Accuracy= 0.97656\n",
191-
"Iter 55040, Minibatch Loss= 0.228721, Training Accuracy= 0.93750\n",
192-
"Iter 56320, Minibatch Loss= 0.152915, Training Accuracy= 0.95312\n",
193-
"Iter 57600, Minibatch Loss= 0.126974, Training Accuracy= 0.96875\n",
194-
"Iter 58880, Minibatch Loss= 0.078870, Training Accuracy= 0.97656\n",
195-
"Iter 60160, Minibatch Loss= 0.225498, Training Accuracy= 0.95312\n",
196-
"Iter 61440, Minibatch Loss= 0.111760, Training Accuracy= 0.97656\n",
197-
"Iter 62720, Minibatch Loss= 0.161434, Training Accuracy= 0.97656\n",
198-
"Iter 64000, Minibatch Loss= 0.207190, Training Accuracy= 0.94531\n",
199-
"Iter 65280, Minibatch Loss= 0.103831, Training Accuracy= 0.96094\n",
200-
"Iter 66560, Minibatch Loss= 0.153846, Training Accuracy= 0.93750\n",
201-
"Iter 67840, Minibatch Loss= 0.082717, Training Accuracy= 0.96875\n",
202-
"Iter 69120, Minibatch Loss= 0.199301, Training Accuracy= 0.95312\n",
203-
"Iter 70400, Minibatch Loss= 0.139725, Training Accuracy= 0.96875\n",
204-
"Iter 71680, Minibatch Loss= 0.169596, Training Accuracy= 0.95312\n",
205-
"Iter 72960, Minibatch Loss= 0.142444, Training Accuracy= 0.96094\n",
206-
"Iter 74240, Minibatch Loss= 0.145822, Training Accuracy= 0.95312\n",
207-
"Iter 75520, Minibatch Loss= 0.129086, Training Accuracy= 0.94531\n",
208-
"Iter 76800, Minibatch Loss= 0.078082, Training Accuracy= 0.97656\n",
209-
"Iter 78080, Minibatch Loss= 0.151803, Training Accuracy= 0.94531\n",
210-
"Iter 79360, Minibatch Loss= 0.050142, Training Accuracy= 0.98438\n",
211-
"Iter 80640, Minibatch Loss= 0.136788, Training Accuracy= 0.95312\n",
212-
"Iter 81920, Minibatch Loss= 0.130100, Training Accuracy= 0.94531\n",
213-
"Iter 83200, Minibatch Loss= 0.058298, Training Accuracy= 0.98438\n",
214-
"Iter 84480, Minibatch Loss= 0.120124, Training Accuracy= 0.96094\n",
215-
"Iter 85760, Minibatch Loss= 0.064916, Training Accuracy= 0.97656\n",
216-
"Iter 87040, Minibatch Loss= 0.137179, Training Accuracy= 0.93750\n",
217-
"Iter 88320, Minibatch Loss= 0.138268, Training Accuracy= 0.95312\n",
218-
"Iter 89600, Minibatch Loss= 0.072827, Training Accuracy= 0.97656\n",
219-
"Iter 90880, Minibatch Loss= 0.123839, Training Accuracy= 0.96875\n",
220-
"Iter 92160, Minibatch Loss= 0.087194, Training Accuracy= 0.96875\n",
221-
"Iter 93440, Minibatch Loss= 0.083489, Training Accuracy= 0.97656\n",
222-
"Iter 94720, Minibatch Loss= 0.131827, Training Accuracy= 0.95312\n",
223-
"Iter 96000, Minibatch Loss= 0.098764, Training Accuracy= 0.96875\n",
224-
"Iter 97280, Minibatch Loss= 0.115553, Training Accuracy= 0.94531\n",
225-
"Iter 98560, Minibatch Loss= 0.079704, Training Accuracy= 0.96875\n",
226-
"Iter 99840, Minibatch Loss= 0.064562, Training Accuracy= 0.98438\n",
227-
"Optimization Finished!\n",
228-
"Testing Accuracy: 0.992188\n"
229-
]
230-
}
231-
],
137+
"execution_count": null,
138+
"metadata": {
139+
"collapsed": false
140+
},
141+
"outputs": [],
232142
"source": [
233143
"# Launch the graph\n",
234144
"with tf.Session() as sess:\n",
@@ -259,6 +169,15 @@
259169
" print \"Testing Accuracy:\", \\\n",
260170
" sess.run(accuracy, feed_dict={x: test_data, y: test_label})"
261171
]
172+
},
173+
{
174+
"cell_type": "code",
175+
"execution_count": null,
176+
"metadata": {
177+
"collapsed": true
178+
},
179+
"outputs": [],
180+
"source": []
262181
}
263182
],
264183
"metadata": {
@@ -270,14 +189,14 @@
270189
"language_info": {
271190
"codemirror_mode": {
272191
"name": "ipython",
273-
"version": 2.0
192+
"version": 2
274193
},
275194
"file_extension": ".py",
276195
"mimetype": "text/x-python",
277196
"name": "python",
278197
"nbconvert_exporter": "python",
279198
"pygments_lexer": "ipython2",
280-
"version": "2.7.11"
199+
"version": "2.7.13"
281200
}
282201
},
283202
"nbformat": 4,

0 commit comments

Comments
 (0)