Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : KV cache view API + better KV cache management #4170

Merged
merged 4 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <regex>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <cinttypes>
Expand Down Expand Up @@ -495,6 +496,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.chatml = true;
} else if (arg == "--infill") {
params.infill = true;
} else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
params.dump_kv_cache = true;
} else if (arg == "--multiline-input") {
params.multiline_input = true;
} else if (arg == "--simple-io") {
Expand Down Expand Up @@ -835,6 +838,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
#endif // GGML_USE_CUBLAS
#endif
printf(" --verbose-prompt print prompt before generation\n");
printf(" -dkvc, --dump-kv-cache\n");
printf(" verbose print of the KV cache\n");
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
Expand Down Expand Up @@ -1386,3 +1391,77 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
}

//
// KV cache utils
//

void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";

printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);

llama_kv_cache_view_cell * c_curr = view.cells;
llama_seq_id * cs_curr = view.cells_sequences;

for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
if (i % row_size == 0) {
printf("\n%5d: ", i);
}
int seq_count = 0;
for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j] >= 0) { seq_count++; }
}
putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
}

printf("\n=== Done dumping\n");
}

void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";

printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);

std::unordered_map<llama_seq_id, size_t> seqs;
llama_kv_cache_view_cell * c_curr = view.cells;
llama_seq_id * cs_curr = view.cells_sequences;

for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j] < 0) { continue; }
if (seqs.find(cs_curr[j]) == seqs.end()) {
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
seqs[cs_curr[j]] = seqs.size();
}
}
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
}

printf("=== Sequence legend: ");
for (const auto & it : seqs) {
printf("%zu=%d, ", it.second, it.first);
}
printf("'+'=other sequence ids");

c_curr = view.cells;
cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
if (i % row_size == 0) {
printf("\n%5d: ", i);
}
for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j] >= 0) {
const auto & it = seqs.find(cs_curr[j]);
putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
} else {
putchar('.');
}
}
putchar(' ');
}

printf("\n=== Done dumping\n");
}
11 changes: 11 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ struct gpt_params {
bool numa = false; // attempt optimizations that help on some NUMA systems
bool verbose_prompt = false; // print prompt tokens before generation
bool infill = false; // use infill mode
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes

// multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector
Expand Down Expand Up @@ -218,3 +219,13 @@ std::string get_sortable_timestamp();
void dump_non_result_info_yaml(
FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);

//
// KV cache utils
//

// Dump the KV cache view with the number of sequences per cell.
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);

// Dump the KV cache view showing individual sequences in each cell (long output).
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
9 changes: 9 additions & 0 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ int main(int argc, char ** argv) {
// insert new requests as soon as the previous one is done
const bool cont_batching = params.cont_batching;

const bool dump_kv_cache = params.dump_kv_cache;

#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("parallel", "log"));
LOG_TEE("Log start\n");
Expand Down Expand Up @@ -172,6 +174,8 @@ int main(int argc, char ** argv) {
int32_t n_total_gen = 0;
int32_t n_cache_miss = 0;

struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients);

const auto t_main_start = ggml_time_us();

LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
Expand Down Expand Up @@ -201,6 +205,11 @@ int main(int argc, char ** argv) {
LOG_TEE("Processing requests ...\n\n");

while (true) {
if (dump_kv_cache) {
llama_kv_cache_view_update(ctx, &kvc_view);
dump_kv_cache_view_seqs(kvc_view, 40);
}

llama_batch_clear(batch);

// decode any currently ongoing sequences
Expand Down
128 changes: 124 additions & 4 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,7 @@ struct llama_kv_cache {
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)

// computed before each graph build
uint32_t n = 0;
Expand Down Expand Up @@ -1504,6 +1505,7 @@ static bool llama_kv_cache_init(

cache.head = 0;
cache.size = n_ctx;
cache.used = 0;

cache.cells.clear();
cache.cells.resize(n_ctx);
Expand Down Expand Up @@ -1605,6 +1607,8 @@ static bool llama_kv_cache_find_slot(
}
}

cache.used += n_tokens;

return true;
}

Expand All @@ -1625,6 +1629,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
cache.cells[i].seq_id.clear();
}
cache.head = 0;
cache.used = 0;
}

static void llama_kv_cache_seq_rm(
Expand All @@ -1647,14 +1652,17 @@ static void llama_kv_cache_seq_rm(
continue;
}
if (cache.cells[i].seq_id.empty()) {
// keep count of the number of used cells
if (cache.cells[i].pos >= 0) cache.used--;

cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
}
}
}

// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
}

static void llama_kv_cache_seq_cp(
Expand All @@ -1680,6 +1688,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id

for (uint32_t i = 0; i < cache.size; ++i) {
if (!cache.cells[i].has_seq_id(seq_id)) {
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
Expand All @@ -1690,7 +1699,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
}

// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
}

static void llama_kv_cache_seq_shift(
Expand All @@ -1711,6 +1720,7 @@ static void llama_kv_cache_seq_shift(
cache.cells[i].delta += delta;

if (cache.cells[i].pos < 0) {
if (!cache.cells[i].seq_id.empty()) cache.used--;
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
Expand Down Expand Up @@ -5469,6 +5479,12 @@ static int llama_decode_internal(
batch.seq_id = seq_id_arr.data();
}

// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*n_tokens) {
kv_self.head = 0;
}

if (!llama_kv_cache_find_slot(kv_self, batch)) {
return 1;
}
Expand All @@ -5479,7 +5495,7 @@ static int llama_decode_internal(
//kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));

//printf("kv_self.n = %d\n", kv_self.n);
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);

ggml_allocr_reset(lctx.alloc);

Expand Down Expand Up @@ -8789,8 +8805,107 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
}
}

struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
struct llama_kv_cache_view result = {
/*.n_cells = */ 0,
/*.n_max_seq = */ n_max_seq,
/*.token_count = */ 0,
/*.used_cells = */ llama_get_kv_cache_used_cells(ctx),
/*.max_contiguous = */ 0,
/*.max_contiguous_idx = */ -1,
/*.cells = */ nullptr,
/*.cells_sequences = */ nullptr,
};
return result;
}

void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
if (view->cells != nullptr) {
free(view->cells);
view->cells = nullptr;
}
if (view->cells_sequences != nullptr) {
free(view->cells_sequences);
view->cells_sequences = nullptr;
}
}

void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
view->n_cells = int32_t(ctx->kv_self.size);
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
view->cells = (struct llama_kv_cache_view_cell *)p;
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
view->cells_sequences = (llama_seq_id *)p;
}

const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self.cells;
llama_kv_cache_view_cell * c_curr = view->cells;
llama_seq_id * cs_curr = view->cells_sequences;
int32_t used_cells = 0;
int32_t token_count = 0;
int32_t curr_contig_idx = -1;
uint32_t max_contig = 0;
int32_t max_contig_idx = -1;

for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
const size_t curr_size = kv_cells[i].seq_id.size();
token_count += curr_size;
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;

if (curr_size > 0) {
if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
max_contig = i - curr_contig_idx;
max_contig_idx = curr_contig_idx;
}
curr_contig_idx = -1;
} else if (curr_contig_idx < 0) {
curr_contig_idx = i;
}

int seq_idx = 0;
for (const llama_seq_id it : kv_cells[i].seq_id) {
if (seq_idx >= view->n_max_seq) {
break;
}
cs_curr[seq_idx] = it;
seq_idx++;
}
if (seq_idx != 0) {
used_cells++;
}
for (; seq_idx < view->n_max_seq; seq_idx++) {
cs_curr[seq_idx] = -1;
}
}
if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
max_contig_idx = curr_contig_idx;
max_contig = kv_cells.size() - curr_contig_idx;
}
view->max_contiguous = max_contig;
view->max_contiguous_idx = max_contig_idx;
view->token_count = token_count;
view->used_cells = used_cells;
if (uint32_t(used_cells) != ctx->kv_self.used) {
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
__func__, ctx->kv_self.used, used_cells);
}
}

int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
return ctx->kv_self.head;
int result = 0;

for (uint32_t i = 0; i < ctx->kv_self.size; i++) {
result += ctx->kv_self.cells[i].seq_id.size();
}

return result;
}

int llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
return ctx->kv_self.used;
}

void llama_kv_cache_clear(struct llama_context * ctx) {
Expand Down Expand Up @@ -8960,10 +9075,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const size_t kv_buf_size = kv_self.buf.size;
const uint32_t kv_head = kv_self.head;
const uint32_t kv_size = kv_self.size;
const uint32_t kv_used = kv_self.used;

data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
data_ctx->write(&kv_head, sizeof(kv_head));
data_ctx->write(&kv_size, sizeof(kv_size));
data_ctx->write(&kv_used, sizeof(kv_used));

if (kv_buf_size) {
const size_t elt_size = ggml_element_size(kv_self.k);
Expand Down Expand Up @@ -9086,10 +9203,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
size_t kv_buf_size;
uint32_t kv_head;
uint32_t kv_size;
uint32_t kv_used;

memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);

if (kv_buf_size) {
GGML_ASSERT(kv_self.buf.size == kv_buf_size);
Expand Down Expand Up @@ -9124,6 +9243,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {

ctx->kv_self.head = kv_head;
ctx->kv_self.size = kv_size;
ctx->kv_self.used = kv_used;

ctx->kv_self.cells.resize(kv_size);

Expand Down
Loading
Loading