Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurrent Attention API for keras #11172

Closed
gabrieldemarmiesse opened this issue Sep 18, 2018 · 19 comments
Closed

Recurrent Attention API for keras #11172

gabrieldemarmiesse opened this issue Sep 18, 2018 · 19 comments
Labels
type:feature The user is asking for a new feature.

Comments

@gabrieldemarmiesse
Copy link
Contributor

This issue is opened to host a discussion about the recurrent attention API for keras.
Related issues:

#11142.
#8296.
#7633.

@gabrieldemarmiesse
Copy link
Contributor Author

From @fchollet in #11142:

I don't think this PR is a good fit to implement attention models at this point.

I believe the Model subclassing API would make this use case far easier.

We should come up with side by side code examples of similar models to figure out whether that is the case (the code doesn't have to work, it would just show the workflow with each API, the API in this PR and the model subclassing API).

Possibly, but first we should make a decision: what's the best way to implement attention models in the Keras API today, and would this PR provide a better API?

To make this decision we need to write quick code examples showing how to build common attention models 1) with the model subclassing API, 2) with the API introduced by #8296, and figure out what looks best.

@gabrieldemarmiesse
Copy link
Contributor Author

From @farizrahman4u in #11142.

Shouldn't we first see what an attention model written in Keras (as of now) looks like? Sometimes you can get away with a well documented example (like @fchollet did with seq2seq) instead of making huge changes to the code base to fit niche cases. I would suggest writing an example using a simpler attention model (not MixGaussian.. it has way too much internal knobs) using existing Keras features and see where it goes.

@andhus
Copy link
Contributor

andhus commented Sep 28, 2018

Great to see this moving 👍I'm happy cont. the work as well.

This was the result of my "requirements analysis" (based on common use-case/authoritative papers) as of the API doc:

  1. The attention mechanism should be implemented separately from the core cell so that it can be reused with any standard keras RNN Cells.
  2. Besides the attended, the attention mechanism should have access to the core cell’s states as well as the input at each timestep for computing the attention representation at that timestep.
  3. The attention mechanism should be allowed to be recurrent, i.e. “have state(s) on its own” that is forwarded to the next attention computation.
  4. One should (optionally) be able to get the sequence of attention representations from each timestep as output, for use in later stacked (recurrent) layers.
  5. It should be easily configurable whether the attention is applied “before or after” (in sequence/time dimension) the core recurrent transformation at each timestep.
  6. Masking should be supported and have the same behaviour as for regular keras recurrent layers.
  7. The attended should be allowed to consist of multiple inputs (tensors) of different shape.
  8. It should require minimal efforts for users to write custom attention mechanisms that fulfil requirement 1 - 7.

These should be useful as a reference when discussing tradeoffs and priorities.

@bhack
Copy link
Contributor

bhack commented Sep 28, 2018

Are we talking about Attention only in recurrent models?

@douglas125
Copy link

Should we consider only the more common dot-product attention or everything else with an overridable method that computes weights?

@bhack
Copy link
Contributor

bhack commented Oct 1, 2018

I think it is important to not be limited to rnn design. I.e. See https://openreview.net/forum?id=HyGBdo0qFm

@andhus
Copy link
Contributor

andhus commented Oct 12, 2018

@douglas125

Should we consider only the more common dot-product attention or everything else with an overridable method that computes weights?

The latter is what's achieved by this method of the suggested base class for recurrent attention mechanisms, as explained here.

@andhus
Copy link
Contributor

andhus commented Oct 12, 2018

@bhack

I think it is important to not be limited to rnn design. I.e. See https://openreview.net/forum?id=HyGBdo0qFm

Generally, I agree. The different approach taken in e.g. Attention Is All You Need is of great importance and should be considered for any seq2seq problem. However (!) I'd claim that any feedforward architecture can already be supported by the existing Keras API. You might have to write several custom layers but feedforward attention can be implemented by reusing, and without duplicating, major parts of the existing API (along the lines of @farizrahman4u commet cited above)

Recurrent attention is different; if you want to implement it in a new layer/model you need to reimplement the majority of the RNN logic or wrap the cell as is suggested in the previous PR.

That said, it might still make sense to add layers or models to the API for non-recurrent attention. But I think it is still high prio to support recurrent attention as was initially in the "request for contributions".

@andhus
Copy link
Contributor

andhus commented Oct 12, 2018

In a sense, I think recurrent attention "is also supported" since we added support for constants in the RNN - you can quite easily write your own cell wrapper. To me, It's just a question of how standardized and simple you want to make this and if/what ready (cell-wrapper) attention mechanisms should be added to the API.

The main current limitation with this approach is that there is no option to return "state sequences" from the RNN, which is required to feed the attention encoding from on layer to subsequent layers (see point 3) here).

@andhus
Copy link
Contributor

andhus commented Oct 14, 2018

Ok. So usage workflow examples have been requested. Since the heading of this and all preceding issues/PRs has been recurrent attention I'll focus on this and repeat/clarify the workflow of the only concrete suggestion so far. I'll use the architecture in this paper for handwriting synthesis as the use-case (but the workflow would be the same for e.g. this kind image captioning)

xy = Input((None, 2))  # coordinates of handwriting trajectory
text = Input((None, n_characters))  # sequence of characters to be synthesized

cell = MixtureOfGaussian1DAttention(LSTMCell(64), components=3, heads=2)
attention_lstm = RNN(cell, return_sequences=True)
h = attention_lstm(xy, constants=text)
xy_pred = TimeDistributed(Dense(2, activation=None))(h)
# MoG output is used instead of basic regression in the original paper, but this is has nothing to do with the attention mechanism and is left out for brevity.

model = Model(inputs=[xy, text], outputs=xy_pred)
model.compile(optimizer='adam', loss='mse')
model.fit([xy_data[:, :-1], text_data], xy_data[:, 1:])

I think that this workflow is perfectly aligned with Keras guiding principles. Note that no modification of existing classes is required, we've just defined a new RNNCell. For the sake of modularity it reuses/wraps the LSTMCell, which can be replaced by e.g. the GRUCell. This use-case and workflow were one of the main drivers for breaking out the RNNCell and adding support for constants in the RNN.

I honestly can't come up with a reasonable alternative based on the Model subclassing API for this use case. I guess these are other options:

attention_lstm = MixtureOfGaussian1DAttentionRNN(
    cell=LSTMCell(64),
    components=3,
    heads=2,
    return_sequences=True
)
h = attention_lstm(xy, attended=text)

Here, MixtureOfGaussian1DAttentionRNN would be a standalone layer where the attention mechanism has been coupled tightly with the general RNN logic (why I think it is a bad option).

Alternatively something like:

attention_lstm = AttentionRNN(
    cell=LSTMCell(64),
    attention_mechanism=MixtureOfGaussian1DAttention(components=3, heads=2),
    components=3,
    heads=2,
    return_sequences=True  # here return_state_sequences could be supported or other
)
h = attention_lstm(xy, attended=text)

Where both the core cell and and attention mechanism are injected into an new class that connects them. But this requires new interfaces both for the attention_mechanism and for the AttentionRNN which seems completely unnecessary since the existing RNNCell interacface already supports attentive cells (thanks to addition of constants).

Bottom Line

  1. Is this use-case at all of relevans!?
  2. Is there any problem with the suggested approach? If so, please be concrete and/or provide an alternative!
  3. If we can agree that recurrent attention is already supported by using an appropriate (attentive) RNNCell - do you feel that there is a need to add "standard" attentive cells, such as MixtureOfGaussian1DAttention in this example, to the official keras API?

@fchollet @farizrahman4u @gabrieldemarmiesse

@farizrahman4u
Copy link
Contributor

  • Thinking in terms of reducing the number of mental models (cognitive load), I vote for the first one.

  • Even in that case, I think the "MixtureOfGaussian1DAttention" wrapper is too niche to make its way into the core api (I might be wrong).

  • We definitely need a standard way to do attention in Keras. I think we should come up with an end to end example, which includes the attention wrapper definition that users can easily extend. For simplicity, use the simplest attention mechanism possible.(MixtureOfGaussian1DAttention has too many moving parts).

@andhus
Copy link
Contributor

andhus commented Oct 15, 2018

I have a feeling that some uncertainty comes from the (motivated!) fuzz about non-recurrent attention mechanisms. If we were to add support for the transformer architecture in Attention is All you Need (or the recent BERT), I definitely think that Model subclassing API is a good place to start - because there are many intricate parts that should be combined the right way. It would look something like:

input_sequence = Input((None, 1))
target_sequence_tm1 = Input((None, 1))
transformer = Tranformer(
    input_tokens=n_input_tokens,  # must be provided if embeddings created internally
    target_tokens=n_target_tokens,  # must be provided if embeddings created internally
    units_model=512,
    units_ff=2048,
    units_keys=64,
    units_values=64,
    layers=6,
    heads=3
)
target_sequence_pred = transformer([input_sequence, target_sequence_tm1])

Where Tranformer would subclass Model. It could also be motivated to expose some of the internals as separate sub-models or layers, which raises the question if the same "Attention layers" can be reused both in the feedforward and recurrent setting. This would be possible (as per my previous suggestion) by defining cell transformations using the functional API, e.g.:

# define complete attentive cell using functional API
units = 32
xy_t = Input((2,))
text = Input((None, n_characters))  # note that this is a sequence
h_tm1 = Input((units,))
c_tm1 = Input((units,))
h_att_t = AttentionMechanism(inputs=concatenate([xy_t, h_tm1]), attended=text)
x_t = concatenate([xy_t, h_att_t])
h_t, c_t = LSTMCell(units)(inputs=x_t, states=[h_tm1, c_tm1])
cell = CellModel(  # creates a valid cell implementation based on functional definition
    inputs=xy_t,
    outputs=h_t,
    input_states=[h_tm1, c_tm1],
    output_states=[h_t, c_t],
    constants=text
)

But we should probably avoid considering this for now and make as few additions as possible. This is why it was decided to not add the RNNAttentionCell base class to the API. ...In the end, the only API addition of the previous PR was one first attentive RNNCell.

@andhus
Copy link
Contributor

andhus commented Oct 15, 2018

Thanks @farizrahman4u, makes sense. So you think something like the RNNAttentionCell base class should be added to the API, to simplify custom implementations? Please let me know if you have suggestions for a simpler, yet relevant, first attention mechanism and corresponding use-case.

@andhus
Copy link
Contributor

andhus commented Oct 15, 2018

I think this version of machine translation (Bengio 2016) would serve as a good end-to-end example. Keras implementation of the paper would look like this:

input_sentence = Input((None,))  # sequence of word idxs language A
target_sentence = Input((None,))  # sequence of word idxs language B
input_embeddings = Embedding(n_input_tokens, 620)(input_sentence)
target_embeddings = Embedding(n_target_tokens, 620)(target_sentence)
input_encoding = Bidirectional(LSTM(1000))(input_embeddings)
h = RNN(
    DenseAnnotationAttention(LSTMCell(1000)),
    return_sequences=True
)(target_embeddings, constants=input_encoding)
target_sentence_pred = TimeDistributed(Dense(n_target_tokens, activation=None))(h)
# NOTE the paper uses "deep output (Pascanu et al., 2014) with a single maxout hidden layer"

Where DenseAnnotationAttention subclasses RNNAttentionCell. It computes the attention weight for each input_encoding_t just by using a single hidden layer MLP that takes the h_lstm_t and input_encoding_t as input - and is thus considerable simpler than the MoG1DAttention.

Sounds good? @farizrahman4u @fchollet If so I'll implement the attention mechanism and end-to-end example.

@farizrahman4u
Copy link
Contributor

Fair enough. I think you can write the whole thing in the example (including the RNNAttentionCell class), submit a PRs, and discuss with @fchollet on what parts can moved into Keras API and what should stay in the example.

@andhus
Copy link
Contributor

andhus commented Oct 17, 2018

As per @farizrahman4u suggestion above, please see #11421 @fchollet

@leemengtw
Copy link

@andhus thanks for making the PR #11421!

As a Keras fan eagering to implement attention mechanism,
do you think it's okay for me to start using the recurrent_attention_machine_translation.py you provide as a example, or should I wait until the PR is merged?

Thanks in advance.

@bhack
Copy link
Contributor

bhack commented Jan 17, 2019

Attention for Dense Networks on Keras RFC:
tensorflow/community#54

@Gilthans
Copy link

Is there an update on this? The RFC has been approved for over a year

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:feature The user is asking for a new feature.
Projects
None yet
Development

No branches or pull requests

8 participants