Skip to content

Commit 8738fa1

Browse files
committed
Revert "kv-cache : drop the "unified" prefix (ggml-org#15467)"
1 parent c07a0b5 commit 8738fa1

15 files changed

+360
-346
lines changed

include/llama.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ extern "C" {
6767

6868
typedef struct llama_memory_i * llama_memory_t;
6969

70+
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
71+
7072
typedef int32_t llama_pos;
7173
typedef int32_t llama_token;
7274
typedef int32_t llama_seq_id;
@@ -470,6 +472,8 @@ extern "C" {
470472
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
471473
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
472474

475+
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
476+
473477
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
474478
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
475479

src/llama-context.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,6 +2338,11 @@ const llama_model * llama_get_model(const llama_context * ctx) {
23382338
return &ctx->get_model();
23392339
}
23402340

2341+
// deprecated
2342+
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2343+
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
2344+
}
2345+
23412346
// deprecated
23422347
void llama_kv_self_update(llama_context * ctx) {
23432348
ctx->kv_self_update(false);

src/llama-graph.cpp

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
#include "llama-batch.h"
55
#include "llama-cparams.h"
66

7-
#include "llama-kv-cache.h"
8-
#include "llama-kv-cache-iswa.h"
7+
#include "llama-kv-cache-unified.h"
8+
#include "llama-kv-cache-unified-iswa.h"
99
#include "llama-memory-hybrid.h"
1010
#include "llama-memory-recurrent.h"
1111

@@ -277,7 +277,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
277277
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
278278
const llama_seq_id s0 = ubatch->seq_id[i0][0];
279279

280-
// TODO: reimplement this like in llama_kv_cache
280+
// TODO: reimplement this like in llama_kv_cache_unified
281281
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
282282
if (hparams.use_alibi) {
283283
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
@@ -294,15 +294,15 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
294294
}
295295
}
296296

297-
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
297+
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
298298
mctx->set_input_k_idxs(self_k_idxs, ubatch);
299299
mctx->set_input_v_idxs(self_v_idxs, ubatch);
300300

301301
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
302302
}
303303

304-
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
305-
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
304+
bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
305+
const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
306306

307307
this->mctx = mctx;
308308

@@ -319,7 +319,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
319319
return res;
320320
}
321321

322-
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
322+
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
323323
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
324324
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
325325

@@ -331,8 +331,8 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
331331
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
332332
}
333333

334-
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
335-
const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
334+
bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
335+
const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
336336

337337
this->mctx = mctx;
338338

@@ -1186,7 +1186,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
11861186
}
11871187

11881188
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1189-
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1189+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
11901190

11911191
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
11921192

@@ -1399,17 +1399,17 @@ ggml_tensor * llm_graph_context::build_attn(
13991399
return cur;
14001400
}
14011401

1402-
static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1402+
static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
14031403
ggml_context * ctx0,
14041404
const llama_ubatch & ubatch,
14051405
const llama_hparams & hparams,
14061406
const llama_cparams & cparams,
1407-
const llama_kv_cache_context * mctx_cur) {
1407+
const llama_kv_cache_unified_context * mctx_cur) {
14081408

1409-
auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
1409+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
14101410

14111411
{
1412-
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1412+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
14131413

14141414
const auto n_kv = mctx_cur->get_n_kv();
14151415
const auto n_tokens = ubatch.n_tokens;
@@ -1427,16 +1427,16 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
14271427
return inp;
14281428
}
14291429

1430-
llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
1431-
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1430+
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1431+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
14321432

1433-
auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1433+
auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
14341434

1435-
return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
1435+
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
14361436
}
14371437

14381438
ggml_tensor * llm_graph_context::build_attn(
1439-
llm_graph_input_attn_kv * inp,
1439+
llm_graph_input_attn_kv_unified * inp,
14401440
ggml_tensor * wo,
14411441
ggml_tensor * wo_b,
14421442
ggml_tensor * q_cur,
@@ -1488,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_attn(
14881488
}
14891489

14901490
ggml_tensor * llm_graph_context::build_attn(
1491-
llm_graph_input_attn_kv_iswa * inp,
1491+
llm_graph_input_attn_kv_unified_iswa * inp,
14921492
ggml_tensor * wo,
14931493
ggml_tensor * wo_b,
14941494
ggml_tensor * q_cur,
@@ -1513,7 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
15131513
}
15141514

15151515
ggml_tensor * llm_graph_context::build_attn_with_sinks(
1516-
llm_graph_input_attn_kv_iswa * inp,
1516+
llm_graph_input_attn_kv_unified_iswa * inp,
15171517
ggml_tensor * wo,
15181518
ggml_tensor * wo_b,
15191519
ggml_tensor * q_cur,
@@ -1636,10 +1636,10 @@ ggml_tensor * llm_graph_context::build_attn(
16361636
// TODO: maybe separate the inner implementation into a separate function
16371637
// like with the non-sliding window equivalent
16381638
// once sliding-window hybrid caches are a thing.
1639-
llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
1640-
const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
1639+
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1640+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
16411641

1642-
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
1642+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
16431643

16441644
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
16451645

@@ -1656,7 +1656,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
16561656
}
16571657

16581658
{
1659-
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
1659+
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
16601660

16611661
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
16621662

@@ -1669,7 +1669,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
16691669
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
16701670
}
16711671

1672-
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
1672+
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
16731673
}
16741674

16751675
ggml_tensor * llm_graph_context::build_rs(
@@ -1792,7 +1792,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
17921792
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
17931793

17941794
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
1795-
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1795+
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
17961796

17971797
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
17981798

src/llama-graph.h

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ struct llama_cparams;
1919

2020
struct llama_memory_context_i;
2121

22-
class llama_kv_cache_context;
23-
class llama_kv_cache_iswa_context;
22+
class llama_kv_cache_unified_context;
23+
class llama_kv_cache_unified_iswa_context;
2424
class llama_memory_recurrent_context;
2525
class llama_memory_hybrid_context;
2626

@@ -152,7 +152,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
152152
public:
153153
llm_graph_input_pos_bucket_kv(
154154
const llama_hparams & hparams,
155-
const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
155+
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
156156
virtual ~llm_graph_input_pos_bucket_kv() = default;
157157

158158
void set_input(const llama_ubatch * ubatch) override;
@@ -161,7 +161,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
161161

162162
const llama_hparams hparams;
163163

164-
const llama_kv_cache_context * mctx;
164+
const llama_kv_cache_unified_context * mctx;
165165
};
166166

167167
class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -257,17 +257,17 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i {
257257
const llama_cparams cparams;
258258
};
259259

260-
class llm_graph_input_attn_kv : public llm_graph_input_i {
260+
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
261261
public:
262-
llm_graph_input_attn_kv(
262+
llm_graph_input_attn_kv_unified(
263263
const llama_hparams & hparams,
264264
const llama_cparams & cparams,
265-
const llama_kv_cache_context * mctx) :
265+
const llama_kv_cache_unified_context * mctx) :
266266
hparams(hparams),
267267
cparams(cparams),
268268
mctx(mctx) {
269269
}
270-
~llm_graph_input_attn_kv() = default;
270+
~llm_graph_input_attn_kv_unified() = default;
271271

272272
void set_input(const llama_ubatch * ubatch) override;
273273

@@ -290,20 +290,20 @@ class llm_graph_input_attn_kv : public llm_graph_input_i {
290290
const llama_hparams hparams;
291291
const llama_cparams cparams;
292292

293-
const llama_kv_cache_context * mctx;
293+
const llama_kv_cache_unified_context * mctx;
294294
};
295295

296-
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
296+
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
297297
public:
298-
llm_graph_input_attn_kv_iswa(
298+
llm_graph_input_attn_kv_unified_iswa(
299299
const llama_hparams & hparams,
300300
const llama_cparams & cparams,
301-
const llama_kv_cache_iswa_context * mctx) :
301+
const llama_kv_cache_unified_iswa_context * mctx) :
302302
hparams(hparams),
303303
cparams(cparams),
304304
mctx(mctx) {
305305
}
306-
~llm_graph_input_attn_kv_iswa() = default;
306+
~llm_graph_input_attn_kv_unified_iswa() = default;
307307

308308
void set_input(const llama_ubatch * ubatch) override;
309309

@@ -330,7 +330,7 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
330330
const llama_hparams hparams;
331331
const llama_cparams cparams;
332332

333-
const llama_kv_cache_iswa_context * mctx;
333+
const llama_kv_cache_unified_iswa_context * mctx;
334334
};
335335

336336
class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -351,7 +351,7 @@ class llm_graph_input_attn_cross : public llm_graph_input_i {
351351
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
352352
public:
353353
llm_graph_input_mem_hybrid(
354-
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
354+
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
355355
std::unique_ptr<llm_graph_input_rs> inp_rs,
356356
const llama_memory_hybrid_context * mctx) :
357357
inp_attn(std::move(inp_attn)),
@@ -361,11 +361,11 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
361361

362362
void set_input(const llama_ubatch * ubatch) override;
363363

364-
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
365-
std::unique_ptr<llm_graph_input_rs> inp_rs;
364+
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
365+
std::unique_ptr<llm_graph_input_rs> inp_rs;
366366

367-
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
368-
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
367+
llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
368+
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
369369

370370
const llama_memory_hybrid_context * mctx;
371371
};
@@ -703,10 +703,10 @@ struct llm_graph_context {
703703
float kq_scale,
704704
int il) const;
705705

706-
llm_graph_input_attn_kv * build_attn_inp_kv() const;
706+
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
707707

708708
ggml_tensor * build_attn(
709-
llm_graph_input_attn_kv * inp,
709+
llm_graph_input_attn_kv_unified * inp,
710710
ggml_tensor * wo,
711711
ggml_tensor * wo_b,
712712
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -717,11 +717,11 @@ struct llm_graph_context {
717717
float kq_scale,
718718
int il) const;
719719

720-
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
720+
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
721721

722722
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
723723
ggml_tensor * build_attn(
724-
llm_graph_input_attn_kv_iswa * inp,
724+
llm_graph_input_attn_kv_unified_iswa * inp,
725725
ggml_tensor * wo,
726726
ggml_tensor * wo_b,
727727
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -734,7 +734,7 @@ struct llm_graph_context {
734734

735735
// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
736736
ggml_tensor * build_attn_with_sinks(
737-
llm_graph_input_attn_kv_iswa * inp,
737+
llm_graph_input_attn_kv_unified_iswa * inp,
738738
ggml_tensor * wo,
739739
ggml_tensor * wo_b,
740740
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -765,7 +765,7 @@ struct llm_graph_context {
765765
//
766766

767767
// TODO: move this implementation to llama_memory_recurrent.
768-
// this is analogous to llama_kv_cache::cpy_k / cpy_v
768+
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
769769
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
770770
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
771771
// `llama_memory_recurrent`

0 commit comments

Comments
 (0)