-
Notifications
You must be signed in to change notification settings - Fork 12.3k
CUDA: 4D FlashAttention support #14628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: 4D FlashAttention support #14628
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests are passing on RTX 2060
f23950a
to
ab82dc2
Compare
326e4e2
to
2f9b295
Compare
There was some issue with the WMMA kernel (which is now fixed), merge when convenient for you. |
c43f275
into
ggml-org:gg/llama-high-throughput
* CUDA: 4D FlashAttention support * CUDA: fix WMMA FA kernel
Something is wrong, I'm getting a ton of failures on 3090Ti (CUDA 12.9):
|
You are testing master. This wa merged in another brabch |
Ah, LOL, sorry. :) Why is master failing though? |
If master is failing, can you do a git bisect to determine since when? |
Its failing the mask->ne[2] != 1 tests. These are not relevant |
This PR adds 4-dimensional CUDA FlashAttention support for #14363 . The data layout for the fixup was changed but there should be no change to performance. As discussed in #14505 (comment) , the CUDA code requires
mask->ne[2] == 1
, otherwise it would require additional complexity to ensure that the GQA-specific optimizations infattn-mma-f16.cuh
produce correct results.