Skip to content

kv-cache : refactor the update/defrag mechanism #13988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 56 additions & 27 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,30 +429,62 @@ const llama_kv_cache * llama_context::get_kv_self() const {
return kv_self;
}

bool llama_context::kv_self_update() {
void llama_context::kv_self_defrag_sched() {
if (!memory) {
return;
}

memory_force_optimize = true;
}

bool llama_context::kv_self_update(bool optimize) {
if (!memory) {
return false;
}

llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());

if (!kv_self->update(*this)) {
// no updates have been performed
return false;
{
// TODO: remove in the future
optimize |= memory_force_optimize;
memory_force_optimize = false;

const auto kv_state = kv_self->init_update(this, optimize);
switch (kv_state->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
// noop
} break;
case LLAMA_MEMORY_STATUS_NO_UPDATE:
{
// no updates need to be performed
return false;
}
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
return false;
}
}

if (!kv_state->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
}
}

// if the KV cache did any computation, we have to reserve a new worst-case graph
const auto kv_state = kv_self->init_full();
if (!kv_state) {
throw std::runtime_error("failed to initialize KV cache");
throw std::runtime_error("failed to initialize memory state");
}

const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);

auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
}

return true;
Expand Down Expand Up @@ -940,13 +972,13 @@ int llama_context::decode(llama_batch & inp_batch) {
n_outputs_all = 1;
}

bool did_optimize = false;

// handle any pending defrags/shifts
kv_self_update();
kv_self_update(false);

llama_memory_state_ptr kv_state;

bool did_defrag = false;

while (true) {
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
if (!kv_state) {
Expand All @@ -957,25 +989,32 @@ int llama_context::decode(llama_batch & inp_batch) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
} break;
case LLAMA_MEMORY_STATUS_NO_UPDATE:
{
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status());

return -2;
}
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
{
if (!did_defrag) {
did_defrag = true;
if (!did_optimize) {
did_optimize = true;

kv_self->defrag_sched(-1.0f);
if (kv_self_update()) {
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
if (kv_self_update(true)) {
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);

continue;
}
}

LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);

return 1;
}
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);

return -2;
}
}
Expand Down Expand Up @@ -1189,11 +1228,6 @@ int llama_context::decode(llama_batch & inp_batch) {
// wait for the computation to finish (automatically done when obtaining the model output)
//synchronize();

// decide if we need to defrag the kv cache
if (cparams.defrag_thold > 0.0f) {
kv_self->defrag_sched(cparams.defrag_thold);
}

// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
Expand Down Expand Up @@ -2283,7 +2317,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {

// deprecated
void llama_kv_self_update(llama_context * ctx) {
ctx->kv_self_update();
ctx->kv_self_update(false);
}

enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
Expand Down Expand Up @@ -2538,13 +2572,8 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {

// deprecated
void llama_kv_self_defrag(llama_context * ctx) {
auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}

// force defrag
kv->defrag_sched(-1.0f);
ctx->kv_self_defrag_sched();
}

bool llama_kv_self_can_shift(const llama_context * ctx) {
Expand Down
6 changes: 5 additions & 1 deletion src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ struct llama_context {

// return true of the KV cache was updated
// TODO: remove
bool kv_self_update();
bool kv_self_update(bool optimize);
void kv_self_defrag_sched();

enum llama_pooling_type pooling_type() const;

Expand Down Expand Up @@ -231,6 +232,9 @@ struct llama_context {

std::unique_ptr<llama_memory_i> memory;

// TODO: temporary, until the llama_kv_self_defrag() API is removed
bool memory_force_optimize = false;

// decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
Expand Down
19 changes: 8 additions & 11 deletions src/llama-kv-cache-recurrent.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "llama-kv-cache-recurrent.h"

#include "llama-impl.h"
#include "llama-io.h"
#include "llama-batch.h"
#include "llama-model.h"

Expand Down Expand Up @@ -386,6 +387,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
}

llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
GGML_UNUSED(lctx);
GGML_UNUSED(optimize);

return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
}

bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
// simply remember the full state because it is very small for this type of cache
// TODO: optimize
Expand Down Expand Up @@ -419,17 +427,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
return success;
}

bool llama_kv_cache_recurrent::update(llama_context & lctx) {
GGML_UNUSED(lctx);
// noop
return false;
}

void llama_kv_cache_recurrent::defrag_sched(float thold) {
GGML_UNUSED(thold);
// noop
}

bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
const uint32_t n_tokens = ubatch.n_tokens;
const uint32_t n_seqs = ubatch.n_seqs;
Expand Down
4 changes: 1 addition & 3 deletions src/llama-kv-cache-recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {

llama_memory_state_ptr init_full() override;

bool update(llama_context & lctx) override;

void defrag_sched(float thold) override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;

bool prepare(const std::vector<llama_ubatch> & ubatches);

Expand Down
59 changes: 31 additions & 28 deletions src/llama-kv-cache-unified-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch

assert(heads_base.size() == heads_swa.size());

return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
return std::make_unique<llama_kv_cache_unified_iswa_state>(
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
}

llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
}

bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
bool res = false;

res = res | kv_base->update(lctx);
res = res | kv_swa ->update(lctx);

return res;
}

void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
kv_base->defrag_sched(thold);
kv_swa ->defrag_sched(thold);
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
}

bool llama_kv_cache_unified_iswa::get_can_shift() const {
Expand Down Expand Up @@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}

llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
llama_memory_status status,
llama_kv_cache_unified_iswa * kv) : status(status) {
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
state_base = kv->get_base()->init_full();
state_swa = kv->get_swa ()->init_full();

status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
}

llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
llama_kv_cache_unified_iswa * kv,
llama_context * lctx,
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
state_base = kv->get_base()->init_update(lctx, optimize);
state_swa = kv->get_swa ()->init_update(lctx, optimize);

status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
}

llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
llama_memory_status status,
llama_kv_cache_unified_iswa * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa,
std::vector<llama_ubatch> ubatches)
: status(status),
sbatch(std::move(sbatch)),
ubatches(std::move(ubatches)) {
// note: here we copy the ubatches. not sure if this is ideal
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
}
: status(LLAMA_MEMORY_STATUS_SUCCESS),
sbatch(std::move(sbatch)),
ubatches(std::move(ubatches)) {
// note: here we copy the ubatches. not sure if this is ideal
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));

status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
}

llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;

Expand Down Expand Up @@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {

const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

return ubatches[i_next];
}

const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

return state_base.get();
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
}

const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

return state_swa.get();
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
}
18 changes: 10 additions & 8 deletions src/llama-kv-cache-unified-iswa.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {

llama_memory_state_ptr init_full() override;

bool update(llama_context & lctx) override;

void defrag_sched(float thold) override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;

bool get_can_shift() const override;

Expand Down Expand Up @@ -86,12 +84,16 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {

// used to create a full-cache state
llama_kv_cache_unified_iswa_state(
llama_memory_status status,
llama_kv_cache_unified_iswa * kv);

// used to create an update state
llama_kv_cache_unified_iswa_state(
llama_kv_cache_unified_iswa * kv,
llama_context * lctx,
bool optimize);

// used to create a state from a batch
llama_kv_cache_unified_iswa_state(
llama_memory_status status,
llama_kv_cache_unified_iswa * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads_base,
Expand Down Expand Up @@ -120,7 +122,7 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
const llama_kv_cache_unified_state * get_swa() const;

private:
const llama_memory_status status;
llama_memory_status status;

//llama_kv_cache_unified_iswa * kv;

Expand All @@ -131,6 +133,6 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {

std::vector<llama_ubatch> ubatches;

std::unique_ptr<llama_kv_cache_unified_state> state_base;
std::unique_ptr<llama_kv_cache_unified_state> state_swa;
llama_memory_state_ptr state_base;
llama_memory_state_ptr state_swa;
};
Loading
Loading