Skip to content

Commit

Permalink
Recurrent Attention API Additions (keras-team#7980)
Browse files Browse the repository at this point in the history
* Added support for passing external constants to RNN, which will pass them on to the cell

* Added class for allowing functional composition of RNN Cells, supporting constants

* put back accidentally commented out recurrent tests

* added basic example of functional cell

* new class AttentionRNN

* restored RNN layer

* renamed constants to attended in FunctionRNNCell, avoided duplicating outputs in wrapped model

* minor clean-up of docs

* Minor cleanup & improvments in docs, fixed PEP breaking formatting in attention test

* removed FunctionalRNNCell and AttentionRNN, added back support for constants in RNN

* fixed PEP8 violations

* fixed minor review comments

* added test case for when both inital_state and constants are passed to RNN.__call__
  • Loading branch information
andhus authored and fchollet committed Oct 25, 2017
1 parent 3c69f98 commit 3f148e4
Show file tree
Hide file tree
Showing 2 changed files with 321 additions and 54 deletions.
206 changes: 152 additions & 54 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
import numpy as np
import functools
import warnings

from .. import backend as K
Expand Down Expand Up @@ -200,7 +199,9 @@ class RNN(Layer):
# Arguments
cell: A RNN cell instance. A RNN cell is a class that has:
- a `call(input_at_t, states_at_t)` method, returning
`(output_at_t, states_at_t_plus_1)`.
`(output_at_t, states_at_t_plus_1)`. The call method of the
cell can also take the optional argument `constants`, see
section "Note on passing external constants" below.
- a `state_size` attribute. This can be a single integer
(single state) in which case it is
the size of the recurrent state
Expand Down Expand Up @@ -292,6 +293,14 @@ class RNN(Layer):
`states` should be a numpy array or list of numpy arrays representing
the initial state of the RNN layer.
# Note on passing external constants to RNNs
You can pass "external" constants to the cell using the `constants`
keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
requires that the `cell.call` method accepts the same keyword argument
`constants`. Such constants can be used to condition the cell
transformation on additional static inputs (not changing over time),
a.k.a. an attention mechanism.
# Examples
```python
Expand Down Expand Up @@ -363,12 +372,10 @@ def __init__(self, cell,

self.supports_masking = True
self.input_spec = [InputSpec(ndim=3)]
if hasattr(self.cell.state_size, '__len__'):
self.state_spec = [InputSpec(shape=(None, dim))
for dim in self.cell.state_size]
else:
self.state_spec = InputSpec(shape=(None, self.cell.state_size))
self.state_spec = None
self._states = None
self.constants_spec = None
self._num_constants = None

@property
def states(self):
Expand Down Expand Up @@ -415,19 +422,46 @@ def compute_mask(self, inputs, mask):
return output_mask

def build(self, input_shape):
# Note input_shape will be list of shapes of initial states and
# constants if these are passed in __call__.
if self._num_constants is not None:
constants_shape = input_shape[-self._num_constants:]
else:
constants_shape = None

if isinstance(input_shape, list):
input_shape = input_shape[0]

batch_size = input_shape[0] if self.stateful else None
input_dim = input_shape[-1]
self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim))

if self.stateful:
self.reset_states()

# allow cell (if layer) to build before we set or validate state_spec
if isinstance(self.cell, Layer):
step_input_shape = (input_shape[0],) + input_shape[2:]
self.cell.build(step_input_shape)
if constants_shape is not None:
self.cell.build([step_input_shape] + constants_shape)
else:
self.cell.build(step_input_shape)

# set or validate state_spec
if hasattr(self.cell.state_size, '__len__'):
state_size = list(self.cell.state_size)
else:
state_size = [self.cell.state_size]

if self.state_spec is not None:
# initial_state was passed in call, check compatibility
if not [spec.shape[-1] for spec in self.state_spec] == state_size:
raise ValueError(
'an initial_state was passed that is not compatible with'
' cell.state_size, state_spec: {}, cell.state_size:'
' {}'.format(self.state_spec, self.cell.state_size))
else:
self.state_spec = [InputSpec(shape=(None, dim))
for dim in state_size]
if self.stateful:
self.reset_states()

def get_initial_state(self, inputs):
# build an all-zero tensor of shape (samples, output_dim)
Expand All @@ -440,62 +474,65 @@ def get_initial_state(self, inputs):
else:
return [K.tile(initial_state, [1, self.cell.state_size])]

def __call__(self, inputs, initial_state=None, **kwargs):
# If there are multiple inputs, then
# they should be the main input and `initial_state`
# e.g. when loading model from file
if isinstance(inputs, (list, tuple)) and len(inputs) > 1 and initial_state is None:
initial_state = inputs[1:]
inputs = inputs[0]
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = self._standardize_args(
inputs, initial_state, constants)

# If `initial_state` is specified,
# and if it a Keras tensor,
# then add it to the inputs and temporarily
# modify the input spec to include the state.
if initial_state is None:
if initial_state is None and constants is None:
return super(RNN, self).__call__(inputs, **kwargs)

if not isinstance(initial_state, (list, tuple)):
initial_state = [initial_state]
# If any of `initial_state` or `constants` are specified and are Keras
# tensors, then add them to the inputs and temporarily modify the
# input_spec to include them.

is_keras_tensor = hasattr(initial_state[0], '_keras_history')
for tensor in initial_state:
additional_inputs = []
additional_specs = []
if initial_state is not None:
kwargs['initial_state'] = initial_state
additional_inputs += initial_state
self.state_spec = [InputSpec(shape=K.int_shape(state))
for state in initial_state]
additional_specs += self.state_spec
if constants is not None:
kwargs['constants'] = constants
additional_inputs += constants
self.constants_spec = [InputSpec(shape=K.int_shape(constant))
for constant in constants]
self._num_constants = len(constants)
additional_specs += self.constants_spec
# at this point additional_inputs cannot be empty
is_keras_tensor = hasattr(additional_inputs[0], '_keras_history')
for tensor in additional_inputs:
if hasattr(tensor, '_keras_history') != is_keras_tensor:
raise ValueError('The initial state of an RNN layer cannot be'
' specified with a mix of Keras tensors and'
' non-Keras tensors')
raise ValueError('The initial state or constants of an RNN'
' layer cannot be specified with a mix of'
' Keras tensors and non-Keras tensors')

if is_keras_tensor:
# Compute the full input spec, including state
input_spec = self.input_spec
state_spec = self.state_spec
if not isinstance(input_spec, list):
input_spec = [input_spec]
if not isinstance(state_spec, list):
state_spec = [state_spec]
self.input_spec = input_spec + state_spec

# Compute the full inputs, including state
inputs = [inputs] + list(initial_state)

# Perform the call
output = super(RNN, self).__call__(inputs, **kwargs)

# Restore original input spec
self.input_spec = input_spec
# Compute the full input spec, including state and constants
full_input = [inputs] + additional_inputs
full_input_spec = self.input_spec + additional_specs
# Perform the call with temporarily replaced input_spec
original_input_spec = self.input_spec
self.input_spec = full_input_spec
output = super(RNN, self).__call__(full_input, **kwargs)
self.input_spec = original_input_spec
return output
else:
kwargs['initial_state'] = initial_state
return super(RNN, self).__call__(inputs, **kwargs)

def call(self, inputs, mask=None, training=None, initial_state=None):
def call(self,
inputs,
mask=None,
training=None,
initial_state=None,
constants=None):
# input shape: `(samples, time (padded with zeros), input_dim)`
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
if isinstance(inputs, list):
initial_state = inputs[1:]
inputs = inputs[0]
elif initial_state is not None:
if initial_state is not None:
pass
elif self.stateful:
initial_state = self.states
Expand Down Expand Up @@ -525,13 +562,27 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
'the time dimension by passing a `shape` '
'or `batch_shape` argument to your Input layer.')

kwargs = {}
if has_arg(self.cell.call, 'training'):
step = functools.partial(self.cell.call, training=training)
kwargs['training'] = training

if constants:
if not has_arg(self.cell.call, 'constants'):
raise ValueError('RNN cell does not support constants')

def step(inputs, states):
constants = states[-self._num_constants:]
states = states[:-self._num_constants]
return self.cell.call(inputs, states, constants=constants,
**kwargs)
else:
step = self.cell.call
def step(inputs, states):
return self.cell.call(inputs, states, **kwargs)

last_output, outputs, states = K.rnn(step,
inputs,
initial_state,
constants=constants,
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
Expand Down Expand Up @@ -560,6 +611,47 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
else:
return output

def _standardize_args(self, inputs, initial_state, constants):
"""Brings the arguments of `__call__` that can contain input tensors to
standard format.
When running a model loaded from file, the input tensors
`initial_state` and `constants` can be passed to `RNN.__call__` as part
of `inputs` instead of by the dedicated keyword arguments. This method
makes sure the arguments are separated and that `initial_state` and
`constants` are lists of tensors (or None).
# Arguments
inputs: tensor or list/tuple of tensors
initial_state: tensor or list of tensors or None
constants: tensor or list of tensors or None
# Returns
inputs: tensor
initial_state: list of tensors or None
constants: list of tensors or None
"""
if isinstance(inputs, list):
assert initial_state is None and constants is None
if self._num_constants is not None:
constants = inputs[-self._num_constants:]
inputs = inputs[:-self._num_constants]
if len(inputs) > 1:
initial_state = inputs[1:]
inputs = inputs[0]

def to_list_or_none(x):
if x is None or isinstance(x, list):
return x
if isinstance(x, tuple):
return list(x)
return [x]

initial_state = to_list_or_none(initial_state)
constants = to_list_or_none(constants)

return inputs, initial_state, constants

def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
Expand Down Expand Up @@ -618,6 +710,9 @@ def get_config(self):
'go_backwards': self.go_backwards,
'stateful': self.stateful,
'unroll': self.unroll}
if self._num_constants is not None:
config['num_constants'] = self._num_constants

cell_config = self.cell.get_config()
config['cell'] = {'class_name': self.cell.__class__.__name__,
'config': cell_config}
Expand All @@ -629,7 +724,10 @@ def from_config(cls, config, custom_objects=None):
from . import deserialize as deserialize_layer
cell = deserialize_layer(config.pop('cell'),
custom_objects=custom_objects)
return cls(cell, **config)
num_constants = config.pop('num_constants', None)
layer = cls(cell, **config)
layer._num_constants = num_constants
return layer

@property
def trainable_weights(self):
Expand Down
Loading

0 comments on commit 3f148e4

Please sign in to comment.