Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any plans to support tree attention mask? #924

Open
KexinFeng opened this issue Apr 21, 2024 · 7 comments
Open

Any plans to support tree attention mask? #924

KexinFeng opened this issue Apr 21, 2024 · 7 comments

Comments

@KexinFeng
Copy link

KexinFeng commented Apr 21, 2024

Tree attention mask is already supported in huggingface/transformers: huggingface/transformers#27539
It will be very helpful for the speculative decoding applications. More sepcifically, in flash_attn/flash_attn_interface.py#flash_attn_with_kvcache, the tree attention mask will need to be specified and passed in as an argument.

Do you have any near plans to support it?

Thanks

Related questions: #840, #918

@tridao
Copy link
Contributor

tridao commented Apr 21, 2024

Sure, we'll just need someone to contribute :D

@thorinf
Copy link

thorinf commented Apr 29, 2024

I'm keen to try supporting a generic mask case, like [B, Q, K] bool, and doing conditional execution. Ideally this covers quite a lot of masking cases, but I guess optimised kernels would work better for more structured masks (like Tree).

@KexinFeng
Copy link
Author

I don't see much difference between a generic mask and a structured mask. For a tree mask, the mask argument would also be of [B, K, Q]. In the 4d attention mask mentioned above, it's nothing but [b, h, k, q] h being number of head.

If you are able to implement a generic mask, then a structured mask will be ready

@thorinf
Copy link

thorinf commented May 6, 2024

What I mean is that for a structured mask you don't necessarily have to create a bool tensor. In the casual case it can be hardcoded in the kernel to ignore j>i+k_cache, which saves a little bit of memory. If its structured the locations you'll visit are predictable.

@KexinFeng
Copy link
Author

I see. Yeah, in the causal mask case, indeed the bool tensor mask argument is not required. For the tree attention mask, however, this argument will be inevitable. But probably this doesn't increase much implementation complexity, since the causal mask will internally be converted to such tensor anyway. @thorinf Look forward to your PR!

@jkobject
Copy link

Hello, sorry for the naive question but:

  1. Why do you need structured masking? can't you do something similar with attention biases?
  2. Are you hoping that you might be able to skip blocks that are entirely masked? or will you still compute attention over the full matrix?

It might help me understand this a bit more :)

@poedator
Copy link

poedator commented Aug 6, 2024

  1. Why do you need structured masking? can't you do something similar with attention biases?

please check out this blogpost with 4D masks description
https://huggingface.co/blog/poedator/4d-masks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants