Skip to content

Commit 23214c9

Browse files
committed
ggml: avoid rebuild of GGML graph for each token (#7456)
Introduces caching of GGML graph to avoid unnecessary full rebuild between each token. KV cache parameters, which change with each token, are updated directly in cached GGML graph. Can be disabled with GGML_DISABLE_GRAPH_CACHING environment variable.
1 parent 7eee341 commit 23214c9

File tree

4 files changed

+158
-8
lines changed

4 files changed

+158
-8
lines changed

ggml/include/ggml-backend.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,12 @@ extern "C" {
321321
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
322322
#endif
323323

324+
// Utility to query whether cached GGML graph is in use
325+
GGML_API bool ggml_use_cached_graph(ggml_backend_sched_t sched);
326+
327+
// Set whether or not to use GGML graph caching
328+
GGML_API void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value);
329+
324330
#ifdef __cplusplus
325331
}
326332
#endif

ggml/include/ggml.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,13 @@ extern "C" {
574574
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
575575
};
576576

577+
// Flag (used on GGML_OP_CPY nodes) on whether node is associated with K or V cache
578+
enum ggml_kv_cache_flag {
579+
GGML_KV_CACHE_FLAG_NONE = 0,
580+
GGML_KV_CACHE_FLAG_K = 1,
581+
GGML_KV_CACHE_FLAG_V = 2
582+
};
583+
577584
// n-dimensional tensor
578585
struct ggml_tensor {
579586
enum ggml_type type;

ggml/src/ggml-backend.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1382,6 +1382,13 @@ struct ggml_backend_sched_split {
13821382
struct ggml_cgraph graph;
13831383
};
13841384

1385+
// Object to facilitate GML graph caching
1386+
struct ggml_cached_graph {
1387+
bool is_active;
1388+
ggml_backend_t input_backend;
1389+
struct ggml_tensor * input_cpy[GGML_SCHED_MAX_SPLIT_INPUTS];
1390+
};
1391+
13851392
struct ggml_backend_sched {
13861393
bool is_reset; // true if the scheduler has been reset since the last graph split
13871394
bool is_alloc;
@@ -1427,6 +1434,8 @@ struct ggml_backend_sched {
14271434
size_t context_buffer_size;
14281435

14291436
bool debug;
1437+
1438+
struct ggml_cached_graph cached_graph;
14301439
};
14311440

14321441
#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
@@ -2113,6 +2122,14 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
21132122
struct ggml_tensor * input = split->inputs[j];
21142123
struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
21152124

2125+
if (!sched->cached_graph.is_active) {
2126+
sched->cached_graph.input_backend = input_backend;
2127+
sched->cached_graph.input_cpy[j] = input_cpy;
2128+
}
2129+
else {
2130+
input_backend = sched->cached_graph.input_backend;
2131+
input_cpy = sched->cached_graph.input_cpy[j];
2132+
}
21162133
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
21172134
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
21182135
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
@@ -2245,6 +2262,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
22452262

22462263
ggml_backend_sched_reset(sched);
22472264

2265+
sched->cached_graph.is_active = false;
2266+
22482267
return sched;
22492268
}
22502269

@@ -2321,6 +2340,9 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st
23212340
}
23222341

23232342
enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
2343+
2344+
if(!sched->cached_graph.is_active)
2345+
{
23242346
if (!sched->is_reset && !sched->is_alloc) {
23252347
ggml_backend_sched_reset(sched);
23262348
}
@@ -2330,7 +2352,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sch
23302352
return GGML_STATUS_ALLOC_FAILED;
23312353
}
23322354
}
2333-
2355+
}
23342356
return ggml_backend_sched_compute_splits(sched);
23352357
}
23362358

@@ -2595,3 +2617,12 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
25952617

25962618
return true;
25972619
}
2620+
2621+
bool ggml_use_cached_graph(ggml_backend_sched_t sched) {
2622+
return sched->cached_graph.is_active;
2623+
}
2624+
2625+
void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value) {
2626+
sched->cached_graph.is_active = set_value;
2627+
}
2628+

src/llama.cpp

Lines changed: 113 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "ggml.h"
88
#include "ggml-alloc.h"
99
#include "ggml-backend.h"
10+
#include "../ggml/src/ggml-impl.h"
1011

1112
#if defined(GGML_USE_VULKAN)
1213
# include "ggml-vulkan.h"
@@ -3254,6 +3255,17 @@ struct llama_sbatch {
32543255
}
32553256
};
32563257

3258+
// Object used to allow caching of GGML graph between tokens where possible.
3259+
struct ggml_cached_graph {
3260+
bool is_active = false;
3261+
ggml_cgraph * gf;
3262+
size_t n;
3263+
ggml_backend_t backend_res;
3264+
ggml_backend_t backend_embd;
3265+
struct ggml_tensor * res;
3266+
struct ggml_tensor * embd;
3267+
};
3268+
32573269
struct llama_context {
32583270
llama_context(const llama_model & model)
32593271
: model(model)
@@ -3352,6 +3364,8 @@ struct llama_context {
33523364
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
33533365
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
33543366
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
3367+
3368+
struct ggml_cached_graph cached_graph;
33553369
};
33563370

33573371
struct llama_lora_weight {
@@ -9146,7 +9160,6 @@ static void llm_build_kv_store(
91469160
v_cur = ggml_transpose(ctx, v_cur);
91479161
}
91489162
cb(v_cache_view, "v_cache_view", il);
9149-
91509163
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
91519164
}
91529165

@@ -17181,11 +17194,44 @@ static int llama_decode_internal(
1718117194
ggml_backend_sched_reset(lctx.sched);
1718217195
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
1718317196

17184-
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
17197+
ggml_cgraph * gf;
17198+
// the output is always the last tensor in the graph
17199+
struct ggml_tensor * res;
17200+
struct ggml_tensor * embd;
17201+
17202+
bool n_has_changed_since_last_token = false;
17203+
if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
17204+
lctx.cached_graph.n = kv_self.n;
17205+
17206+
// Re-build graph only if graph caching is not possible
17207+
if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
17208+
17209+
gf = llama_build_graph(lctx, ubatch, false);
17210+
17211+
// Set whether GGML graph caching is in use within GGML module, based on
17212+
// whether caching was activated here during the previous token
17213+
ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
17214+
17215+
// Disable future graph caching in presence of env var,
17216+
// if there are multiple devices, if batch size is greater than 1,
17217+
// or if nsplits is not 2.
17218+
// TO DO enable graph caching for these cases
17219+
bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
17220+
|| (llama_get_device_count(model) > 1)
17221+
|| (ggml_backend_sched_get_n_splits(lctx.sched) != 2);
17222+
for (int i = 0 ; i < ggml_graph_n_nodes(gf); i++) {
17223+
if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
17224+
disable_cached_ggml_graph = true;
17225+
break;
17226+
}
17227+
}
17228+
17229+
// Set whether graph caching should be used for future tokens
17230+
lctx.cached_graph.is_active=!disable_cached_ggml_graph;
1718517231

1718617232
// the output is always the last tensor in the graph
17187-
struct ggml_tensor * res = ggml_graph_node(gf, -1);
17188-
struct ggml_tensor * embd = ggml_graph_node(gf, -2);
17233+
res = ggml_graph_node(gf, -1);
17234+
embd = ggml_graph_node(gf, -2);
1718917235

1719017236
if (lctx.n_outputs == 0) {
1719117237
// no output
@@ -17205,10 +17251,60 @@ static int llama_decode_internal(
1720517251
embd = nullptr; // do not extract embeddings when not needed
1720617252
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
1720717253
}
17254+
lctx.cached_graph.res = res;
17255+
lctx.cached_graph.embd = embd;
1720817256
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1720917257

1721017258
ggml_backend_sched_alloc_graph(lctx.sched, gf);
1721117259

17260+
}
17261+
else {
17262+
gf = lctx.cached_graph.gf;
17263+
res = lctx.cached_graph.res;
17264+
embd = lctx.cached_graph.embd;
17265+
}
17266+
lctx.cached_graph.gf = gf;
17267+
17268+
// Update K and V cache parameters in cached graph.
17269+
if(gf != nullptr && gf->nodes != nullptr && ggml_use_cached_graph(lctx.sched)) {
17270+
17271+
const struct llama_hparams & hparams = model.hparams;
17272+
const int64_t kv_head = kv_self.head;
17273+
17274+
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
17275+
ggml_tensor * node = gf->nodes[i];
17276+
if (node->op == GGML_OP_CPY) {
17277+
17278+
// K cache
17279+
const char* k_prefix = "k_cache_view-";
17280+
if (strncmp(node->src[1]->name, k_prefix, strlen(k_prefix)) == 0) {
17281+
int il = atoi(node->src[1]->name + strlen(k_prefix)); // Layer index from name
17282+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
17283+
ggml_tensor * tmp_tensor = kv_self.k_l[il];
17284+
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
17285+
node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
17286+
}
17287+
17288+
// V cache
17289+
const char* v_prefix = "v_cache_view-";
17290+
if (strncmp(node->src[1]->name, v_prefix, strlen(v_prefix)) == 0) {
17291+
int il = atoi(node->src[1]->name + strlen(v_prefix)); // Layer index from name
17292+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
17293+
ggml_tensor * tmp_tensor = kv_self.v_l[il];
17294+
size_t tmp_offset;
17295+
if (cparams.flash_attn) {
17296+
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
17297+
} else {
17298+
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
17299+
}
17300+
node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
17301+
}
17302+
17303+
}
17304+
}
17305+
17306+
}
17307+
1721217308
llama_set_inputs(lctx, ubatch);
1721317309

1721417310
llama_graph_compute(lctx, gf, n_threads, threadpool);
@@ -17231,11 +17327,15 @@ static int llama_decode_internal(
1723117327
// extract logits
1723217328
if (res) {
1723317329
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
17234-
GGML_ASSERT(backend_res != nullptr);
17235-
GGML_ASSERT(lctx.logits != nullptr);
17236-
1723717330
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
1723817331
const int32_t n_outputs_new = lctx.n_outputs;
17332+
if(!ggml_use_cached_graph(lctx.sched))
17333+
lctx.cached_graph.backend_res = backend_res;
17334+
else
17335+
backend_res = lctx.cached_graph.backend_res;
17336+
17337+
GGML_ASSERT(backend_res != nullptr);
17338+
GGML_ASSERT(lctx.logits != nullptr);
1723917339

1724017340
if (n_outputs_new) {
1724117341
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
@@ -17247,6 +17347,12 @@ static int llama_decode_internal(
1724717347
// extract embeddings
1724817348
if (embd) {
1724917349
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
17350+
17351+
17352+
if(!ggml_use_cached_graph(lctx.sched))
17353+
lctx.cached_graph.backend_embd = backend_embd;
17354+
else
17355+
backend_embd = lctx.cached_graph.backend_embd;
1725017356
GGML_ASSERT(backend_embd != nullptr);
1725117357

1725217358
switch (cparams.pooling_type) {

0 commit comments

Comments
 (0)