7
7
#include "ggml.h"
8
8
#include "ggml-alloc.h"
9
9
#include "ggml-backend.h"
10
+ #include "../ggml/src/ggml-impl.h"
10
11
11
12
#if defined(GGML_USE_VULKAN)
12
13
# include "ggml-vulkan.h"
@@ -3254,6 +3255,17 @@ struct llama_sbatch {
3254
3255
}
3255
3256
};
3256
3257
3258
+ // Object used to allow caching of GGML graph between tokens where possible.
3259
+ struct ggml_cached_graph {
3260
+ bool is_active = false;
3261
+ ggml_cgraph * gf;
3262
+ size_t n;
3263
+ ggml_backend_t backend_res;
3264
+ ggml_backend_t backend_embd;
3265
+ struct ggml_tensor * res;
3266
+ struct ggml_tensor * embd;
3267
+ };
3268
+
3257
3269
struct llama_context {
3258
3270
llama_context(const llama_model & model)
3259
3271
: model(model)
@@ -3352,6 +3364,8 @@ struct llama_context {
3352
3364
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
3353
3365
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
3354
3366
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
3367
+
3368
+ struct ggml_cached_graph cached_graph;
3355
3369
};
3356
3370
3357
3371
struct llama_lora_weight {
@@ -9146,7 +9160,6 @@ static void llm_build_kv_store(
9146
9160
v_cur = ggml_transpose(ctx, v_cur);
9147
9161
}
9148
9162
cb(v_cache_view, "v_cache_view", il);
9149
-
9150
9163
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
9151
9164
}
9152
9165
@@ -17181,11 +17194,44 @@ static int llama_decode_internal(
17181
17194
ggml_backend_sched_reset(lctx.sched);
17182
17195
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
17183
17196
17184
- ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
17197
+ ggml_cgraph * gf;
17198
+ // the output is always the last tensor in the graph
17199
+ struct ggml_tensor * res;
17200
+ struct ggml_tensor * embd;
17201
+
17202
+ bool n_has_changed_since_last_token = false;
17203
+ if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
17204
+ lctx.cached_graph.n = kv_self.n;
17205
+
17206
+ // Re-build graph only if graph caching is not possible
17207
+ if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
17208
+
17209
+ gf = llama_build_graph(lctx, ubatch, false);
17210
+
17211
+ // Set whether GGML graph caching is in use within GGML module, based on
17212
+ // whether caching was activated here during the previous token
17213
+ ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
17214
+
17215
+ // Disable future graph caching in presence of env var,
17216
+ // if there are multiple devices, if batch size is greater than 1,
17217
+ // or if nsplits is not 2.
17218
+ // TO DO enable graph caching for these cases
17219
+ bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
17220
+ || (llama_get_device_count(model) > 1)
17221
+ || (ggml_backend_sched_get_n_splits(lctx.sched) != 2);
17222
+ for (int i = 0 ; i < ggml_graph_n_nodes(gf); i++) {
17223
+ if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
17224
+ disable_cached_ggml_graph = true;
17225
+ break;
17226
+ }
17227
+ }
17228
+
17229
+ // Set whether graph caching should be used for future tokens
17230
+ lctx.cached_graph.is_active=!disable_cached_ggml_graph;
17185
17231
17186
17232
// the output is always the last tensor in the graph
17187
- struct ggml_tensor * res = ggml_graph_node(gf, -1);
17188
- struct ggml_tensor * embd = ggml_graph_node(gf, -2);
17233
+ res = ggml_graph_node(gf, -1);
17234
+ embd = ggml_graph_node(gf, -2);
17189
17235
17190
17236
if (lctx.n_outputs == 0) {
17191
17237
// no output
@@ -17205,10 +17251,60 @@ static int llama_decode_internal(
17205
17251
embd = nullptr; // do not extract embeddings when not needed
17206
17252
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
17207
17253
}
17254
+ lctx.cached_graph.res = res;
17255
+ lctx.cached_graph.embd = embd;
17208
17256
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
17209
17257
17210
17258
ggml_backend_sched_alloc_graph(lctx.sched, gf);
17211
17259
17260
+ }
17261
+ else {
17262
+ gf = lctx.cached_graph.gf;
17263
+ res = lctx.cached_graph.res;
17264
+ embd = lctx.cached_graph.embd;
17265
+ }
17266
+ lctx.cached_graph.gf = gf;
17267
+
17268
+ // Update K and V cache parameters in cached graph.
17269
+ if(gf != nullptr && gf->nodes != nullptr && ggml_use_cached_graph(lctx.sched)) {
17270
+
17271
+ const struct llama_hparams & hparams = model.hparams;
17272
+ const int64_t kv_head = kv_self.head;
17273
+
17274
+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
17275
+ ggml_tensor * node = gf->nodes[i];
17276
+ if (node->op == GGML_OP_CPY) {
17277
+
17278
+ // K cache
17279
+ const char* k_prefix = "k_cache_view-";
17280
+ if (strncmp(node->src[1]->name, k_prefix, strlen(k_prefix)) == 0) {
17281
+ int il = atoi(node->src[1]->name + strlen(k_prefix)); // Layer index from name
17282
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
17283
+ ggml_tensor * tmp_tensor = kv_self.k_l[il];
17284
+ size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
17285
+ node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
17286
+ }
17287
+
17288
+ // V cache
17289
+ const char* v_prefix = "v_cache_view-";
17290
+ if (strncmp(node->src[1]->name, v_prefix, strlen(v_prefix)) == 0) {
17291
+ int il = atoi(node->src[1]->name + strlen(v_prefix)); // Layer index from name
17292
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
17293
+ ggml_tensor * tmp_tensor = kv_self.v_l[il];
17294
+ size_t tmp_offset;
17295
+ if (cparams.flash_attn) {
17296
+ tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
17297
+ } else {
17298
+ tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
17299
+ }
17300
+ node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
17301
+ }
17302
+
17303
+ }
17304
+ }
17305
+
17306
+ }
17307
+
17212
17308
llama_set_inputs(lctx, ubatch);
17213
17309
17214
17310
llama_graph_compute(lctx, gf, n_threads, threadpool);
@@ -17231,11 +17327,15 @@ static int llama_decode_internal(
17231
17327
// extract logits
17232
17328
if (res) {
17233
17329
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
17234
- GGML_ASSERT(backend_res != nullptr);
17235
- GGML_ASSERT(lctx.logits != nullptr);
17236
-
17237
17330
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
17238
17331
const int32_t n_outputs_new = lctx.n_outputs;
17332
+ if(!ggml_use_cached_graph(lctx.sched))
17333
+ lctx.cached_graph.backend_res = backend_res;
17334
+ else
17335
+ backend_res = lctx.cached_graph.backend_res;
17336
+
17337
+ GGML_ASSERT(backend_res != nullptr);
17338
+ GGML_ASSERT(lctx.logits != nullptr);
17239
17339
17240
17340
if (n_outputs_new) {
17241
17341
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
@@ -17247,6 +17347,12 @@ static int llama_decode_internal(
17247
17347
// extract embeddings
17248
17348
if (embd) {
17249
17349
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
17350
+
17351
+
17352
+ if(!ggml_use_cached_graph(lctx.sched))
17353
+ lctx.cached_graph.backend_embd = backend_embd;
17354
+ else
17355
+ backend_embd = lctx.cached_graph.backend_embd;
17250
17356
GGML_ASSERT(backend_embd != nullptr);
17251
17357
17252
17358
switch (cparams.pooling_type) {
0 commit comments