Skip to content

Commit f3e6051

Browse files
Refactor recurrent network for TF1.0
Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>
1 parent 63abe61 commit f3e6051

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

examples/3_NeuralNetworks/recurrent_network.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from __future__ import print_function
1111

1212
import tensorflow as tf
13-
from tensorflow.python.ops import rnn, rnn_cell
13+
from tensorflow.contrib import rnn
1414

1515
# Import MNIST data
1616
from tensorflow.examples.tutorials.mnist import input_data
@@ -58,29 +58,29 @@ def RNN(x, weights, biases):
5858
# Reshaping to (n_steps*batch_size, n_input)
5959
x = tf.reshape(x, [-1, n_input])
6060
# Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
61-
x = tf.split(0, n_steps, x)
61+
x = tf.split(x, n_steps, 0)
6262

6363
# Define a lstm cell with tensorflow
64-
lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
64+
lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
6565

6666
# Get lstm cell output
67-
outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)
67+
outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
6868

6969
# Linear activation, using rnn inner loop last output
7070
return tf.matmul(outputs[-1], weights['out']) + biases['out']
7171

7272
pred = RNN(x, weights, biases)
7373

7474
# Define loss and optimizer
75-
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
75+
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
7676
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
7777

7878
# Evaluate model
7979
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
8080
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
8181

8282
# Initializing the variables
83-
init = tf.initialize_all_variables()
83+
init = tf.global_variables_initializer()
8484

8585
# Launch the graph
8686
with tf.Session() as sess:

0 commit comments

Comments
 (0)