-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Recurrent Attention API: Support constants in RNN #7980
Merged
fchollet
merged 15 commits into
keras-team:master
from
andhus:recurrent_attention_api_constants_support
Oct 25, 2017
+321
−54
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
28795f1
Added support for passing external constants to RNN, which will pass …
andhus 03b8fda
Added class for allowing functional composition of RNN Cells, support…
andhus 53deca2
put back accidentally commented out recurrent tests
andhus 568fd2e
added basic example of functional cell
andhus 2f9f6f0
new class AttentionRNN
andhus 1b90731
restored RNN layer
andhus 67bd184
Merge branch 'master' of github.com:fchollet/keras into recurrent_att…
andhus e74b125
renamed constants to attended in FunctionRNNCell, avoided duplicating…
andhus fb91e4e
minor clean-up of docs
andhus fcc854c
Minor cleanup & improvments in docs, fixed PEP breaking formatting in…
andhus 2775b2f
Merge branch 'master' of github.com:fchollet/keras into recurrent_att…
andhus ab89c6a
removed FunctionalRNNCell and AttentionRNN, added back support for co…
andhus 95c2359
fixed PEP8 violations
andhus 86fdd93
fixed minor review comments
andhus d33d919
added test case for when both inital_state and constants are passed t…
andhus File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
# -*- coding: utf-8 -*- | ||
from __future__ import absolute_import | ||
import numpy as np | ||
import functools | ||
import warnings | ||
|
||
from .. import backend as K | ||
|
@@ -200,7 +199,9 @@ class RNN(Layer): | |
# Arguments | ||
cell: A RNN cell instance. A RNN cell is a class that has: | ||
- a `call(input_at_t, states_at_t)` method, returning | ||
`(output_at_t, states_at_t_plus_1)`. | ||
`(output_at_t, states_at_t_plus_1)`. The call method of the | ||
cell can also take the optional argument `constants`, see | ||
section "Note on passing external constants" below. | ||
- a `state_size` attribute. This can be a single integer | ||
(single state) in which case it is | ||
the size of the recurrent state | ||
|
@@ -292,6 +293,14 @@ class RNN(Layer): | |
`states` should be a numpy array or list of numpy arrays representing | ||
the initial state of the RNN layer. | ||
|
||
# Note on passing external constants to RNNs | ||
You can pass "external" constants to the cell using the `constants` | ||
keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This | ||
requires that the `cell.call` method accepts the same keyword argument | ||
`constants`. Such constants can be used to condition the cell | ||
transformation on additional static inputs (not changing over time), | ||
a.k.a. an attention mechanism. | ||
|
||
# Examples | ||
|
||
```python | ||
|
@@ -363,12 +372,10 @@ def __init__(self, cell, | |
|
||
self.supports_masking = True | ||
self.input_spec = [InputSpec(ndim=3)] | ||
if hasattr(self.cell.state_size, '__len__'): | ||
self.state_spec = [InputSpec(shape=(None, dim)) | ||
for dim in self.cell.state_size] | ||
else: | ||
self.state_spec = InputSpec(shape=(None, self.cell.state_size)) | ||
self.state_spec = None | ||
self._states = None | ||
self.constants_spec = None | ||
self._num_constants = None | ||
|
||
@property | ||
def states(self): | ||
|
@@ -415,19 +422,46 @@ def compute_mask(self, inputs, mask): | |
return output_mask | ||
|
||
def build(self, input_shape): | ||
# Note input_shape will be list of shapes of initial states and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# constants if these are passed in __call__. | ||
if self._num_constants is not None: | ||
constants_shape = input_shape[-self._num_constants:] | ||
else: | ||
constants_shape = None | ||
|
||
if isinstance(input_shape, list): | ||
input_shape = input_shape[0] | ||
|
||
batch_size = input_shape[0] if self.stateful else None | ||
input_dim = input_shape[-1] | ||
self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim)) | ||
|
||
if self.stateful: | ||
self.reset_states() | ||
|
||
# allow cell (if layer) to build before we set or validate state_spec | ||
if isinstance(self.cell, Layer): | ||
step_input_shape = (input_shape[0],) + input_shape[2:] | ||
self.cell.build(step_input_shape) | ||
if constants_shape is not None: | ||
self.cell.build([step_input_shape] + constants_shape) | ||
else: | ||
self.cell.build(step_input_shape) | ||
|
||
# set or validate state_spec | ||
if hasattr(self.cell.state_size, '__len__'): | ||
state_size = list(self.cell.state_size) | ||
else: | ||
state_size = [self.cell.state_size] | ||
|
||
if self.state_spec is not None: | ||
# initial_state was passed in call, check compatibility | ||
if not [spec.shape[-1] for spec in self.state_spec] == state_size: | ||
raise ValueError( | ||
'an initial_state was passed that is not compatible with' | ||
' cell.state_size, state_spec: {}, cell.state_size:' | ||
' {}'.format(self.state_spec, self.cell.state_size)) | ||
else: | ||
self.state_spec = [InputSpec(shape=(None, dim)) | ||
for dim in state_size] | ||
if self.stateful: | ||
self.reset_states() | ||
|
||
def get_initial_state(self, inputs): | ||
# build an all-zero tensor of shape (samples, output_dim) | ||
|
@@ -440,62 +474,65 @@ def get_initial_state(self, inputs): | |
else: | ||
return [K.tile(initial_state, [1, self.cell.state_size])] | ||
|
||
def __call__(self, inputs, initial_state=None, **kwargs): | ||
# If there are multiple inputs, then | ||
# they should be the main input and `initial_state` | ||
# e.g. when loading model from file | ||
if isinstance(inputs, (list, tuple)) and len(inputs) > 1 and initial_state is None: | ||
initial_state = inputs[1:] | ||
inputs = inputs[0] | ||
def __call__(self, inputs, initial_state=None, constants=None, **kwargs): | ||
inputs, initial_state, constants = self._standardize_args( | ||
inputs, initial_state, constants) | ||
|
||
# If `initial_state` is specified, | ||
# and if it a Keras tensor, | ||
# then add it to the inputs and temporarily | ||
# modify the input spec to include the state. | ||
if initial_state is None: | ||
if initial_state is None and constants is None: | ||
return super(RNN, self).__call__(inputs, **kwargs) | ||
|
||
if not isinstance(initial_state, (list, tuple)): | ||
initial_state = [initial_state] | ||
# If any of `initial_state` or `constants` are specified and are Keras | ||
# tensors, then add them to the inputs and temporarily modify the | ||
# input_spec to include them. | ||
|
||
is_keras_tensor = hasattr(initial_state[0], '_keras_history') | ||
for tensor in initial_state: | ||
additional_inputs = [] | ||
additional_specs = [] | ||
if initial_state is not None: | ||
kwargs['initial_state'] = initial_state | ||
additional_inputs += initial_state | ||
self.state_spec = [InputSpec(shape=K.int_shape(state)) | ||
for state in initial_state] | ||
additional_specs += self.state_spec | ||
if constants is not None: | ||
kwargs['constants'] = constants | ||
additional_inputs += constants | ||
self.constants_spec = [InputSpec(shape=K.int_shape(constant)) | ||
for constant in constants] | ||
self._num_constants = len(constants) | ||
additional_specs += self.constants_spec | ||
# at this point additional_inputs cannot be empty | ||
is_keras_tensor = hasattr(additional_inputs[0], '_keras_history') | ||
for tensor in additional_inputs: | ||
if hasattr(tensor, '_keras_history') != is_keras_tensor: | ||
raise ValueError('The initial state of an RNN layer cannot be' | ||
' specified with a mix of Keras tensors and' | ||
' non-Keras tensors') | ||
raise ValueError('The initial state or constants of an RNN' | ||
' layer cannot be specified with a mix of' | ||
' Keras tensors and non-Keras tensors') | ||
|
||
if is_keras_tensor: | ||
# Compute the full input spec, including state | ||
input_spec = self.input_spec | ||
state_spec = self.state_spec | ||
if not isinstance(input_spec, list): | ||
input_spec = [input_spec] | ||
if not isinstance(state_spec, list): | ||
state_spec = [state_spec] | ||
self.input_spec = input_spec + state_spec | ||
|
||
# Compute the full inputs, including state | ||
inputs = [inputs] + list(initial_state) | ||
|
||
# Perform the call | ||
output = super(RNN, self).__call__(inputs, **kwargs) | ||
|
||
# Restore original input spec | ||
self.input_spec = input_spec | ||
# Compute the full input spec, including state and constants | ||
full_input = [inputs] + additional_inputs | ||
full_input_spec = self.input_spec + additional_specs | ||
# Perform the call with temporarily replaced input_spec | ||
original_input_spec = self.input_spec | ||
self.input_spec = full_input_spec | ||
output = super(RNN, self).__call__(full_input, **kwargs) | ||
self.input_spec = original_input_spec | ||
return output | ||
else: | ||
kwargs['initial_state'] = initial_state | ||
return super(RNN, self).__call__(inputs, **kwargs) | ||
|
||
def call(self, inputs, mask=None, training=None, initial_state=None): | ||
def call(self, | ||
inputs, | ||
mask=None, | ||
training=None, | ||
initial_state=None, | ||
constants=None): | ||
# input shape: `(samples, time (padded with zeros), input_dim)` | ||
# note that the .build() method of subclasses MUST define | ||
# self.input_spec and self.state_spec with complete input shapes. | ||
if isinstance(inputs, list): | ||
initial_state = inputs[1:] | ||
inputs = inputs[0] | ||
elif initial_state is not None: | ||
if initial_state is not None: | ||
pass | ||
elif self.stateful: | ||
initial_state = self.states | ||
|
@@ -525,13 +562,27 @@ def call(self, inputs, mask=None, training=None, initial_state=None): | |
'the time dimension by passing a `shape` ' | ||
'or `batch_shape` argument to your Input layer.') | ||
|
||
kwargs = {} | ||
if has_arg(self.cell.call, 'training'): | ||
step = functools.partial(self.cell.call, training=training) | ||
kwargs['training'] = training | ||
|
||
if constants: | ||
if not has_arg(self.cell.call, 'constants'): | ||
raise ValueError('RNN cell does not support constants') | ||
|
||
def step(inputs, states): | ||
constants = states[-self._num_constants:] | ||
states = states[:-self._num_constants] | ||
return self.cell.call(inputs, states, constants=constants, | ||
**kwargs) | ||
else: | ||
step = self.cell.call | ||
def step(inputs, states): | ||
return self.cell.call(inputs, states, **kwargs) | ||
|
||
last_output, outputs, states = K.rnn(step, | ||
inputs, | ||
initial_state, | ||
constants=constants, | ||
go_backwards=self.go_backwards, | ||
mask=mask, | ||
unroll=self.unroll, | ||
|
@@ -560,6 +611,47 @@ def call(self, inputs, mask=None, training=None, initial_state=None): | |
else: | ||
return output | ||
|
||
def _standardize_args(self, inputs, initial_state, constants): | ||
"""Brings the arguments of `__call__` that can contain input tensors to | ||
standard format. | ||
|
||
When running a model loaded from file, the input tensors | ||
`initial_state` and `constants` can be passed to `RNN.__call__` as part | ||
of `inputs` instead of by the dedicated keyword arguments. This method | ||
makes sure the arguments are separated and that `initial_state` and | ||
`constants` are lists of tensors (or None). | ||
|
||
# Arguments | ||
inputs: tensor or list/tuple of tensors | ||
initial_state: tensor or list of tensors or None | ||
constants: tensor or list of tensors or None | ||
|
||
# Returns | ||
inputs: tensor | ||
initial_state: list of tensors or None | ||
constants: list of tensors or None | ||
""" | ||
if isinstance(inputs, list): | ||
assert initial_state is None and constants is None | ||
if self._num_constants is not None: | ||
constants = inputs[-self._num_constants:] | ||
inputs = inputs[:-self._num_constants] | ||
if len(inputs) > 1: | ||
initial_state = inputs[1:] | ||
inputs = inputs[0] | ||
|
||
def to_list_or_none(x): | ||
if x is None or isinstance(x, list): | ||
return x | ||
if isinstance(x, tuple): | ||
return list(x) | ||
return [x] | ||
|
||
initial_state = to_list_or_none(initial_state) | ||
constants = to_list_or_none(constants) | ||
|
||
return inputs, initial_state, constants | ||
|
||
def reset_states(self, states=None): | ||
if not self.stateful: | ||
raise AttributeError('Layer must be stateful.') | ||
|
@@ -618,6 +710,9 @@ def get_config(self): | |
'go_backwards': self.go_backwards, | ||
'stateful': self.stateful, | ||
'unroll': self.unroll} | ||
if self._num_constants is not None: | ||
config['num_constants'] = self._num_constants | ||
|
||
cell_config = self.cell.get_config() | ||
config['cell'] = {'class_name': self.cell.__class__.__name__, | ||
'config': cell_config} | ||
|
@@ -629,7 +724,10 @@ def from_config(cls, config, custom_objects=None): | |
from . import deserialize as deserialize_layer | ||
cell = deserialize_layer(config.pop('cell'), | ||
custom_objects=custom_objects) | ||
return cls(cell, **config) | ||
num_constants = config.pop('num_constants', None) | ||
layer = cls(cell, **config) | ||
layer._num_constants = num_constants | ||
return layer | ||
|
||
@property | ||
def trainable_weights(self): | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For more advanced cells (typically attention cells such as this one: https://github.com/andhus/keras/pull/3/files#diff-9c4188ddc4dd173d80f64feed5b89412R258) the
cell.state_size
is not defined until the cell is built. But that's not a problem,state_spec
is just set a bit later in build (or in__call__
ifinitial_state
is passed toRNN
)