Skip to content

Commit

Permalink
Recurrent layer to treat 2nd and the rest of inputs as initial_states (
Browse files Browse the repository at this point in the history
…#7691)

* Recurrent layer to treat 2nd and the rest of inputs as initial_states

* Fix spaces

* Follow code review feedback
  • Loading branch information
wanasit authored and fchollet committed Aug 24, 2017
1 parent 89b6be4 commit d440e4b
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
8 changes: 8 additions & 0 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,14 @@ def preprocess_input(self, inputs, training=None):
return inputs

def __call__(self, inputs, initial_state=None, **kwargs):

# If there are multiple inputs, then
# they should be the main input and `initial_state`
# e.g. when loading model from file
if isinstance(inputs, (list, tuple)) and len(inputs) > 1 and initial_state is None:
initial_state = inputs[1:]
inputs = inputs[0]

# If `initial_state` is specified,
# and if it a Keras tensor,
# then add it to the inputs and temporarily
Expand Down
23 changes: 23 additions & 0 deletions tests/keras/layers/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,29 @@ def test_reset_states_with_values(layer_class):
layer.reset_states([1] * (len(layer.states) + 1))


@rnn_test
def test_initial_states_as_other_inputs(layer_class):
num_states = 2 if layer_class is recurrent.LSTM else 1

# Test with Keras tensor
main_inputs = Input((timesteps, embedding_dim))
initial_state = [Input((units,)) for _ in range(num_states)]
inputs = [main_inputs] + initial_state

layer = layer_class(units)
output = layer(inputs)
assert initial_state[0] in layer.inbound_nodes[0].input_tensors

model = Model(inputs, output)
model.compile(loss='categorical_crossentropy', optimizer='adam')

main_inputs = np.random.random((num_samples, timesteps, embedding_dim))
initial_state = [np.random.random((num_samples, units))
for _ in range(num_states)]
targets = np.random.random((num_samples, units))
model.train_on_batch([main_inputs] + initial_state, targets)


@rnn_test
def test_specify_state_with_masking(layer_class):
''' This test based on a previously failing issue here:
Expand Down
21 changes: 20 additions & 1 deletion tests/test_model_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from keras import backend as K
from keras.models import Model, Sequential
from keras.layers import Dense, Lambda, RepeatVector, TimeDistributed
from keras.layers import Dense, Lambda, RepeatVector, TimeDistributed, LSTM
from keras.layers import Input
from keras import optimizers
from keras import losses
Expand Down Expand Up @@ -351,5 +351,24 @@ def test_saving_custom_activation_function():
assert_allclose(out, out2, atol=1e-05)


@keras_test
def test_saving_recurrent_layer_with_init_state():
vector_size = 8
input_length = 20

input_initial_state = Input(shape=(vector_size,))
input_x = Input(shape=(input_length, vector_size))

lstm = LSTM(vector_size, return_sequences=True)(
input_x, initial_state=[input_initial_state, input_initial_state])

model = Model(inputs=[input_x, input_initial_state], outputs=[lstm])

_, fname = tempfile.mkstemp('.h5')
model.save(fname)

loaded_model = load_model(fname)
os.remove(fname)

if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit d440e4b

Please sign in to comment.