44#include " llama-batch.h"
55#include " llama-cparams.h"
66
7- #include " llama-kv-cache.h"
8- #include " llama-kv-cache-iswa.h"
7+ #include " llama-kv-cache-unified .h"
8+ #include " llama-kv-cache-unified- iswa.h"
99#include " llama-memory-hybrid.h"
1010#include " llama-memory-recurrent.h"
1111
@@ -277,7 +277,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
277277 for (int s = 0 ; s < ubatch->n_seq_id [i0]; ++s) {
278278 const llama_seq_id s0 = ubatch->seq_id [i0][0 ];
279279
280- // TODO: reimplement this like in llama_kv_cache
280+ // TODO: reimplement this like in llama_kv_cache_unified
281281 if (s0 == s1 && (!cparams.causal_attn || ubatch->pos [i0] <= ubatch->pos [i1])) {
282282 if (hparams.use_alibi ) {
283283 f = -std::abs (ubatch->pos [i0] - ubatch->pos [i1]);
@@ -294,15 +294,15 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
294294 }
295295}
296296
297- void llm_graph_input_attn_kv ::set_input (const llama_ubatch * ubatch) {
297+ void llm_graph_input_attn_kv_unified ::set_input (const llama_ubatch * ubatch) {
298298 mctx->set_input_k_idxs (self_k_idxs, ubatch);
299299 mctx->set_input_v_idxs (self_v_idxs, ubatch);
300300
301301 mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
302302}
303303
304- bool llm_graph_input_attn_kv ::can_reuse (const llm_graph_params & params) {
305- const auto * mctx = static_cast <const llama_kv_cache_context *>(params.mctx );
304+ bool llm_graph_input_attn_kv_unified ::can_reuse (const llm_graph_params & params) {
305+ const auto * mctx = static_cast <const llama_kv_cache_unified_context *>(params.mctx );
306306
307307 this ->mctx = mctx;
308308
@@ -319,7 +319,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
319319 return res;
320320}
321321
322- void llm_graph_input_attn_kv_iswa ::set_input (const llama_ubatch * ubatch) {
322+ void llm_graph_input_attn_kv_unified_iswa ::set_input (const llama_ubatch * ubatch) {
323323 mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
324324 mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
325325
@@ -331,8 +331,8 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
331331 mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
332332}
333333
334- bool llm_graph_input_attn_kv_iswa ::can_reuse (const llm_graph_params & params) {
335- const auto * mctx = static_cast <const llama_kv_cache_iswa_context *>(params.mctx );
334+ bool llm_graph_input_attn_kv_unified_iswa ::can_reuse (const llm_graph_params & params) {
335+ const auto * mctx = static_cast <const llama_kv_cache_unified_iswa_context *>(params.mctx );
336336
337337 this ->mctx = mctx;
338338
@@ -1186,7 +1186,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
11861186}
11871187
11881188ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec () const {
1189- const auto * mctx_cur = static_cast <const llama_kv_cache_context *>(mctx);
1189+ const auto * mctx_cur = static_cast <const llama_kv_cache_unified_context *>(mctx);
11901190
11911191 auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
11921192
@@ -1399,17 +1399,17 @@ ggml_tensor * llm_graph_context::build_attn(
13991399 return cur;
14001400}
14011401
1402- static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl (
1402+ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl (
14031403 ggml_context * ctx0,
14041404 const llama_ubatch & ubatch,
14051405 const llama_hparams & hparams,
14061406 const llama_cparams & cparams,
1407- const llama_kv_cache_context * mctx_cur) {
1407+ const llama_kv_cache_unified_context * mctx_cur) {
14081408
1409- auto inp = std::make_unique<llm_graph_input_attn_kv >(hparams, cparams, mctx_cur);
1409+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified >(hparams, cparams, mctx_cur);
14101410
14111411 {
1412- GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_iswa for SWA" );
1412+ GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa for SWA" );
14131413
14141414 const auto n_kv = mctx_cur->get_n_kv ();
14151415 const auto n_tokens = ubatch.n_tokens ;
@@ -1427,16 +1427,16 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
14271427 return inp;
14281428}
14291429
1430- llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv () const {
1431- const auto * mctx_cur = static_cast <const llama_kv_cache_context *>(mctx);
1430+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified () const {
1431+ const auto * mctx_cur = static_cast <const llama_kv_cache_unified_context *>(mctx);
14321432
1433- auto inp = build_attn_inp_kv_impl (ctx0, ubatch, hparams, cparams, mctx_cur);
1433+ auto inp = build_attn_inp_kv_unified_impl (ctx0, ubatch, hparams, cparams, mctx_cur);
14341434
1435- return (llm_graph_input_attn_kv *) res->add_input (std::move (inp));
1435+ return (llm_graph_input_attn_kv_unified *) res->add_input (std::move (inp));
14361436}
14371437
14381438ggml_tensor * llm_graph_context::build_attn (
1439- llm_graph_input_attn_kv * inp,
1439+ llm_graph_input_attn_kv_unified * inp,
14401440 ggml_tensor * wo,
14411441 ggml_tensor * wo_b,
14421442 ggml_tensor * q_cur,
@@ -1488,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_attn(
14881488}
14891489
14901490ggml_tensor * llm_graph_context::build_attn (
1491- llm_graph_input_attn_kv_iswa * inp,
1491+ llm_graph_input_attn_kv_unified_iswa * inp,
14921492 ggml_tensor * wo,
14931493 ggml_tensor * wo_b,
14941494 ggml_tensor * q_cur,
@@ -1513,7 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
15131513}
15141514
15151515ggml_tensor * llm_graph_context::build_attn_with_sinks (
1516- llm_graph_input_attn_kv_iswa * inp,
1516+ llm_graph_input_attn_kv_unified_iswa * inp,
15171517 ggml_tensor * wo,
15181518 ggml_tensor * wo_b,
15191519 ggml_tensor * q_cur,
@@ -1636,10 +1636,10 @@ ggml_tensor * llm_graph_context::build_attn(
16361636// TODO: maybe separate the inner implementation into a separate function
16371637// like with the non-sliding window equivalent
16381638// once sliding-window hybrid caches are a thing.
1639- llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa () const {
1640- const auto * mctx_cur = static_cast <const llama_kv_cache_iswa_context *>(mctx);
1639+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
1640+ const auto * mctx_cur = static_cast <const llama_kv_cache_unified_iswa_context *>(mctx);
16411641
1642- auto inp = std::make_unique<llm_graph_input_attn_kv_iswa >(hparams, cparams, mctx_cur);
1642+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa >(hparams, cparams, mctx_cur);
16431643
16441644 const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
16451645
@@ -1656,7 +1656,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
16561656 }
16571657
16581658 {
1659- GGML_ASSERT (hparams.swa_type != LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache for non-SWA" );
1659+ GGML_ASSERT (hparams.swa_type != LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified for non-SWA" );
16601660
16611661 const auto n_kv = mctx_cur->get_swa ()->get_n_kv ();
16621662
@@ -1669,7 +1669,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
16691669 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
16701670 }
16711671
1672- return (llm_graph_input_attn_kv_iswa *) res->add_input (std::move (inp));
1672+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input (std::move (inp));
16731673}
16741674
16751675ggml_tensor * llm_graph_context::build_rs (
@@ -1792,7 +1792,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
17921792 const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx);
17931793
17941794 auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr ());
1795- auto inp_attn = build_attn_inp_kv_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
1795+ auto inp_attn = build_attn_inp_kv_unified_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
17961796
17971797 auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move (inp_attn), std::move (inp_rs), mctx_cur);
17981798
0 commit comments