Skip to content

Commit

Permalink
[RELNOTES] Refactor RNNs to rely on atomic cells. (keras-team#7943)
Browse files Browse the repository at this point in the history
* Refactor RNNs to rely on atomic cells.

* Add RNN docstrings back.

* Fix Theano/CNTK RNN dropout

* Disable dropout in CNTK dynamic RNNs.

* Standardize input dropout masks in RNNs.

* Skip RNN dropout test for CNTK.

* Remove legacy constraints

* Increase stacked RNN test coverage
  • Loading branch information
fchollet authored Sep 22, 2017
1 parent 71a791c commit a510390
Show file tree
Hide file tree
Showing 8 changed files with 1,762 additions and 540 deletions.
88 changes: 0 additions & 88 deletions examples/lstm_benchmark.py

This file was deleted.

16 changes: 16 additions & 0 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,8 @@ def _static_rnn(step_function, inputs, initial_states,
shape = int_shape(inputs)
dims = len(shape)

uses_learning_phase = False

if dims < 3:
raise ValueError('Input should be at least 3D.')

Expand Down Expand Up @@ -1226,6 +1228,8 @@ def _static_rnn(step_function, inputs, initial_states,

output, new_states = step_function(
current, tuple(states) + tuple(constants))
if getattr(output, '_uses_learning_phase', False):
uses_learning_phase = True

if mask is not None:
mask_slice = C.ops.slice(mask, time_axis, i, i + 1)
Expand Down Expand Up @@ -1254,6 +1258,8 @@ def _static_rnn(step_function, inputs, initial_states,

output, new_states = step_function(
current, tuple(states) + tuple(constants))
if getattr(output, '_uses_learning_phase', False):
uses_learning_phase = True

if mask is not None:
mask_slice = C.ops.slice(mask, time_axis, i, i + 1)
Expand Down Expand Up @@ -1285,6 +1291,7 @@ def _static_rnn(step_function, inputs, initial_states,
last_output = outputs[i]
i += 1

last_output._uses_learning_phase = uses_learning_phase
return last_output, final_output, states


Expand All @@ -1295,6 +1302,9 @@ def rnn(step_function, inputs, initial_states,
shape = int_shape(inputs)
dims = len(shape)

global uses_learning_phase
uses_learning_phase = False

if dims < 3:
raise ValueError('CNTK Backend: the input of rnn has only rank %d '
'Need at least rank 3 to run RNN.' % dims)
Expand Down Expand Up @@ -1370,6 +1380,11 @@ def _recurrence(x, states, m):
past_values.append(C.sequence.past_value(p, s))
new_output, new_states = step_function(
x, tuple(past_values) + tuple(constants))

if getattr(new_output, '_uses_learning_phase', False):
global uses_learning_phase
uses_learning_phase = True

if m is not None:
new_states = [C.element_select(m, n, s) for n, s in zip(new_states, past_values)]
n_s = []
Expand Down Expand Up @@ -1398,6 +1413,7 @@ def _recurrence(x, states, m):
else:
f_stats.append(l_s)

last_output._uses_learning_phase = uses_learning_phase
return last_output, final_output, f_stats


Expand Down
14 changes: 14 additions & 0 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2378,6 +2378,9 @@ def rnn(step_function, inputs, initial_states,
if constants is None:
constants = []

global uses_learning_phase
uses_learning_phase = False

if unroll:
if not inputs.get_shape()[0]:
raise ValueError('Unrolling requires a '
Expand All @@ -2397,6 +2400,8 @@ def rnn(step_function, inputs, initial_states,

for inp, mask_t in zip(input_list, mask_list):
output, new_states = step_function(inp, states + constants)
if getattr(output, '_uses_learning_phase', False):
uses_learning_phase = True

# tf.where needs its condition tensor
# to be the same shape as its two
Expand Down Expand Up @@ -2435,6 +2440,8 @@ def rnn(step_function, inputs, initial_states,
else:
for inp in input_list:
output, states = step_function(inp, states + constants)
if getattr(output, '_uses_learning_phase', False):
uses_learning_phase = True
successive_outputs.append(output)
successive_states.append(states)
last_output = successive_outputs[-1]
Expand Down Expand Up @@ -2493,6 +2500,9 @@ def _step(time, output_ta_t, *states):
output, new_states = step_function(current_input,
tuple(states) +
tuple(constants))
if getattr(output, '_uses_learning_phase', False):
global uses_learning_phase
uses_learning_phase = True
for state, new_state in zip(states, new_states):
new_state.set_shape(state.get_shape())
tiled_mask_t = tf.tile(mask_t,
Expand All @@ -2517,6 +2527,9 @@ def _step(time, output_ta_t, *states):
output, new_states = step_function(current_input,
tuple(states) +
tuple(constants))
if getattr(output, '_uses_learning_phase', False):
global uses_learning_phase
uses_learning_phase = True
for state, new_state in zip(states, new_states):
new_state.set_shape(state.get_shape())
output_ta_t = output_ta_t.write(time, output)
Expand All @@ -2537,6 +2550,7 @@ def _step(time, output_ta_t, *states):

axes = [1, 0] + list(range(2, len(outputs.get_shape())))
outputs = tf.transpose(outputs, axes)
last_output._uses_learning_phase = uses_learning_phase
return last_output, outputs, new_states


Expand Down
17 changes: 16 additions & 1 deletion keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,9 @@ def rnn(step_function, inputs, initial_states,
if constants is None:
constants = []

global uses_learning_phase
uses_learning_phase = False

if mask is not None:
if mask.ndim == ndim - 1:
mask = expand_dims(mask)
Expand All @@ -1323,6 +1326,8 @@ def rnn(step_function, inputs, initial_states,
states = initial_states
for i in indices:
output, new_states = step_function(inputs[i], states + constants)
if getattr(output, '_uses_learning_phase', False):
uses_learning_phase = True

if len(successive_outputs) == 0:
prev_output = zeros_like(output)
Expand Down Expand Up @@ -1352,6 +1357,9 @@ def rnn(step_function, inputs, initial_states,

def _step(inputs, mask, output_tm1, *states):
outputs, new_states = step_function(inputs, states)
if getattr(outputs, '_uses_learning_phase', False):
global uses_learning_phase
uses_learning_phase = True
# output previous output if masked.
outputs = T.switch(mask, outputs, output_tm1)
return_states = []
Expand Down Expand Up @@ -1384,6 +1392,8 @@ def _step(inputs, mask, output_tm1, *states):
states = initial_states
for i in indices:
outputs, states = step_function(inputs[i], states + constants)
if getattr(outputs, '_uses_learning_phase', False):
uses_learning_phase = True
successive_outputs.append(outputs)
successive_states.append(states)
outputs = T.stack(*successive_outputs)
Expand All @@ -1394,9 +1404,13 @@ def _step(inputs, mask, output_tm1, *states):
else:
def _step(inputs, *states):
outputs, new_states = step_function(inputs, states)
if getattr(outputs, '_uses_learning_phase', False):
global uses_learning_phase
uses_learning_phase = True
return [outputs] + new_states

# Theano likes to make shape==1 dimensions in the initial states (outputs_info) broadcastable
# Theano likes to make shape==1 dimensions
# in the initial states (outputs_info) broadcastable
if len(initial_states) > 0:
initial_states[0] = T.unbroadcast(initial_states[0], 1)

Expand All @@ -1421,6 +1435,7 @@ def _step(inputs, *states):
axes = [1, 0] + list(range(2, outputs.ndim))
outputs = outputs.dimshuffle(axes)
states = [T.squeeze(state[-1]) for state in states]
last_output._uses_learning_phase = uses_learning_phase
return last_output, outputs, states


Expand Down
Loading

0 comments on commit a510390

Please sign in to comment.