Skip to content

Recurrent Attention API: Cell wrapper base class [WIP] #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
28795f1
Added support for passing external constants to RNN, which will pass …
andhus Sep 24, 2017
c886e84
added base class for attention cell wrapper
andhus Oct 15, 2017
5213a16
Merge branch 'master' of github.com:fchollet/keras into recurrent_att…
andhus Oct 15, 2017
e0dfb6a
added MoG1D attention and MultiLayerWrapperMixin
andhus Oct 15, 2017
767df54
added alignment example, debugging
andhus Oct 15, 2017
08f0a04
fixed dimension bug
andhus Oct 15, 2017
b5dfc3f
started refactoring constants handling
andhus Oct 20, 2017
d2470b8
fixed step_function wrapping, cleaned up multi layer wrapper, removed…
andhus Oct 20, 2017
4bf62f7
fixed state_spec bug
andhus Oct 20, 2017
205d057
added training flag to attention cell
andhus Oct 20, 2017
7265794
removed multi layer wrapper mixin and refctored MoG attention cell ac…
andhus Oct 20, 2017
5fb3c1b
Merge branch 'master' of github.com:fchollet/keras into recurrent_att…
andhus Oct 20, 2017
afd6ae4
added error msg
andhus Oct 21, 2017
7480a83
merged master
andhus Oct 28, 2017
0f6219e
merged master, added TODOs
andhus Oct 28, 2017
3b2753b
detailed docs of attention, WIP
andhus Oct 29, 2017
21e007b
complted docs of attention base class and some cleanup
andhus Oct 29, 2017
9ccdc38
removed dependence of distribution module
andhus Nov 19, 2017
e48f5cd
added support for multiple heads, added class docs
andhus Nov 19, 2017
e5d965b
completed majority of docs, added sigma_epsilon and removed use of ad…
andhus Nov 19, 2017
c0c8968
Merge branch 'master' of github.com:fchollet/keras into recurrent_att…
andhus Nov 19, 2017
a90f0b6
improved docs of recurrent_attention example
andhus Nov 19, 2017
ff1a8b5
Merge branch 'master' into recurrent_attention_api_cell_wrapper_base
farizrahman4u Sep 14, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions examples/recurrent_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
'''Canonical example of using attention for sequence to sequence problems.

This script demonstrates how to use an RNNAttentionCell to implement attention
mechanisms. In the example, the model only have to learn to filter the attended
input to obtain the target. Basically it has to learn to "parse" the attended
input sequence and output only relevant parts.

# Explanation of data:

One sample of input data consists of a sequence of one-hot-vectors separated
by randomly added "extra" zero-vectors:

0 0 0 0 1 0 0 0 0 0
1 0 0 1 0 0 0 0 0 1
0 0 0 0 0 0 1 0 0 0
0 0 1 0 0 0 0 0 0 0
^ ^
| |
| extra zero-vector
one-hot vector

The goal is to retrieve the one-hot-vector sequence _without_ the extra zeros:

0 0 0 1 0 0
1 0 1 0 0 1
0 0 0 0 1 0
0 1 0 0 0 0

# Summary of the algorithm

The task is carried out by letting a Mixture Of Gaussian 1D attention mechanism
attend to the input sequence (with the extra zeros) and select what information
should be passed to the wrapped LSTM cell.

# Attention vs. Encoder-Decoder approach
This is good example where attention mechanisms are suitable. In this case
attention clearly outperforms e.g. encoder-decoder approaches.
TODO add this comparison to the script
TODO add comparison heads=1 vs heads=2 (later converges faster)
'''

from __future__ import division, print_function

import random

import numpy as np

from keras import Input
from keras.engine import Model
from keras.layers import Dense, TimeDistributed, LSTMCell, RNN

from keras.layers.attention import MixtureOfGaussian1DAttention


def get_training_data(n_samples,
n_labels,
n_timesteps_attended,
n_timesteps_labels):
labels = np.random.randint(
n_labels,
size=(n_samples, n_timesteps_labels)
)
attended_time_idx = range(n_timesteps_attended)
label_time_idx = range(1, n_timesteps_labels + 1)

labels_one_hot = np.zeros((n_samples, n_timesteps_labels + 1, n_labels))
attended = np.zeros((n_samples, n_timesteps_attended, n_labels))
for i in range(n_samples):
labels_one_hot[i][label_time_idx, labels[i]] = 1
positions = sorted(random.sample(attended_time_idx, n_timesteps_labels))
attended[i][positions, labels[i]] = 1

return labels_one_hot, attended


n_samples = 10000
n_timesteps_labels = 10
n_timesteps_attended = 30
n_labels = 4

input_labels = Input((n_timesteps_labels, n_labels))
attended = Input((n_timesteps_attended, n_labels))

cell = MixtureOfGaussian1DAttention(LSTMCell(64), components=3, heads=2)
attention_lstm = RNN(cell, return_sequences=True)

attention_lstm_output = attention_lstm(input_labels, constants=attended)
output_layer = TimeDistributed(Dense(n_labels, activation='softmax'))
output = output_layer(attention_lstm_output)

model = Model(inputs=[input_labels, attended], outputs=output)

labels_data, attended_data = get_training_data(n_samples,
n_labels,
n_timesteps_attended,
n_timesteps_labels)
input_labels_data = labels_data[:, :-1, :]
target_labels_data = labels_data[:, 1:, :]

model.compile(optimizer='Adam', loss='categorical_crossentropy')
model.fit(x=[input_labels_data, attended_data], y=target_labels_data, epochs=5)
output_data = model.predict([input_labels_data, attended_data])
4 changes: 4 additions & 0 deletions keras/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def linear(x):
return x


def exponential(x):
return K.exp(x)


def serialize(activation):
return activation.__name__

Expand Down
Loading