diff --git a/llama.cpp b/llama.cpp index 67a9d21eff51da..37078e994b2d5c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1743,6 +1743,7 @@ struct llama_layer { struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; + int32_t src = 0; // used by recurrent state models to copy states std::set seq_id; @@ -1763,6 +1764,7 @@ struct llama_kv_cell { struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; + bool do_copy = false; // with Mamba, a cell can hold the state for more than one past token bool unlimited = false; @@ -2001,7 +2003,8 @@ struct llama_context { struct ggml_tensor * inp_K_shift; // I32 [kv_size] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] - struct ggml_tensor * inp_s_mask; // F32 [kv_size] (only used by constant state models like Mamba) + struct ggml_tensor * inp_s_copy; // I32 [kv_size] + struct ggml_tensor * inp_s_mask; // F32 [kv_size] struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch] #ifdef GGML_USE_MPI @@ -2043,9 +2046,9 @@ static bool llama_kv_cache_init( if (cache.unlimited) { for (uint32_t i = 0; i < cache.size; ++i) { - cache.cells[i].delta = i; + cache.cells[i].src = i; } - } // else, delta is already initialized to zero + } #ifdef GGML_USE_CLBLAST offload = false; @@ -2296,19 +2299,20 @@ static void llama_kv_cache_seq_cp( if (cache.unlimited) { if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { - seq_id_src = cache.cells[seq_id_src].delta; + seq_id_src = cache.cells[seq_id_src].src; GGML_ASSERT((uint32_t) seq_id_src < cache.size); // intent to "copy from" // supports copy chains thanks to taking the source of the source - cache.cells[seq_id_dst].delta = seq_id_src; + cache.cells[seq_id_dst].src = seq_id_src; - // prevent the destination from getting cleared if the source is not empty + // preserve the "keep or clear" status of the copied sequence if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) { cache.cells[seq_id_dst].seq_id.insert(seq_id_dst); + } else { + cache.cells[seq_id_dst].seq_id.erase(seq_id_dst); } - // repurposed as a "need copy" flag - // (shifting can't be done anyway for this kind of KV cache) - cache.has_shift = true; + + cache.do_copy = true; cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos; } @@ -5335,21 +5339,7 @@ struct llm_build_context { struct ggml_cgraph * build_k_shift() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); - // TODO: do this in a another graph with a dedicated input tensor - if (kv_self.unlimited) { - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size); - ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size); - - conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_K_shift); - ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_K_shift); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il])); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il])); - } - - return gf; - } + GGML_ASSERT(kv_self.size == n_ctx); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * tmp = @@ -5369,6 +5359,25 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_s_copy() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size); + ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size); + + conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy); + ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy); + + // TODO: name the intermediate tensors with cb() + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il])); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il])); + } + + return gf; + } + struct ggml_cgraph * build_defrag(const std::vector & ids) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -7985,6 +7994,23 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { return result; } +static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { + llama_batch dummy; + dummy.n_tokens = 0; + + llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; + + struct llm_build_context llm(lctx, dummy, cb, false); + + llm.init(); + + struct ggml_cgraph * result = llm.build_s_copy(); + + llm.free(); + + return result; +} + static struct ggml_cgraph * llama_build_graph( llama_context & lctx, const llama_batch & batch, @@ -8120,6 +8146,18 @@ static void llama_set_k_shift(llama_context & lctx) { } } +static void llama_set_s_copy(llama_context & lctx) { + const int64_t kv_size = lctx.kv_self.size; + + assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); + + int32_t * data = (int32_t *) lctx.inp_s_copy->data; + + for (int i = 0; i < kv_size; ++i) { + data[i] = lctx.kv_self.cells[i].src; + } +} + static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // // set input data @@ -8234,7 +8272,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } if (kv_self.unlimited) { - const int64_t n_kv = kv_self.n; + const int64_t n_kv = kv_self.n; { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); @@ -8242,9 +8280,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // states which are not affected by the current batch are left untouched for (int i = 0; i < n_kv; ++i) { - llama_seq_id seq_id = i + lctx.kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; - bool has_self_seq = kv_cell.has_seq_id(seq_id); + llama_seq_id seq_id = i + lctx.kv_self.head; + llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; + bool has_self_seq = kv_cell.has_seq_id(seq_id); data[i] = (float) has_self_seq; @@ -8731,7 +8769,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { static void llama_kv_cache_update_internal(struct llama_context & lctx) { // apply K-shift if needed - if ((lctx.kv_self.unlimited || lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) && lctx.kv_self.has_shift) { + if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) { llama_set_k_shift(lctx); { @@ -8746,7 +8784,27 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { kv_self.has_shift = false; for (uint32_t i = 0; i < kv_self.size; ++i) { - kv_self.cells[i].delta = kv_self.unlimited ? i : 0; + kv_self.cells[i].delta = 0; + } + } + } + + if (lctx.kv_self.unlimited && lctx.kv_self.do_copy) { + llama_set_s_copy(lctx); + + { + ggml_cgraph * gf = llama_build_graph_s_copy(lctx); + + llama_graph_compute(lctx, gf, lctx.cparams.n_threads); + } + + { + auto & kv_self = lctx.kv_self; + + kv_self.do_copy = false; + + for (uint32_t i = 0; i < kv_self.size; ++i) { + kv_self.cells[i].src = i; } } } @@ -12458,7 +12516,7 @@ struct llama_context * llama_new_context_with_model( // graph inputs { ggml_init_params init_params = { - /* .mem_size */ ggml_tensor_overhead()*(8 + 2*(ctx->kv_self.unlimited)), + /* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.unlimited)), /* .mem_buffer */ nullptr, /* .no_alloc */ true, }; @@ -12473,6 +12531,7 @@ struct llama_context * llama_new_context_with_model( ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch); ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); if (ctx->kv_self.unlimited) { + ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size); ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size); ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch); } @@ -12486,6 +12545,7 @@ struct llama_context * llama_new_context_with_model( ggml_set_name(ctx->inp_mean, "inp_mean"); ggml_set_name(ctx->inp_cls, "inp_cls"); if (ctx->kv_self.unlimited) { + ggml_set_name(ctx->inp_s_copy, "inp_s_copy"); ggml_set_name(ctx->inp_s_mask, "inp_s_mask"); ggml_set_name(ctx->inp_s_seq, "inp_s_seq"); }