Skip to content

Commit

Permalink
Added support for passing external constants to RNN, which will pass …
Browse files Browse the repository at this point in the history
…them on to the cell
  • Loading branch information
andhus committed Sep 24, 2017
1 parent 710898f commit 28795f1
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 33 deletions.
164 changes: 131 additions & 33 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ class RNN(Layer):
# Arguments
cell: A RNN cell instance. A RNN cell is a class that has:
- a `call(input_at_t, states_at_t)` method, returning
`(output_at_t, states_at_t_plus_1)`.
`(output_at_t, states_at_t_plus_1)`. The call method of the
cell can also take the optional argument `constants`, see
section "Note on passing external constants" below.
- a `state_size` attribute. This can be a single integer
(single state) in which case it is
the size of the recurrent state
Expand Down Expand Up @@ -276,6 +278,14 @@ class RNN(Layer):
`states` should be a numpy array or list of numpy arrays representing
the initial state of the RNN layer.
# Note on passing external constants to RNNs
You can pass "external" constants to the cell using the `constants`
keyword argument of RNN.__call__ (as well as RNN.call) method. This
requires that the `cell.call` method accepts the same keyword argument
`constants`. Such constants can be used to condition the cell
transformation on additional static inputs (not changing over time)
(a.k.a. an attention mechanism).
# Examples
```python
Expand Down Expand Up @@ -354,6 +364,8 @@ def __init__(self, cell,
self.state_spec = InputSpec(shape=(None, self.cell.state_size))
self._states = None

self.external_constants_spec = None

@property
def states(self):
if self._states is None:
Expand Down Expand Up @@ -399,6 +411,14 @@ def compute_mask(self, inputs, mask):
return output_mask

def build(self, input_shape):
# Note input_shape will be list of shapes of initial states and
# constants if these are passed in __call__.
if self.external_constants_spec is not None:
# input_shape must be list
constants_shape = input_shape[-len(self.external_constants_spec):]
else:
constants_shape = None

if isinstance(input_shape, list):
input_shape = input_shape[0]

Expand All @@ -411,7 +431,10 @@ def build(self, input_shape):

if isinstance(self.cell, Layer):
step_input_shape = (input_shape[0],) + input_shape[2:]
self.cell.build(step_input_shape)
if constants_shape is not None:
self.cell.build([step_input_shape] + constants_shape)
else:
self.cell.build(step_input_shape)

def get_initial_state(self, inputs):
# build an all-zero tensor of shape (samples, output_dim)
Expand All @@ -424,43 +447,58 @@ def get_initial_state(self, inputs):
else:
return [K.tile(initial_state, [1, self.cell.state_size])]

def __call__(self, inputs, initial_state=None, **kwargs):
# If there are multiple inputs, then
# they should be the main input and `initial_state`
# e.g. when loading model from file
if isinstance(inputs, (list, tuple)) and len(inputs) > 1 and initial_state is None:
initial_state = inputs[1:]
inputs = inputs[0]

# If `initial_state` is specified,
# and if it a Keras tensor,
# then add it to the inputs and temporarily
# modify the input spec to include the state.
if initial_state is None:
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
# If there are multiple inputs, then they should be the main input,
# `initial_state` and (optionally) `constants` e.g. when loading model
# from file # TODO ask for clarification
inputs, initial_state, constants = self._normalize_args(
inputs, initial_state, constants)

# we need to know length of constants in build
if constants:
self.external_constants_spec = [
InputSpec(shape=K.int_shape(constant))
for constant in constants
]

if initial_state is None and constants is None:
return super(RNN, self).__call__(inputs, **kwargs)

if not isinstance(initial_state, (list, tuple)):
initial_state = [initial_state]

is_keras_tensor = hasattr(initial_state[0], '_keras_history')
for tensor in initial_state:
# If any of `initial_state` or `constants` are specified and are Keras
# tensors, then add them to the inputs and temporarily modify the
# input_spec to include them.

check_list = []
if initial_state:
check_list += initial_state
if constants:
check_list += constants
# at this point check_list cannot be empty
is_keras_tensor = hasattr(check_list[0], '_keras_history')
for tensor in check_list:
if hasattr(tensor, '_keras_history') != is_keras_tensor:
raise ValueError('The initial state of an RNN layer cannot be'
' specified with a mix of Keras tensors and'
' non-Keras tensors')
raise ValueError('The initial state and constants of an RNN'
' layer cannot be specified with a mix of'
' Keras tensors and non-Keras tensors')

if is_keras_tensor:
# Compute the full input spec, including state
# Compute the full input spec, including state and constants
input_spec = self.input_spec
state_spec = self.state_spec
if not isinstance(input_spec, list):
input_spec = [input_spec]
if not isinstance(state_spec, list):
state_spec = [state_spec]
self.input_spec = input_spec + state_spec

# Compute the full inputs, including state
inputs = [inputs] + list(initial_state)
self.input_spec = input_spec
inputs = [inputs]
if initial_state:
self.input_spec += state_spec
inputs += initial_state
kwargs['initial_state'] = initial_state
if constants:
self.input_spec += self.external_constants_spec
inputs += constants
kwargs['constants'] = constants

# Perform the call
output = super(RNN, self).__call__(inputs, **kwargs)
Expand All @@ -470,16 +508,22 @@ def __call__(self, inputs, initial_state=None, **kwargs):
return output
else:
kwargs['initial_state'] = initial_state
if constants is not None:
kwargs['constants'] = constants
return super(RNN, self).__call__(inputs, **kwargs)

def call(self, inputs, mask=None, training=None, initial_state=None):
def call(self,
inputs,
mask=None,
training=None,
initial_state=None,
constants=None):
# input shape: `(samples, time (padded with zeros), input_dim)`
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
if isinstance(inputs, list):
initial_state = inputs[1:]
inputs = inputs[0]
elif initial_state is not None:
if initial_state is not None:
pass
elif self.stateful:
initial_state = self.states
Expand Down Expand Up @@ -508,9 +552,17 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
'- If using the functional API, specify '
'the time dimension by passing a `shape` '
'or `batch_shape` argument to your Input layer.')

cell_kwargs = {}
if has_arg(self.cell.call, 'training'):
step = functools.partial(self.cell.call, training=training)
cell_kwargs['training'] = training

if constants is not None:
if not has_arg(self.cell.call, 'constants'):
raise TypeError('cell does not take keyword argument constants')
cell_kwargs['constants'] = constants

if cell_kwargs:
step = functools.partial(self.cell.call, **cell_kwargs)
else:
step = self.cell.call
last_output, outputs, states = K.rnn(step,
Expand Down Expand Up @@ -544,6 +596,52 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
else:
return output

def _normalize_args(self, inputs, initial_state=None, constants=None):
"""The inputs `initial_state` and `constants` can be passed to
RNN.__call__ either by separate arguments or as part of `inputs`. In
this case `inputs` is a list of tensors of which the first one is the
actual (sequence) input followed by initial states, followed by
constants.
This method separates and noramlizes the different groups of inputs.
# Arguments
inputs: tensor of list/tuple of tensors
initial_state: tensor or list of tensors or None
constants: tensor or list of tensors or None
# Returns
inputs: tensor
initial_state: list of tensors or None
constants: list of tensors or None
"""
if isinstance(inputs, (list, tuple)):
remaining_inputs = inputs[1:]
inputs = inputs[0]
if remaining_inputs and initial_state is None:
if isinstance(self.state_spec, list):
n_states = len(self.state_spec)
else:
n_states = 1
initial_state = remaining_inputs[:n_states]
remaining_inputs = remaining_inputs[n_states:]
if remaining_inputs and constants is None:
constants = remaining_inputs
if len(remaining_inputs) > 0:
raise ValueError('too many inputs were passed')

def to_list_or_none(x): # TODO break out?
if x is None or isinstance(x, list):
return x
if isinstance(x, tuple):
return list(x)
return [x]

initial_state = to_list_or_none(initial_state)
constants = to_list_or_none(constants)

return inputs, initial_state, constants

def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
Expand Down
72 changes: 72 additions & 0 deletions tests/keras/layers/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,5 +564,77 @@ def test_batch_size_equal_one(layer_class):
model.train_on_batch(x, y)


def test_rnn_cell_with_constants_layer():

class RNNCellWithConstants(keras.layers.Layer):

def __init__(self, units, **kwargs):
self.units = units
self.state_size = units
super(RNNCellWithConstants, self).__init__(**kwargs)

def build(self, input_shape):
if not isinstance(input_shape, list):
raise TypeError('expects constants shape')
[input_shape, constant_shape] = input_shape
# will (and should) raise if more than one constant passed

self.input_kernel = self.add_weight(
shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.constant_kernel = self.add_weight(
shape=(constant_shape[-1], self.units),
initializer='uniform',
name='constant_kernel')
self.built = True

def call(self, inputs, states, constants):
[prev_output] = states
[constant] = constants
h_input = keras.backend.dot(inputs, self.input_kernel)
h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
h_const = keras.backend.dot(constant, self.constant_kernel)
output = h_input + h_state + h_const
return output, [output]

def get_config(self):
config = {'units': self.units}
base_config = super(RNNCellWithConstants, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

# Test basic case.
x = keras.Input((None, 5))
c = keras.Input((3,))
cell = RNNCellWithConstants(32)
layer = recurrent.RNN(cell)
y = layer(x, constants=c)
model = keras.models.Model([x, c], y)
model.compile(optimizer='rmsprop', loss='mse')
model.train_on_batch(
[np.zeros((6, 5, 5)), np.zeros((6, 3))],
np.zeros((6, 32))
)

# Test basic case serialization.
x_np = np.random.random((6, 5, 5))
c_np = np.random.random((6, 3))
y_np = model.predict([x_np, c_np])
weights = model.get_weights()
config = layer.get_config()
with keras.utils.CustomObjectScope(
{'RNNCellWithConstants': RNNCellWithConstants}):
layer = recurrent.RNN.from_config(config)
y = layer(x, constants=c)
model = keras.models.Model([x, c], y)
model.set_weights(weights)
y_np_2 = model.predict([x_np, c_np])
assert_allclose(y_np, y_np_2, atol=1e-4)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 28795f1

Please sign in to comment.