Skip to content

llama : refactor kv cache guard #12695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ int main(int argc, char ** argv) {

common_params params;

params.n_predict = 128;

if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
return 1;
}
Expand Down
59 changes: 8 additions & 51 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1201,33 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) {
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;

// TODO: remove this stuff
class batch_guard {
public:
batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
}

~batch_guard() {
if (!is_done) {
kv_slot_restorer.restore();
}
}

void done() {
is_done = true;
}

void save(const llama_kv_cache_slot_info & slot_info) {
kv_slot_restorer.save(slot_info);
}

private:
bool is_done = false;

llama_kv_slot_restorer kv_slot_restorer;
};

batch_guard bg(*kv_self);
llama_kv_cache_guard kv_guard(kv_self.get());

GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT

Expand Down Expand Up @@ -1280,6 +1254,9 @@ int llama_context::decode(llama_batch & inp_batch) {
return -2;
};

// handle any pending defrags/shifts
kv_self_update();

int64_t n_outputs_prev = 0;

while (sbatch.n_tokens > 0) {
Expand Down Expand Up @@ -1319,22 +1296,12 @@ int llama_context::decode(llama_batch & inp_batch) {

// find KV slot
{
kv_self_update();
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);

// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
kv_self->head = 0;
return 1;
}

const auto slot_info = kv_self->find_slot(ubatch);
if (!slot_info) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
return -3;
}

bg.save(slot_info);

if (!kv_self->recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
Expand Down Expand Up @@ -1371,16 +1338,6 @@ int llama_context::decode(llama_batch & inp_batch) {
}
}

// update the kv ring buffer
{
kv_self->head += ubatch.n_tokens;

// Ensure kv cache head points to a valid index.
if (kv_self->head >= kv_self->size) {
kv_self->head = 0;
}
}

// plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
Expand Down Expand Up @@ -1467,7 +1424,7 @@ int llama_context::decode(llama_batch & inp_batch) {
}

// finalize the batch processing
bg.done();
kv_guard.commit();

// set output mappings
{
Expand Down
69 changes: 61 additions & 8 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
#include <map>
#include <stdexcept>

static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};

llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
}

Expand Down Expand Up @@ -206,6 +204,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
return false;
}
}

return true;
}

for (uint32_t i = 0; i < size; ++i) {
Expand Down Expand Up @@ -446,16 +446,66 @@ void llama_kv_cache_unified::defrag() {
}
}

void llama_kv_cache_unified::restore() {
if (pending.ranges.empty()) {
return;
}

// TODO: tmp - move to llama_kv_cache_recurrent
if (recurrent) {
seq_rm(-1, -1, -1);
return;
}

uint32_t new_head = size;

for (auto & range : pending.ranges) {
for (uint32_t i = range.c0; i < range.c1; ++i) {
cells[i].seq_id.clear();

// keep count of the number of used cells
if (cells[i].pos >= 0) {
used--;
}

cells[i].pos = -1;
cells[i].src = -1;
}

new_head = std::min(new_head, range.c0);
}

if (new_head != size && new_head < head) {
head = new_head;
}
}

void llama_kv_cache_unified::commit() {
if (pending.ranges.empty()) {
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
return;
}

pending.ranges.clear();
}

bool llama_kv_cache_unified::get_can_shift() const {
return can_shift;
}

llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
bool llama_kv_cache_unified::find_slot(
const llama_ubatch & ubatch) {
const uint32_t n_tokens = ubatch.n_tokens;
const uint32_t n_seqs = ubatch.n_seqs;
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;

// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (head > used + 2*ubatch.n_tokens) {
head = 0;
}

if (recurrent) {
// For recurrent state architectures (like Mamba or RWKV),
// each cache cell can store the state for a whole sequence.
Expand All @@ -477,7 +527,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
// too big seq_id
// TODO: would it be possible to resize the cache instead?
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
return llama_kv_cache_slot_info_failed;
return false;
}
if (j > 0) {
llama_kv_cell & seq = cells[seq_id];
Expand Down Expand Up @@ -616,14 +666,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
[](const llama_kv_cell& cell){ return !cell.is_empty(); });

// sanity check
return llama_kv_cache_slot_info(n >= n_seqs);
return n >= n_seqs;
}

// otherwise, one cell per token.

if (n_tokens > size) {
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
return llama_kv_cache_slot_info_failed;
return false;
}

uint32_t n_tested = 0;
Expand Down Expand Up @@ -651,7 +701,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(

if (n_tested >= size) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return llama_kv_cache_slot_info_failed;
return false;
}
}

Expand All @@ -668,7 +718,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(

used += n_tokens;

return llama_kv_cache_slot_info(head, head + n_tokens);
pending.ranges.push_back({head, head + n_tokens});

return true;
}

uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
Expand Down Expand Up @@ -1033,6 +1085,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;
}
commit();

// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells
Expand Down
Loading