-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Recurrent Attention API: Cell wrapper base class #8296
Recurrent Attention API: Cell wrapper base class #8296
Conversation
…them on to the cell
…ention_back_to_constants_in_RNN
… output attention
…ention_back_to_constants_in_RNN
keras/layers/attention.py
Outdated
|
||
# TODO should it be made private like some other base classes? The idea is that | ||
# it should be used to implement custom attention mechanisms though... | ||
class RecurrentAttentionCellWrapperABC(Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better/shorter name?
Checking this out now. Thank you for your patience. What's the absolute simplest implementation of Attention you could produce (in a standalone way, without subclassing anything that isn't already in the public API)? This PR has a lot going on. |
No worries, yes there is a fair amount of things going on... Basically all the questions and "suggested improvements" can be addressed later. To offer easy-to-use attention mechanisms, we just need to add what's in the As I wrote, the most "urgent" feature that is missing is to be able to feed the attention encoding inferred from a first layer to later layers in a stack of RNNs, but I think this should really be solved by making it possible to return "state sequences" from the base RNN - which does not affect the API of the attention mechanism itself. |
...I think it would be just silly to merge the base class |
Thank you for the info. So let's start with the base class (made private, i.e. Other base classes in Keras don't have the |
…vanced activation
…ention_api_cell_wrapper_base
Ok, I've done the fixes as discussed: I've also added some better docs of the canonical MoG attention example Please give feedback @fchollet, and I'll add tests... |
I need to prioritize reviewing this. Since it is a large PR, it has been difficult. Sorry for the delay. If anyone wants to help giving feedback, that's very much appreciated. |
Hi @andhus, thanks for the great PR! It saves a lot of time from writing own custom attention wrappers. The first thing I've noticed when I'm trying this implementation is that, under an encoder-decoder setting, it seems that we need to provide additional For instance, encoder_input = Input(shape=(None, 8))
context, *states = LSTM(32, return_sequences=True, return_state=True)(encoder_input)
# have to provide initial states for the attention wrapper
# 32 dims for `attention_h` and 2 dims for `mu`
states.append(Lambda(lambda x: K.zeros_like(x))(states[0]))
states.append(Lambda(lambda x: K.zeros_like(x)[:, :2])(states[0]))
decoder_input = Input(shape=(None, 8))
out = RNN(MixtureOfGaussian1DAttention(LSTMCell(32), 2))(decoder_input,
initial_state=states,
constants=context)
model = Model([encoder_input, decoder_input], out) Having to provide additional states like this is a bit redundant and unexpected IMO (especially if the attention wrapper is simple enough where Is there any good way (or plan) to avoid this behavior? Thanks. |
@myutwo150 Hi! Yes, that's a very good point, I've thought about the same thing. I don't have any great general solution for this, but one option could be to allow initial states to be
I've ended up using this approach sometimes to reduce clutter (from creating the additional initial states). You still need to know how many states are required and pass a list of same length but that can contain |
...and of course, you could extend the solution above to allow passing a list of ´initial_states´ that is shorter than the total number of states as well, and append default initial states internally for those not specified. @fchollet do you think it makes sense to add support for this in the |
Hi, is there any progress here? I wonder if the attention API can support varying length input, just like what the RNN layers like, the input shape can be (NoneNonefeatures). Thanks. |
Hey @LeZhengThu! I'm not working actively on it atm. I guess we're basically waiting for relevant reviews/reviewers... I've only received positive feedback so far and believe the API is pretty solid and general :D Regarding shapes, the suggested API have no limitations on the number of dimensions or which of these are explicitly specified when building the model/graph. It all comes down to the implementation of |
...and wrt. the current implementation of |
There is however no explicit handling of masks for varying length sequences (or varying size images etc.) but this could be added. |
@andhus Fabulous work. I'll try to explore your code myself. Hope Keras can release it in the near future. Thank you for your time. |
Hi, @andhus, is there any example of how should your model be used at prediction time? As far as I understand the only possible solution now is to predict each symbol with the new run of entire model feeding it previous output label and state, am I right? Also I think it would be beneficial to add that prediction part to the example if you don't mind |
I tried to use ...
cell = MixtureOfGaussian1DAttention(LSTMCell(64), components=2, heads=2)
attention_lstm = RNN(cell, return_sequences=True)
# here attended is not Input, but recieved from encoder
# attended = encoder_model(input)
x = attention_lstm(output_embedded, constants=[attended])
x = Dense(len(char_code)+1)(x)
... and recieved the warning (after the first epoch end) After that the training continued, but the loss increased significantly and accuracy fell to zero (not because of exploding gradients, I clipped them and reduced learning rate but the problem remains). May it be caused by feeding |
It's necessary and important work. Hope Keras can release it in the near future. |
@bezigon @fchollet what about "attention-is-all-you-need-keras" - https://github.com/Lsdefine/attention-is-all-you-need-keras |
looking forward for this to be released. it will help many |
Hi, what is the status of this? |
I can work on it (commits or code review). But some help would be appreciated since this PR looks huge. |
@farizrahman4u I propose that you review it and make comments. If we don't get any other commits or responses from @andhus after a few days, I'll open a new pr and use those commits as a base. |
attention_states, | ||
training=None): | ||
# only one attended sequence for now (verified in build) | ||
[attended] = attended |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a check? If so please be explicit: if len(atteneded) != 1:...
|
||
model.compile(optimizer='Adam', loss='categorical_crossentropy') | ||
model.fit(x=[input_labels_data, attended_data], y=target_labels_data, epochs=5) | ||
output_data = model.predict([input_labels_data, attended_data]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We wont have input_labels_data during prediction.
You have to:
unfiltered_array = get_data() # shape: (batch_size, num_timsteps, num_classes), contains one hots and empty vectors
previous_output = np.zeros((batch_size, 1, num_classes))
filtered = []
num_required = #Number of one hots we want to generate (we have no way of knowing how many timesteps are empty vectos in given data
for t in range(num_time_steps):
output = model.predict([previous_output, unfiltered_data) # Also have to feedback previous internal state here?
filtered.append(output)
previous_ouptut = output
filtered = np.concatenate(filtered, axis=1) # shape: (num_samples, num_required, num_classes)
# convert probabilites to onehots
filtered = .....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, that makes the example more complete 👍
return dict(list(base_config.items()) + list(config.items())) | ||
|
||
|
||
class MixtureOfGaussian1DAttention(_RNNAttentionCell): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need unit tests for MixtureOfGaussian1DAttention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
absolutely, was waiting for feedback on the overall approach.
### Summary This refactoring will allow the simplification of some code in #8296 ### Related Issues ### PR Overview - [ ] This PR requires new unit tests [y/n] (make sure tests are included) - [ ] This PR requires to update the documentation [y/n] (make sure the docs are up-to-date) - [x] This PR is backwards compatible [y/n] - [ ] This PR changes the current API [y/n] (all API changes need to be approved by fchollet)
Closing this PR in favor of #11172 for organization purposes. This PR can be reopened later on if necessary. |
Intro
This PR implements step 3 (and example of step 4) as laid out in the API doc. See PR for step 1 (step 2 is put on hold and) and the API issue for further background.
The main addition is an abstract base class for implementing attention cell wrappers:
RecurrentAttentionCellWrapperABC
. This class is pretty solid and I've tried to make the docs really detailed to clarify the important concepts.There is also an implementation (of the base class) for Mixture of Gaussian 1D (i.e. sequence) attention with an example included - but here there are some more points to discuss:
1) Should we add a distribution module?
I will try to make a case for this: firstly, it is generally relevant to make it easy to implement "mixture of density" networks, and secondly, "predicting distributions" is an important part of attention mechanisms and the same logic could be reused. The main idea is to keep all parameters of any distribution in a "flat output vector" and gather the activation function for parameters and the corresponding loss (typicall -log(pdf)) in a class for each distribution. I'm happy make a separate API suggestion for this.
2) Should we have (make it easy to write) layers with several wrapped layers (and/or its own weights) internally?
So far there are only a few cases in Keras where a layer is injected into another;
TimeDistributed
,Bidirectional
and most recently RNNcell
->RNN
. In all of these cases there are no additional parameters (weights) in the receiving layers. For such a case it is enough to extend theWrapper
class to get most of the required API (likeget/set_weights
etc.) "for free".With recurrent attention it is different: we should wrap a core RNN cell with additional learnable transforms with their own weights. Moreover, an attention mechanism can be quite complex and consist of "several layers".
The pain-point is this: normally when we want to define complex transforms we should use the functional API and compose it based on the atomic layers. But this is not possible (or at least not straight fwd) when defining a RNN cell transformation. This is why I suggested step 2, the
FunctionalRNNCell
.If we instead take the standard approach of using low level
add_weight
for all internal transformations of a layer we end up with a lot of boiler-plate and configuration code. This is e.g. the case with theMixtureOfGaussian1DAttention
in this PR: The__init__
method takes 12 parameters (all the regularisers, initializers etc...) when all I really want to do is to inject aDense
layer which already implements this and adequately groups the parameters. I think one could motivate something like aMultiLayerWrapper
class that takes care of the boilerplate for these cases (get/set_weights
etc.). Something like https://github.com/andhus/keras/blob/205d057b7078bfe885b967f39ab6324271c764fa/keras/layers/attention.py#L17 (@fchollet if you recall, something similar was mentioned in my original attention API suggestion...)3) Should we add support for returning state sequences from RNN (not only final states)?
The most important thing that is not supported in this PR is that one can not get out the "attention encoding" or the "attention location parameters" (where is it looking) for the full sequence processed. The first is a problem because typically one want to stack multiple recurrent layers, let the first "drive/interact with" the attention mechanism but then feed the attention encoding to all later layers ("cascading"). The second is problem because it is very useful to be able to analyse/visualise what part of the input is attended to a any point in time. There are uggly/hacky ways of solving this (by concatenating this info to the wrapped cell output) but it causes a lot of additional complexity - and if we could just get out the full state sequences the information would basically already be there. Like this:
Other TODOs
RecurrentAttentionCellWrapperABC
.