Skip to content

[CUDA backend ONLY] Use just K-cache for MLA + FA: 47% saving on KV-cache size #13529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

jukofyork
Copy link
Collaborator

@jukofyork jukofyork commented May 14, 2025

From #13435:

Going forward I would suggest re-writing the code in other backends as well to use only the K tensor. The KV cache size could then be reduced by ~47% by simply not allocating and filling the V cache. At least as long as FlashAttention is used this should be relatively simple. So for a first version it would I think be fine to only deduplicate K and V if FA is used.

I've only tested this to work with #13435 for now, but it should still work with the other backends' flash attention implementations so long as they don't assume that the V-cache they are passed is contiguous:

    } else {
        // note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
        v = ggml_view_3d(ctx0, kv_self->k_l[il],
                n_embd_head_v, n_kv, n_head_kv,
                ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
                ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
                n_embd_head_k-n_embd_head_v); // offset by n_rot elements
    }

The full context of 160k tokens now takes up less than 11GB:

llama_kv_cache_unified: kv_size = 163840, type_k = 'f16', type_v = 'f16', n_layer = 61, can_shift = 0, padding = 256
llama_kv_cache_unified:      CUDA0 KV buffer size = 10980.00 MiB
llama_kv_cache_unified: KV self size  = 10980.00 MiB, K (f16): 10980.00 MiB, V (f16):    0.00 MiB

! 😮


I've just disabled context shifting for now, as like I said in the other post; I'm not at all confident I can cleanly change all that is required to deal with the empty V-cache:

image


I will leave it as a draft for now, as the other backends' flash attention implementations need to either:

A. Check they can work with the non-contiguous V-cache passed to them.
B. Copy @JohannesGaessler's strategy of only using the last 512 elements of the K-cache in place of the V-cache.

I think (B) is preferable, as there is likely to be some significant gains possible regarding CPU cache (re-)use, etc.

@jukofyork jukofyork changed the title MLA + FA now only uses K-cache - 47% saving on KV-cache szie (only for use with #13435 for now) MLA + FA now only uses K-cache - 47% saving on KV-cache size (only for use with #13435 for now) May 14, 2025
@pwilkin
Copy link
Contributor

pwilkin commented May 14, 2025

Wow, this is pretty huge. Would go a long way towards supporting long contexts on potato devices.

@jukofyork
Copy link
Collaborator Author

You'll need to use something like this to get the merge of the 2 PRs:

git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp

git fetch origin pull/13435/head:pr-13435
git fetch origin pull/13529/head:pr-13529

git merge pr-13435 --no-edit
git merge pr-13529 --no-edit

(there may be a better way, but I'm pretty dumb when it comes to using git...)

@jukofyork
Copy link
Collaborator Author

jukofyork commented May 14, 2025

Wow, this is pretty huge. Would go a long way towards supporting long contexts on potato devices.

Yeah, it's nuts: I can now get 160k context, with a ubatch of 4096 using a Q6_K model (with non-shared experts stored in RAM), all on a single 32GB RTX 5000 Ada card!

@Panchovix
Copy link

#13435 got merged, so this could be converted back to PR now I think.

@jukofyork
Copy link
Collaborator Author

#13435 got merged, so this could be converted back to PR now I think.

It really needs people to test if the other back-ends can handle a non-contiguous V-cache view like this first, or preferably do as @JohannesGaessler suggested and make the other back-ends' FA implementation use just the K-cache like he did.

@ggerganov
Copy link
Member

#13435 got merged, so this could be converted back to PR now I think.

It really needs people to test if the other back-ends can handle a non-contiguous V-cache view like this first, or preferably do as @JohannesGaessler suggested and make the other back-ends' FA implementation use just the K-cache like he did.

The CPU implementation does not seem to support it. This command generates junk for me:

make -j && ./bin/llama-cli -m ../models/deepseek-v2-lite-chat/ggml-model-q8_0-mla.gguf -no-cnv -p "I believe the meaning of life is" --top-k 1 -n 32 -fa -dev none

Maybe look into fixing it and adding test-backend-ops tests that verify this use case. It will make it easier to support it in the rest of the backends.

@bartowski1182
Copy link
Contributor

Does this have any speed implications by not storing the values ?

@jukofyork
Copy link
Collaborator Author

#13435 got merged, so this could be converted back to PR now I think.

It really needs people to test if the other back-ends can handle a non-contiguous V-cache view like this first, or preferably do as @JohannesGaessler suggested and make the other back-ends' FA implementation use just the K-cache like he did.

The CPU implementation does not seem to support it. This command generates junk for me:

make -j && ./bin/llama-cli -m ../models/deepseek-v2-lite-chat/ggml-model-q8_0-mla.gguf -no-cnv -p "I believe the meaning of life is" --top-k 1 -n 32 -fa -dev none

Maybe look into fixing it and adding test-backend-ops tests that verify this use case. It will make it easier to support it in the rest of the backends.

Yeah, I think having a non-cont V-cache is likely to break most of the backends.

I'll try and have a look at fixing this and adding the test today or tomorrow.

@jukofyork
Copy link
Collaborator Author

jukofyork commented May 15, 2025

Does this have any speed implications by not storing the values ?

If the backends are rewritten to only use the K-cache, then it could have a big performance improvement for some of the backends due to not having to access the V-cache in memory (eg: a CPU's L3-cache will only have to store K-cache elements).

If the backends are just fixed to use the non-contiguous view of the K-cache in place of the V-cache, there will likely be some degree of performance reduction though.

@jukofyork jukofyork changed the title MLA + FA now only uses K-cache - 47% saving on KV-cache size (only for use with #13435 for now) [CUDA backend ONLY] Use just K-cache for MLA + FA: 47% saving on KV-cache size May 19, 2025
@jukofyork
Copy link
Collaborator Author

I changed the title to make it clear that it will only work for the CUDA backend for now.

I'm away until Thursday so will try and get this working for CPU when I get back home then.

@jukofyork
Copy link
Collaborator Author

jukofyork commented Jun 4, 2025

Just looking to see if I can revive this PR, but not sure about how to proceed:

1. Carry on using llama-kv-cache-unified.cpp and add something like this to the constructor:

    // note: MLA with flash attention can use the last 512 elements of K-cache in place of a V-cache
    if (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0 && !v_trans && type_k == type_v) {
        v_view_of_k = true;
    }

and then add lots of special cases to all the members that access v.

2. Make a new subclass of llama_kv_cache which will be 99% the same as llama-kv-cache-unified.cpp.


I don't really like the look of either of these options though:

  • Both solutions are ugly and I think will be hard to maintain.
  • Neither solution really accounts for the fact that the CUDA backend doesn't care about the v cache being populated when using MLA+FA (ie: it just reads everything from k now AFAIK and doesn't need/use any of the dimensions read from the v tensor - @JohannesGaessler is this correct?).
  • Neither solves the problem of the other backends expecting v to be contiguous, and crashing or outputting gibberish when passed a view of k in place of v.

I could try to factor out all the common code to a common subclass for option (2), but:

  1. Not sure what the llama.cpp consensus is on nested inheritance like this (ie: it can end up an unmaintainable mess in the future too).
  2. Is there still ongoing work on the refactoring of the KV-cache stuff?

@ggerganov @slaren @ngxson What are your thoughts on how best to tackle this?

@ggerganov
Copy link
Member

The KV cache can have a flag is_mla similar to v_trans and handle it in the get_v() and cpy_v() methods. Unless I am missing something, this should be pretty clean.

The other thing that is missing is to update the CPU implementation to support non-contiguous v. After you implement this, add a test in test-backend-ops to exercise it and make all backends to "not support" it by updating the respective supports_op functions. Then in new PRs maintainers will implement support for this in the other backends.

@JohannesGaessler
Copy link
Collaborator

One thing to keep in mind is that (for the memory savings to be automatic) the allocation code would already need to check support for whether or not the noncontiguous V cache is supported. Unless the list of backends with FA support for Deepseek is very short and we could simply add support to those backends that don't yet have it (I don't have a good overview).

Neither solution really accounts for the fact that the CUDA backend doesn't care about the v cache being populated when using MLA+FA (ie: it just reads everything from k now AFAIK and doesn't need/use any of the dimensions read from the v tensor - @JohannesGaessler is this correct?).

Unless I did the implementation wrong it should work even if V is null.

@jukofyork
Copy link
Collaborator Author

The KV cache can have a flag is_mla similar to v_trans and handle it in the get_v() and cpy_v() methods. Unless I am missing something, this should be pretty clean.

The other thing that is missing is to update the CPU implementation to support non-contiguous v. After you implement this, add a test in test-backend-ops to exercise it and make all backends to "not support" it by updating the respective supports_op functions. Then in new PRs maintainers will implement support for this in the other backends.

I'm still not sure if we should just leave the v tensor as a nullptr or try to make it a strided view of k?

Leaving it as a nullptr would make it clearer IMO and I doubt any of the backends will work without changes even if we make it a strided view of k...

@jukofyork jukofyork closed this Jun 12, 2025
@jukofyork jukofyork force-pushed the mla-fa-disable-v-cache branch from 271560e to 7d51644 Compare June 12, 2025 11:42
@jukofyork jukofyork reopened this Jun 12, 2025
@jukofyork
Copy link
Collaborator Author

jukofyork commented Jun 12, 2025

Leaving it as a nullptr would make it clearer

Something seems to be reading v and it seems to still needs to have the view made.


I think I've managed to get this back to the same state as before (ie: CUDA only, context-shifting disabled, and no tests yet).

@jukofyork
Copy link
Collaborator Author

jukofyork commented Jun 12, 2025

#13435 got merged, so this could be converted back to PR now I think.

It really needs people to test if the other back-ends can handle a non-contiguous V-cache view like this first, or preferably do as @JohannesGaessler suggested and make the other back-ends' FA implementation use just the K-cache like he did.

The CPU implementation does not seem to support it. This command generates junk for me:

make -j && ./bin/llama-cli -m ../models/deepseek-v2-lite-chat/ggml-model-q8_0-mla.gguf -no-cnv -p "I believe the meaning of life is" --top-k 1 -n 32 -fa -dev none

Maybe look into fixing it and adding test-backend-ops tests that verify this use case. It will make it easier to support it in the rest of the backends.

Got a bit further figuring out where the CPU backend might be going wrong:

const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));

Which then gets treated as contiguous for these:

ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);

v_to_float(v_data, V32, DV);

ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);

BUT, not 100% sure it's here as there is this test at the top of ggml_compute_forward_flash_attn_ext_f16:

GGML_ASSERT(nbv0 == ggml_type_size(v->type));

which should have asserted out before then...?


I can't really see a better fix than just making a temp copy of this using the stride:

const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));

as the parameters of ggml_vec_mad_f16 don''t really allow for any finer granularity for the section of memory to be copied:

inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {


There is also the potential to store the 64-element RoPEed part separately and then k and v will be identical, but that will lose the ability to context-shift and possibly (which expects the n_rot RoPEed part to be at first elements) hurt the memory locality for flash attention?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants