Skip to content

Commit 53a7228

Browse files
Refactor bidirectional rnn for TF1.0
Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>
1 parent f3e6051 commit 53a7228

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

examples/3_NeuralNetworks/bidirectional_rnn.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from __future__ import print_function
1111

1212
import tensorflow as tf
13-
from tensorflow.python.ops import rnn, rnn_cell
13+
from tensorflow.contrib import rnn
1414
import numpy as np
1515

1616
# Import MNIST data
@@ -60,20 +60,20 @@ def BiRNN(x, weights, biases):
6060
# Reshape to (n_steps*batch_size, n_input)
6161
x = tf.reshape(x, [-1, n_input])
6262
# Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
63-
x = tf.split(0, n_steps, x)
63+
x = tf.split(x, n_steps, 0)
6464

6565
# Define lstm cells with tensorflow
6666
# Forward direction cell
67-
lstm_fw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
67+
lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
6868
# Backward direction cell
69-
lstm_bw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
69+
lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
7070

7171
# Get lstm cell output
7272
try:
73-
outputs, _, _ = rnn.bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
73+
outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
7474
dtype=tf.float32)
7575
except Exception: # Old TensorFlow version only returns outputs not states
76-
outputs = rnn.bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
76+
outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
7777
dtype=tf.float32)
7878

7979
# Linear activation, using rnn inner loop last output
@@ -82,15 +82,15 @@ def BiRNN(x, weights, biases):
8282
pred = BiRNN(x, weights, biases)
8383

8484
# Define loss and optimizer
85-
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
85+
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
8686
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
8787

8888
# Evaluate model
8989
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
9090
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
9191

9292
# Initializing the variables
93-
init = tf.initialize_all_variables()
93+
init = tf.global_variables_initializer()
9494

9595
# Launch the graph
9696
with tf.Session() as sess:

0 commit comments

Comments
 (0)