Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurrent Attention API: Support constants in RNN #7980

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 152 additions & 54 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
import numpy as np
import functools
import warnings

from .. import backend as K
Expand Down Expand Up @@ -200,7 +199,9 @@ class RNN(Layer):
# Arguments
cell: A RNN cell instance. A RNN cell is a class that has:
- a `call(input_at_t, states_at_t)` method, returning
`(output_at_t, states_at_t_plus_1)`.
`(output_at_t, states_at_t_plus_1)`. The call method of the
cell can also take the optional argument `constants`, see
section "Note on passing external constants" below.
- a `state_size` attribute. This can be a single integer
(single state) in which case it is
the size of the recurrent state
Expand Down Expand Up @@ -292,6 +293,14 @@ class RNN(Layer):
`states` should be a numpy array or list of numpy arrays representing
the initial state of the RNN layer.

# Note on passing external constants to RNNs
You can pass "external" constants to the cell using the `constants`
keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
requires that the `cell.call` method accepts the same keyword argument
`constants`. Such constants can be used to condition the cell
transformation on additional static inputs (not changing over time),
a.k.a. an attention mechanism.

# Examples

```python
Expand Down Expand Up @@ -363,12 +372,10 @@ def __init__(self, cell,

self.supports_masking = True
self.input_spec = [InputSpec(ndim=3)]
if hasattr(self.cell.state_size, '__len__'):
self.state_spec = [InputSpec(shape=(None, dim))
for dim in self.cell.state_size]
else:
self.state_spec = InputSpec(shape=(None, self.cell.state_size))
self.state_spec = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more advanced cells (typically attention cells such as this one: https://github.com/andhus/keras/pull/3/files#diff-9c4188ddc4dd173d80f64feed5b89412R258) the cell.state_sizeis not defined until the cell is built. But that's not a problem, state_specis just set a bit later in build (or in __call__ if initial_stateis passed to RNN)

self._states = None
self.constants_spec = None
self._num_constants = None

@property
def states(self):
Expand Down Expand Up @@ -415,19 +422,46 @@ def compute_mask(self, inputs, mask):
return output_mask

def build(self, input_shape):
# Note input_shape will be list of shapes of initial states and
Copy link
Contributor Author

@andhus andhus Oct 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note input_shape will be list of shapes of input followed by shapes of initial states and constants if the later as passed in __call__

# constants if these are passed in __call__.
if self._num_constants is not None:
constants_shape = input_shape[-self._num_constants:]
else:
constants_shape = None

if isinstance(input_shape, list):
input_shape = input_shape[0]

batch_size = input_shape[0] if self.stateful else None
input_dim = input_shape[-1]
self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim))

if self.stateful:
self.reset_states()

# allow cell (if layer) to build before we set or validate state_spec
if isinstance(self.cell, Layer):
step_input_shape = (input_shape[0],) + input_shape[2:]
self.cell.build(step_input_shape)
if constants_shape is not None:
self.cell.build([step_input_shape] + constants_shape)
else:
self.cell.build(step_input_shape)

# set or validate state_spec
if hasattr(self.cell.state_size, '__len__'):
state_size = list(self.cell.state_size)
else:
state_size = [self.cell.state_size]

if self.state_spec is not None:
# initial_state was passed in call, check compatibility
if not [spec.shape[-1] for spec in self.state_spec] == state_size:
raise ValueError(
'an initial_state was passed that is not compatible with'
' cell.state_size, state_spec: {}, cell.state_size:'
' {}'.format(self.state_spec, self.cell.state_size))
else:
self.state_spec = [InputSpec(shape=(None, dim))
for dim in state_size]
if self.stateful:
self.reset_states()

def get_initial_state(self, inputs):
# build an all-zero tensor of shape (samples, output_dim)
Expand All @@ -440,62 +474,65 @@ def get_initial_state(self, inputs):
else:
return [K.tile(initial_state, [1, self.cell.state_size])]

def __call__(self, inputs, initial_state=None, **kwargs):
# If there are multiple inputs, then
# they should be the main input and `initial_state`
# e.g. when loading model from file
if isinstance(inputs, (list, tuple)) and len(inputs) > 1 and initial_state is None:
initial_state = inputs[1:]
inputs = inputs[0]
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = self._standardize_args(
inputs, initial_state, constants)

# If `initial_state` is specified,
# and if it a Keras tensor,
# then add it to the inputs and temporarily
# modify the input spec to include the state.
if initial_state is None:
if initial_state is None and constants is None:
return super(RNN, self).__call__(inputs, **kwargs)

if not isinstance(initial_state, (list, tuple)):
initial_state = [initial_state]
# If any of `initial_state` or `constants` are specified and are Keras
# tensors, then add them to the inputs and temporarily modify the
# input_spec to include them.

is_keras_tensor = hasattr(initial_state[0], '_keras_history')
for tensor in initial_state:
additional_inputs = []
additional_specs = []
if initial_state is not None:
kwargs['initial_state'] = initial_state
additional_inputs += initial_state
self.state_spec = [InputSpec(shape=K.int_shape(state))
for state in initial_state]
additional_specs += self.state_spec
if constants is not None:
kwargs['constants'] = constants
additional_inputs += constants
self.constants_spec = [InputSpec(shape=K.int_shape(constant))
for constant in constants]
self._num_constants = len(constants)
additional_specs += self.constants_spec
# at this point additional_inputs cannot be empty
is_keras_tensor = hasattr(additional_inputs[0], '_keras_history')
for tensor in additional_inputs:
if hasattr(tensor, '_keras_history') != is_keras_tensor:
raise ValueError('The initial state of an RNN layer cannot be'
' specified with a mix of Keras tensors and'
' non-Keras tensors')
raise ValueError('The initial state or constants of an RNN'
' layer cannot be specified with a mix of'
' Keras tensors and non-Keras tensors')

if is_keras_tensor:
# Compute the full input spec, including state
input_spec = self.input_spec
state_spec = self.state_spec
if not isinstance(input_spec, list):
input_spec = [input_spec]
if not isinstance(state_spec, list):
state_spec = [state_spec]
self.input_spec = input_spec + state_spec

# Compute the full inputs, including state
inputs = [inputs] + list(initial_state)

# Perform the call
output = super(RNN, self).__call__(inputs, **kwargs)

# Restore original input spec
self.input_spec = input_spec
# Compute the full input spec, including state and constants
full_input = [inputs] + additional_inputs
full_input_spec = self.input_spec + additional_specs
# Perform the call with temporarily replaced input_spec
original_input_spec = self.input_spec
self.input_spec = full_input_spec
output = super(RNN, self).__call__(full_input, **kwargs)
self.input_spec = original_input_spec
return output
else:
kwargs['initial_state'] = initial_state
return super(RNN, self).__call__(inputs, **kwargs)

def call(self, inputs, mask=None, training=None, initial_state=None):
def call(self,
inputs,
mask=None,
training=None,
initial_state=None,
constants=None):
# input shape: `(samples, time (padded with zeros), input_dim)`
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
if isinstance(inputs, list):
initial_state = inputs[1:]
inputs = inputs[0]
elif initial_state is not None:
if initial_state is not None:
pass
elif self.stateful:
initial_state = self.states
Expand Down Expand Up @@ -525,13 +562,27 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
'the time dimension by passing a `shape` '
'or `batch_shape` argument to your Input layer.')

kwargs = {}
if has_arg(self.cell.call, 'training'):
step = functools.partial(self.cell.call, training=training)
kwargs['training'] = training

if constants:
if not has_arg(self.cell.call, 'constants'):
raise ValueError('RNN cell does not support constants')

def step(inputs, states):
constants = states[-self._num_constants:]
states = states[:-self._num_constants]
return self.cell.call(inputs, states, constants=constants,
**kwargs)
else:
step = self.cell.call
def step(inputs, states):
return self.cell.call(inputs, states, **kwargs)

last_output, outputs, states = K.rnn(step,
inputs,
initial_state,
constants=constants,
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
Expand Down Expand Up @@ -560,6 +611,47 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
else:
return output

def _standardize_args(self, inputs, initial_state, constants):
"""Brings the arguments of `__call__` that can contain input tensors to
standard format.

When running a model loaded from file, the input tensors
`initial_state` and `constants` can be passed to `RNN.__call__` as part
of `inputs` instead of by the dedicated keyword arguments. This method
makes sure the arguments are separated and that `initial_state` and
`constants` are lists of tensors (or None).

# Arguments
inputs: tensor or list/tuple of tensors
initial_state: tensor or list of tensors or None
constants: tensor or list of tensors or None

# Returns
inputs: tensor
initial_state: list of tensors or None
constants: list of tensors or None
"""
if isinstance(inputs, list):
assert initial_state is None and constants is None
if self._num_constants is not None:
constants = inputs[-self._num_constants:]
inputs = inputs[:-self._num_constants]
if len(inputs) > 1:
initial_state = inputs[1:]
inputs = inputs[0]

def to_list_or_none(x):
if x is None or isinstance(x, list):
return x
if isinstance(x, tuple):
return list(x)
return [x]

initial_state = to_list_or_none(initial_state)
constants = to_list_or_none(constants)

return inputs, initial_state, constants

def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
Expand Down Expand Up @@ -618,6 +710,9 @@ def get_config(self):
'go_backwards': self.go_backwards,
'stateful': self.stateful,
'unroll': self.unroll}
if self._num_constants is not None:
config['num_constants'] = self._num_constants

cell_config = self.cell.get_config()
config['cell'] = {'class_name': self.cell.__class__.__name__,
'config': cell_config}
Expand All @@ -629,7 +724,10 @@ def from_config(cls, config, custom_objects=None):
from . import deserialize as deserialize_layer
cell = deserialize_layer(config.pop('cell'),
custom_objects=custom_objects)
return cls(cell, **config)
num_constants = config.pop('num_constants', None)
layer = cls(cell, **config)
layer._num_constants = num_constants
return layer

@property
def trainable_weights(self):
Expand Down
72 changes: 72 additions & 0 deletions tests/keras/layers/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,5 +568,77 @@ def test_batch_size_equal_one(layer_class):
model.train_on_batch(x, y)


def test_rnn_cell_with_constants_layer():

class RNNCellWithConstants(keras.layers.Layer):

def __init__(self, units, **kwargs):
self.units = units
self.state_size = units
super(RNNCellWithConstants, self).__init__(**kwargs)

def build(self, input_shape):
if not isinstance(input_shape, list):
raise TypeError('expects constants shape')
[input_shape, constant_shape] = input_shape
# will (and should) raise if more than one constant passed

self.input_kernel = self.add_weight(
shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.constant_kernel = self.add_weight(
shape=(constant_shape[-1], self.units),
initializer='uniform',
name='constant_kernel')
self.built = True

def call(self, inputs, states, constants):
[prev_output] = states
[constant] = constants
h_input = keras.backend.dot(inputs, self.input_kernel)
h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
h_const = keras.backend.dot(constant, self.constant_kernel)
output = h_input + h_state + h_const
return output, [output]

def get_config(self):
config = {'units': self.units}
base_config = super(RNNCellWithConstants, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

# Test basic case.
x = keras.Input((None, 5))
c = keras.Input((3,))
cell = RNNCellWithConstants(32)
layer = recurrent.RNN(cell)
y = layer(x, constants=c)
model = keras.models.Model([x, c], y)
model.compile(optimizer='rmsprop', loss='mse')
model.train_on_batch(
[np.zeros((6, 5, 5)), np.zeros((6, 3))],
np.zeros((6, 32))
)

# Test basic case serialization.
x_np = np.random.random((6, 5, 5))
c_np = np.random.random((6, 3))
y_np = model.predict([x_np, c_np])
weights = model.get_weights()
config = layer.get_config()
custom_objects = {'RNNCellWithConstants': RNNCellWithConstants}
with keras.utils.CustomObjectScope(custom_objects):
layer = recurrent.RNN.from_config(config)
y = layer(x, constants=c)
model = keras.models.Model([x, c], y)
model.set_weights(weights)
y_np_2 = model.predict([x_np, c_np])
assert_allclose(y_np, y_np_2, atol=1e-4)


if __name__ == '__main__':
pytest.main([__file__])