diff --git a/keras/layers/recurrent.py b/keras/layers/recurrent.py index 4fd4edc50f8..bfcaa0c2711 100644 --- a/keras/layers/recurrent.py +++ b/keras/layers/recurrent.py @@ -185,7 +185,9 @@ class RNN(Layer): # Arguments cell: A RNN cell instance. A RNN cell is a class that has: - a `call(input_at_t, states_at_t)` method, returning - `(output_at_t, states_at_t_plus_1)`. + `(output_at_t, states_at_t_plus_1)`. The call method of the + cell can also take the optional argument `constants`, see + section "Note on passing external constants" below. - a `state_size` attribute. This can be a single integer (single state) in which case it is the size of the recurrent state @@ -276,6 +278,14 @@ class RNN(Layer): `states` should be a numpy array or list of numpy arrays representing the initial state of the RNN layer. + # Note on passing external constants to RNNs + You can pass "external" constants to the cell using the `constants` + keyword argument of RNN.__call__ (as well as RNN.call) method. This + requires that the `cell.call` method accepts the same keyword argument + `constants`. Such constants can be used to condition the cell + transformation on additional static inputs (not changing over time) + (a.k.a. an attention mechanism). + # Examples ```python @@ -354,6 +364,8 @@ def __init__(self, cell, self.state_spec = InputSpec(shape=(None, self.cell.state_size)) self._states = None + self.external_constants_spec = None + @property def states(self): if self._states is None: @@ -399,6 +411,14 @@ def compute_mask(self, inputs, mask): return output_mask def build(self, input_shape): + # Note input_shape will be list of shapes of initial states and + # constants if these are passed in __call__. + if self.external_constants_spec is not None: + # input_shape must be list + constants_shape = input_shape[-len(self.external_constants_spec):] + else: + constants_shape = None + if isinstance(input_shape, list): input_shape = input_shape[0] @@ -411,7 +431,10 @@ def build(self, input_shape): if isinstance(self.cell, Layer): step_input_shape = (input_shape[0],) + input_shape[2:] - self.cell.build(step_input_shape) + if constants_shape is not None: + self.cell.build([step_input_shape] + constants_shape) + else: + self.cell.build(step_input_shape) def get_initial_state(self, inputs): # build an all-zero tensor of shape (samples, output_dim) @@ -424,43 +447,58 @@ def get_initial_state(self, inputs): else: return [K.tile(initial_state, [1, self.cell.state_size])] - def __call__(self, inputs, initial_state=None, **kwargs): - # If there are multiple inputs, then - # they should be the main input and `initial_state` - # e.g. when loading model from file - if isinstance(inputs, (list, tuple)) and len(inputs) > 1 and initial_state is None: - initial_state = inputs[1:] - inputs = inputs[0] - - # If `initial_state` is specified, - # and if it a Keras tensor, - # then add it to the inputs and temporarily - # modify the input spec to include the state. - if initial_state is None: + def __call__(self, inputs, initial_state=None, constants=None, **kwargs): + # If there are multiple inputs, then they should be the main input, + # `initial_state` and (optionally) `constants` e.g. when loading model + # from file # TODO ask for clarification + inputs, initial_state, constants = self._normalize_args( + inputs, initial_state, constants) + + # we need to know length of constants in build + if constants: + self.external_constants_spec = [ + InputSpec(shape=K.int_shape(constant)) + for constant in constants + ] + + if initial_state is None and constants is None: return super(RNN, self).__call__(inputs, **kwargs) - if not isinstance(initial_state, (list, tuple)): - initial_state = [initial_state] - - is_keras_tensor = hasattr(initial_state[0], '_keras_history') - for tensor in initial_state: + # If any of `initial_state` or `constants` are specified and are Keras + # tensors, then add them to the inputs and temporarily modify the + # input_spec to include them. + + check_list = [] + if initial_state: + check_list += initial_state + if constants: + check_list += constants + # at this point check_list cannot be empty + is_keras_tensor = hasattr(check_list[0], '_keras_history') + for tensor in check_list: if hasattr(tensor, '_keras_history') != is_keras_tensor: - raise ValueError('The initial state of an RNN layer cannot be' - ' specified with a mix of Keras tensors and' - ' non-Keras tensors') + raise ValueError('The initial state and constants of an RNN' + ' layer cannot be specified with a mix of' + ' Keras tensors and non-Keras tensors') if is_keras_tensor: - # Compute the full input spec, including state + # Compute the full input spec, including state and constants input_spec = self.input_spec state_spec = self.state_spec if not isinstance(input_spec, list): input_spec = [input_spec] if not isinstance(state_spec, list): state_spec = [state_spec] - self.input_spec = input_spec + state_spec - - # Compute the full inputs, including state - inputs = [inputs] + list(initial_state) + self.input_spec = input_spec + inputs = [inputs] + if initial_state: + self.input_spec += state_spec + inputs += initial_state + kwargs['initial_state'] = initial_state + if constants: + self.input_spec += self.external_constants_spec + inputs += constants + kwargs['constants'] = constants # Perform the call output = super(RNN, self).__call__(inputs, **kwargs) @@ -470,16 +508,22 @@ def __call__(self, inputs, initial_state=None, **kwargs): return output else: kwargs['initial_state'] = initial_state + if constants is not None: + kwargs['constants'] = constants return super(RNN, self).__call__(inputs, **kwargs) - def call(self, inputs, mask=None, training=None, initial_state=None): + def call(self, + inputs, + mask=None, + training=None, + initial_state=None, + constants=None): # input shape: `(samples, time (padded with zeros), input_dim)` # note that the .build() method of subclasses MUST define # self.input_spec and self.state_spec with complete input shapes. if isinstance(inputs, list): - initial_state = inputs[1:] inputs = inputs[0] - elif initial_state is not None: + if initial_state is not None: pass elif self.stateful: initial_state = self.states @@ -508,9 +552,17 @@ def call(self, inputs, mask=None, training=None, initial_state=None): '- If using the functional API, specify ' 'the time dimension by passing a `shape` ' 'or `batch_shape` argument to your Input layer.') - + cell_kwargs = {} if has_arg(self.cell.call, 'training'): - step = functools.partial(self.cell.call, training=training) + cell_kwargs['training'] = training + + if constants is not None: + if not has_arg(self.cell.call, 'constants'): + raise TypeError('cell does not take keyword argument constants') + cell_kwargs['constants'] = constants + + if cell_kwargs: + step = functools.partial(self.cell.call, **cell_kwargs) else: step = self.cell.call last_output, outputs, states = K.rnn(step, @@ -544,6 +596,52 @@ def call(self, inputs, mask=None, training=None, initial_state=None): else: return output + def _normalize_args(self, inputs, initial_state=None, constants=None): + """The inputs `initial_state` and `constants` can be passed to + RNN.__call__ either by separate arguments or as part of `inputs`. In + this case `inputs` is a list of tensors of which the first one is the + actual (sequence) input followed by initial states, followed by + constants. + + This method separates and noramlizes the different groups of inputs. + + # Arguments + inputs: tensor of list/tuple of tensors + initial_state: tensor or list of tensors or None + constants: tensor or list of tensors or None + + # Returns + inputs: tensor + initial_state: list of tensors or None + constants: list of tensors or None + """ + if isinstance(inputs, (list, tuple)): + remaining_inputs = inputs[1:] + inputs = inputs[0] + if remaining_inputs and initial_state is None: + if isinstance(self.state_spec, list): + n_states = len(self.state_spec) + else: + n_states = 1 + initial_state = remaining_inputs[:n_states] + remaining_inputs = remaining_inputs[n_states:] + if remaining_inputs and constants is None: + constants = remaining_inputs + if len(remaining_inputs) > 0: + raise ValueError('too many inputs were passed') + + def to_list_or_none(x): # TODO break out? + if x is None or isinstance(x, list): + return x + if isinstance(x, tuple): + return list(x) + return [x] + + initial_state = to_list_or_none(initial_state) + constants = to_list_or_none(constants) + + return inputs, initial_state, constants + def reset_states(self, states=None): if not self.stateful: raise AttributeError('Layer must be stateful.') diff --git a/tests/keras/layers/recurrent_test.py b/tests/keras/layers/recurrent_test.py index fc328caf57d..24122e1dce9 100644 --- a/tests/keras/layers/recurrent_test.py +++ b/tests/keras/layers/recurrent_test.py @@ -564,5 +564,77 @@ def test_batch_size_equal_one(layer_class): model.train_on_batch(x, y) +def test_rnn_cell_with_constants_layer(): + + class RNNCellWithConstants(keras.layers.Layer): + + def __init__(self, units, **kwargs): + self.units = units + self.state_size = units + super(RNNCellWithConstants, self).__init__(**kwargs) + + def build(self, input_shape): + if not isinstance(input_shape, list): + raise TypeError('expects constants shape') + [input_shape, constant_shape] = input_shape + # will (and should) raise if more than one constant passed + + self.input_kernel = self.add_weight( + shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + self.constant_kernel = self.add_weight( + shape=(constant_shape[-1], self.units), + initializer='uniform', + name='constant_kernel') + self.built = True + + def call(self, inputs, states, constants): + [prev_output] = states + [constant] = constants + h_input = keras.backend.dot(inputs, self.input_kernel) + h_state = keras.backend.dot(prev_output, self.recurrent_kernel) + h_const = keras.backend.dot(constant, self.constant_kernel) + output = h_input + h_state + h_const + return output, [output] + + def get_config(self): + config = {'units': self.units} + base_config = super(RNNCellWithConstants, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + # Test basic case. + x = keras.Input((None, 5)) + c = keras.Input((3,)) + cell = RNNCellWithConstants(32) + layer = recurrent.RNN(cell) + y = layer(x, constants=c) + model = keras.models.Model([x, c], y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + [np.zeros((6, 5, 5)), np.zeros((6, 3))], + np.zeros((6, 32)) + ) + + # Test basic case serialization. + x_np = np.random.random((6, 5, 5)) + c_np = np.random.random((6, 3)) + y_np = model.predict([x_np, c_np]) + weights = model.get_weights() + config = layer.get_config() + with keras.utils.CustomObjectScope( + {'RNNCellWithConstants': RNNCellWithConstants}): + layer = recurrent.RNN.from_config(config) + y = layer(x, constants=c) + model = keras.models.Model([x, c], y) + model.set_weights(weights) + y_np_2 = model.predict([x_np, c_np]) + assert_allclose(y_np, y_np_2, atol=1e-4) + + if __name__ == '__main__': pytest.main([__file__])