Skip to content

Commit

Permalink
Cleanup and consistency for variable handling in RNNCells.
Browse files Browse the repository at this point in the history
In the run-up to TF 1.0, we are making RNNCells' variable names compatible
with those of tf layers.

This is a breaking change for those who wish to reload their old RNN model
checkpoints in newly created graphs.  After this change is in, variables
created with RNNCells will have slightly different names than before;
loading old checkpoints to run with newly created graphs requires
renaming at load time.

Loading and executing old graphs with old checkpoints will continue to work
without any problems.  Creating and loading new checkpoints with graphs
after this change is in will work without any problems.  The only people
affected by this change are those who want to load old RNN model checkpoints
graphs created after this change is in.

Renaming on checkpoint load can be performed with
tf.contrib.framework.variables.assign_from_checkpoint.  Example usage
is available here[1] if you use Saver and/or Supervisor, and [2] if you
are using the newer tf.learn classes.

Examples of renamed parameters:

LSTMCell without sharding:
my_scope/LSTMCell/W_0 -> my_scope/lstm_cell/weights
my_scope/LSTMCell/W_F_diag -> my_scope/lstm_cell/w_f_diag
my_scope/LSTMCell/B -> my_scope/lstm_cell/biases

LSTMCell with sharding:
my_scope/LSTMCell/W_0 -> my_scope/lstm_cell/weights/part_0
my_scope/LSTMCell/W_1 -> my_scope/lstm_cell/weights/part_1
my_scope/LSTMCell/W_2 -> my_scope/lstm_cell/weights/part_2
my_scope/LSTMCell/W_F_diag -> my_scope/lstm_cell/w_f_diag
my_scope/LSTMCell/B -> my_scope/lstm_cell/biases

BasicLSTMCell:
my_scope/BasicLSTMCell/Linear/Matrix -> my_scope/basic_lstm_cell/weights
my_scope/BasicLSTMCell/Linear/Bias -> my_scope/basic_lstm_cell/biases

MultiRNNCell:
my_scope/MultiRNNCell/Cell0/LSTMCell/W_0 -> my_scope/multi_rnn_cell/cell_0/lstm_cell/weights
my_scope/MultiRNNCell/Cell0/LSTMCell/W_F_diag -> my_scope/multi_rnn_cell/cell_0/lstm_cell/w_f_diag
my_scope/MultiRNNCell/Cell0/LSTMCell/B -> my_scope/multi_rnn_cell/cell_0/lstm_cell/biases

1.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/README.md

2. https://github.com/tensorflow/tensorflow/blob/86f5ab7474825da756838b34e1b4eac93f5fc68a/tensorflow/contrib/framework/python/ops/variables_test.py#L810
Change: 140060366
  • Loading branch information
ebrevdo authored and tensorflower-gardener committed Nov 23, 2016
1 parent 75254c3 commit 92da8ab
Show file tree
Hide file tree
Showing 9 changed files with 255 additions and 273 deletions.
8 changes: 4 additions & 4 deletions tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def testCompatibleNames(self):
basic_names = {v.name: v.get_shape() for v in tf.trainable_variables()}

with self.test_session(use_gpu=self._use_gpu, graph=tf.Graph()):
cell = tf.contrib.rnn.LSTMBlockCell(10, use_compatible_names=True)
cell = tf.contrib.rnn.LSTMBlockCell(10)
pcell = tf.contrib.rnn.LSTMBlockCell(
10, use_peephole=True, use_compatible_names=True)
10, use_peephole=True)
inputs = [tf.zeros([4, 5])] * 6
tf.nn.rnn(cell, inputs, dtype=tf.float32, scope="basic")
tf.nn.rnn(pcell, inputs, dtype=tf.float32, scope="peephole")
Expand All @@ -93,8 +93,8 @@ def testCompatibleNames(self):
cell = tf.contrib.rnn.LSTMBlockFusedCell(10)
pcell = tf.contrib.rnn.LSTMBlockFusedCell(10, use_peephole=True)
inputs = [tf.zeros([4, 5])] * 6
cell(inputs, dtype=tf.float32, scope="basic/LSTMCell")
pcell(inputs, dtype=tf.float32, scope="peephole/LSTMCell")
cell(inputs, dtype=tf.float32, scope="basic/lstm_cell")
pcell(inputs, dtype=tf.float32, scope="peephole/lstm_cell")
fused_names = {v.name: v.get_shape() for v in tf.trainable_variables()}

self.assertEqual(basic_names, block_names)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def _testScope(self, factory, prefix="prefix", use_outer_scope=True):
# check that all the variables names starts with the proper scope.
tf.global_variables_initializer()
all_vars = tf.all_variables()
prefix = prefix or "StackRNN"
prefix = prefix or "stack_bidirectional_rnn"
scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")]
tf.logging.info("StackRNN with scope: %s (%s)"
% (prefix, "scope" if use_outer_scope else "str"))
Expand Down
54 changes: 21 additions & 33 deletions tensorflow/contrib/rnn/python/ops/lstm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,44 +334,31 @@ class LSTMBlockCell(rnn_cell.RNNCell):
Unlike `rnn_cell.LSTMCell`, this is a monolithic op and should be much faster.
The weight and bias matrixes should be compatible as long as the variable
scope matches, and you use `use_compatible_names=True`.
scope matches.
"""

def __init__(self,
num_units,
forget_bias=1.0,
use_peephole=False,
use_compatible_names=False):
use_peephole=False):
"""Initialize the basic LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
use_peephole: Whether to use peephole connections or not.
use_compatible_names: If True, use the same variable naming as
rnn_cell.LSTMCell
"""
self._num_units = num_units
self._forget_bias = forget_bias
self._use_peephole = use_peephole
if use_compatible_names:
self._names = {
"W": "W_0",
"b": "B",
"wci": "W_I_diag",
"wco": "W_O_diag",
"wcf": "W_F_diag",
"scope": "LSTMCell"
}
else:
self._names = {
"W": "W",
"b": "b",
"wci": "wci",
"wco": "wco",
"wcf": "wcf",
"scope": "LSTMBlockCell"
}
self._names = {
"W": "weights",
"b": "biases",
"wci": "w_i_diag",
"wco": "w_o_diag",
"wcf": "w_f_diag",
"scope": "lstm_cell"
}

@property
def state_size(self):
Expand All @@ -385,15 +372,15 @@ def __call__(self, x, states_prev, scope=None):
"""Long short-term memory cell (LSTM)."""
with vs.variable_scope(scope or self._names["scope"]):
x_shape = x.get_shape().with_rank(2)
if not x_shape[1]:
raise ValueError("Expecting x_shape[1] to be sets: %s" % str(x_shape))
if not x_shape[1].value:
raise ValueError("Expecting x_shape[1] to be set: %s" % str(x_shape))
if len(states_prev) != 2:
raise ValueError("Expecting states_prev to be a tuple with length 2.")
input_size = x_shape[1]
input_size = x_shape[1].value
w = vs.get_variable(self._names["W"], [input_size + self._num_units,
self._num_units * 4])
b = vs.get_variable(
self._names["b"], [w.get_shape().with_rank(2)[1]],
self._names["b"], [w.get_shape().with_rank(2)[1].value],
initializer=init_ops.constant_initializer(0.0))
if self._use_peephole:
wci = vs.get_variable(self._names["wci"], [self._num_units])
Expand Down Expand Up @@ -490,7 +477,7 @@ def __call__(self,
Raises:
ValueError: in case of shape mismatches
"""
with vs.variable_scope(scope or type(self).__name__):
with vs.variable_scope(scope or "lstm_block_wrapper"):
is_list = isinstance(inputs, list)
if is_list:
inputs = array_ops.pack(inputs)
Expand Down Expand Up @@ -634,15 +621,16 @@ def _call_cell(self, inputs, initial_cell_state, initial_output, dtype,
time_len = array_ops.shape(inputs)[0]
input_size = inputs_shape[2].value
w = vs.get_variable(
"W_0", [input_size + self._num_units, self._num_units * 4], dtype=dtype)
"weights",
[input_size + self._num_units, self._num_units * 4], dtype=dtype)
b = vs.get_variable(
"B", [w.get_shape().with_rank(2)[1]],
"biases", [w.get_shape().with_rank(2)[1]],
initializer=init_ops.constant_initializer(0.0),
dtype=dtype)
if self._use_peephole:
wci = vs.get_variable("W_I_diag", [self._num_units], dtype=dtype)
wco = vs.get_variable("W_O_diag", [self._num_units], dtype=dtype)
wcf = vs.get_variable("W_F_diag", [self._num_units], dtype=dtype)
wci = vs.get_variable("w_i_diag", [self._num_units], dtype=dtype)
wco = vs.get_variable("w_o_diag", [self._num_units], dtype=dtype)
wcf = vs.get_variable("w_f_diag", [self._num_units], dtype=dtype)
else:
wci = wco = wcf = array_ops.zeros([self._num_units], dtype=dtype)

Expand Down
11 changes: 6 additions & 5 deletions tensorflow/contrib/rnn/python/ops/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def stack_bidirectional_rnn(cells_fw,
states_bw = []
prev_layer = inputs

with vs.variable_scope(scope or "StackRNN"):
with vs.variable_scope(scope or "stack_bidirectional_rnn"):
for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)):
initial_state_fw = None
initial_state_bw = None
Expand All @@ -104,15 +104,16 @@ def stack_bidirectional_rnn(cells_fw,
if initial_states_bw:
initial_state_bw = initial_states_bw[i]

with vs.variable_scope("Layer%d" % i):
with vs.variable_scope("cell_%d" % i) as cell_scope:
prev_layer, state_fw, state_bw = tf.nn.bidirectional_rnn(
cell_fw,
cell_bw,
prev_layer,
initial_state_fw=initial_state_fw,
initial_state_bw=initial_state_bw,
sequence_length=sequence_length,
dtype=dtype)
dtype=dtype,
scope=cell_scope)
states_fw.append(state_fw)
states_bw.append(state_bw)

Expand Down Expand Up @@ -192,7 +193,7 @@ def stack_bidirectional_dynamic_rnn(cells_fw,
states_bw = []
prev_layer = inputs

with vs.variable_scope(scope or "StackRNN"):
with vs.variable_scope(scope or "stack_bidirectional_rnn"):
for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)):
initial_state_fw = None
initial_state_bw = None
Expand All @@ -201,7 +202,7 @@ def stack_bidirectional_dynamic_rnn(cells_fw,
if initial_states_bw:
initial_state_bw = initial_states_bw[i]

with vs.variable_scope("Layer%d" % i):
with vs.variable_scope("cell_%d" % i):
outputs, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(
cell_fw,
cell_bw,
Expand Down
99 changes: 50 additions & 49 deletions tensorflow/contrib/rnn/python/ops/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ def __call__(self, inputs, state, scope=None):
input_size = inputs.get_shape().with_rank(2)[1]
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
with vs.variable_scope(scope or type(self).__name__,
initializer=self._initializer): # "LSTMCell"
with vs.variable_scope(scope or "coupled_input_forget_gate_lstm_cell",
initializer=self._initializer):
concat_w = _get_concat_variable(
"W", [input_size.value + num_proj, 3 * self._num_units],
dtype, self._num_unit_shards)
Expand Down Expand Up @@ -328,7 +328,7 @@ def __call__(self, inputs, state, scope=None):
freq_inputs = self._make_tf_features(inputs)
dtype = inputs.dtype
actual_input_size = freq_inputs[0].get_shape().as_list()[1]
with vs.variable_scope(scope or type(self).__name__,
with vs.variable_scope(scope or "time_freq_lstm_cell",
initializer=self._initializer): # "TimeFreqLSTMCell"
concat_w = _get_concat_variable(
"W", [actual_input_size + 2*self._num_units, 4 * self._num_units],
Expand Down Expand Up @@ -546,7 +546,7 @@ def __call__(self, inputs, state, scope=None):
"""
batch_size = int(inputs.get_shape()[0])
freq_inputs = self._make_tf_features(inputs)
with vs.variable_scope(scope or type(self).__name__,
with vs.variable_scope(scope or "grid_lstm_cell",
initializer=self._initializer): # "GridLSTMCell"
m_out_lst = []
state_out_lst = []
Expand Down Expand Up @@ -968,29 +968,29 @@ def __call__(self, inputs, state, scope=None):
bwd_inputs = fwd_inputs

# Forward processing
with vs.variable_scope((scope or type(self).__name__) + "/fwd",
initializer=self._initializer):
fwd_m_out_lst = []
fwd_state_out_lst = []
for block in range(len(fwd_inputs)):
fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
fwd_inputs[block], block, state, batch_size,
state_prefix="fwd_state", state_is_tuple=True)
fwd_m_out_lst.extend(fwd_m_out_lst_current)
fwd_state_out_lst.extend(fwd_state_out_lst_current)
# Backward processing
bwd_m_out_lst = []
bwd_state_out_lst = []
with vs.variable_scope((scope or type(self).__name__) + "/bwd",
with vs.variable_scope(scope or "bidirectional_grid_lstm_cell",
initializer=self._initializer):
for block in range(len(bwd_inputs)):
# Reverse the blocks
bwd_inputs_reverse = bwd_inputs[block][::-1]
bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
bwd_inputs_reverse, block, state, batch_size,
state_prefix="bwd_state", state_is_tuple=True)
bwd_m_out_lst.extend(bwd_m_out_lst_current)
bwd_state_out_lst.extend(bwd_state_out_lst_current)
with vs.variable_scope("fwd"):
fwd_m_out_lst = []
fwd_state_out_lst = []
for block in range(len(fwd_inputs)):
fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
fwd_inputs[block], block, state, batch_size,
state_prefix="fwd_state", state_is_tuple=True)
fwd_m_out_lst.extend(fwd_m_out_lst_current)
fwd_state_out_lst.extend(fwd_state_out_lst_current)
# Backward processing
bwd_m_out_lst = []
bwd_state_out_lst = []
with vs.variable_scope("bwd"):
for block in range(len(bwd_inputs)):
# Reverse the blocks
bwd_inputs_reverse = bwd_inputs[block][::-1]
bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
bwd_inputs_reverse, block, state, batch_size,
state_prefix="bwd_state", state_is_tuple=True)
bwd_m_out_lst.extend(bwd_m_out_lst_current)
bwd_state_out_lst.extend(bwd_state_out_lst_current)
state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
# Outputs are always concated as it is never used separately.
m_out = array_ops.concat(1, fwd_m_out_lst + bwd_m_out_lst)
Expand Down Expand Up @@ -1071,7 +1071,7 @@ def output_size(self):

def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell with attention (LSTMA)."""
with vs.variable_scope(scope or type(self).__name__):
with vs.variable_scope(scope or "attention_cell_wrapper"):
if self._state_is_tuple:
state, attns, attn_states = state
else:
Expand All @@ -1094,7 +1094,7 @@ def __call__(self, inputs, state, scope=None):
else:
new_state_cat = new_state
new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
with vs.variable_scope("AttnOutputProjection"):
with vs.variable_scope("attn_output_projection"):
output = _linear([lstm_output, new_attns], self._attn_size, True)
new_attn_states = array_ops.concat(1, [new_attn_states,
array_ops.expand_dims(output, 1)])
Expand All @@ -1111,9 +1111,10 @@ def _attention(self, query, attn_states):
softmax = nn_ops.softmax
tanh = math_ops.tanh

with vs.variable_scope("Attention"):
k = vs.get_variable("AttnW", [1, 1, self._attn_size, self._attn_vec_size])
v = vs.get_variable("AttnV", [self._attn_vec_size])
with vs.variable_scope("attention"):
k = vs.get_variable(
"attn_w", [1, 1, self._attn_size, self._attn_vec_size])
v = vs.get_variable("attn_v", [self._attn_vec_size])
hidden = array_ops.reshape(attn_states,
[-1, self._attn_length, 1, self._attn_size])
hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
Expand Down Expand Up @@ -1191,30 +1192,30 @@ def output_size(self):
return self._num_units

def _norm(self, inp, scope):
with vs.variable_scope(scope) as scope:
shape = inp.get_shape()[-1:]
gamma_init = init_ops.constant_initializer(self._g)
beta_init = init_ops.constant_initializer(self._b)
gamma = vs.get_variable("gamma", shape=shape, initializer=gamma_init) # pylint: disable=unused-variable
beta = vs.get_variable("beta", shape=shape, initializer=beta_init) # pylint: disable=unused-variable
normalized = layers.layer_norm(inp, reuse=True, scope=scope)
return normalized

def _linear(self, args, scope="linear"):
shape = inp.get_shape()[-1:]
gamma_init = init_ops.constant_initializer(self._g)
beta_init = init_ops.constant_initializer(self._b)
with vs.variable_scope(scope):
# Initialize beta and gamma for use by layer_norm.
vs.get_variable("gamma", shape=shape, initializer=gamma_init)
vs.get_variable("beta", shape=shape, initializer=beta_init)
normalized = layers.layer_norm(inp, reuse=True, scope=scope)
return normalized

def _linear(self, args):
out_size = 4 * self._num_units
proj_size = args.get_shape()[-1]
with vs.variable_scope(scope) as scope:
weights = vs.get_variable("weights", [proj_size, out_size])
out = math_ops.matmul(args, weights)
if not self._layer_norm:
bias = vs.get_variable("b", [out_size])
out += bias
return out
weights = vs.get_variable("weights", [proj_size, out_size])
out = math_ops.matmul(args, weights)
if not self._layer_norm:
bias = vs.get_variable("biases", [out_size])
out = nn_ops.bias_add(out, bias)
return out

def __call__(self, inputs, state, scope=None):
"""LSTM cell with layer normalization and recurrent dropout."""

with vs.variable_scope(scope or type(self).__name__) as scope: # LayerNormBasicLSTMCell # pylint: disable=unused-variables
with vs.variable_scope(scope or "layer_norm_basic_lstm_cell"):
c, h = state
args = array_ops.concat(1, [inputs, h])
concat = self._linear(args)
Expand Down
Loading

0 comments on commit 92da8ab

Please sign in to comment.