Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Are there any plans for supporting an explicit attention mask? #840

Open
Avelina9X opened this issue Feb 20, 2024 · 5 comments
Open

Are there any plans for supporting an explicit attention mask? #840

Avelina9X opened this issue Feb 20, 2024 · 5 comments

Comments

@Avelina9X
Copy link
Contributor

I've noticed that the Triton implementation supports explicit attention bias, which can be used to support arbitrary mask shapes with large negative values, however is there any planned support for explicit (boolean) masks in the CUDA implementation?

I've noticed some requests for features like off-diagonal attention, but an explicit attention mask would be able to facilitate this and any other arbitrary masking scheme - such as XL-Net, attention sinking, landmark attention - without needing to hardcode the attention scheme and enable it with an argument or seperate python interface.

@normster
Copy link

normster commented Mar 5, 2024

It seems like the PyTorch attention implementation supports custom attention masks and also uses Flash-Attention 2: https://twitter.com/StasBekman/status/1736083447658225665. Though I'm not sure that passing in an attention mask doesn't cause the op to dispatch to a non-FA2 kernel.

@tridao
Copy link
Contributor

tridao commented Mar 5, 2024

If there's attn mask pytorch does not dispatch to FA2 kernel, rather the kernel from xformers.

@abdulfatir
Copy link

Thanks for the info @tridao! Is support for arbitrary attention masks on your roadmap? This would be incredibly useful for some encoder-decoder and prefixLM models. Mandatory thank you for your amazing work!

@xiabingquan
Copy link

If there's attn mask pytorch does not dispatch to FA2 kernel, rather the kernel from xformers.

Thanks for this valuable tip. No wonder torch.nn.functional.scaled_dot_product_attention does not bring any speed up in my case

@lin-ht
Copy link

lin-ht commented Sep 10, 2024

I'm looking for bias mask support too, in FA2 and better FA3. Is there a roadmap for this? Thank you~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants