Skip to content

llama-graph: fix for MLA with FA causing extra overhead for small batches #14198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

jukofyork
Copy link
Collaborator

@jukofyork jukofyork commented Jun 15, 2025

The discussion leading to finding this starts here: #10466 (comment)

I'm still running the full test and:

  • Haven't checked the ggml_cont_2d change works or helps yet.
  • The threshold of 32 may need to be changed.

So leaving as a draft-PR for now.


I'll hopefully have my 3 tests ready for comparison later tonight (main, proposed change without ggml_cont_2d, and proposed change with ggml_cont_2d), and will post the results then.

@JohannesGaessler Can you test on v2-lite to see if the threshold is in a similar ballpark?

@jukofyork
Copy link
Collaborator Author

jukofyork commented Jun 15, 2025

The first run has finished and it looks like the 32 crossover was just an artefact of me not using PP = 512 for both runs:

image

The originial threshold that @fairydreaming used in his code (that he tested on the CPU backend), did this permutation flip at n_tokens > n_heads, so it might be worth trying that later (I'll leave at 32 for now, so I can test the use of ggml_cont_2d() and plot all 3 together).

I've set the ggml_cont_2d() test run off going, but have to go now.

@JohannesGaessler
Copy link
Collaborator

I'm not seeing an improvement with an RTX 4090, Deepseek V2 Lite q4_0, and llama-bench/batched-bench.

@jukofyork
Copy link
Collaborator Author

jukofyork commented Jun 15, 2025

I'm not seeing an improvement with an RTX 4090, Deepseek V2 Lite q4_0, and llama-bench/batched-bench.

What quant is your wv_b? Perhaps it only effects BF16?

This tested Deepseek-V3-0324-Q4_K_XL.gguf quant is q4_K for the shared expert tensors, q6_K for everything else apart from wk_b and wv_b (v_mla in this code) which are in BF16:

llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type q4_K:  174 tensors
llama_model_loader: - type q6_K:  429 tensors
llama_model_loader: - type bf16:  122 tensors
CUDA_VISIBLE_DEVICES=0 ~/llama.cpp/build/bin/llama-batched-bench \
              --model ~/models/gguf/Deepseek-V3-0324-Q4_K_XL.gguf \
              --ctx_size 32768 \
              --n-gpu-layers 99 \
              --flash-attn \
              --numa distribute \
              --threads 80 \
              --override-tensor exps=CPU \
              -pps \
              -npp 0 \
              -npl 1,2,3,4,5,6,7,8,12,16,32,64 \
              -ntg 1

This PR

main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 1, n_gpu_layers = 99, n_threads = 80, n_threads_batch = 80
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
0 1 1 1 0.000 0.00 0.156 6.40 0.156 6.40
0 1 2 2 0.000 0.00 0.298 6.71 0.298 6.70
0 1 3 3 0.000 0.00 0.319 9.42 0.319 9.41
0 1 4 4 0.000 0.00 0.346 11.56 0.346 11.55
0 1 5 5 0.000 0.00 0.369 13.56 0.369 13.55
0 1 6 6 0.000 0.00 0.395 15.17 0.396 15.16
0 1 7 7 0.000 0.00 0.424 16.51 0.424 16.49
0 1 8 8 0.000 0.00 0.453 17.68 0.453 17.66
0 1 12 12 0.000 0.00 0.576 20.84 0.576 20.82
0 1 16 16 0.001 0.00 0.697 22.94 0.698 22.92
0 1 32 32 0.001 0.00 1.178 27.16 1.179 27.14
0 1 64 64 0.002 0.00 2.146 29.82 2.148 29.79

Main

main: n_kv_max = 32768, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 1, n_gpu_layers = 99, n_threads = 80, n_threads_batch = 80
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
0 1 1 1 0.000 0.00 0.161 6.22 0.161 6.21
0 1 2 2 0.000 0.00 0.432 4.63 0.432 4.63
0 1 3 3 0.000 0.00 0.436 6.88 0.436 6.87
0 1 4 4 0.000 0.00 0.461 8.68 0.461 8.67
0 1 5 5 0.000 0.00 0.487 10.26 0.487 10.26
0 1 6 6 0.000 0.00 0.524 11.45 0.524 11.45
0 1 7 7 0.000 0.00 0.541 12.95 0.541 12.94
0 1 8 8 0.000 0.00 0.569 14.05 0.570 14.04
0 1 12 12 0.000 0.00 0.682 17.59 0.683 17.58
0 1 16 16 0.001 0.00 0.797 20.07 0.798 20.06
0 1 32 32 0.001 0.00 1.238 25.85 1.239 25.83
0 1 64 64 0.002 0.00 2.144 29.85 2.146 29.82

(Setting -npp 0 larger has the same pattern, but takes ages to run as I can't offload unless PP > ~2048)

If I don't use BF16 for those two tensors I get about 1/2 the tokens/s (or at least I used to do...), and my RTX 5000 Ada should be almost the same as the RTX 4090.

@jukofyork
Copy link
Collaborator Author

I've set the ggml_cont_2d() test run off going, but have to go now.

It looks like the use of ggml_cont_2d() makes no difference:

image

(which is kind of expected as the only time it would trigger for the spec decoding tests that produce this graph is for the PP part which runs once at PP=512 and once at PP=1024)

@jukofyork
Copy link
Collaborator Author

Also llama-batched-bench won't let me use -npl greater than 64, but I think the pattern shows no cross-over and the threshold should probably just use n_head in place of 32.

@JohannesGaessler
Copy link
Collaborator

From what I can tell the problem is that BF16 and FP32 in particular do not have support for batched matrix multiplications for batch sizes > 1 in the CUDA backend. If you do a matrix vector multiplication that triggers the kernel for batch size 1 which has support for batching and it ends up being faster. I think a better change would be to just add the missing support in the CUDA backend than to try and work around it.

@jukofyork
Copy link
Collaborator Author

jukofyork commented Jun 15, 2025

From what I can tell the problem is that BF16 and FP32 in particular do not have support for batched matrix multiplications for batch sizes > 1 in the CUDA backend. If you do a matrix vector multiplication that triggers the kernel for batch size 1 which has support for batching and it ends up being faster. I think a better change would be to just add the missing support in the CUDA backend than to try and work around it.

I see. So it looks like BF16 with batch size 1 gets caught by the first case here:

    if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
        // the custom F16 vector kernel can be used over batched cuBLAS GEMM
        // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
        ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
    } else if (!split && use_mul_mat_vec_q) {
        ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
    } else if (!split && use_mul_mat_q) {
        ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
            !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
        // general KQ + KQV multi-batch without FlashAttention
        ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
    } else if (use_mul_mat_vec) {
        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
    } else if (use_mul_mat_vec_q) {
        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
    } else if (use_mul_mat_q) {
        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
    } else {
        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
    }

but for larger batch sizes, it's slipping past all the tests and then the ggml_cuda_op_mul_mat ends up spawning a separate mul_mat call for each in the batch?

@jukofyork
Copy link
Collaborator Author

I will retry using F16 and Q8_0 for the attn_v_b tensors tomorrow, eg:

safe_sed_function "src/llama-quant.cpp" \
  "/^static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor \\* tensor, llama_ftype ftype) {" \
  "static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {\n\
    const std::string name = ggml_get_name(tensor);\n\
    if (name.find(\"_exps\") != std::string::npos) {\n\
        return GGML_TYPE_Q4_K;\n\
    } else if (name.find(\"attn_k_b\") != std::string::npos) {\n\
        return GGML_TYPE_BF16;\n\
    } else if (name.find(\"attn_v_b\") != std::string::npos) {\n\
        return GGML_TYPE_F16;\n\
    }\n\
    return GGML_TYPE_Q6_K;\n\
}"

as IIRC, it was attn_k_b that overflowed and looking at the code for ggml_cuda_mul_mat_batched_cublas(); it looks like it could get quite a boost if it works.

I will also try Q8_0 to see if I still get the much lower tokens/s that I did originally (things may have changed a lot since then, so there might be no reason to use BF16 any more...).


I'll also reread the Cublas docs and see if I can get a bf16 version of ggml_cuda_mul_mat_batched_cublas() going tomorrow.

I did try using the CUBLAS_COMPUTE_32F_FAST_16BF stuff for something and got no gains, but I think that may have been in ggml_cuda_op_mul_mat_cublas() as I don't remember anything about batches.

@JohannesGaessler
Copy link
Collaborator

I'll also reread the Cublas docs and see if I can get a bf16 version of ggml_cuda_mul_mat_batched_cublas() going tomorrow.

Feel free to do that, it would be a good addition. For speculative decoding in particular however I think it's more important to extend the mul_mat_vec kernels with support for batch sizes > 1.

@jukofyork
Copy link
Collaborator Author

jukofyork commented Jun 16, 2025

Just closing this PR with some final numbers:

This PR

BF16 attn_v_b.weight + BF16 attn_k_b.weight

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
0 1 1 1 0.000 0.00 0.156 6.40 0.156 6.40
0 1 2 2 0.000 0.00 0.298 6.71 0.298 6.70
0 1 3 3 0.000 0.00 0.319 9.42 0.319 9.41
0 1 4 4 0.000 0.00 0.346 11.56 0.346 11.55
0 1 5 5 0.000 0.00 0.369 13.56 0.369 13.55
0 1 6 6 0.000 0.00 0.395 15.17 0.396 15.16
0 1 7 7 0.000 0.00 0.424 16.51 0.424 16.49
0 1 8 8 0.000 0.00 0.453 17.68 0.453 17.66
0 1 12 12 0.000 0.00 0.576 20.84 0.576 20.82
0 1 16 16 0.001 0.00 0.697 22.94 0.698 22.92
0 1 32 32 0.001 0.00 1.178 27.16 1.179 27.14
0 1 64 64 0.002 0.00 2.146 29.82 2.148 29.79

Main

F16 attn_v_b.weight + BF16 attn_k_b.weight

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
0 1 1 1 0.000 0.00 0.151 6.64 0.151 6.63
0 1 2 2 0.000 0.00 0.294 6.79 0.294 6.79
0 1 3 3 0.000 0.00 0.313 9.59 0.313 9.59
0 1 4 4 0.000 0.00 0.336 11.90 0.336 11.90
0 1 5 5 0.000 0.00 0.360 13.89 0.360 13.88
0 1 6 6 0.000 0.00 0.386 15.55 0.386 15.54
0 1 7 7 0.000 0.00 0.412 16.98 0.412 16.97
0 1 8 8 0.000 0.00 0.440 18.20 0.440 18.19
0 1 12 12 0.000 0.00 0.552 21.73 0.553 21.71
0 1 16 16 0.001 0.00 0.666 24.03 0.666 24.01
0 1 32 32 0.001 0.00 1.113 28.75 1.114 28.72
0 1 64 64 0.002 0.00 2.020 31.69 2.022 31.65

Q8_0 attn_v_b.weight + BF16 attn_k_b.weight

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
0 1 1 1 0.000 0.00 0.154 6.49 0.154 6.49
0 1 2 2 0.000 0.00 0.302 6.63 0.302 6.62
0 1 3 3 0.000 0.00 0.319 9.41 0.319 9.41
0 1 4 4 0.000 0.00 0.344 11.62 0.344 11.61
0 1 5 5 0.000 0.00 0.367 13.64 0.367 13.63
0 1 6 6 0.000 0.00 0.391 15.34 0.391 15.33
0 1 7 7 0.000 0.00 0.418 16.76 0.418 16.75
0 1 8 8 0.000 0.00 0.447 17.88 0.448 17.87
0 1 12 12 0.000 0.00 0.558 21.51 0.558 21.49
0 1 16 16 0.001 0.00 0.670 23.89 0.670 23.86
0 1 32 32 0.001 0.00 1.120 28.58 1.121 28.55
0 1 64 64 0.002 0.00 2.032 31.49 2.034 31.46

Q8_0 attn_v_b.weight + Q8_0 attn_k_b.weight

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
0 1 1 1 0.000 0.00 0.148 6.75 0.148 6.75
0 1 2 2 0.000 0.00 0.154 12.95 0.155 12.93
0 1 3 3 0.000 0.00 0.173 17.36 0.173 17.34
0 1 4 4 0.000 0.00 0.196 20.41 0.196 20.39
0 1 5 5 0.000 0.00 0.223 22.41 0.223 22.39
0 1 6 6 0.000 0.00 0.250 24.04 0.250 24.01
0 1 7 7 0.000 0.00 0.277 25.23 0.278 25.20
0 1 8 8 0.000 0.00 0.312 25.66 0.312 25.63
0 1 12 12 0.000 0.00 0.417 28.77 0.418 28.73
0 1 16 16 0.001 0.00 0.530 30.21 0.530 30.17
0 1 32 32 0.001 0.00 0.979 32.70 0.980 32.66
0 1 64 64 0.002 0.00 1.891 33.85 1.893 33.81

It seems that Q8_0 has gone from a couple of months ago getting around 1/2 the tokens/s to 2x the tokens/s vs BF16! :O

Safe to say I'm requanting all my deepseek models now!

@jukofyork jukofyork closed this Jun 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants