Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : KV cache view API + better KV cache management #4170

Merged
merged 4 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,7 @@ struct llama_kv_cache {
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)

// computed before each graph build
uint32_t n = 0;
Expand Down Expand Up @@ -1504,6 +1505,7 @@ static bool llama_kv_cache_init(

cache.head = 0;
cache.size = n_ctx;
cache.used = 0;

cache.cells.clear();
cache.cells.resize(n_ctx);
Expand Down Expand Up @@ -1605,6 +1607,8 @@ static bool llama_kv_cache_find_slot(
}
}

cache.used += n_tokens;

return true;
}

Expand All @@ -1625,6 +1629,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
cache.cells[i].seq_id.clear();
}
cache.head = 0;
cache.used = 0;
}

static void llama_kv_cache_seq_rm(
Expand All @@ -1647,14 +1652,17 @@ static void llama_kv_cache_seq_rm(
continue;
}
if (cache.cells[i].seq_id.empty()) {
// keep count of the number of used cells
if (cache.cells[i].pos >= 0) cache.used--;

cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
}
}
}

// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
}

static void llama_kv_cache_seq_cp(
Expand All @@ -1680,6 +1688,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id

for (uint32_t i = 0; i < cache.size; ++i) {
if (!cache.cells[i].has_seq_id(seq_id)) {
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
Expand All @@ -1690,7 +1699,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
}

// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
}

static void llama_kv_cache_seq_shift(
Expand All @@ -1711,6 +1720,7 @@ static void llama_kv_cache_seq_shift(
cache.cells[i].delta += delta;

if (cache.cells[i].pos < 0) {
if (!cache.cells[i].seq_id.empty()) cache.used--;
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
Expand Down Expand Up @@ -5469,6 +5479,12 @@ static int llama_decode_internal(
batch.seq_id = seq_id_arr.data();
}

// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*n_tokens) {
kv_self.head = 0;
}

if (!llama_kv_cache_find_slot(kv_self, batch)) {
return 1;
}
Expand All @@ -5479,7 +5495,7 @@ static int llama_decode_internal(
//kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));

//printf("kv_self.n = %d\n", kv_self.n);
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);

ggml_allocr_reset(lctx.alloc);

Expand Down Expand Up @@ -8790,7 +8806,17 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
}

int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
return ctx->kv_self.head;
int result = 0;

for (uint32_t i = 0; i < ctx->kv_self.size; i++) {
result += ctx->kv_self.cells[i].seq_id.size();
}

return result;
}

int llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
return ctx->kv_self.used;
}

void llama_kv_cache_clear(struct llama_context * ctx) {
Expand Down Expand Up @@ -8960,10 +8986,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const size_t kv_buf_size = kv_self.buf.size;
const uint32_t kv_head = kv_self.head;
const uint32_t kv_size = kv_self.size;
const uint32_t kv_used = kv_self.used;

data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
data_ctx->write(&kv_head, sizeof(kv_head));
data_ctx->write(&kv_size, sizeof(kv_size));
data_ctx->write(&kv_used, sizeof(kv_used));

if (kv_buf_size) {
const size_t elt_size = ggml_element_size(kv_self.k);
Expand Down Expand Up @@ -9086,10 +9114,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
size_t kv_buf_size;
uint32_t kv_head;
uint32_t kv_size;
uint32_t kv_used;

memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);

if (kv_buf_size) {
GGML_ASSERT(kv_self.buf.size == kv_buf_size);
Expand Down Expand Up @@ -9124,6 +9154,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {

ctx->kv_self.head = kv_head;
ctx->kv_self.size = kv_size;
ctx->kv_self.used = kv_used;

ctx->kv_self.cells.resize(kv_size);

Expand Down
9 changes: 6 additions & 3 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,12 @@ extern "C" {
// KV cache
//

// Returns the number of tokens in the KV cache
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);

// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int llama_get_kv_cache_used_cells(const struct llama_context * ctx);

// Clear the KV cache
LLAMA_API void llama_kv_cache_clear(
Expand Down
Loading