From 6289ed6b1b969888c93305f08d3dcef03bf6a478 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 18 Sep 2023 18:00:25 +0300 Subject: [PATCH] llama : add llama_kv_cache_shift_seq + no more context swaps --- common/common.cpp | 1 + examples/main/main.cpp | 21 ++++++++------ llama.cpp | 65 ++++++++++++++++++++++++++++++------------ llama.h | 6 ++-- 4 files changed, 64 insertions(+), 29 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index b638efe9ebae86..fd50891f8cdbc6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -781,6 +781,7 @@ std::tuple llama_init_from_gpt_par std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads); + llama_kv_cache_keep_seq(lctx, -1); llama_reset_timings(lctx); } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3e78fdaa05459c..ed2d9e2f706c5a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -499,17 +499,22 @@ int main(int argc, char ** argv) { break; } - const int n_left = n_past - params.n_keep; - LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep); + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; - // always keep the first token - BOS - n_past = std::max(1, params.n_keep); - n_past_guidance = std::max(1, params.n_keep + guidance_offset); + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); - LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + llama_kv_cache_rm_seq (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_shift_seq(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + + n_past -= n_discard; - // insert n_left/2 tokens at the start of embd from last_tokens - embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size()); + if (ctx_guidance) { + n_past_guidance -= n_discard; + } + + LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); diff --git a/llama.cpp b/llama.cpp index 6634f753fd165e..b223249e62382b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1007,7 +1007,8 @@ struct llama_layer { }; struct llama_kv_cell { - llama_pos pos = -1; + llama_pos pos = -1; + llama_pos delta = 0; std::set seq_id; @@ -1018,7 +1019,7 @@ struct llama_kv_cell { // ring-buffer of cached KV data struct llama_kv_cache { - bool is_roped = false; + bool has_shift = false; uint32_t head = 0; uint32_t size = 0; @@ -1333,9 +1334,13 @@ void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t } } -void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) { +void llama_kv_cache_rm_seq( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id)) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.cells[i].seq_id.erase(seq_id); if (cache.cells[i].seq_id.empty()) { cache.cells[i].pos = -1; @@ -1353,18 +1358,22 @@ void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) } } -void llama_kv_cache_shift( - struct llama_context & ctx, +void llama_kv_cache_shift_seq( + struct llama_kv_cache & cache, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - auto & hparams = ctx.model.hparams; - auto & cache = ctx.kv_self; - for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.cells[i].pos += delta; + if (cache.cells[i].pos < 0) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } else { + cache.has_shift = true; + cache.cells[i].delta = delta; + } } } } @@ -2595,6 +2604,8 @@ static struct ggml_cgraph * llm_build_llama( const int32_t n_tokens = batch.n_tokens; const int32_t n_kv = llama_kv_cache_cell_max(kv_self); + const bool do_rope_shift = kv_self.has_shift; + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -2698,6 +2709,16 @@ static struct ggml_cgraph * llm_build_llama( } } + // K_shift + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + ggml_allocr_alloc(lctx.alloc, K_shift); + if (!ggml_allocr_is_measure(lctx.alloc) && do_rope_shift) { + int * data = (int *) K_shift->data; + for (int i = 0; i < n_ctx; ++i) { + data[i] = kv_self.cells[i].delta; + } + } + for (int il = 0; il < n_layer; ++il) { ggml_format_name(inpL, "layer_inp_%d", il); @@ -2723,6 +2744,17 @@ static struct ggml_cgraph * llm_build_llama( ggml_set_name(cur, "attention_norm_0"); } + if (do_rope_shift) { + ggml_build_forward_expand(gf, + ggml_rope_custom_inplace(ctx0, + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_head_kv, n_ctx, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), + K_shift, n_embd_head, 0, 0, freq_base, freq_scale)); + } + // self-attention { // compute Q and K and RoPE them @@ -4033,7 +4065,8 @@ static bool llama_eval_internal( #endif // update the kv ring buffer - lctx.kv_self.head += n_tokens; + lctx.kv_self.head += n_tokens; + lctx.kv_self.has_shift = false; #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) @@ -6562,10 +6595,6 @@ struct llama_context * llama_new_context_with_model( return nullptr; } - if (model->arch == LLM_ARCH_LLAMA) { - ctx->kv_self.is_roped = true; - } - { const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v); LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); @@ -6803,16 +6832,16 @@ void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1 llama_kv_cache_rm_tokens(ctx->kv_self, c0, c1); } -void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id) { - llama_kv_cache_rm_seq(ctx->kv_self, seq_id); +void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + llama_kv_cache_rm_seq(ctx->kv_self, seq_id, p0, p1); } void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) { llama_kv_cache_keep_seq(ctx->kv_self, seq_id); } -void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - llama_kv_cache_shift(*ctx, seq_id, p0, p1, delta); +void llama_kv_cache_shift_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + llama_kv_cache_shift_seq(ctx->kv_self, seq_id, p0, p1, delta); } // Returns the *maximum* size of the state diff --git a/llama.h b/llama.h index ec05fa6ea0eec0..4a5f2e3bf70cfd 100644 --- a/llama.h +++ b/llama.h @@ -324,15 +324,15 @@ extern "C" { // Remove all tokens data of cells in [c0, c1) LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1); - // Removes all tokens that belong to the specified sequence - LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id); + // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); // Removes all tokens that do not belong to the specified sequence LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly - LLAMA_API void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); + LLAMA_API void llama_kv_cache_shift_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); // // State / sessions