-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: faster FlashAttention for batch sizes > 1 #6646
CUDA: faster FlashAttention for batch sizes > 1 #6646
Conversation
Maybe there's something I'm missing here, but doesn't the mention of ~100t/s prompt processing for batch size 1 seem inaccurate? I think you should be able to hit 1000t/s on q4_0 7b for bs1 on master easy with a 3090, unless the actual token batch count is what is described here and it's not about concurrent requests? I think I'm just mixing up my definitions of "batching" here and it's not about concurrent requests. |
On my machine, hosting to llama.cpp server, I'm able to get a couple coherent tokens, sometimes up to a short paragraph in length, before garbage output consistently comes out for Cohere's Command-R 35b at q5_K_S:
It always degenerates into "Nameservice" repeatedly. |
The numbers are from running something like
So they are representative of a single concurrent context with a varying number of tokens processed in parallel. You would essentially get these speeds for prompt processing if you were to set this batch size. batch size 1 is equivalent to generation.
I can reproduce the issue. It's a bug in the batch size 1 kernel that is already on |
The problem seemed to be that there was a race condition that - depending on the specifics - could sometimes result in garbage outputs. @kalomaze can you confirm that the fix works? |
Works fine on my end now! Responses are perfectly normal on regen / fresh context ingest / etc. I ran a full offload of a q6_K 20b model (not a new / noteworthy one, just was convenient) to test out different
Seems great @JohannesGaessler. Is there anything else left for FlashAttention to be merged into mainline |
@JohannesGaessler Thanks for the work here, I am out of the flash attention branches, but I would like to understand why the bench workflow is slower, 272 iterations here versus 426 on master. It runs on phi-2 Q4_0. Thanks |
You'll have to ask @ggerganov about that.
For this PR the bot reports 241.82 t/s pp and 42.58 t/s tg. In the post you linked it's 118.38 t/s pp and 24.06 t/s tg. Aren't those the relevant numbers? |
EDIT: No this not the relevant numbers |
No, the bench has been fixed since, it is now using the metrics from the server side, not the client side.
|
I don't have a comprehensive overview of the changes in gg/flash-attn. And I don't know the details of how the server benchmark presented here works either. If the issue is due to the CUDA kernels and not some other changes it could just be that I tested and optimized the code on Ampere and Ada Lovelace and that for whatever reason the performance is just bad on Turing. It could also be that it's an issue specific to Phi-2 since that particular model has a janky head size of 80, If you have access to T4s, do a quick benchmark on them to check whether the issue is specific to T4s, to the server, or to Phi-2. Or tell me how I can run this benchmark locally on my own hardware. |
Thanks @JohannesGaessler for your explanation.
The
It would be cool if you can confirm, you can have a look at the Bench README.md: cd examples/server/bench
mkdir models
LLAMA_SERVER_BIN_PATH=../../../build/bin/server python bench.py \
--runner-label local \
--name local \
--branch `git rev-parse --abbrev-ref HEAD` \
--commit `git rev-parse HEAD` \
--scenario script.js \
--duration 10m \
--hf-repo ggml-org/models \
--hf-file phi-2/ggml-model-q4_0.gguf \
--model-path-prefix models \
--parallel 8 \
-ngl 33 \
--batch-size 2048 \
--ubatch-size 256 \
--ctx-size 16384 \
--n-prompts 1000 \
--max-prompt-tokens 1024 \
--max-tokens 2048 As the current branch is behind master, it's better to look at the '/metrics' endpoint directly after the test to have accurate metrics: curl http://localhost:8080/metrics |
The benchmark is crashing after the benchmark so I cannot retrieve any data from the endpoint but I was able to reproduce master being faster than this PR on an RTX 4090 in this specific benchmark. Console out master
Console out PR
If you compare the logs you'll see that on master the slots are used asynchronously while with this PR they are used synchronously. I don't know whether this is the result of changes on gg/flash-attn or whether that branch simply lacks some performance optimizations on master but in any case this has nothing to do with the CUDA kernels. |
Thanks for having taken the time to confirm this is nothing related to the CUDA kernel. Let's wait for the master to be synchronized then. |
The master commit immediately preceding the last merge into gg/flash-attn is still fast so the issue has to be some commit on that branch. |
While 35b and 20b were functional, I also tried a 70b at q4_K_M; it's still affected by the bug and puts out pure gibberish. (alto alto alto alto...)
Not sure if this is relevant, but I notice that these are much smaller too:
when compared to mainline:
The generation speed is also slower on the PR for the 70b, oddly enough, which didn't apply for the other (smaller) model sizes I ran successfully where it was a consistent speedup. |
I cannot reproduce the issue with LLaMA 2 70b q4_K_M. I'll download Miqu and see if there is a difference. |
Could be related to rotary embeddings, maybe, if it does only happen on Miqu |
I can reproduce the issue with Miqu. It could be the same issue as with Phi-2 where (according to Georgi) IEEE 754 half precision floats are not sufficient. One solution would be to instead use bfloat16 or FP32 but bfloat16 only has hardware support since Ampere and FP32 needs more memory. I'll revisit Miqu once there is a solution for Phi-2 and check whether it works then. |
I suspect that the benchmark using phi-2 is invalid because of the precision issues - likely each submitted request keeps generating garbage tokens without ever hitting EOS. Will try to confirm this and move forward the FA branch this week |
@ggerganov I think it's not an issue with the precision but rather with the numerical range. Do you know which parts of the calculation specifically are problematic? |
On RTX 2060, using LLAMA_CUBLAS=1 make -j main && ./main -m ./models/phi-2/ggml-model-f16.gguf -p "I believe the meaning of life is" -s 1 -n 32 -ngl 99 Generation is fine:
Now, disable diff --git a/llama.cpp b/llama.cpp
index cf95cea1..9882b3b8 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -5994,7 +5994,7 @@ static struct ggml_tensor * llm_build_kqv(
if (model.arch == LLM_ARCH_PHI2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
- ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+ //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
}
if (model.arch == LLM_ARCH_GROK) { Generation is now garbage:
So this makes me think that the FA kernels need to start respecting the |
If you add a check at the end of the FP16 kernel to write back 0.0f instead of NaN the results from Phi-2 and Miqu are correct. So presumably it will be possible to fix the outputs without having to write a kernel with higher precision because the problem is underflows rather than overflows. I'll investigate where exactly the problem is and how to best fix it. |
Btw, this #6685 (comment) reminded me that the |
I pushed a partial fix for the numerical precision issues. One of the problems was that if the model produces KQ values with a wide range the FP16 exponentiation can result in arithmetic underflow. Unfortunately the result is then not 0 but NaN. This can be fixed by flushing the results below a certain threshold to 0. I chose a difference of more than 20 to the max value. This is equivalent with flushing all post-exponentiation values ~ In addition to that, if you were to then also add a check at the end that avoids NaNs from 0.0f/0.0f division you could get Phi-2 to produce coherent outputs. With the way I did the implementation NaN values from the KQ matrix multiplication also get set to 0 so if you were to do this you would essentially be ignoring those problematic values. However, this severely affects quality: FP16 perplexity becomes roughly equivalent to q4_K_S. So I don't think this is a good way of handling the Phi-2 precision issues. |
I forgot: I can confirm that the server performance issues are mostly caused by NaN outputs; with the hacky Phi-2 fix the performance is much better (but still slower than master). |
Hello, I just started a server with 32 slots on a llama 70b arch with this branch. It generates garbage: "â-...â-...â-...". Does it support continuous or parallel batching ? Note: Added |
I pushed an implementation for calculating KQ and the corresponding softmax at FP32 precision, Phi-2 should work now. The KQ values produced by Phi-2 are just a mess. They frequently fall outside of the max. representable range of IEEE 754 half precsion floats while the values produced by e.g. LLaMA 2 are significantly smaller. In any case, now it works. The performance impact of using some FP32 is most notable at large batch sizes where the performance difference is 3-5% (may be a lot more on Volta/Turing where there is less shared memory per SM). In any case this is the current level of performance on my systems: Vs. master
Vs. gg/flash-attn
|
@phymbert Does it also produce garbage using |
Yes tested just now with the latest commit here, a llama 70b on 2 A100, only generating |
The target branch only fixes the V cache defrag which I thought was causing the problem that you observed. However, if you observe the garbage using I tested 70B llama with |
const float * Q_f = (const float *) (Q + nb02* blockIdx.y); | ||
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); | ||
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); | ||
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just remembered, for calculating the offsets I am not using blockIdx.z
and nb03
/nb13
because I didn't understand what the purpose was and in the test cases ne3 was always 1. Are they used for continuous batching? If so, what is the expected memory layout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think dim 3 is not used in llama.cpp
during inference (I think it's used only during training and in other ggml
projects). But it would be nice to take it into account. The easiest thing to do is add tests in test-backend-ops
for that.
But in any case, I don't think this is causing the garbage output with 70B that @phymbert reports
@JohannesGaessler Maybe I am facing a precision issue also, I am trying to force But I am facing |
Looking at the documentation, it seems that that particular instruction is indeed not available in CUDA 11.6. But if you're going to force FP32 precision anyways you can just delete those lines since they will not be used. Conversely, they may be the fix for running your model at FP16 precision so it would be worthwhile to also test with CUDA 12. |
I did a quick reimplementation of the function for CUDA 11, it should compile now. |
On RTX 2060 and V100, the Phi-2 F16 model generates garbage (bf6a496): LLAMA_CUBLAS=1 make -j main && ./main -m ./models-mnt/phi-2/ggml-model-f16.gguf -p "I believe the meaning of life is" -s 200 -n 64 -ngl 99
@JohannesGaessler Does it work on your end? |
@ggerganov it looks you need CUDA 12. My issue disapears on this branch with CUDA 12. So good to go if we consider this is a breaking change. |
Hm I think I already use CUDA 12.3. Will double-check later |
The issue has nothing to do with CUDA 12. The code is working correctly for |
My bad then, sorry for the confusion and thanks for the explanation. |
I pushed a fix, now it should work. |
@JohannesGaessler would it be possible to resolve conflicts here ? I would like to test the server |
7999f78
to
44ca576
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, merge at will
87968de
into
ggerganov:gg/flash-attn
This PR does the following:
parallel_blocks
the value is chosen based on the number of streaming multiprocessors; as long as the increased number of blocks still finishes in a single wave it's basically always worthwhile to do.FP16_AVAILABLE
andFP16_MMA_AVAILABLE
that can be used to determine the available of general FP16 intrinsics and FP16 tensor cores in device code.Performance relative to current FlashAttention kernels:
Performance relative to master:
On my systems FlashAttention now seems to be universally faster than master.