Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurrent Attention API: Cell wrapper base class #8296

Conversation

andhus
Copy link
Contributor

@andhus andhus commented Oct 29, 2017

Intro

This PR implements step 3 (and example of step 4) as laid out in the API doc. See PR for step 1 (step 2 is put on hold and) and the API issue for further background.

The main addition is an abstract base class for implementing attention cell wrappers: RecurrentAttentionCellWrapperABC. This class is pretty solid and I've tried to make the docs really detailed to clarify the important concepts.

There is also an implementation (of the base class) for Mixture of Gaussian 1D (i.e. sequence) attention with an example included - but here there are some more points to discuss:

1) Should we add a distribution module?

I will try to make a case for this: firstly, it is generally relevant to make it easy to implement "mixture of density" networks, and secondly, "predicting distributions" is an important part of attention mechanisms and the same logic could be reused. The main idea is to keep all parameters of any distribution in a "flat output vector" and gather the activation function for parameters and the corresponding loss (typicall -log(pdf)) in a class for each distribution. I'm happy make a separate API suggestion for this.

2) Should we have (make it easy to write) layers with several wrapped layers (and/or its own weights) internally?

So far there are only a few cases in Keras where a layer is injected into another; TimeDistributed, Bidirectional and most recently RNN cell -> RNN. In all of these cases there are no additional parameters (weights) in the receiving layers. For such a case it is enough to extend the Wrapper class to get most of the required API (like get/set_weights etc.) "for free".

With recurrent attention it is different: we should wrap a core RNN cell with additional learnable transforms with their own weights. Moreover, an attention mechanism can be quite complex and consist of "several layers".

The pain-point is this: normally when we want to define complex transforms we should use the functional API and compose it based on the atomic layers. But this is not possible (or at least not straight fwd) when defining a RNN cell transformation. This is why I suggested step 2, the FunctionalRNNCell.

If we instead take the standard approach of using low level add_weight for all internal transformations of a layer we end up with a lot of boiler-plate and configuration code. This is e.g. the case with the MixtureOfGaussian1DAttention in this PR: The __init__ method takes 12 parameters (all the regularisers, initializers etc...) when all I really want to do is to inject a Dense layer which already implements this and adequately groups the parameters. I think one could motivate something like a MultiLayerWrapper class that takes care of the boilerplate for these cases (get/set_weights etc.). Something like https://github.com/andhus/keras/blob/205d057b7078bfe885b967f39ab6324271c764fa/keras/layers/attention.py#L17 (@fchollet if you recall, something similar was mentioned in my original attention API suggestion...)

3) Should we add support for returning state sequences from RNN (not only final states)?

The most important thing that is not supported in this PR is that one can not get out the "attention encoding" or the "attention location parameters" (where is it looking) for the full sequence processed. The first is a problem because typically one want to stack multiple recurrent layers, let the first "drive/interact with" the attention mechanism but then feed the attention encoding to all later layers ("cascading"). The second is problem because it is very useful to be able to analyse/visualise what part of the input is attended to a any point in time. There are uggly/hacky ways of solving this (by concatenating this info to the wrapped cell output) but it causes a lot of additional complexity - and if we could just get out the full state sequences the information would basically already be there. Like this:

y, [h, c] = LSTM(32, return_sequences=True, return_state_sequences=True)(x)
# h.shape[1] == c.shape[1] == y.shape[1] == x.shape[1]

Other TODOs


# TODO should it be made private like some other base classes? The idea is that
# it should be used to implement custom attention mechanisms though...
class RecurrentAttentionCellWrapperABC(Layer):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better/shorter name?

@fchollet
Copy link
Member

fchollet commented Nov 8, 2017

Checking this out now. Thank you for your patience.

What's the absolute simplest implementation of Attention you could produce (in a standalone way, without subclassing anything that isn't already in the public API)? This PR has a lot going on.

@andhus
Copy link
Contributor Author

andhus commented Nov 12, 2017

No worries, yes there is a fair amount of things going on... Basically all the questions and "suggested improvements" can be addressed later.

To offer easy-to-use attention mechanisms, we just need to add what's in the layers.attention module. We can make the RecurrentAttentionCellWrapperABC hidden and then there will only be one new public class MixtureOfGaussian1DAttention which is a wrapper of any standard RNN cell for sequence attention (I'll add a 2D version as well for images later). It can easily be refactored to not make use of the distributionsmodule and its only parameterised by standard python and keras objects.

As I wrote, the most "urgent" feature that is missing is to be able to feed the attention encoding inferred from a first layer to later layers in a stack of RNNs, but I think this should really be solved by making it possible to return "state sequences" from the base RNN - which does not affect the API of the attention mechanism itself.

@andhus
Copy link
Contributor Author

andhus commented Nov 12, 2017

...I think it would be just silly to merge the base class RecurrentAttentionCellWrapperABCwith the specific implementation MixtureOfGaussian1DAttention as it describes and implements a highly relevant and general abstraction level for attention mechanisms. As soon as we add a MoG2DAttentionfor example there would be a lot of code duplication if its not there. But the base class does not have to part of the API (let's hide it for now).

@fchollet
Copy link
Member

Thank you for the info. So let's start with the base class (made private, i.e. _ in front of the name -- we may make it public later on), and let's add a single attention cell wrapper that subclasses it (as a proof of concept and API example).

Other base classes in Keras don't have the ABC suffix. Let's call this one RNNAttentionCell?

@andhus
Copy link
Contributor Author

andhus commented Nov 19, 2017

Ok, I've done the fixes as discussed: MixtureOfGaussian1DAttention is now a standalone, relevant, PoC of the _RNNAttentionCell base class.

I've also added some better docs of the canonical MoG attention example

Please give feedback @fchollet, and I'll add tests...

@fchollet
Copy link
Member

I need to prioritize reviewing this. Since it is a large PR, it has been difficult. Sorry for the delay. If anyone wants to help giving feedback, that's very much appreciated.

@yuyang-huang
Copy link
Contributor

Hi @andhus, thanks for the great PR! It saves a lot of time from writing own custom attention wrappers.

The first thing I've noticed when I'm trying this implementation is that, under an encoder-decoder setting, it seems that we need to provide additional initial_state for the states used in the attention wrapper.

For instance,

encoder_input = Input(shape=(None, 8))
context, *states = LSTM(32, return_sequences=True, return_state=True)(encoder_input)

# have to provide initial states for the attention wrapper
# 32 dims for `attention_h` and 2 dims for `mu`
states.append(Lambda(lambda x: K.zeros_like(x))(states[0]))
states.append(Lambda(lambda x: K.zeros_like(x)[:, :2])(states[0]))

decoder_input = Input(shape=(None, 8))
out = RNN(MixtureOfGaussian1DAttention(LSTMCell(32), 2))(decoder_input,
                                                         initial_state=states,
                                                         constants=context)
model = Model([encoder_input, decoder_input], out)

Having to provide additional states like this is a bit redundant and unexpected IMO (especially if the attention wrapper is simple enough where attention_states is nothing but [attention_h]).

Is there any good way (or plan) to avoid this behavior? Thanks.

@andhus
Copy link
Contributor Author

andhus commented Jan 23, 2018

@myutwo150 Hi! Yes, that's a very good point, I've thought about the same thing. I don't have any great general solution for this, but one option could be to allow initial states to be None (fallback on default) as with solution below:

from keras.layers import RNN as _RNN

class RNN(_RNN):
    """Allows subset of initial states to be None"""
    # TODO make keras PR for this
    def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
        if isinstance(initial_state, list):
            if any([s is None for s in initial_state]):
                default_initial_state = Lambda(self.get_initial_state)(inputs)
                initial_state = [
                    s if s is not None else ds
                    for s, ds in zip(initial_state, default_initial_state)
                ]

        return super(RNN, self).__call__(
            inputs,
            initial_state=initial_state,
            constants=constants,
            **kwargs
        )

I've ended up using this approach sometimes to reduce clutter (from creating the additional initial states). You still need to know how many states are required and pass a list of same length but that can contain None placeholders.

@andhus
Copy link
Contributor Author

andhus commented Jan 24, 2018

...and of course, you could extend the solution above to allow passing a list of ´initial_states´ that is shorter than the total number of states as well, and append default initial states internally for those not specified.

@fchollet do you think it makes sense to add support for this in the RNN class?

yuyang-huang added a commit to yuyang-huang/keras that referenced this pull request Jan 31, 2018
@LeZhengThu
Copy link

Hi, is there any progress here? I wonder if the attention API can support varying length input, just like what the RNN layers like, the input shape can be (NoneNonefeatures). Thanks.

@andhus
Copy link
Contributor Author

andhus commented Apr 2, 2018

Hey @LeZhengThu! I'm not working actively on it atm. I guess we're basically waiting for relevant reviews/reviewers... I've only received positive feedback so far and believe the API is pretty solid and general :D

Regarding shapes, the suggested API have no limitations on the number of dimensions or which of these are explicitly specified when building the model/graph. It all comes down to the implementation of attention_build and attention_call of (extensions of) the RNNAttentionCell.

@andhus
Copy link
Contributor Author

andhus commented Apr 2, 2018

...and wrt. the current implementation of MixtureOfGaussian1DAttention as far as I recall it supports what you're asking for by reading the shape here.

@andhus
Copy link
Contributor Author

andhus commented Apr 2, 2018

There is however no explicit handling of masks for varying length sequences (or varying size images etc.) but this could be added.

@LeZhengThu
Copy link

@andhus Fabulous work. I'll try to explore your code myself. Hope Keras can release it in the near future. Thank you for your time.

@asmekal
Copy link

asmekal commented Apr 11, 2018

Hi, @andhus, is there any example of how should your model be used at prediction time? As far as I understand the only possible solution now is to predict each symbol with the new run of entire model feeding it previous output label and state, am I right? Also I think it would be beneficial to add that prediction part to the example if you don't mind

@asmekal
Copy link

asmekal commented Apr 12, 2018

I tried to use MixtureOfGaussian1DAttention in the following way (actually exactly the same as in example provided):

    ...
    cell = MixtureOfGaussian1DAttention(LSTMCell(64), components=2, heads=2)
    attention_lstm = RNN(cell, return_sequences=True)
    # here attended is not Input, but recieved from encoder
    # attended = encoder_model(input)
    x = attention_lstm(output_embedded, constants=[attended])
    x = Dense(len(char_code)+1)(x)
    ...

and recieved the warning (after the first epoch end)
/home/azharkov/anaconda2/lib/python2.7/site-packages/keras/engine/topology.py:2364: UserWarning: Layer rnn_1 was passed non-serializable keyword arguments: {'constants': [<tf.Tensor 'dropout_5/cond/Merge:0' shape=(?, ?, 100) dtype=float32>]}. They will not be included in the serialized model (and thus will be missing at deserialization time). str(node.arguments) + '. They will not be included '

After that the training continued, but the loss increased significantly and accuracy fell to zero (not because of exploding gradients, I clipped them and reduced learning rate but the problem remains).

May it be caused by feeding attended which is not recieved from Input? Or what may be the reason if someone knows? By the way, example provided worked without that warning, so the problem seems not to be related to keras version.

@bezigon
Copy link

bezigon commented Jun 6, 2018

It's necessary and important work. Hope Keras can release it in the near future.

@loretoparisi
Copy link

@bezigon @fchollet what about "attention-is-all-you-need-keras" - https://github.com/Lsdefine/attention-is-all-you-need-keras

@pigubaoza
Copy link

looking forward for this to be released. it will help many

@andhus andhus changed the title Recurrent Attention API: Cell wrapper base class [work in progress] Recurrent Attention API: Cell wrapper base class Sep 4, 2018
@farizrahman4u
Copy link
Contributor

Hi, what is the status of this?

@gabrieldemarmiesse
Copy link
Contributor

I can work on it (commits or code review). But some help would be appreciated since this PR looks huge.

@gabrieldemarmiesse
Copy link
Contributor

@farizrahman4u I propose that you review it and make comments. If we don't get any other commits or responses from @andhus after a few days, I'll open a new pr and use those commits as a base.

attention_states,
training=None):
# only one attended sequence for now (verified in build)
[attended] = attended
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a check? If so please be explicit: if len(atteneded) != 1:...


model.compile(optimizer='Adam', loss='categorical_crossentropy')
model.fit(x=[input_labels_data, attended_data], y=target_labels_data, epochs=5)
output_data = model.predict([input_labels_data, attended_data])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We wont have input_labels_data during prediction.
You have to:

unfiltered_array = get_data() # shape: (batch_size, num_timsteps,  num_classes), contains one hots and empty vectors

previous_output = np.zeros((batch_size, 1, num_classes))

filtered = []

num_required = #Number of one hots we want to generate (we have no way of knowing how many timesteps are empty vectos in given data
for t in range(num_time_steps):
    output = model.predict([previous_output, unfiltered_data)  # Also have to feedback previous internal state here?
    filtered.append(output)
    previous_ouptut = output

filtered = np.concatenate(filtered, axis=1)  # shape: (num_samples, num_required, num_classes)

# convert probabilites to onehots
filtered = .....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that makes the example more complete 👍

return dict(list(base_config.items()) + list(config.items()))


class MixtureOfGaussian1DAttention(_RNNAttentionCell):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need unit tests for MixtureOfGaussian1DAttention

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absolutely, was waiting for feedback on the overall approach.

Dref360 pushed a commit that referenced this pull request Sep 14, 2018
### Summary

This refactoring will allow the simplification of some code in #8296

### Related Issues

### PR Overview

- [ ] This PR requires new unit tests [y/n] (make sure tests are included)
- [ ] This PR requires to update the documentation [y/n] (make sure the docs are up-to-date)
- [x] This PR is backwards compatible [y/n]
- [ ] This PR changes the current API [y/n] (all API changes need to be approved by fchollet)
@gabrieldemarmiesse
Copy link
Contributor

gabrieldemarmiesse commented Sep 18, 2018

Closing this PR in favor of #11172 for organization purposes. This PR can be reopened later on if necessary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants