-
Notifications
You must be signed in to change notification settings - Fork 12.1k
llama-graph: fix for MLA with FA causing extra overhead for small batches #14198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The first run has finished and it looks like the The originial threshold that @fairydreaming used in his code (that he tested on the CPU backend), did this permutation flip at I've set the |
I'm not seeing an improvement with an RTX 4090, Deepseek V2 Lite q4_0, and |
What quant is your This tested
CUDA_VISIBLE_DEVICES=0 ~/llama.cpp/build/bin/llama-batched-bench \
--model ~/models/gguf/Deepseek-V3-0324-Q4_K_XL.gguf \
--ctx_size 32768 \
--n-gpu-layers 99 \
--flash-attn \
--numa distribute \
--threads 80 \
--override-tensor exps=CPU \
-pps \
-npp 0 \
-npl 1,2,3,4,5,6,7,8,12,16,32,64 \
-ntg 1 This PR
Main
(Setting If I don't use |
It looks like the use of (which is kind of expected as the only time it would trigger for the spec decoding tests that produce this graph is for the PP part which runs once at |
Also |
From what I can tell the problem is that BF16 and FP32 in particular do not have support for batched matrix multiplications for batch sizes > 1 in the CUDA backend. If you do a matrix vector multiplication that triggers the kernel for batch size 1 which has support for batching and it ends up being faster. I think a better change would be to just add the missing support in the CUDA backend than to try and work around it. |
I see. So it looks like if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_vec_q) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_q) {
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// general KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
} else if (use_mul_mat_vec) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
} else if (use_mul_mat_vec_q) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
} else if (use_mul_mat_q) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
} else {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
} but for larger batch sizes, it's slipping past all the tests and then the |
I will retry using
as IIRC, it was I will also try I'll also reread the Cublas docs and see if I can get a I did try using the |
Feel free to do that, it would be a good addition. For speculative decoding in particular however I think it's more important to extend the |
Just closing this PR with some final numbers: This PRBF16
|
PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
---|---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | 1 | 0.000 | 0.00 | 0.156 | 6.40 | 0.156 | 6.40 |
0 | 1 | 2 | 2 | 0.000 | 0.00 | 0.298 | 6.71 | 0.298 | 6.70 |
0 | 1 | 3 | 3 | 0.000 | 0.00 | 0.319 | 9.42 | 0.319 | 9.41 |
0 | 1 | 4 | 4 | 0.000 | 0.00 | 0.346 | 11.56 | 0.346 | 11.55 |
0 | 1 | 5 | 5 | 0.000 | 0.00 | 0.369 | 13.56 | 0.369 | 13.55 |
0 | 1 | 6 | 6 | 0.000 | 0.00 | 0.395 | 15.17 | 0.396 | 15.16 |
0 | 1 | 7 | 7 | 0.000 | 0.00 | 0.424 | 16.51 | 0.424 | 16.49 |
0 | 1 | 8 | 8 | 0.000 | 0.00 | 0.453 | 17.68 | 0.453 | 17.66 |
0 | 1 | 12 | 12 | 0.000 | 0.00 | 0.576 | 20.84 | 0.576 | 20.82 |
0 | 1 | 16 | 16 | 0.001 | 0.00 | 0.697 | 22.94 | 0.698 | 22.92 |
0 | 1 | 32 | 32 | 0.001 | 0.00 | 1.178 | 27.16 | 1.179 | 27.14 |
0 | 1 | 64 | 64 | 0.002 | 0.00 | 2.146 | 29.82 | 2.148 | 29.79 |
Main
F16 attn_v_b.weight
+ BF16 attn_k_b.weight
PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
---|---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | 1 | 0.000 | 0.00 | 0.151 | 6.64 | 0.151 | 6.63 |
0 | 1 | 2 | 2 | 0.000 | 0.00 | 0.294 | 6.79 | 0.294 | 6.79 |
0 | 1 | 3 | 3 | 0.000 | 0.00 | 0.313 | 9.59 | 0.313 | 9.59 |
0 | 1 | 4 | 4 | 0.000 | 0.00 | 0.336 | 11.90 | 0.336 | 11.90 |
0 | 1 | 5 | 5 | 0.000 | 0.00 | 0.360 | 13.89 | 0.360 | 13.88 |
0 | 1 | 6 | 6 | 0.000 | 0.00 | 0.386 | 15.55 | 0.386 | 15.54 |
0 | 1 | 7 | 7 | 0.000 | 0.00 | 0.412 | 16.98 | 0.412 | 16.97 |
0 | 1 | 8 | 8 | 0.000 | 0.00 | 0.440 | 18.20 | 0.440 | 18.19 |
0 | 1 | 12 | 12 | 0.000 | 0.00 | 0.552 | 21.73 | 0.553 | 21.71 |
0 | 1 | 16 | 16 | 0.001 | 0.00 | 0.666 | 24.03 | 0.666 | 24.01 |
0 | 1 | 32 | 32 | 0.001 | 0.00 | 1.113 | 28.75 | 1.114 | 28.72 |
0 | 1 | 64 | 64 | 0.002 | 0.00 | 2.020 | 31.69 | 2.022 | 31.65 |
Q8_0 attn_v_b.weight
+ BF16 attn_k_b.weight
PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
---|---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | 1 | 0.000 | 0.00 | 0.154 | 6.49 | 0.154 | 6.49 |
0 | 1 | 2 | 2 | 0.000 | 0.00 | 0.302 | 6.63 | 0.302 | 6.62 |
0 | 1 | 3 | 3 | 0.000 | 0.00 | 0.319 | 9.41 | 0.319 | 9.41 |
0 | 1 | 4 | 4 | 0.000 | 0.00 | 0.344 | 11.62 | 0.344 | 11.61 |
0 | 1 | 5 | 5 | 0.000 | 0.00 | 0.367 | 13.64 | 0.367 | 13.63 |
0 | 1 | 6 | 6 | 0.000 | 0.00 | 0.391 | 15.34 | 0.391 | 15.33 |
0 | 1 | 7 | 7 | 0.000 | 0.00 | 0.418 | 16.76 | 0.418 | 16.75 |
0 | 1 | 8 | 8 | 0.000 | 0.00 | 0.447 | 17.88 | 0.448 | 17.87 |
0 | 1 | 12 | 12 | 0.000 | 0.00 | 0.558 | 21.51 | 0.558 | 21.49 |
0 | 1 | 16 | 16 | 0.001 | 0.00 | 0.670 | 23.89 | 0.670 | 23.86 |
0 | 1 | 32 | 32 | 0.001 | 0.00 | 1.120 | 28.58 | 1.121 | 28.55 |
0 | 1 | 64 | 64 | 0.002 | 0.00 | 2.032 | 31.49 | 2.034 | 31.46 |
Q8_0 attn_v_b.weight
+ Q8_0 attn_k_b.weight
PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
---|---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | 1 | 0.000 | 0.00 | 0.148 | 6.75 | 0.148 | 6.75 |
0 | 1 | 2 | 2 | 0.000 | 0.00 | 0.154 | 12.95 | 0.155 | 12.93 |
0 | 1 | 3 | 3 | 0.000 | 0.00 | 0.173 | 17.36 | 0.173 | 17.34 |
0 | 1 | 4 | 4 | 0.000 | 0.00 | 0.196 | 20.41 | 0.196 | 20.39 |
0 | 1 | 5 | 5 | 0.000 | 0.00 | 0.223 | 22.41 | 0.223 | 22.39 |
0 | 1 | 6 | 6 | 0.000 | 0.00 | 0.250 | 24.04 | 0.250 | 24.01 |
0 | 1 | 7 | 7 | 0.000 | 0.00 | 0.277 | 25.23 | 0.278 | 25.20 |
0 | 1 | 8 | 8 | 0.000 | 0.00 | 0.312 | 25.66 | 0.312 | 25.63 |
0 | 1 | 12 | 12 | 0.000 | 0.00 | 0.417 | 28.77 | 0.418 | 28.73 |
0 | 1 | 16 | 16 | 0.001 | 0.00 | 0.530 | 30.21 | 0.530 | 30.17 |
0 | 1 | 32 | 32 | 0.001 | 0.00 | 0.979 | 32.70 | 0.980 | 32.66 |
0 | 1 | 64 | 64 | 0.002 | 0.00 | 1.891 | 33.85 | 1.893 | 33.81 |
It seems that Q8_0
has gone from a couple of months ago getting around 1/2 the tokens/s to 2x the tokens/s vs BF16
! :O
Safe to say I'm requanting all my deepseek
models now!
The discussion leading to finding this starts here: #10466 (comment)
I'm still running the full test and:
ggml_cont_2d
change works or helps yet.32
may need to be changed.So leaving as a draft-PR for now.
I'll hopefully have my 3 tests ready for comparison later tonight (main, proposed change without
ggml_cont_2d
, and proposed change withggml_cont_2d
), and will post the results then.@JohannesGaessler Can you test on
v2-lite
to see if the threshold is in a similar ballpark?