Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: Faster FlashAttention kernel #6374

Merged
merged 10 commits into from
Apr 2, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Mar 28, 2024

This PR adds a rewritten, faster FlashAttention CUDA kernel. A notable feature is that tensor core fragments can now have different shapes depending on batch size. Performance for batch sizes 1-4 is still somewhat suboptimal, especially for head sizes 64 and 80. I think there is no way around writing a second FlashAttention kernel that does not use tensor cores (luckily this will be comparatively simple though). Performance on my system changes as follows:

Batch size t/s master jg /flash-attn-12 no FlashAttention t/s gg/flash-attn Speedup gg/flash-attn t/s jg/flash-attn-12 Speedup jg/flash-attn-12
1 112.51 104.20 0.93 110.48 0.98
2 212.14 203.85 0.96 215.48 1.02
4 338.06 349.16 1.03 365.31 1.08
8 444.67 488.77 1.10 504.61 1.13
16 450.71 519.93 1.15 529.54 1.17
32 578.73 536.66 0.93 582.85 1.01
64 1183.09 1156.06 0.98 1273.54 1.08
128 1891.63 1979.86 1.05 2120.92 1.12
256 2564.85 2861.33 1.12 2943.26 1.15
512 2934.15 3279.97 1.12 3433.73 1.17
1024 2938.55 3293.07 1.12 3445.01 1.17
2048 2945.06 3311.10 1.12 3457.93 1.17
4096 2951.00 3321.18 1.13 3468.74 1.18

Plot of the same numbers:

flash_attention_perf

Initially I tried an implementation that calculated softmax using the maximum value of multiple KQ columns. This allowed you to directly scale tensor core fragments instead of having to go through shared memory. This was faster than the kernel in this PR but unfortunately also had a ~0.1% chance to blow up the computation and give you NaN.

For my kernel I am assuming K->ne[1] % D == 0. Unless I am missing something something this should always be true for actual use in language models. However, the tests included cases where this is not the case. For this reason I have edited the tests.

@JohannesGaessler JohannesGaessler mentioned this pull request Mar 28, 2024
8 tasks
@ggerganov
Copy link
Owner

A notable feature is that tensor core fragments can now have different shapes depending on batch size.

Cool!

For my kernel I am assuming K->ne[1] % D == 0.

Yes, this can be satisfied by padding the KV cache appropriately

Can you take a look at the performance of Gemma since it has head size of 256? On my RTX 2060 the TG with this PR and Gemma 2B is faster, but the PP is ~15% slower. You can apply the following patch to satisfy the requirement for the K->ne[1] size:

diff --git a/llama.cpp b/llama.cpp
index b80080da..18f09d49 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -9166,7 +9166,7 @@ static int llama_decode_internal(
                 // a heuristic, to avoid attending the full cache if it is not yet utilized
                 // after enough generations, the benefit from this heuristic disappears
                 // if we start defragmenting the cache, the benefit from this will be more important
-                kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128)));
+                kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
                 //kv_self.n = llama_kv_cache_cell_max(kv_self);
             }
         }
@@ -13083,7 +13083,7 @@ struct llama_context * llama_new_context_with_model(
     cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
 
     // this is necessary due to kv_self.n being padded later during inference
-    cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32);
+    cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
 
     // with causal attention, the batch size is limited by the context size
     cparams.n_batch          = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;

Also the generation is producing garbage, so there might be some issues still:

LLAMA_CUBLAS=1 make -j main && ./main -m ./models/gemma-2b/ggml-model-f16.gguf -p "I believe the meaning of life is" -ngl 99

I believe the meaning of life is<unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31>....

@JohannesGaessler
Copy link
Collaborator Author

Yes, this can be satisfied by padding the KV cache appropriately

No, I don't mean the size of the KV cache. I mean the size of the view of the KV cache used in the kernel. Unless I'm missing something that view is always a multiple of n_embd_head_k == D and I am explicitly making use of that fact.

@JohannesGaessler
Copy link
Collaborator Author

Disregard my previous post, you are right that that would be the correct patch; I interpreted the tensor shapes that I observed during regular inference incorrectly. However, this is good news because it means that I can actually write a better kernel by adjusting the padding.

@JohannesGaessler
Copy link
Collaborator Author

Gemma should produce correct results now. In addition to the padding I had forgotten to consider GQA.

@phymbert
Copy link
Collaborator

Could you please resolve conflicts in order for the ci to start

@JohannesGaessler
Copy link
Collaborator Author

It really doesn't matter. This PR is not going into master and as of right now it's not going to compile anyways because I'm using features that are not available on all compute capabilities without the appropriate checks. It's more time efficient for me to only add these at the end.

@JohannesGaessler
Copy link
Collaborator Author

Which models have n_embd_head == 64 or 80?

@phymbert
Copy link
Collaborator

I would like to see the new Benchmark job on this code. Dont you ?

@JohannesGaessler
Copy link
Collaborator Author

No, I'll need to rewrite the logic for selecting the correct number of warps anyways because my assumptions that went into writing the kernel were incorrect.

@ggerganov
Copy link
Owner

The phi-2 model has head size of 80. However, the F16 precision for this model is not enough and it will produce garbage. We need FA kernels with F32 accumulators that will be used when GGML_PREC_F32 is specified:

llama.cpp/llama.cpp

Lines 5705 to 5709 in bfe7daf

if (model.arch == LLM_ARCH_PHI2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
}

@JohannesGaessler
Copy link
Collaborator Author

I am getting correct results for tests/test-backend-ops with a head size of 80. However, the actual Phi-2 model is still producing garbage. I had a similar issue with Gemma because that model was using GQA. Does Phi-2 do anything special with its attention mechanism?

@JohannesGaessler
Copy link
Collaborator Author

I'm dumb, I forgot about the FP16 accumulator issues described just one post prior.

@JohannesGaessler
Copy link
Collaborator Author

Can you take a look at the performance of Gemma since it has head size of 256? On my RTX 2060 the TG with this PR and Gemma 2B is faster, but the PP is ~15% slower.

I am not seeing any performance regressions for Gemma 2b on my RTX 3090. Relative to which branch did you observe the regression? In any case, I have pushed a performance optimization that reduces register pressure for Gemma; for me this is only a few % faster but it may have been the issue on Turing.

@ggerganov
Copy link
Owner

ggerganov commented Mar 29, 2024

I'm comparing to the master branch. Here are the results with the following command using V100:

LLAMA_CUBLAS=1 make -j llama-bench && ./llama-bench -m /mnt/llama.cpp/models/gemma-2b/ggml-model-f16.gguf -ngl 99 -p 512,1024,2048

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes

model size params backend ngl test t/s
gemma 2B F16 (guessed) 5.64 GiB 3.03 B CUDA 99 pp 512 11308.07 ± 144.58
gemma 2B F16 (guessed) 5.64 GiB 3.03 B CUDA 99 pp 1024 11621.09 ± 82.05
gemma 2B F16 (guessed) 5.64 GiB 3.03 B CUDA 99 pp 2048 11324.10 ± 109.07
gemma 2B F16 (guessed) 5.64 GiB 3.03 B CUDA 99 tg 128 118.77 ± 0.31

build: cfde806 (2575)

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes

model size params backend ngl test t/s
gemma 2B F16 (guessed) 5.64 GiB 3.03 B CUDA 99 pp 512 9160.73 ± 69.51
gemma 2B F16 (guessed) 5.64 GiB 3.03 B CUDA 99 pp 1024 8885.20 ± 18.08
gemma 2B F16 (guessed) 5.64 GiB 3.03 B CUDA 99 pp 2048 8117.28 ± 20.89
gemma 2B F16 (guessed) 5.64 GiB 3.03 B CUDA 99 tg 128 122.19 ± 1.71

build: e4badc1 (2576)

@JohannesGaessler
Copy link
Collaborator Author

Relative to master I'm also seeing a performance regression for Gemma. But I am seeing this performance regression both for CPU and CUDA and regardless of whether or not LLAMA_FLASH_ATTENTION is defined in llama.cpp; If LLAMA_FLASH_ATTENTION is defined the performance is comparatively better. To me this suggests that the cause for the regression is some other change either on master or in gg/flash-attn; I don't have a comprehensive overview of all the things that are being changed in gg/flash-attn and which of those changes are supposed to be temporary.

@JohannesGaessler
Copy link
Collaborator Author

I just realized that I had incorrectly labeled my table. The "master" numbers were actually from jg/flash-attn-12 with FlashAttention turned off. For LLaMA 2 q4_0 I am also seeing a very slight performance regression of ~2% relative to master when FlashAttention is turned off.

@ggerganov
Copy link
Owner

I see - this branch does not have the recent optimization for skipping the computation of the unused tokens in the last layer. Likely after rebasing to latest gg/flash-attn the discrepancy will disappear

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Mar 29, 2024

I added a specialized kernel for batch size 1 and rebased on top of the most recent gg/flash-attn. This is the current performance:

GPU Model Batch size Test t/s master t/s jg/flash-attn-16 Speedup
RTX 3090 gemma 2B all F32 (guessed) 1 pp4096 73.10 73.00 1.00
RTX 3090 gemma 2B all F32 (guessed) 2 pp4096 128.68 121.51 0.94
RTX 3090 gemma 2B all F32 (guessed) 4 pp4096 252.66 240.55 0.95
RTX 3090 gemma 2B all F32 (guessed) 8 pp4096 492.13 476.25 0.97
RTX 3090 gemma 2B all F32 (guessed) 16 pp4096 940.09 943.43 1.00
RTX 3090 gemma 2B all F32 (guessed) 32 pp4096 1902.00 1841.48 0.97
RTX 3090 gemma 2B all F32 (guessed) 64 pp4096 3408.40 3349.84 0.98
RTX 3090 gemma 2B all F32 (guessed) 128 pp4096 3002.53 3036.20 1.01
RTX 3090 gemma 2B all F32 (guessed) 256 pp4096 4159.57 4343.17 1.04
RTX 3090 gemma 2B all F32 (guessed) 512 pp4096 5746.50 6003.39 1.04
RTX 3090 gemma 2B all F32 (guessed) 1024 pp4096 5821.18 6072.62 1.04
RTX 3090 gemma 2B all F32 (guessed) 2048 pp4096 5888.14 6136.36 1.04
RTX 3090 gemma 2B all F32 (guessed) 4096 pp4096 5912.42 6161.33 1.04
RTX 3090 llama 7B Q4_0 1 pp4096 112.50 117.42 1.04
RTX 3090 llama 7B Q4_0 2 pp4096 212.92 213.85 1.00
RTX 3090 llama 7B Q4_0 4 pp4096 342.69 366.36 1.07
RTX 3090 llama 7B Q4_0 8 pp4096 456.29 514.65 1.13
RTX 3090 llama 7B Q4_0 16 pp4096 466.84 553.58 1.19
RTX 3090 llama 7B Q4_0 32 pp4096 606.66 625.60 1.03
RTX 3090 llama 7B Q4_0 64 pp4096 1212.55 1311.68 1.08
RTX 3090 llama 7B Q4_0 128 pp4096 1929.57 2181.56 1.13
RTX 3090 llama 7B Q4_0 256 pp4096 2606.33 3040.31 1.17
RTX 3090 llama 7B Q4_0 512 pp4096 2975.79 3542.18 1.19
RTX 3090 llama 7B Q4_0 1024 pp4096 2981.22 3538.68 1.19
RTX 3090 llama 7B Q4_0 2048 pp4096 2983.94 3549.31 1.19
RTX 3090 llama 7B Q4_0 4096 pp4096 2986.36 3547.84 1.19
RTX 3090 phi2 3B F16 1 pp4096 100.57 71.66 0.71
RTX 3090 phi2 3B F16 2 pp4096 175.81 130.96 0.74
RTX 3090 phi2 3B F16 4 pp4096 340.64 260.00 0.76
RTX 3090 phi2 3B F16 8 pp4096 656.95 514.21 0.78
RTX 3090 phi2 3B F16 16 pp4096 1206.90 1005.32 0.83
RTX 3090 phi2 3B F16 32 pp4096 2130.25 1911.41 0.90
RTX 3090 phi2 3B F16 64 pp4096 3448.82 3191.17 0.93
RTX 3090 phi2 3B F16 128 pp4096 4756.77 4570.06 0.96
RTX 3090 phi2 3B F16 256 pp4096 5279.77 4930.82 0.93
RTX 3090 phi2 3B F16 512 pp4096 5665.75 5178.48 0.91
RTX 3090 phi2 3B F16 1024 pp4096 5673.42 5149.88 0.91
RTX 3090 phi2 3B F16 2048 pp4096 5649.56 5140.25 0.91
RTX 3090 phi2 3B F16 4096 pp4096 5669.77 5113.27 0.90

LLaMA (head size 128) is universally faster. Gemma (head size 256) is faster for the most important use cases of batch size 1 and batch size >= 512. Phi-2 (head size 80) is slower because 80 is just a terrible choice of head size for CUDA. Quite honestly I don't want to put in the effort to optimize for that particular head size because as far as I'm concerned 3b models are pretty much useless.

@JohannesGaessler
Copy link
Collaborator Author

This PR seems to be faster than gg/flash-attn for LLaMA and Gemma but slower for Phi-2:

GPU Model Batch size Test t/s gg/flash-attn t/s jg/flash-attn-16 Speedup
RTX 3090 gemma 2B all F32 (guessed) 1 pp4096 66.18 73.00 1.10
RTX 3090 gemma 2B all F32 (guessed) 2 pp4096 118.16 121.51 1.03
RTX 3090 gemma 2B all F32 (guessed) 4 pp4096 233.35 240.55 1.03
RTX 3090 gemma 2B all F32 (guessed) 8 pp4096 461.12 476.25 1.03
RTX 3090 gemma 2B all F32 (guessed) 16 pp4096 909.86 943.43 1.04
RTX 3090 gemma 2B all F32 (guessed) 32 pp4096 1514.52 1841.48 1.22
RTX 3090 gemma 2B all F32 (guessed) 64 pp4096 2801.46 3349.84 1.20
RTX 3090 gemma 2B all F32 (guessed) 128 pp4096 2821.02 3036.20 1.08
RTX 3090 gemma 2B all F32 (guessed) 256 pp4096 4075.36 4343.17 1.07
RTX 3090 gemma 2B all F32 (guessed) 512 pp4096 5816.61 6003.39 1.03
RTX 3090 gemma 2B all F32 (guessed) 1024 pp4096 5895.99 6072.62 1.03
RTX 3090 gemma 2B all F32 (guessed) 2048 pp4096 5964.79 6136.36 1.03
RTX 3090 gemma 2B all F32 (guessed) 4096 pp4096 5980.33 6161.33 1.03
RTX 3090 llama 7B Q4_0 1 pp4096 106.17 117.42 1.11
RTX 3090 llama 7B Q4_0 2 pp4096 206.06 213.85 1.04
RTX 3090 llama 7B Q4_0 4 pp4096 361.45 366.36 1.01
RTX 3090 llama 7B Q4_0 8 pp4096 513.63 514.65 1.00
RTX 3090 llama 7B Q4_0 16 pp4096 558.89 553.58 0.99
RTX 3090 llama 7B Q4_0 32 pp4096 577.57 625.60 1.08
RTX 3090 llama 7B Q4_0 64 pp4096 1191.42 1311.68 1.10
RTX 3090 llama 7B Q4_0 128 pp4096 2052.02 2181.56 1.06
RTX 3090 llama 7B Q4_0 256 pp4096 2990.26 3040.31 1.02
RTX 3090 llama 7B Q4_0 512 pp4096 3421.84 3542.18 1.04
RTX 3090 llama 7B Q4_0 1024 pp4096 3419.07 3538.68 1.03
RTX 3090 llama 7B Q4_0 2048 pp4096 3423.65 3549.31 1.04
RTX 3090 llama 7B Q4_0 4096 pp4096 3426.62 3547.84 1.04
RTX 3090 phi2 3B F16 1 pp4096 96.07 71.66 0.75
RTX 3090 phi2 3B F16 2 pp4096 163.86 130.96 0.80
RTX 3090 phi2 3B F16 4 pp4096 322.84 260.00 0.81
RTX 3090 phi2 3B F16 8 pp4096 635.40 514.21 0.81
RTX 3090 phi2 3B F16 16 pp4096 1239.63 1005.32 0.81
RTX 3090 phi2 3B F16 32 pp4096 1511.25 1911.41 1.26
RTX 3090 phi2 3B F16 64 pp4096 2886.93 3191.17 1.11
RTX 3090 phi2 3B F16 128 pp4096 4668.12 4570.06 0.98
RTX 3090 phi2 3B F16 256 pp4096 6612.29 4930.82 0.75
RTX 3090 phi2 3B F16 512 pp4096 7015.66 5178.48 0.74
RTX 3090 phi2 3B F16 1024 pp4096 7184.97 5149.88 0.72
RTX 3090 phi2 3B F16 2048 pp4096 7081.17 5140.25 0.73
RTX 3090 phi2 3B F16 4096 pp4096 7059.46 5113.27 0.72

For Phi-2 neither branch is consistently faster than master.

@JohannesGaessler
Copy link
Collaborator Author

For Phi-2 most of the performance regression vs. gg/flash-attn can be fixed by using fewer columns per kernel:

GPU Model Batch size Test t/s gg/flash-attn t/s jg/flash-attn-16 Speedup
RTX 3090 phi2 3B F16 1 pp4096 96.07 72.46 0.75
RTX 3090 phi2 3B F16 2 pp4096 163.86 132.39 0.81
RTX 3090 phi2 3B F16 4 pp4096 322.84 262.85 0.81
RTX 3090 phi2 3B F16 8 pp4096 635.40 519.46 0.82
RTX 3090 phi2 3B F16 16 pp4096 1239.63 1015.07 0.82
RTX 3090 phi2 3B F16 32 pp4096 1511.25 1934.57 1.28
RTX 3090 phi2 3B F16 64 pp4096 2886.93 3648.19 1.26
RTX 3090 phi2 3B F16 128 pp4096 4668.12 6035.87 1.29
RTX 3090 phi2 3B F16 256 pp4096 6612.29 6647.63 1.01
RTX 3090 phi2 3B F16 512 pp4096 7015.66 7162.11 1.02
RTX 3090 phi2 3B F16 1024 pp4096 7184.97 7091.45 0.99
RTX 3090 phi2 3B F16 2048 pp4096 7081.17 7068.18 1.00
RTX 3090 phi2 3B F16 4096 pp4096 7059.46 7033.53 1.00

@ggerganov
Copy link
Owner

ggerganov commented Mar 30, 2024

Here are some numbers on my systems before the Phi-2 change from just now. These are only the attention in isolation - I think it makes more sense to measure that:

# regular attention
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b CUDA0 perf

# new fa kernel
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b CUDA0 perf
RTX 2060: Regular batch / context
HS H N_KV B ATTN FATTN SPEEDUP
64 32 512 1 45.37 21.64 2.097
64 32 512 2 49.18 23.76 2.070
64 32 512 4 65.01 24.54 2.649
64 32 512 8 100.03 24.39 4.101
64 32 512 512 1096.23 284.33 3.855
64 32 512 1024 2184.76 550.16 3.971
64 32 512 2048 4331.77 1054.48 4.108
64 32 1024 1 76.08 48.97 1.554
64 32 1024 2 81.73 54.48 1.500
64 32 1024 4 112.55 54.37 2.070
64 32 1024 8 179.57 54.62 3.288
64 32 1024 512 2302.55 542.24 4.246
64 32 1024 1024 4581.62 1058.22 4.330
64 32 1024 2048 9124.48 2059.91 4.430
64 32 2048 1 126.84 94.49 1.342
64 32 2048 2 142.13 105.06 1.353
64 32 2048 4 199.87 105.18 1.900
64 32 2048 8 320.02 105.38 3.037
64 32 2048 512 3421.63 1065.86 3.210
64 32 2048 1024 6810.31 2091.5 3.256
64 32 2048 2048 13575.14 4058.98 3.344
64 32 4096 1 224.13 186.68 1.201
64 32 4096 2 254.6 206.62 1.232
64 32 4096 4 361.18 206.69 1.747
64 32 4096 8 616.09 207 2.976
64 32 4096 512 5751.57 2112.32 2.723
64 32 4096 1024 11441.37 4148.82 2.758
64 32 4096 2048 22746.93 8054.49 2.824
80 32 512 1 52.03 54.96 0.947
80 32 512 2 57.46 54.8 1.049
80 32 512 4 74.67 55.21 1.352
80 32 512 8 110.29 55.38 1.992
80 32 512 512 1144.73 918.45 1.246
80 32 512 1024 2277.57 1757.1 1.296
80 32 512 2048 4517.02 3576.76 1.263
80 32 1024 1 81.79 102.32 0.799
80 32 1024 2 96.3 102.61 0.939
80 32 1024 4 128.85 102.92 1.252
80 32 1024 8 196.54 103.14 1.906
80 32 1024 512 2364.8 1809.9 1.307
80 32 1024 1024 4704.85 3784.78 1.243
80 32 1024 2048 9374.22 7855.45 1.193
80 32 2048 1 140.16 202.69 0.691
80 32 2048 2 171.4 203.1 0.844
80 32 2048 4 229.73 202.74 1.133
80 32 2048 8 353.9 203.15 1.742
80 32 2048 512 3513.05 4058.7 0.866
80 32 2048 1024 6984.16 7989.61 0.874
80 32 2048 2048 13930 15695.98 0.887
80 32 4096 1 248.34 404.48 0.614
80 32 4096 2 314.59 405.2 0.776
80 32 4096 4 421.73 404.07 1.044
80 32 4096 8 674.29 405.37 1.663
80 32 4096 512 5907.92 8047.75 0.734
80 32 4096 1024 11759.38 15960.71 0.737
80 32 4096 2048 23490.11 31129.34 0.755
128 32 512 1 56.88 25.95 2.192
128 32 512 2 63.08 38.71 1.630
128 32 512 4 81.7 38.48 2.123
128 32 512 8 123.79 38.88 3.184
128 32 512 512 1270.79 533.56 2.382
128 32 512 1024 2522.43 1011.81 2.493
128 32 512 2048 4998.68 2067.57 2.418
128 32 1024 1 88.73 47.76 1.858
128 32 1024 2 103.3 71.61 1.443
128 32 1024 4 138.85 71.59 1.940
128 32 1024 8 215.1 72.08 2.984
128 32 1024 512 2500.57 1077.71 2.320
128 32 1024 1024 4965.72 2021.58 2.456
128 32 1024 2048 9891.45 4179.07 2.367
128 32 2048 1 147.63 89.5 1.649
128 32 2048 2 179.5 138.38 1.297
128 32 2048 4 244.09 138.4 1.764
128 32 2048 8 380.33 138.63 2.743
128 32 2048 512 3681.95 2112.82 1.743
128 32 2048 1024 7311.78 4047.36 1.807
128 32 2048 2048 14562.56 8199.13 1.776
128 32 4096 1 261.39 173.66 1.505
128 32 4096 2 333.04 273.14 1.219
128 32 4096 4 448.99 273.04 1.644
128 32 4096 8 728.05 273.46 2.662
128 32 4096 512 6131.54 4174.15 1.469
128 32 4096 1024 12183.61 8300.18 1.468
128 32 4096 2048 24325.8 16262.95 1.496
256 32 512 1 77.22 43.77 1.764
256 32 512 2 86.67 58.98 1.469
256 32 512 4 111.88 59.23 1.889
256 32 512 8 165.46 59.88 2.763
256 32 512 512 1673.23 1105.01 1.514
256 32 512 1024 3320.45 2156.21 1.540
256 32 512 2048 6603.12 4311.68 1.531
256 32 1024 1 120.98 83.5 1.449
256 32 1024 2 144.51 112.32 1.287
256 32 1024 4 187.68 112.11 1.674
256 32 1024 8 283.49 112.87 2.512
256 32 1024 512 3043.4 2111.85 1.441
256 32 1024 1024 6057.62 4116.67 1.471
256 32 1024 2048 12053.14 8196.68 1.470
256 32 2048 1 212.32 163.08 1.302
256 32 2048 2 263.62 217.9 1.210
256 32 2048 4 338.72 217.16 1.560
256 32 2048 8 506.75 218.03 2.324
256 32 2048 512 4506.54 4121.64 1.093
256 32 2048 1024 8951.12 8034.18 1.114
256 32 2048 2048 17812.7 15862.05 1.123
256 32 4096 1 385.78 324.07 1.190
256 32 4096 2 485.69 427.86 1.135
256 32 4096 4 617.62 427.19 1.446
256 32 4096 8 963.15 428.03 2.250
256 32 4096 512 7663.44 8136.86 0.942
256 32 4096 1024 15310.24 15847.58 0.966
256 32 4096 2048 30460.62 31380.33 0.971
RTX 2060: Small batch + very long contexts
HS H N_KV B ATTN FATTN SPEEDUP
64 32 8192 1 419.71 370.29 1.133
64 32 8192 2 482.77 375.89 1.284
64 32 16384 1 817.1 734 1.113
64 32 16384 2 1052.68 749.6 1.404
64 32 32768 1 1635.39 1466.83 1.115
64 32 32768 2 2049.98 1489.6 1.376
64 32 65536 1 3265.59 2937.52 1.112
64 32 65536 2 4114.62 2984.86 1.378
64 32 131072 1 6455.83 5860.23 1.102
64 32 131072 2 8273.58 6028.98 1.372
80 32 8192 1 466.53 807.03 0.578
80 32 8192 2 608.31 804.4 0.756
80 32 16384 1 904.33 1593.4 0.568
80 32 16384 2 1281.81 1593.92 0.804
80 32 32768 1 1793.54 3192.5 0.562
80 32 32768 2 2568.14 3183.74 0.807
80 32 65536 1 3565.5 6381.14 0.559
80 32 65536 2 5127.21 6382.32 0.803
80 32 131072 1 7036.97 12655.64 0.556
80 32 131072 2 10197.66 12714.01 0.802
128 32 8192 1 491.55 340.52 1.444
128 32 8192 2 627.54 541.39 1.159
128 32 16384 1 1020.1 673 1.516
128 32 16384 2 1371.68 1081.26 1.269
128 32 32768 1 1959.56 1335.88 1.467
128 32 32768 2 2783.25 2158.5 1.289
128 32 65536 1 3888.94 2676.52 1.453
128 32 65536 2 5504.49 4339.14 1.269
128 32 131072 1 7794.26 5372.76 1.451
128 32 131072 2 11003.42 8650.74 1.272
256 32 8192 1 728.68 659.9 1.104
256 32 8192 2 914.63 848.84 1.078
256 32 16384 1 1678.8 1337.01 1.256
256 32 16384 2 2101.38 1689.95 1.243
256 32 32768 1 3322.51 2690.17 1.235
256 32 32768 2 4204.37 3373.24 1.246
256 32 65536 1 6674.27 5431.68 1.229
256 32 65536 2 8343.98 6742.49 1.238
256 32 131072 1 13224.92 10819.69 1.222
256 32 131072 2 16743.26 13520.32 1.238
V100: Regular batch / context
HS H N_KV B ATTN FATTN SPEEDUP
64 32 512 1 34.33 18.71 1.835
64 32 512 2 36.82 18.73 1.966
64 32 512 4 48.21 18.98 2.540
64 32 512 8 58.26 18.96 3.073
64 32 512 512 411.71 159.39 2.583
64 32 512 1024 794.87 309.7 2.567
64 32 512 2048 1564.16 545.47 2.868
64 32 1024 1 53.1 40.66 1.306
64 32 1024 2 53.57 40.61 1.319
64 32 1024 4 65.49 40.8 1.605
64 32 1024 8 93.27 41.06 2.272
64 32 1024 512 740.9 300.11 2.469
64 32 1024 1024 1475.44 579.25 2.547
64 32 1024 2048 2882.67 1011.48 2.850
64 32 2048 1 82.41 78.19 1.054
64 32 2048 2 86.14 78 1.104
64 32 2048 4 109.71 78.41 1.399
64 32 2048 8 159.33 78.53 2.029
64 32 2048 512 1271.4 583.64 2.178
64 32 2048 1024 2471.33 1122.92 2.201
64 32 2048 2048 4851.66 1972.88 2.459
64 32 4096 1 135.18 152.98 0.884
64 32 4096 2 146.14 153.16 0.954
64 32 4096 4 190.13 152.54 1.246
64 32 4096 8 286.52 152.2 1.883
64 32 4096 512 2372.55 1152.83 2.058
64 32 4096 1024 4727.58 2210.4 2.139
64 32 4096 2048 9195.68 3868.34 2.377
80 32 512 1 43.09 46.16 0.933
80 32 512 2 35.09 46.19 0.760
80 32 512 4 45.1 46.15 0.977
80 32 512 8 62.17 46.26 1.344
80 32 512 512 440.38 294.48 1.495
80 32 512 1024 852.84 528.13 1.615
80 32 512 2048 1682.85 1011.24 1.664
80 32 1024 1 57.66 108.37 0.532
80 32 1024 2 58.63 107.96 0.543
80 32 1024 4 71.46 108.21 0.660
80 32 1024 8 101.08 108.53 0.931
80 32 1024 512 779.13 558.08 1.396
80 32 1024 1024 1551.65 1022.11 1.518
80 32 1024 2048 3050.4 2181.65 1.398
80 32 2048 1 90.21 214.02 0.422
80 32 2048 2 94.71 214.95 0.441
80 32 2048 4 119.03 214.56 0.555
80 32 2048 8 170.37 213.82 0.797
80 32 2048 512 1320.48 1158.92 1.139
80 32 2048 1024 2575.03 2152.88 1.196
80 32 2048 2048 5059.98 4340.1 1.166
80 32 4096 1 152.62 426.23 0.358
80 32 4096 2 160.76 425.33 0.378
80 32 4096 4 204.91 426.42 0.481
80 32 4096 8 307.04 425.08 0.722
80 32 4096 512 2430.17 2344.12 1.037
80 32 4096 1024 4798.27 4275.17 1.122
80 32 4096 2048 9497.57 8648.04 1.098
128 32 512 1 45.87 22.82 2.010
128 32 512 2 43.96 25.32 1.736
128 32 512 4 51.45 25.48 2.019
128 32 512 8 68.58 25.55 2.684
128 32 512 512 501.15 421.79 1.188
128 32 512 1024 971.19 758.84 1.280
128 32 512 2048 1917.73 1454.19 1.319
128 32 1024 1 62.16 42.39 1.466
128 32 1024 2 63.32 45.73 1.385
128 32 1024 4 76.89 45.88 1.676
128 32 1024 8 108.3 46.12 2.348
128 32 1024 512 846.38 792.2 1.068
128 32 1024 1024 1679.94 1431.42 1.174
128 32 1024 2048 3308.68 2744.28 1.206
128 32 2048 1 96.29 81.54 1.181
128 32 2048 2 99.81 86.21 1.158
128 32 2048 4 125.56 86.42 1.453
128 32 2048 8 179.62 86.4 2.079
128 32 2048 512 1392.65 1536.02 0.907
128 32 2048 1024 2728.39 2777.61 0.982
128 32 2048 2048 5357.98 5336.94 1.004
128 32 4096 1 163.5 156.38 1.046
128 32 4096 2 169.75 171.76 0.988
128 32 4096 4 215.5 170.91 1.261
128 32 4096 8 320.28 171.55 1.867
128 32 4096 512 2595.86 3015.95 0.861
128 32 4096 1024 5012.73 5460.97 0.918
128 32 4096 2048 9930.55 10503.19 0.945
256 32 512 1 56.74 28.23 2.010
256 32 512 2 57.75 33.62 1.718
256 32 512 4 67.63 34.25 1.975
256 32 512 8 88.66 34.38 2.579
256 32 512 512 687.12 1477.84 0.465
256 32 512 1024 1336.81 3016.7 0.443
256 32 512 2048 2619.79 6093.62 0.430
256 32 1024 1 82.79 53.39 1.551
256 32 1024 2 86.69 60.58 1.431
256 32 1024 4 103.5 61.31 1.688
256 32 1024 8 138.78 61.78 2.246
256 32 1024 512 1077.39 2897.77 0.372
256 32 1024 1024 2133.48 5959.1 0.358
256 32 1024 2048 4157.9 12042.37 0.345
256 32 2048 1 129.63 97.27 1.333
256 32 2048 2 142.03 112.65 1.261
256 32 2048 4 171.28 113.06 1.515
256 32 2048 8 230.75 113.16 2.039
256 32 2048 512 1727.98 5667.65 0.305
256 32 2048 1024 3397.51 11734.68 0.290
256 32 2048 2048 6644.12 23758.48 0.280
256 32 4096 1 233.47 188.69 1.237
256 32 4096 2 247.61 217.68 1.137
256 32 4096 4 298.07 216.3 1.378
256 32 4096 8 413.98 216.68 1.911
256 32 4096 512 3115.77 11376.38 0.274
256 32 4096 1024 6170.53 23505.54 0.263
256 32 4096 2048 12091.16 47484.04 0.255
V100: Small batch + very long contexts
HS H N_KV B ATTN FATTN SPEEDUP
64 32 8192 1 252.45 320.79 0.787
64 32 8192 2 264.42 301.43 0.877
64 32 16384 1 500.96 602 0.832
64 32 16384 2 526.49 601.07 0.876
64 32 32768 1 947 1201.3 0.788
64 32 32768 2 1053.66 1197.49 0.880
64 32 65536 1 1915.04 2382.67 0.804
64 32 65536 2 2055.56 2391.24 0.860
64 32 131072 1 3964.44 4750.43 0.835
64 32 131072 2 4111.5 4758.58 0.864
80 32 8192 1 277.34 846.74 0.328
80 32 8192 2 300.98 847.47 0.355
80 32 16384 1 548.61 1682.06 0.326
80 32 16384 2 594.47 1681.94 0.353
80 32 32768 1 1054.09 3361.54 0.314
80 32 32768 2 1182.08 3362 0.352
80 32 65536 1 2117.11 6727.9 0.315
80 32 65536 2 2375.56 6745.57 0.352
80 32 131072 1 4286.83 13456.1 0.319
80 32 131072 2 4770.43 13443.52 0.355
128 32 8192 1 298.63 308.19 0.969
128 32 8192 2 312.56 337 0.927
128 32 16384 1 653.26 618.28 1.057
128 32 16384 2 653.18 671.89 0.972
128 32 32768 1 1295.67 1229.77 1.054
128 32 32768 2 1273.95 1340.88 0.950
128 32 65536 1 2579.33 2443.74 1.055
128 32 65536 2 2517.98 2684.24 0.938
128 32 131072 1 5278.7 4887.34 1.080
128 32 131072 2 5004.94 5378.31 0.931
256 32 8192 1 430.57 370.2 1.163
256 32 8192 2 464.61 423.89 1.096
256 32 16384 1 1105.32 730.21 1.514
256 32 16384 2 987.74 835.25 1.183
256 32 32768 1 2243.32 1449.77 1.547
256 32 32768 2 1966.27 1663.56 1.182
256 32 65536 1 4526.39 2911.01 1.555
256 32 65536 2 3875.53 3322.87 1.166
256 32 131072 1 9239.49 5789.79 1.596
256 32 131072 2 7685.27 6647.67 1.156

The gains on V100 do not look as good as on RTX 2060. For example, the vec kernel on V100 seems to be slower than the non-vec for HS=64. Also HS=256 with B>=512 performs quite poorly.

@JohannesGaessler
Copy link
Collaborator Author

These are only the attention in isolation - I think it makes more sense to measure that

Keep in mind that a very important factor for performance is the number of attention heads. The number of CUDA blocks is proportional to the number of heads and if there are too few blocks this leads to poor GPU utilization. In the tests the number of attention heads is always 32 but Gemma for example has only 8 heads and this is one of the reasons why it performs comparatively poorly with FlashAttention.

Also if the goal is only to get something that is good enough to merge the tests only need to be sensitive enough to detect whether the code is faster/slower. As such I am using llama-bench because it requires very little manual effort from my side to get a usable table.

The gains on V100 do not look as good as on RTX 2060. For example, the vec kernel on V100 seems to be slower than the non-vec for HS=64. Also HS=256 with B>=512 performs quite poorly.

Generally speaking, GPU performance portability is very poor compared to CPU. The biggest issue here is that for FlashAttention the number of CUDA blocks is simply too low so the kernel performance is very sensitive to the number of streaming multiprocessors on a device vs. the number of CUDA blocks.

In any case, the issue for large batches with hs==256 is most likely that Volta does not have hardware support for asynchronously loading data so it needs more registers to load data into tensor cores. I pushed a patch that (I think) fixes this.

@ggerganov
Copy link
Owner

ggerganov commented Mar 30, 2024

Using llama-bench you are not measuring cases where there is already tokens in the KV cache and you now process a new large batch of prompt. For example:

HS H N_KV B ATTN FATTN SPEEDUP
128 32 4096 512 2595.86 3018.69 0.860

This is important for long prompt processing so we should be measuring it in some way

Edit: updated the tables in the previous comment with latest numbers on V100

@JohannesGaessler
Copy link
Collaborator Author

Using llama-bench you are not measuring cases where there is already tokens in the KV cache and you now process a new large batch of prompt.

I am testing pp4096, with batch sizes 1,2,...,2048,4096. Unless I am missing something that should be the correct test setup.

@JohannesGaessler
Copy link
Collaborator Author

With the latest version the FlashAttention kernel is faster than regular attention on my RTX 3090 for all test cases in tests/test-backend-ops other than batch size 1 and head size 64 or 80. There are currently still issues with shared memory limits when trying to set nwarps = 8; with that the only remaining problem may become batch size 1 and head size 80. Unfortunately that one is also the most cancerous to optimize performance for.

@JohannesGaessler
Copy link
Collaborator Author

I implemented splitting the calculations across multiple CUDA blocks for batch size 1 in order to get better GPU utilization; the same trick is also being used on the official FlashAttention repository. This is the current performance on my system:

GPU Model Batch size Test t/s master t/s jg/flash-attn-18 Speedup
RTX 3090 gemma 2B F16 1 pp4096 124.27 135.22 1.09
RTX 3090 gemma 2B F16 2 pp4096 223.37 200.10 0.90
RTX 3090 gemma 2B F16 4 pp4096 446.92 400.69 0.90
RTX 3090 gemma 2B F16 8 pp4096 861.96 792.99 0.92
RTX 3090 gemma 2B F16 16 pp4096 1588.75 1553.48 0.98
RTX 3090 gemma 2B F16 32 pp4096 3181.24 2920.65 0.92
RTX 3090 gemma 2B F16 64 pp4096 5742.32 5285.37 0.92
RTX 3090 gemma 2B F16 128 pp4096 9333.71 9239.47 0.99
RTX 3090 gemma 2B F16 256 pp4096 10957.27 11859.30 1.08
RTX 3090 gemma 2B F16 512 pp4096 11947.09 13296.03 1.11
RTX 3090 gemma 2B F16 1024 pp4096 12135.52 13492.08 1.11
RTX 3090 gemma 2B F16 2048 pp4096 12252.33 13748.61 1.12
RTX 3090 gemma 2B F16 4096 pp4096 12420.43 13626.31 1.10
RTX 3090 llama 7B Q4_0 1 pp4096 113.18 125.32 1.11
RTX 3090 llama 7B Q4_0 2 pp4096 213.52 224.22 1.05
RTX 3090 llama 7B Q4_0 4 pp4096 346.91 382.54 1.10
RTX 3090 llama 7B Q4_0 8 pp4096 462.59 534.06 1.15
RTX 3090 llama 7B Q4_0 16 pp4096 470.46 567.01 1.21
RTX 3090 llama 7B Q4_0 32 pp4096 614.74 643.33 1.05
RTX 3090 llama 7B Q4_0 64 pp4096 1219.97 1338.04 1.10
RTX 3090 llama 7B Q4_0 128 pp4096 1937.68 2277.09 1.18
RTX 3090 llama 7B Q4_0 256 pp4096 2627.53 3153.93 1.20
RTX 3090 llama 7B Q4_0 512 pp4096 3001.62 3720.68 1.24
RTX 3090 llama 7B Q4_0 1024 pp4096 3002.21 3730.97 1.24
RTX 3090 llama 7B Q4_0 2048 pp4096 3005.64 3734.92 1.24
RTX 3090 llama 7B Q4_0 4096 pp4096 3006.77 3739.41 1.24
RTX 3090 phi2 3B F16 1 pp4096 100.77 96.94 0.96
RTX 3090 phi2 3B F16 2 pp4096 176.08 165.16 0.94
RTX 3090 phi2 3B F16 4 pp4096 343.50 327.68 0.95
RTX 3090 phi2 3B F16 8 pp4096 661.02 645.63 0.98
RTX 3090 phi2 3B F16 16 pp4096 1214.18 1257.54 1.04
RTX 3090 phi2 3B F16 32 pp4096 2148.22 2250.75 1.05
RTX 3090 phi2 3B F16 64 pp4096 3493.84 3624.80 1.04
RTX 3090 phi2 3B F16 128 pp4096 4805.64 5270.91 1.10
RTX 3090 phi2 3B F16 256 pp4096 5337.40 5766.99 1.08
RTX 3090 phi2 3B F16 512 pp4096 5703.90 6465.30 1.13
RTX 3090 phi2 3B F16 1024 pp4096 5715.14 6392.56 1.12
RTX 3090 phi2 3B F16 2048 pp4096 5737.96 6419.13 1.12
RTX 3090 phi2 3B F16 4096 pp4096 5745.58 6387.34 1.11

The biggest issue is still batch size 1 + head size 80. Quite frankly I don't know what else to do in order to improve performance. @ggerganov how strict is the goal of getting consistently better performance than master? From this point on the amount of time that it will take me to find improvements for small batch sizes (if there are any at all) will be much higher. Also I will need to attend to other projects so I will not have a lot of time to work on this in the immediate future.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Apr 1, 2024

Quite frankly I don't know what else to do in order to improve performance.

Actually, there is one thing: if the changes to the memory layout of the KV cache were reverted I may be able to write a faster kernel for batch size 1 and head size 80. But that of course has other drawbacks.

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay. I think the proposed kernels are all around better compared to what is on the other branch, so we can proceed to merge.

Regarding to merging this into master - there are still some things necessary to wrap up, like logic in llama.cpp to choose when to use FA based on the backend and potentially model parameters (e.g. head size) and F32 precision variants that will work with models like Phi-2. But more importantly, we should make a decision if this will be the approach to implement the attention in general. There can be other strategies explored (see the beginning of the discussion in the other PR) - we can either explore these before or after merging this work into master. Though I'm open to opinions and recommendations

Edit: there are some compile warnings, could you clean these up:

ggml-cuda/fattn.cu(784): warning #177-D: variable "frag_m" was declared but never referenced
      const int frag_m = cols_per_block == 8 ? 32 : 16;
                ^

Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"

ggml-cuda/fattn.cu(479): warning #128-D: loop is not reachable
      for (int i0 = 0; i0 < D; i0 += nwarps*32) {
      ^
          detected during instantiation of "void flash_attn_ext_f16<D,ncols,nwarps,VKQ_stride,parallel_blocks>(const char *, const char *, const char *, const char *, float *, half2 *, float, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=64, ncols=8, nwarps=4, VKQ_stride=64, parallel_blocks=1]" at line 792

ggml-cuda/fattn.cu(479): warning #128-D: loop is not reachable
      for (int i0 = 0; i0 < D; i0 += nwarps*32) {
      ^
          detected during instantiation of "void flash_attn_ext_f16<D,ncols,nwarps,VKQ_stride,parallel_blocks>(const char *, const char *, const char *, const char *, float *, half2 *, float, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=64, ncols=16, nwarps=4, VKQ_stride=64, parallel_blocks=1]" at line 793

@JohannesGaessler
Copy link
Collaborator Author

I fixed the compiler warnings as well as the cmake build. While going through the code I also discovered that there luckily (?) was a bug that caused excessive shared memory loads for head sizes < 128. With the bug fixed Phi-2 performance is much better:

GPU Model Batch size Test t/s master t/s PR Speedup
RTX 3090 phi2 3B F16 1 pp4096 100.77 106.13 1.05
RTX 3090 phi2 3B F16 2 pp4096 176.08 170.33 0.97
RTX 3090 phi2 3B F16 4 pp4096 343.50 338.29 0.98
RTX 3090 phi2 3B F16 8 pp4096 661.02 669.26 1.01
RTX 3090 phi2 3B F16 16 pp4096 1214.18 1286.26 1.06
RTX 3090 phi2 3B F16 32 pp4096 2148.22 2424.68 1.13
RTX 3090 phi2 3B F16 64 pp4096 3493.84 4252.78 1.22
RTX 3090 phi2 3B F16 128 pp4096 4805.64 6858.33 1.43
RTX 3090 phi2 3B F16 256 pp4096 5337.40 7708.88 1.44
RTX 3090 phi2 3B F16 512 pp4096 5703.90 8682.26 1.52
RTX 3090 phi2 3B F16 1024 pp4096 5715.14 8488.52 1.49
RTX 3090 phi2 3B F16 2048 pp4096 5737.96 8682.73 1.51
RTX 3090 phi2 3B F16 4096 pp4096 5745.58 8433.07 1.47

From my end, this PR is now ready to merge. I'll leave merging to you since I don't know whether you want to squash prior to merging. Feel free to go through the code and request additional explanatory comments if necessary.

@JohannesGaessler
Copy link
Collaborator Author

I forgot: currently the code for running multiple parallel blocks and then merging the results in another kernel only works for batch size 1. But in principle this code could be extended to also work for batch sizes > 1. This should fix the few remaining performance regressions for batch sizes <= 8. But I will probably only get around to implementing this by the end of the week at the earliest.

@ggerganov ggerganov merged commit c63dfdf into ggerganov:gg/flash-attn Apr 2, 2024
32 of 56 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants