Skip to content

kv-cache : rework kv_cell #13706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented May 22, 2025

cont #13194

The KV cells editing logic is now implemented via the new struct llama_kv_cells_unified in the new src/llama-kv-cells.h source. The goal is to simplify the implementation in llama-kv-cache.cpp and make it easier to understand and update in the future.

One of the primary simplifications is that llama_kv_cache_unified no longer tracks the number of used cells manually. This is now automatically tracked by the llama_kv_cells_unified based on the edits that we apply, such as adding and removing sequences from the cells. Same for the has_shift flag.

  • The cell information (i.e. pos, delta, seq) is now a structure of arrays for better cache locality
  • The sequences that belong to a cell are now tracked with a std::bitset instead of std::set

Here is an example of the position shift logic before and after the change:

// before
    for (uint32_t i = 0; i < size; ++i) {
        if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
            has_shift = true;

            cells[i].pos   += delta;
            cells[i].delta += delta;

            if (cells[i].pos < 0) {
                if (!cells[i].is_empty()) {
                    used--;
                }
                cells[i].pos = -1;
                cells[i].seq_id.clear();
                if (new_head == size) {
                    new_head = i;
                }
            }
        }
    }

// after
    for (uint32_t i = 0; i < cells.size(); ++i) {
        if (!cells.pos_in(i, p0, p1)) {
            continue;
        }

        if (cells.seq_has(i, seq_id)) {
            if (cells.pos_add(i, delta)) {
                if (new_head == cells.size()) {
                    new_head = i;
                }
            }
        }
    }

Next

  • Add a heap for efficiently tracking min and max pos per sequence
  • Efficiently track the max occupied KV cell (i.e. n = cell_max()) instead of searching for it on every batch

@ggerganov
Copy link
Member Author

ggerganov commented May 22, 2025

In the next PR I will try to rework these 3 methods with something like batch_process_info_t init(const llama_batch & batch);:

// =============================================================================================================
// TODO: refactor and simplify this [TAG: KV_API]
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
// different KV caches require different batch splitting strategies
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
// find an empty slot of size "n_tokens" in the cache
virtual bool find_slot(const llama_ubatch & batch) = 0;

The main goal is to be able to run SWA caches with just n_swa + n_ubatch cells as explained in #13194 (comment). This refactor will also absorb the find_slot logic so that the llama_context won't need to be aware of this implementation detail. The batch -> sbatch -> ubatch logic will also be better contained within the llama_kv_cache and won't leak into llama_decode() as it does now, which will be useful later for implementing new KV caches.

When this rework is ready, I will use the new init(batch) mechanism to remove the set_full() call from the interface:

// simulate full cache, used for allocating worst-case compute buffers
// TODO: remove
virtual void set_full() = 0;

Simulating a full cache will be now achieved by initializing the appropriate batches and just not processing them.

Any suggestions about the plan are welcome.

@ggerganov ggerganov requested a review from slaren May 22, 2025 13:26
@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part2 branch from 0a8cdc3 to eda2e13 Compare May 23, 2025 09:03
}

// note: call only if the cell is not empty
llama_pos get_pos(uint32_t i) const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be pos_get for consistency with pos_set etc?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants