Skip to content

kv-cache : use ggml_set_rows #14285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
return false;
}
} break;
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CPY: {
ggml_tensor *src = op->src[0];
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2222,6 +2222,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
default:
return false;
}
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4285,6 +4285,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return false;
}
}
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CPY:
{
ggml_type src0_type = op->src[0]->type;
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10333,6 +10333,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false;
}
} break;
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CONT:
case GGML_OP_CPY:
case GGML_OP_DUP:
Expand Down
68 changes: 46 additions & 22 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
}

void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_v_idxs(self_v_idxs, ubatch);

mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}

void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);

if (self_kq_mask_swa) {
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);

mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);

mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}

void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
Expand Down Expand Up @@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
}

void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);

mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);

const int64_t n_rs = mctx->get_recr()->get_n_rs();

Expand All @@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
}
}

void llm_graph_input_one::set_input(const llama_ubatch *) {
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
GGML_ASSERT(one && ggml_nelements(one) == 1);
float f_one = 1.0f;
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
Expand Down Expand Up @@ -997,6 +1002,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {

const auto n_kv = inp->mctx->get_attn()->get_n_kv();

inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);

inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
Expand Down Expand Up @@ -1198,8 +1206,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()

const auto n_kv = mctx_cur->get_n_kv();

inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);

inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);

inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
Expand Down Expand Up @@ -1230,8 +1240,11 @@ ggml_tensor * llm_graph_context::build_attn(

// store to KV cache
{
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
const auto & k_idxs = inp->get_k_idxs();
const auto & v_idxs = inp->get_v_idxs();

ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
}

const auto & kq_mask = inp->get_kq_mask();
Expand Down Expand Up @@ -1290,11 +1303,15 @@ ggml_tensor * llm_graph_context::build_attn(

// optionally store to KV cache
if (k_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();

ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
}

if (v_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();

ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
}

const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
Expand Down Expand Up @@ -1398,8 +1415,11 @@ ggml_tensor * llm_graph_context::build_attn(

// store to KV cache
{
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
const auto & k_idxs = inp->get_k_idxs();
const auto & v_idxs = inp->get_v_idxs();

ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
}

const auto & kq_mask = inp->get_kq_mask();
Expand Down Expand Up @@ -1434,8 +1454,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
{
const auto n_kv = mctx_cur->get_base()->get_n_kv();

inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);

inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);

inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
Expand All @@ -1446,8 +1468,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif

const auto n_kv = mctx_cur->get_swa()->get_n_kv();

inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);

inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa);

inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
Expand Down
24 changes: 23 additions & 1 deletion src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,14 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {

void set_input(const llama_ubatch * ubatch) override;

ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }

ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }

ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]

ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]

Expand All @@ -274,9 +280,19 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {

void set_input(const llama_ubatch * ubatch) override;

ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }

ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }

ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]

ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
Expand Down Expand Up @@ -319,8 +335,14 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {

ggml_tensor * s_copy; // I32 [kv_size]

ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }

ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }

ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]

ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]

Expand All @@ -336,7 +358,7 @@ class llm_graph_input_one : public llm_graph_input_i {
llm_graph_input_one() {}
virtual ~llm_graph_input_one() = default;

void set_input(const llama_ubatch *) override;
void set_input(const llama_ubatch * ubatch) override;

ggml_tensor * one = nullptr; // F32
};
Expand Down
32 changes: 16 additions & 16 deletions src/llama-kv-cache-unified-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT
}

auto heads_base = kv_base->prepare(ubatches);
if (heads_base.empty()) {
auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;
}

auto heads_swa = kv_swa->prepare(ubatches);
if (heads_swa.empty()) {
auto sinfos_swa = kv_swa->prepare(ubatches);
if (sinfos_swa.empty()) {
break;
}

assert(heads_base.size() == heads_swa.size());
assert(sinfos_base.size() == sinfos_swa.size());

return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false);

// if it fails, try equal split
Expand All @@ -144,20 +144,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT
}

auto heads_base = kv_base->prepare(ubatches);
if (heads_base.empty()) {
auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;
}

auto heads_swa = kv_swa->prepare(ubatches);
if (heads_swa.empty()) {
auto sinfos_swa = kv_swa->prepare(ubatches);
if (sinfos_swa.empty()) {
break;
}

assert(heads_base.size() == heads_swa.size());
assert(sinfos_base.size() == sinfos_swa.size());

return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false);

// TODO: if we fail again, we should attempt different splitting strategies
Expand Down Expand Up @@ -220,13 +220,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(

llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}

Expand Down
6 changes: 4 additions & 2 deletions src/llama-kv-cache-unified-iswa.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {

class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;

// used for errors
llama_kv_cache_unified_iswa_context(llama_memory_status status);

Expand All @@ -90,8 +92,8 @@ class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
// used to create a batch processing context from a batch
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches);

virtual ~llama_kv_cache_unified_iswa_context();
Expand Down
Loading
Loading