Skip to content

[CUDA backend ONLY] Use just K-cache for MLA + FA: 47% saving on KV-cache size #13529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,11 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache
{
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));

// note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
if (!v_mla || !cparams.flash_attn) {
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
}
}

const auto & kq_mask = inp->get_kq_mask();
Expand Down Expand Up @@ -1341,7 +1345,11 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache
{
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));

// note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
if (!v_mla || !cparams.flash_attn) {
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
}
}

const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
Expand Down
35 changes: 27 additions & 8 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(

cells.resize(kv_size);

const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);

for (uint32_t il = 0; il < hparams.n_layer; il++) {
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
Expand Down Expand Up @@ -93,7 +95,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
ggml_tensor * v;

k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);

// note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, !is_mla || v_trans ? kv_size : 0);

ggml_format_name(k, "cache_k_l%d", il);
ggml_format_name(v, "cache_v_l%d", il);
Expand Down Expand Up @@ -700,7 +704,9 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
}

bool llama_kv_cache_unified::get_can_shift() const {
return true;
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);

return !is_mla || v_trans; // TODO: allow context shifting for MLA with flash attention;
}

uint32_t llama_kv_cache_unified::get_size() const {
Expand Down Expand Up @@ -733,12 +739,25 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
auto * v = layers[ikv].v;

if (!v_trans) {
// note: v->nb[1] <= v->nb[2]
return ggml_view_3d(ctx, v,
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
0);
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);

if (!is_mla) {
// note: v->nb[1] <= v->nb[2]
return ggml_view_3d(ctx, v,
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
0);
} else {
auto * k = layers[ikv].k;

// note: v->nb[1] == v->nb[2] for MLA as transforms into MQA
return ggml_view_3d(ctx, k,
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
ggml_row_size(k->type, hparams.n_embd_head_k), // v->nb[1]
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), // v->nb[2]
hparams.n_embd_head_k - hparams.n_embd_head_v); // offset by n_rot elements
}
}

// note: v->nb[1] > v->nb[2]
Expand Down
Loading