|
1 | 1 | {
|
2 | 2 | "cells": [
|
3 | 3 | {
|
4 |
| - "cell_type": "code", |
5 |
| - "execution_count": null, |
6 |
| - "metadata": {}, |
7 |
| - "outputs": [], |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": { |
| 6 | + "collapsed": true |
| 7 | + }, |
8 | 8 | "source": [
|
9 | 9 | "'''\n",
|
10 | 10 | "A Bidirectional Reccurent Neural Network (LSTM) implementation example using TensorFlow library.\n",
|
|
18 | 18 | },
|
19 | 19 | {
|
20 | 20 | "cell_type": "code",
|
21 |
| - "execution_count": 1, |
22 |
| - "metadata": {}, |
23 |
| - "outputs": [ |
24 |
| - { |
25 |
| - "name": "stdout", |
26 |
| - "output_type": "stream", |
27 |
| - "text": [ |
28 |
| - "Extracting /tmp/data/train-images-idx3-ubyte.gz\n", |
29 |
| - "Extracting /tmp/data/train-labels-idx1-ubyte.gz\n", |
30 |
| - "Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n", |
31 |
| - "Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n" |
32 |
| - ] |
33 |
| - } |
34 |
| - ], |
| 21 | + "execution_count": null, |
| 22 | + "metadata": { |
| 23 | + "collapsed": false |
| 24 | + }, |
| 25 | + "outputs": [], |
35 | 26 | "source": [
|
36 | 27 | "import tensorflow as tf\n",
|
37 |
| - "from tensorflow.models.rnn import rnn, rnn_cell\n", |
| 28 | + "from tensorflow.contrib import rnn\n", |
38 | 29 | "import numpy as np\n",
|
39 | 30 | "\n",
|
40 | 31 | "# Import MINST data\n",
|
41 | 32 | "from tensorflow.examples.tutorials.mnist import input_data\n",
|
42 |
| - "mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)" |
| 33 | + "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)" |
43 | 34 | ]
|
44 | 35 | },
|
45 | 36 | {
|
46 |
| - "cell_type": "code", |
47 |
| - "execution_count": null, |
48 |
| - "metadata": {}, |
49 |
| - "outputs": [], |
| 37 | + "cell_type": "markdown", |
| 38 | + "metadata": { |
| 39 | + "collapsed": true |
| 40 | + }, |
50 | 41 | "source": [
|
51 | 42 | "'''\n",
|
52 | 43 | "To classify images using a bidirectional reccurent neural network, we consider\n",
|
|
58 | 49 | {
|
59 | 50 | "cell_type": "code",
|
60 | 51 | "execution_count": 2,
|
61 |
| - "metadata": {}, |
| 52 | + "metadata": { |
| 53 | + "collapsed": true |
| 54 | + }, |
62 | 55 | "outputs": [],
|
63 | 56 | "source": [
|
64 | 57 | "# Parameters\n",
|
|
90 | 83 | {
|
91 | 84 | "cell_type": "code",
|
92 | 85 | "execution_count": 3,
|
93 |
| - "metadata": {}, |
| 86 | + "metadata": { |
| 87 | + "collapsed": false |
| 88 | + }, |
94 | 89 | "outputs": [],
|
95 | 90 | "source": [
|
96 | 91 | "def BiRNN(x, weights, biases):\n",
|
|
104 | 99 | " # Reshape to (n_steps*batch_size, n_input)\n",
|
105 | 100 | " x = tf.reshape(x, [-1, n_input])\n",
|
106 | 101 | " # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)\n",
|
107 |
| - " x = tf.split(0, n_steps, x)\n", |
| 102 | + " x = tf.split(x, n_steps, 0)\n", |
108 | 103 | "\n",
|
109 | 104 | " # Define lstm cells with tensorflow\n",
|
110 | 105 | " # Forward direction cell\n",
|
111 |
| - " lstm_fw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)\n", |
| 106 | + " lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)\n", |
112 | 107 | " # Backward direction cell\n",
|
113 |
| - " lstm_bw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)\n", |
| 108 | + " lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)\n", |
114 | 109 | "\n",
|
115 | 110 | " # Get lstm cell output\n",
|
116 | 111 | " try:\n",
|
117 |
| - " outputs, _, _ = rnn.bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,\n", |
| 112 | + " outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,\n", |
118 | 113 | " dtype=tf.float32)\n",
|
119 | 114 | " except Exception: # Old TensorFlow version only returns outputs not states\n",
|
120 |
| - " outputs = rnn.bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,\n", |
| 115 | + " outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,\n", |
121 | 116 | " dtype=tf.float32)\n",
|
122 | 117 | "\n",
|
123 | 118 | " # Linear activation, using rnn inner loop last output\n",
|
|
126 | 121 | "pred = BiRNN(x, weights, biases)\n",
|
127 | 122 | "\n",
|
128 | 123 | "# Define loss and optimizer\n",
|
129 |
| - "cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))\n", |
| 124 | + "cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))\n", |
130 | 125 | "optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)\n",
|
131 | 126 | "\n",
|
132 | 127 | "# Evaluate model\n",
|
133 | 128 | "correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))\n",
|
134 | 129 | "accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n",
|
135 | 130 | "\n",
|
136 | 131 | "# Initializing the variables\n",
|
137 |
| - "init = tf.initialize_all_variables()" |
| 132 | + "init = tf.global_variables_initializer()" |
138 | 133 | ]
|
139 | 134 | },
|
140 | 135 | {
|
141 | 136 | "cell_type": "code",
|
142 |
| - "execution_count": 4, |
143 |
| - "metadata": {}, |
144 |
| - "outputs": [ |
145 |
| - { |
146 |
| - "name": "stdout", |
147 |
| - "output_type": "stream", |
148 |
| - "text": [ |
149 |
| - "Iter 1280, Minibatch Loss= 1.689740, Training Accuracy= 0.36719\n", |
150 |
| - "Iter 2560, Minibatch Loss= 1.477009, Training Accuracy= 0.44531\n", |
151 |
| - "Iter 3840, Minibatch Loss= 1.245874, Training Accuracy= 0.53125\n", |
152 |
| - "Iter 5120, Minibatch Loss= 0.990923, Training Accuracy= 0.64062\n", |
153 |
| - "Iter 6400, Minibatch Loss= 0.752950, Training Accuracy= 0.71875\n", |
154 |
| - "Iter 7680, Minibatch Loss= 1.023025, Training Accuracy= 0.61719\n", |
155 |
| - "Iter 8960, Minibatch Loss= 0.921414, Training Accuracy= 0.68750\n", |
156 |
| - "Iter 10240, Minibatch Loss= 0.719829, Training Accuracy= 0.75000\n", |
157 |
| - "Iter 11520, Minibatch Loss= 0.468657, Training Accuracy= 0.86719\n", |
158 |
| - "Iter 12800, Minibatch Loss= 0.654315, Training Accuracy= 0.78125\n", |
159 |
| - "Iter 14080, Minibatch Loss= 0.595391, Training Accuracy= 0.83594\n", |
160 |
| - "Iter 15360, Minibatch Loss= 0.392862, Training Accuracy= 0.83594\n", |
161 |
| - "Iter 16640, Minibatch Loss= 0.421122, Training Accuracy= 0.92188\n", |
162 |
| - "Iter 17920, Minibatch Loss= 0.311471, Training Accuracy= 0.88281\n", |
163 |
| - "Iter 19200, Minibatch Loss= 0.276949, Training Accuracy= 0.92188\n", |
164 |
| - "Iter 20480, Minibatch Loss= 0.170499, Training Accuracy= 0.94531\n", |
165 |
| - "Iter 21760, Minibatch Loss= 0.419481, Training Accuracy= 0.86719\n", |
166 |
| - "Iter 23040, Minibatch Loss= 0.183765, Training Accuracy= 0.92188\n", |
167 |
| - "Iter 24320, Minibatch Loss= 0.386232, Training Accuracy= 0.86719\n", |
168 |
| - "Iter 25600, Minibatch Loss= 0.335571, Training Accuracy= 0.88281\n", |
169 |
| - "Iter 26880, Minibatch Loss= 0.169092, Training Accuracy= 0.92969\n", |
170 |
| - "Iter 28160, Minibatch Loss= 0.247623, Training Accuracy= 0.92969\n", |
171 |
| - "Iter 29440, Minibatch Loss= 0.242989, Training Accuracy= 0.94531\n", |
172 |
| - "Iter 30720, Minibatch Loss= 0.253811, Training Accuracy= 0.92188\n", |
173 |
| - "Iter 32000, Minibatch Loss= 0.169660, Training Accuracy= 0.93750\n", |
174 |
| - "Iter 33280, Minibatch Loss= 0.291349, Training Accuracy= 0.90625\n", |
175 |
| - "Iter 34560, Minibatch Loss= 0.172026, Training Accuracy= 0.95312\n", |
176 |
| - "Iter 35840, Minibatch Loss= 0.186019, Training Accuracy= 0.93750\n", |
177 |
| - "Iter 37120, Minibatch Loss= 0.298480, Training Accuracy= 0.89062\n", |
178 |
| - "Iter 38400, Minibatch Loss= 0.158750, Training Accuracy= 0.92188\n", |
179 |
| - "Iter 39680, Minibatch Loss= 0.162706, Training Accuracy= 0.94531\n", |
180 |
| - "Iter 40960, Minibatch Loss= 0.339814, Training Accuracy= 0.86719\n", |
181 |
| - "Iter 42240, Minibatch Loss= 0.068817, Training Accuracy= 0.99219\n", |
182 |
| - "Iter 43520, Minibatch Loss= 0.188742, Training Accuracy= 0.93750\n", |
183 |
| - "Iter 44800, Minibatch Loss= 0.176708, Training Accuracy= 0.92969\n", |
184 |
| - "Iter 46080, Minibatch Loss= 0.096726, Training Accuracy= 0.96875\n", |
185 |
| - "Iter 47360, Minibatch Loss= 0.220973, Training Accuracy= 0.92969\n", |
186 |
| - "Iter 48640, Minibatch Loss= 0.226749, Training Accuracy= 0.94531\n", |
187 |
| - "Iter 49920, Minibatch Loss= 0.188906, Training Accuracy= 0.94531\n", |
188 |
| - "Iter 51200, Minibatch Loss= 0.145194, Training Accuracy= 0.95312\n", |
189 |
| - "Iter 52480, Minibatch Loss= 0.168948, Training Accuracy= 0.95312\n", |
190 |
| - "Iter 53760, Minibatch Loss= 0.069116, Training Accuracy= 0.97656\n", |
191 |
| - "Iter 55040, Minibatch Loss= 0.228721, Training Accuracy= 0.93750\n", |
192 |
| - "Iter 56320, Minibatch Loss= 0.152915, Training Accuracy= 0.95312\n", |
193 |
| - "Iter 57600, Minibatch Loss= 0.126974, Training Accuracy= 0.96875\n", |
194 |
| - "Iter 58880, Minibatch Loss= 0.078870, Training Accuracy= 0.97656\n", |
195 |
| - "Iter 60160, Minibatch Loss= 0.225498, Training Accuracy= 0.95312\n", |
196 |
| - "Iter 61440, Minibatch Loss= 0.111760, Training Accuracy= 0.97656\n", |
197 |
| - "Iter 62720, Minibatch Loss= 0.161434, Training Accuracy= 0.97656\n", |
198 |
| - "Iter 64000, Minibatch Loss= 0.207190, Training Accuracy= 0.94531\n", |
199 |
| - "Iter 65280, Minibatch Loss= 0.103831, Training Accuracy= 0.96094\n", |
200 |
| - "Iter 66560, Minibatch Loss= 0.153846, Training Accuracy= 0.93750\n", |
201 |
| - "Iter 67840, Minibatch Loss= 0.082717, Training Accuracy= 0.96875\n", |
202 |
| - "Iter 69120, Minibatch Loss= 0.199301, Training Accuracy= 0.95312\n", |
203 |
| - "Iter 70400, Minibatch Loss= 0.139725, Training Accuracy= 0.96875\n", |
204 |
| - "Iter 71680, Minibatch Loss= 0.169596, Training Accuracy= 0.95312\n", |
205 |
| - "Iter 72960, Minibatch Loss= 0.142444, Training Accuracy= 0.96094\n", |
206 |
| - "Iter 74240, Minibatch Loss= 0.145822, Training Accuracy= 0.95312\n", |
207 |
| - "Iter 75520, Minibatch Loss= 0.129086, Training Accuracy= 0.94531\n", |
208 |
| - "Iter 76800, Minibatch Loss= 0.078082, Training Accuracy= 0.97656\n", |
209 |
| - "Iter 78080, Minibatch Loss= 0.151803, Training Accuracy= 0.94531\n", |
210 |
| - "Iter 79360, Minibatch Loss= 0.050142, Training Accuracy= 0.98438\n", |
211 |
| - "Iter 80640, Minibatch Loss= 0.136788, Training Accuracy= 0.95312\n", |
212 |
| - "Iter 81920, Minibatch Loss= 0.130100, Training Accuracy= 0.94531\n", |
213 |
| - "Iter 83200, Minibatch Loss= 0.058298, Training Accuracy= 0.98438\n", |
214 |
| - "Iter 84480, Minibatch Loss= 0.120124, Training Accuracy= 0.96094\n", |
215 |
| - "Iter 85760, Minibatch Loss= 0.064916, Training Accuracy= 0.97656\n", |
216 |
| - "Iter 87040, Minibatch Loss= 0.137179, Training Accuracy= 0.93750\n", |
217 |
| - "Iter 88320, Minibatch Loss= 0.138268, Training Accuracy= 0.95312\n", |
218 |
| - "Iter 89600, Minibatch Loss= 0.072827, Training Accuracy= 0.97656\n", |
219 |
| - "Iter 90880, Minibatch Loss= 0.123839, Training Accuracy= 0.96875\n", |
220 |
| - "Iter 92160, Minibatch Loss= 0.087194, Training Accuracy= 0.96875\n", |
221 |
| - "Iter 93440, Minibatch Loss= 0.083489, Training Accuracy= 0.97656\n", |
222 |
| - "Iter 94720, Minibatch Loss= 0.131827, Training Accuracy= 0.95312\n", |
223 |
| - "Iter 96000, Minibatch Loss= 0.098764, Training Accuracy= 0.96875\n", |
224 |
| - "Iter 97280, Minibatch Loss= 0.115553, Training Accuracy= 0.94531\n", |
225 |
| - "Iter 98560, Minibatch Loss= 0.079704, Training Accuracy= 0.96875\n", |
226 |
| - "Iter 99840, Minibatch Loss= 0.064562, Training Accuracy= 0.98438\n", |
227 |
| - "Optimization Finished!\n", |
228 |
| - "Testing Accuracy: 0.992188\n" |
229 |
| - ] |
230 |
| - } |
231 |
| - ], |
| 137 | + "execution_count": null, |
| 138 | + "metadata": { |
| 139 | + "collapsed": false |
| 140 | + }, |
| 141 | + "outputs": [], |
232 | 142 | "source": [
|
233 | 143 | "# Launch the graph\n",
|
234 | 144 | "with tf.Session() as sess:\n",
|
|
259 | 169 | " print \"Testing Accuracy:\", \\\n",
|
260 | 170 | " sess.run(accuracy, feed_dict={x: test_data, y: test_label})"
|
261 | 171 | ]
|
| 172 | + }, |
| 173 | + { |
| 174 | + "cell_type": "code", |
| 175 | + "execution_count": null, |
| 176 | + "metadata": { |
| 177 | + "collapsed": true |
| 178 | + }, |
| 179 | + "outputs": [], |
| 180 | + "source": [] |
262 | 181 | }
|
263 | 182 | ],
|
264 | 183 | "metadata": {
|
|
270 | 189 | "language_info": {
|
271 | 190 | "codemirror_mode": {
|
272 | 191 | "name": "ipython",
|
273 |
| - "version": 2.0 |
| 192 | + "version": 2 |
274 | 193 | },
|
275 | 194 | "file_extension": ".py",
|
276 | 195 | "mimetype": "text/x-python",
|
277 | 196 | "name": "python",
|
278 | 197 | "nbconvert_exporter": "python",
|
279 | 198 | "pygments_lexer": "ipython2",
|
280 |
| - "version": "2.7.11" |
| 199 | + "version": "2.7.13" |
281 | 200 | }
|
282 | 201 | },
|
283 | 202 | "nbformat": 4,
|
|
0 commit comments