-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Recurrent Attention API: Support constants in RNN #7980
Changes from 12 commits
28795f1
03b8fda
53deca2
568fd2e
2f9f6f0
1b90731
67bd184
e74b125
fb91e4e
fcc854c
2775b2f
ab89c6a
95c2359
86fdd93
d33d919
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
# -*- coding: utf-8 -*- | ||
from __future__ import absolute_import | ||
import numpy as np | ||
import functools | ||
import warnings | ||
|
||
from .. import backend as K | ||
|
@@ -200,7 +199,9 @@ class RNN(Layer): | |
# Arguments | ||
cell: A RNN cell instance. A RNN cell is a class that has: | ||
- a `call(input_at_t, states_at_t)` method, returning | ||
`(output_at_t, states_at_t_plus_1)`. | ||
`(output_at_t, states_at_t_plus_1)`. The call method of the | ||
cell can also take the optional argument `constants`, see | ||
section "Note on passing external constants" below. | ||
- a `state_size` attribute. This can be a single integer | ||
(single state) in which case it is | ||
the size of the recurrent state | ||
|
@@ -292,6 +293,14 @@ class RNN(Layer): | |
`states` should be a numpy array or list of numpy arrays representing | ||
the initial state of the RNN layer. | ||
|
||
# Note on passing external constants to RNNs | ||
You can pass "external" constants to the cell using the `constants` | ||
keyword argument of RNN.__call__ (as well as RNN.call) method. This | ||
requires that the `cell.call` method accepts the same keyword argument | ||
`constants`. Such constants can be used to condition the cell | ||
transformation on additional static inputs (not changing over time) | ||
(a.k.a. as attention mechanism). | ||
|
||
# Examples | ||
|
||
```python | ||
|
@@ -363,13 +372,11 @@ def __init__(self, cell, | |
|
||
self.supports_masking = True | ||
self.input_spec = [InputSpec(ndim=3)] | ||
if hasattr(self.cell.state_size, '__len__'): | ||
self.state_spec = [InputSpec(shape=(None, dim)) | ||
for dim in self.cell.state_size] | ||
else: | ||
self.state_spec = InputSpec(shape=(None, self.cell.state_size)) | ||
self.state_spec = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For more advanced cells (typically attention cells such as this one: https://github.com/andhus/keras/pull/3/files#diff-9c4188ddc4dd173d80f64feed5b89412R258) the |
||
self._states = None | ||
|
||
self.constants_spec = None | ||
self._n_constants = None # used for splitting inputs after | ||
# serialization of layer | ||
@property | ||
def states(self): | ||
if self._states is None: | ||
|
@@ -415,19 +422,46 @@ def compute_mask(self, inputs, mask): | |
return output_mask | ||
|
||
def build(self, input_shape): | ||
# Note input_shape will be list of shapes of initial states and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# constants if these are passed in __call__. | ||
if self._n_constants is not None: | ||
constants_shape = input_shape[-self._n_constants:] | ||
else: | ||
constants_shape = None | ||
|
||
if isinstance(input_shape, list): | ||
input_shape = input_shape[0] | ||
|
||
batch_size = input_shape[0] if self.stateful else None | ||
input_dim = input_shape[-1] | ||
self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim)) | ||
|
||
if self.stateful: | ||
self.reset_states() | ||
|
||
# allow cell (if layer) to build before we set or validate state_spec | ||
if isinstance(self.cell, Layer): | ||
step_input_shape = (input_shape[0],) + input_shape[2:] | ||
self.cell.build(step_input_shape) | ||
if constants_shape is not None: | ||
self.cell.build([step_input_shape] + constants_shape) | ||
else: | ||
self.cell.build(step_input_shape) | ||
|
||
# set or validate state_spec | ||
if hasattr(self.cell.state_size, '__len__'): | ||
state_size = list(self.cell.state_size) | ||
else: | ||
state_size = [self.cell.state_size] | ||
|
||
if self.state_spec is not None: | ||
# initial_state was passed in call, check compatibility | ||
if not [spec.shape[-1] for spec in self.state_spec] == state_size: | ||
raise ValueError( | ||
'an initial_state was passed that is not compatible with' | ||
' cell.state_size, state_spec: {}, cell.state_size:' | ||
' {}'.format(self.state_spec, self.cell.state_size)) | ||
else: | ||
self.state_spec = [InputSpec(shape=(None, dim)) | ||
for dim in state_size] | ||
if self.stateful: | ||
self.reset_states() | ||
|
||
def get_initial_state(self, inputs): | ||
# build an all-zero tensor of shape (samples, output_dim) | ||
|
@@ -440,62 +474,68 @@ def get_initial_state(self, inputs): | |
else: | ||
return [K.tile(initial_state, [1, self.cell.state_size])] | ||
|
||
def __call__(self, inputs, initial_state=None, **kwargs): | ||
# If there are multiple inputs, then | ||
# they should be the main input and `initial_state` | ||
# e.g. when loading model from file | ||
if isinstance(inputs, (list, tuple)) and len(inputs) > 1 and initial_state is None: | ||
initial_state = inputs[1:] | ||
inputs = inputs[0] | ||
def __call__(self, inputs, initial_state=None, constants=None, **kwargs): | ||
inputs, initial_state, constants = self._normalize_args( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prefer "standardize" (also used elsewhere) since "normalize" has a specific meaning elsewhere. |
||
inputs, initial_state, constants) | ||
|
||
# If `initial_state` is specified, | ||
# and if it a Keras tensor, | ||
# then add it to the inputs and temporarily | ||
# modify the input spec to include the state. | ||
if initial_state is None: | ||
if initial_state is None and constants is None: | ||
return super(RNN, self).__call__(inputs, **kwargs) | ||
|
||
if not isinstance(initial_state, (list, tuple)): | ||
initial_state = [initial_state] | ||
# If any of `initial_state` or `constants` are specified and are Keras | ||
# tensors, then add them to the inputs and temporarily modify the | ||
# input_spec to include them. | ||
|
||
is_keras_tensor = hasattr(initial_state[0], '_keras_history') | ||
for tensor in initial_state: | ||
check_list = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unclear what is meant by |
||
if initial_state is not None: | ||
kwargs['initial_state'] = initial_state | ||
check_list += initial_state | ||
self.state_spec = [InputSpec(shape=K.int_shape(state)) | ||
for state in initial_state] | ||
if constants is not None: | ||
kwargs['constants'] = constants | ||
check_list += constants | ||
self.constants_spec = [InputSpec(shape=K.int_shape(constant)) | ||
for constant in constants] | ||
self._n_constants = len(constants) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In line with naming conventions in this API, this should be |
||
# at this point check_list cannot be empty | ||
is_keras_tensor = hasattr(check_list[0], '_keras_history') | ||
for tensor in check_list: | ||
if hasattr(tensor, '_keras_history') != is_keras_tensor: | ||
raise ValueError('The initial state of an RNN layer cannot be' | ||
' specified with a mix of Keras tensors and' | ||
' non-Keras tensors') | ||
raise ValueError('The initial state and constants of an RNN' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and -> or There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I think I wanted to refer to that they "jointly" must contain only one kind of tensors (that's what checked for now)... I don't have any opinion on the error msg, but just to verify/understand: shouldn't all inputs be keras tensors or not - and why not include |
||
' layer cannot be specified with a mix of' | ||
' Keras tensors and non-Keras tensors') | ||
|
||
if is_keras_tensor: | ||
# Compute the full input spec, including state | ||
input_spec = self.input_spec | ||
state_spec = self.state_spec | ||
if not isinstance(input_spec, list): | ||
input_spec = [input_spec] | ||
if not isinstance(state_spec, list): | ||
state_spec = [state_spec] | ||
self.input_spec = input_spec + state_spec | ||
|
||
# Compute the full inputs, including state | ||
inputs = [inputs] + list(initial_state) | ||
|
||
# Perform the call | ||
output = super(RNN, self).__call__(inputs, **kwargs) | ||
|
||
# Restore original input spec | ||
self.input_spec = input_spec | ||
# Compute the full input spec, including state and constants | ||
full_input = [inputs] | ||
full_input_spec = self.input_spec | ||
if initial_state: | ||
full_input += initial_state | ||
full_input_spec += self.state_spec | ||
if constants: | ||
full_input += constants | ||
full_input_spec += self.constants_spec | ||
# Perform the call with temporarily replaced input_spec | ||
original_input_spec = self.input_spec | ||
self.input_spec = full_input_spec | ||
output = super(RNN, self).__call__(full_input, **kwargs) | ||
self.input_spec = original_input_spec | ||
return output | ||
else: | ||
kwargs['initial_state'] = initial_state | ||
return super(RNN, self).__call__(inputs, **kwargs) | ||
|
||
def call(self, inputs, mask=None, training=None, initial_state=None): | ||
def call(self, | ||
inputs, | ||
mask=None, | ||
training=None, | ||
initial_state=None, | ||
constants=None): | ||
# input shape: `(samples, time (padded with zeros), input_dim)` | ||
# note that the .build() method of subclasses MUST define | ||
# self.input_spec and self.state_spec with complete input shapes. | ||
if isinstance(inputs, list): | ||
initial_state = inputs[1:] | ||
inputs = inputs[0] | ||
elif initial_state is not None: | ||
if initial_state is not None: | ||
pass | ||
elif self.stateful: | ||
initial_state = self.states | ||
|
@@ -525,13 +565,27 @@ def call(self, inputs, mask=None, training=None, initial_state=None): | |
'the time dimension by passing a `shape` ' | ||
'or `batch_shape` argument to your Input layer.') | ||
|
||
kwargs = {} | ||
if has_arg(self.cell.call, 'training'): | ||
step = functools.partial(self.cell.call, training=training) | ||
kwargs['training'] = training | ||
|
||
if constants: | ||
if not has_arg(self.cell.call, 'constants'): | ||
raise ValueError('RNN cell does not support constants') | ||
|
||
def step(inputs, states): | ||
constants = states[-self._n_constants:] | ||
states = states[:-self._n_constants] | ||
return self.cell.call(inputs, states, constants=constants, | ||
**kwargs) | ||
else: | ||
step = self.cell.call | ||
def step(inputs, states): | ||
return self.cell.call(inputs, states, **kwargs) | ||
|
||
last_output, outputs, states = K.rnn(step, | ||
inputs, | ||
initial_state, | ||
constants=constants, | ||
go_backwards=self.go_backwards, | ||
mask=mask, | ||
unroll=self.unroll, | ||
|
@@ -560,6 +614,48 @@ def call(self, inputs, mask=None, training=None, initial_state=None): | |
else: | ||
return output | ||
|
||
def _normalize_args(self, inputs, initial_state, constants): | ||
"""When running a model loaded from file, the input tensors | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: first line of docstring (one-line summary) should fit in one line and end with a period. |
||
`initial_state` and `constants` can be passed to RNN.__call__ as part | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ` around code keywords |
||
of `inputs` in stead of by the dedicated keyword argumetes. In this | ||
case `inputs` is a list of tensors of which the first one is the | ||
actual (sequence) input followed by initial states, followed by | ||
constants. | ||
|
||
This method makes sure initial_states and constants are separated from | ||
inputs and that the are lists of tensors (or None). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the -> they |
||
|
||
# Arguments | ||
inputs: tensor of list/tuple of tensors | ||
initial_state: tensor or list of tensors or None | ||
constants: tensor or list of tensors or None | ||
|
||
# Returns | ||
inputs: tensor | ||
initial_state: list of tensors or None | ||
constants: list of tensors or None | ||
""" | ||
if isinstance(inputs, list): | ||
assert initial_state is None and constants is None | ||
if self._n_constants is not None: | ||
constants = inputs[-self._n_constants:] | ||
inputs = inputs[:-self._n_constants] | ||
if len(inputs) > 1: | ||
initial_state = inputs[1:] | ||
inputs = inputs[0] | ||
|
||
def to_list_or_none(x): # TODO break out? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unless you want to use it elsewhere, it is fine as it is, you can remove this comment |
||
if x is None or isinstance(x, list): | ||
return x | ||
if isinstance(x, tuple): | ||
return list(x) | ||
return [x] | ||
|
||
initial_state = to_list_or_none(initial_state) | ||
constants = to_list_or_none(constants) | ||
|
||
return inputs, initial_state, constants | ||
|
||
def reset_states(self, states=None): | ||
if not self.stateful: | ||
raise AttributeError('Layer must be stateful.') | ||
|
@@ -618,6 +714,9 @@ def get_config(self): | |
'go_backwards': self.go_backwards, | ||
'stateful': self.stateful, | ||
'unroll': self.unroll} | ||
if self._n_constants is not None: | ||
config['_n_constants'] = self._n_constants | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please call it |
||
|
||
cell_config = self.cell.get_config() | ||
config['cell'] = {'class_name': self.cell.__class__.__name__, | ||
'config': cell_config} | ||
|
@@ -629,7 +728,10 @@ def from_config(cls, config, custom_objects=None): | |
from . import deserialize as deserialize_layer | ||
cell = deserialize_layer(config.pop('cell'), | ||
custom_objects=custom_objects) | ||
return cls(cell, **config) | ||
n_constants = config.pop('_n_constants', None) | ||
layer = cls(cell, **config) | ||
layer._n_constants = n_constants | ||
return layer | ||
|
||
@property | ||
def trainable_weights(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
` around code keywords