-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: Faster FlashAttention kernel #6374
CUDA: Faster FlashAttention kernel #6374
Conversation
Cool!
Yes, this can be satisfied by padding the KV cache appropriately Can you take a look at the performance of Gemma since it has head size of 256? On my RTX 2060 the TG with this PR and Gemma 2B is faster, but the PP is ~15% slower. You can apply the following patch to satisfy the requirement for the diff --git a/llama.cpp b/llama.cpp
index b80080da..18f09d49 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -9166,7 +9166,7 @@ static int llama_decode_internal(
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
- kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128)));
+ kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
@@ -13083,7 +13083,7 @@ struct llama_context * llama_new_context_with_model(
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
// this is necessary due to kv_self.n being padded later during inference
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32);
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
// with causal attention, the batch size is limited by the context size
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; Also the generation is producing garbage, so there might be some issues still: LLAMA_CUBLAS=1 make -j main && ./main -m ./models/gemma-2b/ggml-model-f16.gguf -p "I believe the meaning of life is" -ngl 99
I believe the meaning of life is<unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31>.... |
No, I don't mean the size of the KV cache. I mean the size of the view of the KV cache used in the kernel. Unless I'm missing something that view is always a multiple of n_embd_head_k == D and I am explicitly making use of that fact. |
Disregard my previous post, you are right that that would be the correct patch; I interpreted the tensor shapes that I observed during regular inference incorrectly. However, this is good news because it means that I can actually write a better kernel by adjusting the padding. |
Gemma should produce correct results now. In addition to the padding I had forgotten to consider GQA. |
Could you please resolve conflicts in order for the ci to start |
48d9f25
to
79bcf52
Compare
It really doesn't matter. This PR is not going into master and as of right now it's not going to compile anyways because I'm using features that are not available on all compute capabilities without the appropriate checks. It's more time efficient for me to only add these at the end. |
Which models have n_embd_head == 64 or 80? |
I would like to see the new |
No, I'll need to rewrite the logic for selecting the correct number of warps anyways because my assumptions that went into writing the kernel were incorrect. |
The Lines 5705 to 5709 in bfe7daf
|
I am getting correct results for |
I'm dumb, I forgot about the FP16 accumulator issues described just one post prior. |
I am not seeing any performance regressions for Gemma 2b on my RTX 3090. Relative to which branch did you observe the regression? In any case, I have pushed a performance optimization that reduces register pressure for Gemma; for me this is only a few % faster but it may have been the issue on Turing. |
I'm comparing to the LLAMA_CUBLAS=1 make -j llama-bench && ./llama-bench -m /mnt/llama.cpp/models/gemma-2b/ggml-model-f16.gguf -ngl 99 -p 512,1024,2048 ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
build: cfde806 (2575) ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
build: e4badc1 (2576) |
Relative to master I'm also seeing a performance regression for Gemma. But I am seeing this performance regression both for CPU and CUDA and regardless of whether or not |
I just realized that I had incorrectly labeled my table. The "master" numbers were actually from |
I see - this branch does not have the recent optimization for skipping the computation of the unused tokens in the last layer. Likely after rebasing to latest |
e4badc1
to
912a6aa
Compare
I added a specialized kernel for batch size 1 and rebased on top of the most recent
LLaMA (head size 128) is universally faster. Gemma (head size 256) is faster for the most important use cases of batch size 1 and batch size >= 512. Phi-2 (head size 80) is slower because 80 is just a terrible choice of head size for CUDA. Quite honestly I don't want to put in the effort to optimize for that particular head size because as far as I'm concerned 3b models are pretty much useless. |
This PR seems to be faster than
For Phi-2 neither branch is consistently faster than master. |
For Phi-2 most of the performance regression vs.
|
Here are some numbers on my systems before the Phi-2 change from just now. These are only the attention in isolation - I think it makes more sense to measure that: # regular attention
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b CUDA0 perf
# new fa kernel
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b CUDA0 perf RTX 2060: Regular batch / context
RTX 2060: Small batch + very long contexts
V100: Regular batch / context
V100: Small batch + very long contexts
The gains on V100 do not look as good as on RTX 2060. For example, the vec kernel on V100 seems to be slower than the non-vec for HS=64. Also HS=256 with B>=512 performs quite poorly. |
Keep in mind that a very important factor for performance is the number of attention heads. The number of CUDA blocks is proportional to the number of heads and if there are too few blocks this leads to poor GPU utilization. In the tests the number of attention heads is always 32 but Gemma for example has only 8 heads and this is one of the reasons why it performs comparatively poorly with FlashAttention. Also if the goal is only to get something that is good enough to merge the tests only need to be sensitive enough to detect whether the code is faster/slower. As such I am using
Generally speaking, GPU performance portability is very poor compared to CPU. The biggest issue here is that for FlashAttention the number of CUDA blocks is simply too low so the kernel performance is very sensitive to the number of streaming multiprocessors on a device vs. the number of CUDA blocks. In any case, the issue for large batches with hs==256 is most likely that Volta does not have hardware support for asynchronously loading data so it needs more registers to load data into tensor cores. I pushed a patch that (I think) fixes this. |
Using
This is important for long prompt processing so we should be measuring it in some way Edit: updated the tables in the previous comment with latest numbers on V100 |
I am testing pp4096, with batch sizes 1,2,...,2048,4096. Unless I am missing something that should be the correct test setup. |
With the latest version the FlashAttention kernel is faster than regular attention on my RTX 3090 for all test cases in |
I implemented splitting the calculations across multiple CUDA blocks for batch size 1 in order to get better GPU utilization; the same trick is also being used on the official FlashAttention repository. This is the current performance on my system:
The biggest issue is still batch size 1 + head size 80. Quite frankly I don't know what else to do in order to improve performance. @ggerganov how strict is the goal of getting consistently better performance than master? From this point on the amount of time that it will take me to find improvements for small batch sizes (if there are any at all) will be much higher. Also I will need to attend to other projects so I will not have a lot of time to work on this in the immediate future. |
Actually, there is one thing: if the changes to the memory layout of the KV cache were reverted I may be able to write a faster kernel for batch size 1 and head size 80. But that of course has other drawbacks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the delay. I think the proposed kernels are all around better compared to what is on the other branch, so we can proceed to merge.
Regarding to merging this into master
- there are still some things necessary to wrap up, like logic in llama.cpp
to choose when to use FA based on the backend and potentially model parameters (e.g. head size) and F32 precision variants that will work with models like Phi-2. But more importantly, we should make a decision if this will be the approach to implement the attention in general. There can be other strategies explored (see the beginning of the discussion in the other PR) - we can either explore these before or after merging this work into master
. Though I'm open to opinions and recommendations
Edit: there are some compile warnings, could you clean these up:
ggml-cuda/fattn.cu(784): warning #177-D: variable "frag_m" was declared but never referenced
const int frag_m = cols_per_block == 8 ? 32 : 16;
^
Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"
ggml-cuda/fattn.cu(479): warning #128-D: loop is not reachable
for (int i0 = 0; i0 < D; i0 += nwarps*32) {
^
detected during instantiation of "void flash_attn_ext_f16<D,ncols,nwarps,VKQ_stride,parallel_blocks>(const char *, const char *, const char *, const char *, float *, half2 *, float, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=64, ncols=8, nwarps=4, VKQ_stride=64, parallel_blocks=1]" at line 792
ggml-cuda/fattn.cu(479): warning #128-D: loop is not reachable
for (int i0 = 0; i0 < D; i0 += nwarps*32) {
^
detected during instantiation of "void flash_attn_ext_f16<D,ncols,nwarps,VKQ_stride,parallel_blocks>(const char *, const char *, const char *, const char *, float *, half2 *, float, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=64, ncols=16, nwarps=4, VKQ_stride=64, parallel_blocks=1]" at line 793
I fixed the compiler warnings as well as the cmake build. While going through the code I also discovered that there luckily (?) was a bug that caused excessive shared memory loads for head sizes < 128. With the bug fixed Phi-2 performance is much better:
From my end, this PR is now ready to merge. I'll leave merging to you since I don't know whether you want to squash prior to merging. Feel free to go through the code and request additional explanatory comments if necessary. |
I forgot: currently the code for running multiple parallel blocks and then merging the results in another kernel only works for batch size 1. But in principle this code could be extended to also work for batch sizes > 1. This should fix the few remaining performance regressions for batch sizes <= 8. But I will probably only get around to implementing this by the end of the week at the earliest. |
This PR adds a rewritten, faster FlashAttention CUDA kernel. A notable feature is that tensor core fragments can now have different shapes depending on batch size. Performance for batch sizes 1-4 is still somewhat suboptimal, especially for head sizes 64 and 80. I think there is no way around writing a second FlashAttention kernel that does not use tensor cores (luckily this will be comparatively simple though). Performance on my system changes as follows:
masterjg /flash-attn-12 no FlashAttentionPlot of the same numbers:
Initially I tried an implementation that calculated softmax using the maximum value of multiple KQ columns. This allowed you to directly scale tensor core fragments instead of having to go through shared memory. This was faster than the kernel in this PR but unfortunately also had a ~0.1% chance to blow up the computation and give you NaN.
For my kernel I am assuming
K->ne[1] % D == 0
. Unless I am missing something something this should always be true for actual use in language models. However, the tests included cases where this is not the case. For this reason I have edited the tests.