Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SelfAttention misses Linear after attention, wrong for Conformer, Transformer #221

Closed
albertz opened this issue Oct 17, 2022 · 15 comments
Closed
Assignees
Milestone

Comments

@albertz
Copy link
Member

albertz commented Oct 17, 2022

image

There is a linear projection after the attention.

ESPNet MultiHeadedAttention has it.
PyTorch torch.nn.MultiheadAttention does not have it.
Keras tf.keras.layers.MultiHeadAttention has it.
torchaudio.models.wav2vec2.components.SelfAttention has it.
Fairseq MultiheadAttention has it.

Our nn.GenericSelfAttention (and thus nn.SelfAttention) does not have it.
The RETURNN SelfAttentionLayer also does not have it.

But then we also don't have it in ConformerEncoderLayer, so it's clearly missing.
Also we don't have it in our Transformer, so it is missing there as well.

So, should we change nn.GenericSelfAttention? Always include it? Or optionally include it? Make it a required argument that there is no confusion about it, like out_dim: Optional[nn.Dim] (without default). In case the user sets None, no linear transformation at the end, otherwise there is.

If we don't change nn.GenericSelfAttention, we must fix the Transformer and Conformer.

@albertz albertz added this to the first-release milestone Oct 17, 2022
@albertz albertz changed the title Conformer self-attention misses Linear after attention SelfAttention misses Linear after attention, wrong for Conformer, Transformer Oct 17, 2022
@JackTemaki
Copy link
Contributor

It should not be part of nn.SelfAttention, but definitely be part of the ConformerEncoderLayer and other related layers.

@albertz
Copy link
Member Author

albertz commented Oct 17, 2022

It should not be part of nn.SelfAttention

Why? This might confuse everyone else (everyone who has not used maybe RETURNN SelfAttentionLayer before). This also now lead to these quite serious bugs, so it even confused ourselves.

@JackTemaki
Copy link
Contributor

JackTemaki commented Oct 17, 2022

I think now that I looked at the code the whole naming is flawed right now. In our case "GenericSelfAttention" is not generic at all, but is fixed to a dot attention as far as I can see. our current GenericSelfAttention is rather a MultiHeadedDotSelfAttention, and then yes, it should include the Linear layer to be correct. (Or lets say compatible to ESPNet and what people expect from the layer)

@albertz
Copy link
Member Author

albertz commented Oct 17, 2022

GenericSelfAttention is called "generic" because it is the base class for SelfAttention and CausalSelfAttention.

@JackTemaki
Copy link
Contributor

JackTemaki commented Oct 17, 2022

Yes, but this is highly misleading. Lets say I now implement an MLP based self-attention instead of a dot based, how would I call that then? And I would expect this should inherit from GenericSelfAttention as template, but this would be incorrect.

@albertz
Copy link
Member Author

albertz commented Oct 17, 2022

I think self-attention always uses dot product. I have never seen sth else. So this is pretty standard. You can of course always implement sth custom, but then this would be a separate implementation. Why would you expect to inherit from sth else if you would not share any code with the standard self-attention? I don't think this is a problem.

Look at any other framework. It's always just called SelfAttention for self-attention or MultiHeadAttention (or alike) for dot attention.

@albertz
Copy link
Member Author

albertz commented Oct 17, 2022

Change nn.GenericSelfAttention? Make it a required argument that there is no confusion about it, like out_dim: Optional[nn.Dim] (without default). In case the user sets None, no linear transformation at the end, otherwise there is.

I tend to prefer this variant, i.e. having it explicit. The majority of other code seems to have it included, just like the original paper describes it, but there are some exceptions.

@JackTemaki
Copy link
Contributor

I think self-attention always uses dot product.

I just wanted to clarify that this is a hidden implication.

Why would you expect to inherit from sth else if you would not share any code with the standard self-attention?

If dot_attention would not be called explicitly but be e.g. an overridable function, the GenericSelfAttention would be more "Generic", all I am saying it that it makes more fixed assumptions than the name hints at.

@albertz
Copy link
Member Author

albertz commented Oct 17, 2022

I just wanted to clarify that this is a hidden implication.

Yes but if sth is standard (as dot-product for self-attention), then I would leave it out of the name for simplification. So I would not add "dot" to the name of some self-attention function/module. For a similar reason, I also left out "multi-head".

@albertz
Copy link
Member Author

albertz commented Oct 17, 2022

This does not say that the names GenericSelfAttention, SelfAttention and CausalSelfAttention are maybe still bad. Maybe you have better suggestions? I also thought about SelfAttention (as the base), NonCausalSelfAttention and CausalSelfAttention, but I preferred the other variant. Or maybe BaseSelfAttention instead of GenericSelfAttention? Or SelfAttentionBase? Or CausalAndNonCausalSelfAttention?

@albertz
Copy link
Member Author

albertz commented Oct 17, 2022

Anyway, the naming is somewhat orthogonal (independent, irrelevant) for this specific issue. What about my suggestion?

@albertz
Copy link
Member Author

albertz commented Oct 17, 2022

Also note, many other frameworks don't actually have self-attention directly as a function/module, but instead have MultiheadAttention or so, and the user must pass three separate arguments for query, key, value. Or the Fairseq MultiheadAttention supports the self-attention case as an option (here).

@albertz
Copy link
Member Author

albertz commented Oct 17, 2022

Somewhat related is also rel pos encoding, and we will need a separate self-attention implementation just for that, for efficiency reason. See #132, #74.

@albertz
Copy link
Member Author

albertz commented Oct 17, 2022

Change nn.GenericSelfAttention? Make it a required argument that there is no confusion about it, like out_dim: Optional[nn.Dim] (without default). In case the user sets None, no linear transformation at the end, otherwise there is.

I tend to prefer this variant, i.e. having it explicit. The majority of other code seems to have it included, just like the original paper describes it, but there are some exceptions.

This is what I implemented now. I named it proj_dim instead of out_dim to not cause confusion.

@albertz
Copy link
Member Author

albertz commented Oct 17, 2022

This fixes the issue here. For other issues (on naming or whatever), if you think there are issues, please open separate issues, maybe referring to the initial discussion from here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants