From d440e4bf7a2dc84bb7ccfed023d0532fa2f9db83 Mon Sep 17 00:00:00 2001 From: Wanasit Tanakitrungruang Date: Fri, 25 Aug 2017 06:55:45 +0900 Subject: [PATCH] Recurrent layer to treat 2nd and the rest of inputs as initial_states (#7691) * Recurrent layer to treat 2nd and the rest of inputs as initial_states * Fix spaces * Follow code review feedback --- keras/layers/recurrent.py | 8 ++++++++ tests/keras/layers/recurrent_test.py | 23 +++++++++++++++++++++++ tests/test_model_saving.py | 21 ++++++++++++++++++++- 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/keras/layers/recurrent.py b/keras/layers/recurrent.py index a31416de638..dfc3d44a95d 100644 --- a/keras/layers/recurrent.py +++ b/keras/layers/recurrent.py @@ -252,6 +252,14 @@ def preprocess_input(self, inputs, training=None): return inputs def __call__(self, inputs, initial_state=None, **kwargs): + + # If there are multiple inputs, then + # they should be the main input and `initial_state` + # e.g. when loading model from file + if isinstance(inputs, (list, tuple)) and len(inputs) > 1 and initial_state is None: + initial_state = inputs[1:] + inputs = inputs[0] + # If `initial_state` is specified, # and if it a Keras tensor, # then add it to the inputs and temporarily diff --git a/tests/keras/layers/recurrent_test.py b/tests/keras/layers/recurrent_test.py index 820f3f642db..9c653782366 100644 --- a/tests/keras/layers/recurrent_test.py +++ b/tests/keras/layers/recurrent_test.py @@ -268,6 +268,29 @@ def test_reset_states_with_values(layer_class): layer.reset_states([1] * (len(layer.states) + 1)) +@rnn_test +def test_initial_states_as_other_inputs(layer_class): + num_states = 2 if layer_class is recurrent.LSTM else 1 + + # Test with Keras tensor + main_inputs = Input((timesteps, embedding_dim)) + initial_state = [Input((units,)) for _ in range(num_states)] + inputs = [main_inputs] + initial_state + + layer = layer_class(units) + output = layer(inputs) + assert initial_state[0] in layer.inbound_nodes[0].input_tensors + + model = Model(inputs, output) + model.compile(loss='categorical_crossentropy', optimizer='adam') + + main_inputs = np.random.random((num_samples, timesteps, embedding_dim)) + initial_state = [np.random.random((num_samples, units)) + for _ in range(num_states)] + targets = np.random.random((num_samples, units)) + model.train_on_batch([main_inputs] + initial_state, targets) + + @rnn_test def test_specify_state_with_masking(layer_class): ''' This test based on a previously failing issue here: diff --git a/tests/test_model_saving.py b/tests/test_model_saving.py index 7a8a558c975..e1821ba130f 100644 --- a/tests/test_model_saving.py +++ b/tests/test_model_saving.py @@ -6,7 +6,7 @@ from keras import backend as K from keras.models import Model, Sequential -from keras.layers import Dense, Lambda, RepeatVector, TimeDistributed +from keras.layers import Dense, Lambda, RepeatVector, TimeDistributed, LSTM from keras.layers import Input from keras import optimizers from keras import losses @@ -351,5 +351,24 @@ def test_saving_custom_activation_function(): assert_allclose(out, out2, atol=1e-05) +@keras_test +def test_saving_recurrent_layer_with_init_state(): + vector_size = 8 + input_length = 20 + + input_initial_state = Input(shape=(vector_size,)) + input_x = Input(shape=(input_length, vector_size)) + + lstm = LSTM(vector_size, return_sequences=True)( + input_x, initial_state=[input_initial_state, input_initial_state]) + + model = Model(inputs=[input_x, input_initial_state], outputs=[lstm]) + + _, fname = tempfile.mkstemp('.h5') + model.save(fname) + + loaded_model = load_model(fname) + os.remove(fname) + if __name__ == '__main__': pytest.main([__file__])