Skip to content

Commit

Permalink
rnn_cell.linear and models.rnn are going away.
Browse files Browse the repository at this point in the history
Instead of rnn_cell.linear, use tf.contrib.layers.linear (though keep in mind the weight initializer is different; to avoid a bias term use bias_init=None).

tf.models.rnn and tf.models.rnn_cell had been kept around for backwards compatibility.  importing them now raises ImportError pointing to tf.nn.rnn*.
Change: 122064919
  • Loading branch information
ebrevdo authored and tensorflower-gardener committed May 11, 2016
1 parent 893ebc1 commit 396b586
Show file tree
Hide file tree
Showing 11 changed files with 29 additions and 42 deletions.
8 changes: 0 additions & 8 deletions tensorflow/models/rnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,3 @@ File | What's in it?
--- | ---
`ptb/` | PTB language model, see the [RNN Tutorial](http://tensorflow.org/tutorials/recurrent/)
`translate/` | Translation model, see the [Sequence-to-Sequence Tutorial](http://tensorflow.org/tutorials/seq2seq/)
`linear.py` | Basic helper functions for creating linear layers (moved to tf.nn, deprecated in favor of layers).
`linear_test.py` | Unit tests for `linear.py` (moved into TF core).
`rnn_cell.py` | Cells for recurrent neural networks, e.g., LSTM (moved to tf.nn.rnn_cell).
`rnn_cell_test.py` | Unit tests for `rnn_cell.py` (moved into TF core).
`rnn.py` | Functions for building recurrent neural networks (functions moved into in tf.nn).
`rnn_test.py` | Unit tests for `rnn.py` (moved into TF core).
`seq2seq.py` | Functions for building sequence-to-sequence models (moved to tf.nn.seq2seq).
`seq2seq_test.py` | Unit tests for `seq2seq.py` (moved into TF core).
4 changes: 0 additions & 4 deletions tensorflow/models/rnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,3 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.models.rnn import rnn
from tensorflow.models.rnn import rnn_cell
from tensorflow.models.rnn import seq2seq
4 changes: 1 addition & 3 deletions tensorflow/models/rnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,4 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf

linear = tf.nn.linear
raise ImportError("This module is deprecated. Use tf.contrib.layers.linear.")
3 changes: 0 additions & 3 deletions tensorflow/models/rnn/ptb/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ py_binary(
deps = [
":reader",
"//tensorflow:tensorflow_py",
"//tensorflow/models/rnn",
"//tensorflow/models/rnn:rnn_cell",
"//tensorflow/models/rnn:seq2seq",
],
)

Expand Down
3 changes: 1 addition & 2 deletions tensorflow/models/rnn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,4 @@
from __future__ import division
from __future__ import print_function

# pylint: disable=wildcard-import
from tensorflow.python.ops.rnn import *
raise ImportError("This module is deprecated. Use tf.nn.rnn_* instead.")
3 changes: 1 addition & 2 deletions tensorflow/models/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,4 @@
from __future__ import division
from __future__ import print_function

# pylint: disable=wildcard-import
from tensorflow.python.ops.rnn_cell import *
raise ImportError("This module is deprecated. Use tf.nn.rnn_cell instead.")
3 changes: 1 addition & 2 deletions tensorflow/models/rnn/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,4 @@
from __future__ import division
from __future__ import print_function

# pylint: disable=wildcard-import
from tensorflow.python.ops.seq2seq import *
raise ImportError("This module is deprecated. Use tf.nn.seq2seq instead.")
1 change: 0 additions & 1 deletion tensorflow/models/rnn/translate/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ py_library(
deps = [
":data_utils",
"//tensorflow:tensorflow_py",
"//tensorflow/models/rnn:seq2seq",
],
)

Expand Down
17 changes: 11 additions & 6 deletions tensorflow/python/kernel_tests/rnn_cell_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,32 @@
import numpy as np
import tensorflow as tf

# TODO(ebrevdo): Remove once _linear is fully deprecated.
# pylint: disable=protected-access
from tensorflow.python.ops.rnn_cell import _linear as linear
# pylint: enable=protected-access


class RNNCellTest(tf.test.TestCase):

def testLinear(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(1.0)):
x = tf.zeros([1, 2])
l = tf.nn.rnn_cell.linear([x], 2, False)
l = linear([x], 2, False)
sess.run([tf.initialize_all_variables()])
res = sess.run([l], {x.name: np.array([[1., 2.]])})
self.assertAllClose(res[0], [[3.0, 3.0]])

# Checks prevent you from accidentally creating a shared function.
with self.assertRaises(ValueError):
l1 = tf.nn.rnn_cell.linear([x], 2, False)
l1 = linear([x], 2, False)

# But you can create a new one in a new scope and share the variables.
with tf.variable_scope("l1") as new_scope:
l1 = tf.nn.rnn_cell.linear([x], 2, False)
l1 = linear([x], 2, False)
with tf.variable_scope(new_scope, reuse=True):
tf.nn.rnn_cell.linear([l1], 2, False)
linear([l1], 2, False)
self.assertEqual(len(tf.trainable_variables()), 2)

def testBasicRNNCell(self):
Expand Down Expand Up @@ -311,8 +316,8 @@ def basic_rnn_cell(inputs, state, num_units, scope=None):
return init_output, init_state
else:
with tf.variable_op_scope([inputs, state], scope, "BasicRNNCell"):
output = tf.tanh(tf.nn.rnn_cell.linear([inputs, state],
num_units, True))
output = tf.tanh(linear([inputs, state],
num_units, True))
return output, output

if __name__ == "__main__":
Expand Down
16 changes: 8 additions & 8 deletions tensorflow/python/ops/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def output_size(self):
def __call__(self, inputs, state, scope=None):
"""Most basic RNN: output = new_state = tanh(W * input + U * state + B)."""
with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
output = tanh(linear([inputs, state], self._num_units, True))
output = tanh(_linear([inputs, state], self._num_units, True))
return output, output


Expand All @@ -161,11 +161,11 @@ def __call__(self, inputs, state, scope=None):
with vs.variable_scope(scope or type(self).__name__): # "GRUCell"
with vs.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
r, u = array_ops.split(1, 2, linear([inputs, state],
2 * self._num_units, True, 1.0))
r, u = array_ops.split(1, 2, _linear([inputs, state],
2 * self._num_units, True, 1.0))
r, u = sigmoid(r), sigmoid(u)
with vs.variable_scope("Candidate"):
c = tanh(linear([inputs, r * state], self._num_units, True))
c = tanh(_linear([inputs, r * state], self._num_units, True))
new_h = u * state + (1 - u) * c
return new_h, new_h

Expand Down Expand Up @@ -223,7 +223,7 @@ def __call__(self, inputs, state, scope=None):
c, h = state
else:
c, h = array_ops.split(1, 2, state)
concat = linear([inputs, h], 4 * self._num_units, True)
concat = _linear([inputs, h], 4 * self._num_units, True)

# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(1, 4, concat)
Expand Down Expand Up @@ -485,7 +485,7 @@ def __call__(self, inputs, state, scope=None):
output, res_state = self._cell(inputs, state)
# Default scope: "OutputProjectionWrapper"
with vs.variable_scope(scope or type(self).__name__):
projected = linear(output, self._output_size, True)
projected = _linear(output, self._output_size, True)
return projected, res_state


Expand Down Expand Up @@ -527,7 +527,7 @@ def __call__(self, inputs, state, scope=None):
"""Run the input projection and then the cell."""
# Default scope: "InputProjectionWrapper"
with vs.variable_scope(scope or type(self).__name__):
projected = linear(inputs, self._num_proj, True)
projected = _linear(inputs, self._num_proj, True)
return self._cell(projected, state)


Expand Down Expand Up @@ -751,7 +751,7 @@ def __call__(self, inputs, state, scope=None):
return output, state


def linear(args, output_size, bias, bias_start=0.0, scope=None):
def _linear(args, output_size, bias, bias_start=0.0, scope=None):
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
Args:
Expand Down
9 changes: 6 additions & 3 deletions tensorflow/python/ops/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope

# TODO(ebrevdo): Remove once _linear is fully deprecated.
linear = rnn_cell._linear # pylint: disable=protected-access


def _extract_argmax_and_embed(embedding, output_projection=None,
update_embedding=True):
Expand Down Expand Up @@ -522,7 +525,7 @@ def attention(query):
ds = [] # Results of attention reads will be stored here.
for a in xrange(num_heads):
with variable_scope.variable_scope("Attention_%d" % a):
y = rnn_cell.linear(query, attention_vec_size, True)
y = linear(query, attention_vec_size, True)
y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
# Attention mask is a softmax of v^T * tanh(...).
s = math_ops.reduce_sum(
Expand Down Expand Up @@ -555,7 +558,7 @@ def attention(query):
input_size = inp.get_shape().with_rank(2)[1]
if input_size.value is None:
raise ValueError("Could not infer input size from input: %s" % inp.name)
x = rnn_cell.linear([inp] + attns, input_size, True)
x = linear([inp] + attns, input_size, True)
# Run the RNN.
cell_output, state = cell(x, state)
# Run the attention mechanism.
Expand All @@ -567,7 +570,7 @@ def attention(query):
attns = attention(state)

with variable_scope.variable_scope("AttnOutputProjection"):
output = rnn_cell.linear([cell_output] + attns, output_size, True)
output = linear([cell_output] + attns, output_size, True)
if loop_function is not None:
prev = output
outputs.append(output)
Expand Down

0 comments on commit 396b586

Please sign in to comment.