Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurrent Attention API: Support constants in RNN #7980

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 156 additions & 54 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
import numpy as np
import functools
import warnings

from .. import backend as K
Expand Down Expand Up @@ -200,7 +199,9 @@ class RNN(Layer):
# Arguments
cell: A RNN cell instance. A RNN cell is a class that has:
- a `call(input_at_t, states_at_t)` method, returning
`(output_at_t, states_at_t_plus_1)`.
`(output_at_t, states_at_t_plus_1)`. The call method of the
cell can also take the optional argument `constants`, see
section "Note on passing external constants" below.
- a `state_size` attribute. This can be a single integer
(single state) in which case it is
the size of the recurrent state
Expand Down Expand Up @@ -292,6 +293,14 @@ class RNN(Layer):
`states` should be a numpy array or list of numpy arrays representing
the initial state of the RNN layer.

# Note on passing external constants to RNNs
You can pass "external" constants to the cell using the `constants`
keyword argument of RNN.__call__ (as well as RNN.call) method. This
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

` around code keywords

requires that the `cell.call` method accepts the same keyword argument
`constants`. Such constants can be used to condition the cell
transformation on additional static inputs (not changing over time)
(a.k.a. as attention mechanism).

# Examples

```python
Expand Down Expand Up @@ -363,12 +372,10 @@ def __init__(self, cell,

self.supports_masking = True
self.input_spec = [InputSpec(ndim=3)]
if hasattr(self.cell.state_size, '__len__'):
self.state_spec = [InputSpec(shape=(None, dim))
for dim in self.cell.state_size]
else:
self.state_spec = InputSpec(shape=(None, self.cell.state_size))
self.state_spec = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more advanced cells (typically attention cells such as this one: https://github.com/andhus/keras/pull/3/files#diff-9c4188ddc4dd173d80f64feed5b89412R258) the cell.state_sizeis not defined until the cell is built. But that's not a problem, state_specis just set a bit later in build (or in __call__ if initial_stateis passed to RNN)

self._states = None
self.constants_spec = None
self._n_constants = None

@property
def states(self):
Expand Down Expand Up @@ -415,19 +422,46 @@ def compute_mask(self, inputs, mask):
return output_mask

def build(self, input_shape):
# Note input_shape will be list of shapes of initial states and
Copy link
Contributor Author

@andhus andhus Oct 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note input_shape will be list of shapes of input followed by shapes of initial states and constants if the later as passed in __call__

# constants if these are passed in __call__.
if self._n_constants is not None:
constants_shape = input_shape[-self._n_constants:]
else:
constants_shape = None

if isinstance(input_shape, list):
input_shape = input_shape[0]

batch_size = input_shape[0] if self.stateful else None
input_dim = input_shape[-1]
self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim))

if self.stateful:
self.reset_states()

# allow cell (if layer) to build before we set or validate state_spec
if isinstance(self.cell, Layer):
step_input_shape = (input_shape[0],) + input_shape[2:]
self.cell.build(step_input_shape)
if constants_shape is not None:
self.cell.build([step_input_shape] + constants_shape)
else:
self.cell.build(step_input_shape)

# set or validate state_spec
if hasattr(self.cell.state_size, '__len__'):
state_size = list(self.cell.state_size)
else:
state_size = [self.cell.state_size]

if self.state_spec is not None:
# initial_state was passed in call, check compatibility
if not [spec.shape[-1] for spec in self.state_spec] == state_size:
raise ValueError(
'an initial_state was passed that is not compatible with'
' cell.state_size, state_spec: {}, cell.state_size:'
' {}'.format(self.state_spec, self.cell.state_size))
else:
self.state_spec = [InputSpec(shape=(None, dim))
for dim in state_size]
if self.stateful:
self.reset_states()

def get_initial_state(self, inputs):
# build an all-zero tensor of shape (samples, output_dim)
Expand All @@ -440,62 +474,68 @@ def get_initial_state(self, inputs):
else:
return [K.tile(initial_state, [1, self.cell.state_size])]

def __call__(self, inputs, initial_state=None, **kwargs):
# If there are multiple inputs, then
# they should be the main input and `initial_state`
# e.g. when loading model from file
if isinstance(inputs, (list, tuple)) and len(inputs) > 1 and initial_state is None:
initial_state = inputs[1:]
inputs = inputs[0]
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = self._normalize_args(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer "standardize" (also used elsewhere) since "normalize" has a specific meaning elsewhere.

inputs, initial_state, constants)

# If `initial_state` is specified,
# and if it a Keras tensor,
# then add it to the inputs and temporarily
# modify the input spec to include the state.
if initial_state is None:
if initial_state is None and constants is None:
return super(RNN, self).__call__(inputs, **kwargs)

if not isinstance(initial_state, (list, tuple)):
initial_state = [initial_state]
# If any of `initial_state` or `constants` are specified and are Keras
# tensors, then add them to the inputs and temporarily modify the
# input_spec to include them.

is_keras_tensor = hasattr(initial_state[0], '_keras_history')
for tensor in initial_state:
check_list = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear what is meant by check_list. Maybe input_list would be better?

if initial_state is not None:
kwargs['initial_state'] = initial_state
check_list += initial_state
self.state_spec = [InputSpec(shape=K.int_shape(state))
for state in initial_state]
if constants is not None:
kwargs['constants'] = constants
check_list += constants
self.constants_spec = [InputSpec(shape=K.int_shape(constant))
for constant in constants]
self._n_constants = len(constants)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line with naming conventions in this API, this should be _num_constants.

# at this point check_list cannot be empty
is_keras_tensor = hasattr(check_list[0], '_keras_history')
for tensor in check_list:
if hasattr(tensor, '_keras_history') != is_keras_tensor:
raise ValueError('The initial state of an RNN layer cannot be'
' specified with a mix of Keras tensors and'
' non-Keras tensors')
raise ValueError('The initial state and constants of an RNN'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and -> or

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I think I wanted to refer to that they "jointly" must contain only one kind of tensors (that's what checked for now)... I don't have any opinion on the error msg, but just to verify/understand: shouldn't all inputs be keras tensors or not - and why not include inputs then as well in the check (not only constants and initial_state)?

' layer cannot be specified with a mix of'
' Keras tensors and non-Keras tensors')

if is_keras_tensor:
# Compute the full input spec, including state
input_spec = self.input_spec
state_spec = self.state_spec
if not isinstance(input_spec, list):
input_spec = [input_spec]
if not isinstance(state_spec, list):
state_spec = [state_spec]
self.input_spec = input_spec + state_spec

# Compute the full inputs, including state
inputs = [inputs] + list(initial_state)

# Perform the call
output = super(RNN, self).__call__(inputs, **kwargs)

# Restore original input spec
self.input_spec = input_spec
# Compute the full input spec, including state and constants
full_input = [inputs]
full_input_spec = self.input_spec
if initial_state:
full_input += initial_state
full_input_spec += self.state_spec
if constants:
full_input += constants
full_input_spec += self.constants_spec
# Perform the call with temporarily replaced input_spec
original_input_spec = self.input_spec
self.input_spec = full_input_spec
output = super(RNN, self).__call__(full_input, **kwargs)
self.input_spec = original_input_spec
return output
else:
kwargs['initial_state'] = initial_state
return super(RNN, self).__call__(inputs, **kwargs)

def call(self, inputs, mask=None, training=None, initial_state=None):
def call(self,
inputs,
mask=None,
training=None,
initial_state=None,
constants=None):
# input shape: `(samples, time (padded with zeros), input_dim)`
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
if isinstance(inputs, list):
initial_state = inputs[1:]
inputs = inputs[0]
elif initial_state is not None:
if initial_state is not None:
pass
elif self.stateful:
initial_state = self.states
Expand Down Expand Up @@ -525,13 +565,27 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
'the time dimension by passing a `shape` '
'or `batch_shape` argument to your Input layer.')

kwargs = {}
if has_arg(self.cell.call, 'training'):
step = functools.partial(self.cell.call, training=training)
kwargs['training'] = training

if constants:
if not has_arg(self.cell.call, 'constants'):
raise ValueError('RNN cell does not support constants')

def step(inputs, states):
constants = states[-self._n_constants:]
states = states[:-self._n_constants]
return self.cell.call(inputs, states, constants=constants,
**kwargs)
else:
step = self.cell.call
def step(inputs, states):
return self.cell.call(inputs, states, **kwargs)

last_output, outputs, states = K.rnn(step,
inputs,
initial_state,
constants=constants,
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
Expand Down Expand Up @@ -560,6 +614,48 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
else:
return output

def _normalize_args(self, inputs, initial_state, constants):
"""When running a model loaded from file, the input tensors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: first line of docstring (one-line summary) should fit in one line and end with a period.

`initial_state` and `constants` can be passed to RNN.__call__ as part
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

` around code keywords

of `inputs` in stead of by the dedicated keyword argumetes. In this
case `inputs` is a list of tensors of which the first one is the
actual (sequence) input followed by initial states, followed by
constants.

This method makes sure initial_states and constants are separated from
inputs and that the are lists of tensors (or None).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the -> they


# Arguments
inputs: tensor of list/tuple of tensors
initial_state: tensor or list of tensors or None
constants: tensor or list of tensors or None

# Returns
inputs: tensor
initial_state: list of tensors or None
constants: list of tensors or None
"""
if isinstance(inputs, list):
assert initial_state is None and constants is None
if self._n_constants is not None:
constants = inputs[-self._n_constants:]
inputs = inputs[:-self._n_constants]
if len(inputs) > 1:
initial_state = inputs[1:]
inputs = inputs[0]

def to_list_or_none(x): # TODO break out?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless you want to use it elsewhere, it is fine as it is, you can remove this comment

if x is None or isinstance(x, list):
return x
if isinstance(x, tuple):
return list(x)
return [x]

initial_state = to_list_or_none(initial_state)
constants = to_list_or_none(constants)

return inputs, initial_state, constants

def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
Expand Down Expand Up @@ -618,6 +714,9 @@ def get_config(self):
'go_backwards': self.go_backwards,
'stateful': self.stateful,
'unroll': self.unroll}
if self._n_constants is not None:
config['_n_constants'] = self._n_constants
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please call it num_constants is the config.


cell_config = self.cell.get_config()
config['cell'] = {'class_name': self.cell.__class__.__name__,
'config': cell_config}
Expand All @@ -629,7 +728,10 @@ def from_config(cls, config, custom_objects=None):
from . import deserialize as deserialize_layer
cell = deserialize_layer(config.pop('cell'),
custom_objects=custom_objects)
return cls(cell, **config)
n_constants = config.pop('_n_constants', None)
layer = cls(cell, **config)
layer._n_constants = n_constants
return layer

@property
def trainable_weights(self):
Expand Down
72 changes: 72 additions & 0 deletions tests/keras/layers/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,5 +568,77 @@ def test_batch_size_equal_one(layer_class):
model.train_on_batch(x, y)


def test_rnn_cell_with_constants_layer():

class RNNCellWithConstants(keras.layers.Layer):

def __init__(self, units, **kwargs):
self.units = units
self.state_size = units
super(RNNCellWithConstants, self).__init__(**kwargs)

def build(self, input_shape):
if not isinstance(input_shape, list):
raise TypeError('expects constants shape')
[input_shape, constant_shape] = input_shape
# will (and should) raise if more than one constant passed

self.input_kernel = self.add_weight(
shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.constant_kernel = self.add_weight(
shape=(constant_shape[-1], self.units),
initializer='uniform',
name='constant_kernel')
self.built = True

def call(self, inputs, states, constants):
[prev_output] = states
[constant] = constants
h_input = keras.backend.dot(inputs, self.input_kernel)
h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
h_const = keras.backend.dot(constant, self.constant_kernel)
output = h_input + h_state + h_const
return output, [output]

def get_config(self):
config = {'units': self.units}
base_config = super(RNNCellWithConstants, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

# Test basic case.
x = keras.Input((None, 5))
c = keras.Input((3,))
cell = RNNCellWithConstants(32)
layer = recurrent.RNN(cell)
y = layer(x, constants=c)
model = keras.models.Model([x, c], y)
model.compile(optimizer='rmsprop', loss='mse')
model.train_on_batch(
[np.zeros((6, 5, 5)), np.zeros((6, 3))],
np.zeros((6, 32))
)

# Test basic case serialization.
x_np = np.random.random((6, 5, 5))
c_np = np.random.random((6, 3))
y_np = model.predict([x_np, c_np])
weights = model.get_weights()
config = layer.get_config()
custom_objects = {'RNNCellWithConstants': RNNCellWithConstants}
with keras.utils.CustomObjectScope(custom_objects):
layer = recurrent.RNN.from_config(config)
y = layer(x, constants=c)
model = keras.models.Model([x, c], y)
model.set_weights(weights)
y_np_2 = model.predict([x_np, c_np])
assert_allclose(y_np, y_np_2, atol=1e-4)


if __name__ == '__main__':
pytest.main([__file__])