From 2e1ddf422bcfd796fa2fcda13bae2074066f6a95 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 25 Feb 2024 09:59:53 -0500 Subject: [PATCH] mamba : make the server and parallel examples work with whole sequences A seq_id is dedicated to the system prompt in both cases. * llama : make llama_kv_cache_seq_rm return whether it succeeded or not --- examples/parallel/parallel.cpp | 20 ++++++++++------ examples/server/server.cpp | 43 ++++++++++++++++++++++++---------- llama.cpp | 38 ++++++++++++++++++++---------- llama.h | 2 +- 4 files changed, 70 insertions(+), 33 deletions(-) diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7d11fcd5930806..a2ef0fb039c3f4 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -107,6 +107,9 @@ int main(int argc, char ** argv) { // number of simultaneous "clients" to simulate const int32_t n_clients = params.n_parallel; + // dedicate one sequence to the system prompt + params.n_parallel += 1; + // requests to simulate const int32_t n_seq = params.n_sequences; @@ -196,8 +199,8 @@ int main(int argc, char ** argv) { } // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i < n_clients; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system); + for (int32_t i = 1; i <= n_clients; ++i) { + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("\n"); @@ -221,15 +224,17 @@ int main(int argc, char ** argv) { client.i_batch = batch.n_tokens; - llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true); + llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true); client.n_decoded += 1; } if (batch.n_tokens == 0) { // all sequences have ended - clear the entire KV cache - for (int i = 0; i < n_clients; ++i) { - llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1); + for (int i = 1; i <= n_clients; ++i) { + llama_kv_cache_seq_rm(ctx, i, -1, -1); + // but keep the system prompt + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("%s: clearing the KV cache\n", __func__); @@ -255,7 +260,7 @@ int main(int argc, char ** argv) { tokens_prompt = ::llama_tokenize(ctx, client.prompt, false); for (size_t i = 0; i < tokens_prompt.size(); ++i) { - llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false); + llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false); } // extract the logits only for the last token @@ -366,7 +371,8 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1); + llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8aadc95a9728f7..17dfdde0b6e132 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -438,12 +438,16 @@ struct llama_server_context return false; } - if (params.n_ctx < 2048) { // request larger context for the image embedding + if (params.n_ctx != 0 && params.n_ctx < 2048) { // request larger context for the image embedding params.n_ctx = 2048; } } + // dedicate one sequence to the system prompt + params.n_parallel += 1; + std::tie(model, ctx) = llama_init_from_gpt_params(params); + params.n_parallel -= 1; // but be sneaky about it if (model == nullptr) { LOG_ERROR("unable to load model", {{"model", params.model}}); @@ -923,9 +927,9 @@ struct llama_server_context } // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i < params.n_parallel; ++i) + for (int32_t i = 1; i <= params.n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size()); + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } } @@ -1400,7 +1404,7 @@ struct llama_server_context std::vector append_tokens = tokenize(json_prompt, false); // has next image for (int i = 0; i < (int) append_tokens.size(); ++i) { - llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true); + llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id + 1 }, true); slot.n_past += 1; } } @@ -1636,8 +1640,8 @@ struct llama_server_context {"n_system_tokens", system_tokens.size()}, {"n_cache_tokens", slot.cache_tokens.size()} }); - llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { @@ -1689,7 +1693,7 @@ struct llama_server_context // TODO: we always have to take into account the "system_tokens" // this is not great and needs to be improved somehow - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true); slot.n_past += 1; } @@ -1852,13 +1856,28 @@ struct llama_server_context } } + // keep only the common part int p0 = (int) system_tokens.size() + slot.n_past; LOG_INFO("kv cache rm [p0, end)", { { "slot_id", slot.id }, { "task_id", slot.task_id }, { "p0", p0 } }); - llama_kv_cache_seq_rm(ctx, slot.id, p0, -1); + if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { + // could not partially delete (likely using a non-Transformer model) + // TODO: logging + llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1); + + // there is no common part left (except for the system prompt) + // TODO: maybe find a way to refactor this to reuse the !cache_prompt case above + slot.n_past = 0; + slot.n_past_se = 0; + slot.ga_i = 0; + slot.num_prompt_tokens_processed = slot.num_prompt_tokens; + // TODO: is the system prompt ever in the sampling context? + llama_sampling_reset(slot.ctx_sampling); + } LOG_VERBOSE("prompt ingested", { {"n_past", slot.n_past}, @@ -1887,7 +1906,7 @@ struct llama_server_context ga_i += ga_w/ga_n; } } - llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false); slot_npast++; } @@ -1941,9 +1960,9 @@ struct llama_server_context LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); + llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); + llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); + llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); slot.n_past_se -= bd; diff --git a/llama.cpp b/llama.cpp index 9992e63724fa05..67a9d21eff51da 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2231,7 +2231,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) { cache.used = 0; } -static void llama_kv_cache_seq_rm( +static bool llama_kv_cache_seq_rm( struct llama_kv_cache & cache, llama_seq_id seq_id, llama_pos p0, @@ -2241,11 +2241,23 @@ static void llama_kv_cache_seq_rm( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + // models like Mamba can't have a state partially erased if (cache.unlimited) { - // can only remove whole sequences for models like Mamba - GGML_ASSERT(p0 == 0); - GGML_ASSERT((uint32_t)seq_id < cache.size); - GGML_ASSERT(cache.cells[seq_id].pos < p1); + if (seq_id >= (int64_t) cache.size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + // partial intersection is invalid + if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) { + return false; + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + } } for (uint32_t i = 0; i < cache.size; ++i) { @@ -2269,6 +2281,8 @@ static void llama_kv_cache_seq_rm( // If we freed up a slot, set head to it so searching can start there. if (new_head != cache.size && new_head < cache.head) cache.head = new_head; + + return true; } static void llama_kv_cache_seq_cp( @@ -12323,13 +12337,11 @@ struct llama_context * llama_new_context_with_model( // Mamba only needs a constant number of KV cache cells per sequence if (model->arch == LLM_ARCH_MAMBA) { - // Mamba needs as many KV cells as there are sequences kept at any time - // The extra cell allows dedicating a sequence id to the system prompt - // TODO: find a better way to get the max number of parallel sequences - kv_size = params.n_parallel + 1; + // Mamba needs at least as many KV cells as there are sequences kept at any time + kv_size = std::max((uint32_t) 1, params.n_parallel); // it's probably best to keep as much precision as possible for the states - type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state - type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_state + type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states + type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states } GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); @@ -12839,8 +12851,8 @@ void llama_kv_cache_clear(struct llama_context * ctx) { llama_kv_cache_clear(ctx->kv_self); } -void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); +bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); } void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { diff --git a/llama.h b/llama.h index 63e1ba5e8d6cd1..2daacb4828801f 100644 --- a/llama.h +++ b/llama.h @@ -504,7 +504,7 @@ extern "C" { // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_rm( + LLAMA_API bool llama_kv_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0,