@@ -101,7 +101,8 @@ llama_context::llama_context(
101101
102102 cparams.n_ubatch = std::min (cparams.n_batch , params.n_ubatch == 0 ? params.n_batch : params.n_ubatch );
103103
104- cparams.op_offload = params.op_offload ;
104+ cparams.op_offload = params.op_offload ;
105+ cparams.graph_reuse = params.graph_reuse ;
105106
106107 const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max ;
107108
@@ -227,8 +228,8 @@ llama_context::llama_context(
227228
228229 LLAMA_LOG_DEBUG (" %s: max_nodes = %zu\n " , __func__, max_nodes);
229230
230- // buffer used to store the computation graph and the tensor meta data
231- buf_compute_meta. resize ( ggml_tensor_overhead ()*max_nodes + ggml_graph_overhead_custom (max_nodes, false ));
231+ gf_res_prev. reset ( new llm_graph_result (max_nodes));
232+ gf_res_reserve. reset ( new llm_graph_result (max_nodes));
232233
233234 // TODO: move these checks to ggml_backend_sched
234235 // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@@ -388,10 +389,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
388389 return sched.get ();
389390}
390391
391- ggml_context * llama_context::get_ctx_compute () const {
392- return ctx_compute.get ();
393- }
394-
395392uint32_t llama_context::n_ctx () const {
396393 return cparams.n_ctx ;
397394}
@@ -678,38 +675,52 @@ bool llama_context::apply_adapter_cvec(
678675 return cvec.apply (model, data, len, n_embd, il_start, il_end);
679676}
680677
681- llm_graph_result_ptr llama_context::process_ubatch (const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
678+ llm_graph_result_i * llama_context::process_ubatch (const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682679 if (mctx && !mctx->apply ()) {
683680 LLAMA_LOG_ERROR (" %s: failed to apply memory context\n " , __func__);
684681 ret = GGML_STATUS_FAILED;
685682 return nullptr ;
686683 }
687684
688- auto * gf = graph_init ();
689- if (!gf) {
690- LLAMA_LOG_ERROR (" %s: failed to initialize graph\n " , __func__);
691- ret = GGML_STATUS_FAILED;
692- return nullptr ;
693- }
685+ auto * res = gf_res_prev.get ();
686+ auto * gf = res->get_gf ();
694687
695- auto res = graph_build (ctx_compute.get (), gf, ubatch, gtype, mctx);
696- if (!res) {
697- LLAMA_LOG_ERROR (" %s: failed to build graph\n " , __func__);
698- ret = GGML_STATUS_FAILED;
699- return nullptr ;
700- }
688+ // the new graph parameters
689+ // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
690+ const auto gparams = graph_params (res, ubatch, mctx, gtype);
701691
702- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
692+ const bool can_reuse = cparams.graph_reuse && res->update (gparams);
693+ if (can_reuse) {
694+ LLAMA_LOG_DEBUG (" %s: reusing previous graph\n " , __func__);
695+ n_reused++;
696+ } else {
697+ res->reset ();
703698
704- if (!ggml_backend_sched_alloc_graph (sched.get (), gf)) {
705- LLAMA_LOG_ERROR (" %s: failed to allocate graph\n " , __func__);
706- ret = GGML_STATUS_ALLOC_FAILED;
707- return nullptr ;
699+ ggml_backend_sched_reset (sched.get ());
700+ ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
701+
702+ // const auto t_start_us = ggml_time_us();
703+
704+ gf = model.build_graph (gparams);
705+
706+ // LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
707+
708+ if (!gf) {
709+ LLAMA_LOG_ERROR (" %s: failed to initialize graph\n " , __func__);
710+ ret = GGML_STATUS_FAILED;
711+ return nullptr ;
712+ }
713+
714+ if (!ggml_backend_sched_alloc_graph (sched.get (), gf)) {
715+ LLAMA_LOG_ERROR (" %s: failed to allocate graph\n " , __func__);
716+ ret = GGML_STATUS_ALLOC_FAILED;
717+ return nullptr ;
718+ }
708719 }
709720
710721 res->set_inputs (&ubatch);
711722
712- const auto status = graph_compute (gf , ubatch.n_tokens > 1 );
723+ const auto status = graph_compute (res-> get_gf () , ubatch.n_tokens > 1 );
713724 if (status != GGML_STATUS_SUCCESS) {
714725 LLAMA_LOG_ERROR (" %s: failed to compute graph, compute status: %d\n " , __func__, status);
715726 ret = status;
@@ -767,9 +778,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
767778
768779 n_outputs = n_tokens;
769780
770- ggml_backend_sched_reset (sched.get ());
771- ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
772-
773781 const auto causal_attn_org = cparams.causal_attn ;
774782
775783 // always use non-causal attention for encoder graphs
@@ -778,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
778786 cparams.causal_attn = false ;
779787
780788 ggml_status status;
781- const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr , status);
789+ const auto * res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr , status);
782790
783791 cparams.causal_attn = causal_attn_org;
784792
@@ -846,7 +854,9 @@ int llama_context::encode(const llama_batch & batch_inp) {
846854
847855 // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
848856 // overlap with device computation.
849- ggml_backend_sched_reset (sched.get ());
857+ if (!cparams.graph_reuse ) {
858+ ggml_backend_sched_reset (sched.get ());
859+ }
850860
851861 // TODO: hacky solution
852862 if (model.arch == LLM_ARCH_T5 && t_embd) {
@@ -1005,11 +1015,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
10051015 n_outputs = n_outputs_new;
10061016 }
10071017
1008- ggml_backend_sched_reset (sched.get ());
1009- ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
1010-
10111018 ggml_status status;
1012- const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get (), status);
1019+ const auto * res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get (), status);
10131020
10141021 if (!res) {
10151022 // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1192,7 +1199,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
11921199
11931200 // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
11941201 // overlap with device computation.
1195- ggml_backend_sched_reset (sched.get ());
1202+ if (!cparams.graph_reuse ) {
1203+ ggml_backend_sched_reset (sched.get ());
1204+ }
11961205
11971206 return 0 ;
11981207}
@@ -1275,20 +1284,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
12751284// graph
12761285//
12771286
1278- int32_t llama_context::graph_max_nodes () const {
1279- return std::max<int32_t >(65536 , 5 *model.n_tensors ());
1280- }
1281-
1282- ggml_cgraph * llama_context::graph_init () {
1283- ggml_init_params params = {
1284- /* .mem_size =*/ buf_compute_meta.size (),
1285- /* .mem_buffer =*/ buf_compute_meta.data (),
1286- /* .no_alloc =*/ true ,
1287- };
1288-
1289- ctx_compute.reset (ggml_init (params));
1290-
1291- return ggml_new_graph_custom (ctx_compute.get (), graph_max_nodes (), false );
1287+ uint32_t llama_context::graph_max_nodes () const {
1288+ return std::max<uint32_t >(65536u , 5u *model.n_tensors ());
12921289}
12931290
12941291ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
@@ -1301,6 +1298,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13011298 LLAMA_LOG_DEBUG (" %s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n " , __func__, n_tokens, n_seqs, n_outputs);
13021299 }
13031300
1301+ gf_res_prev->reset ();
1302+ ggml_backend_sched_reset (sched.get ());
1303+
13041304 // store the n_outputs as it is, and restore it afterwards
13051305 // TODO: not sure if needed, might simplify in the future by removing this
13061306 const auto save_n_outputs = this ->n_outputs ;
@@ -1310,17 +1310,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13101310 llama_batch_allocr balloc (model.hparams .n_pos_per_embd ());
13111311 llama_ubatch ubatch = balloc.ubatch_reserve (n_tokens/n_seqs, n_seqs);
13121312
1313- auto * gf = graph_init ();
1314- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1313+ auto * res = gf_res_reserve.get ();
13151314
1316- this -> n_outputs = save_n_outputs ;
1315+ const auto gparams = graph_params (res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT) ;
13171316
1318- if (!res) {
1319- LLAMA_LOG_ERROR (" %s: failed to build worst-case graph\n " , __func__);
1320- return nullptr ;
1321- }
1317+ res->reset ();
13221318
1323- ggml_backend_sched_reset (sched.get ());
1319+ auto * gf = model.build_graph (gparams);
1320+
1321+ this ->n_outputs = save_n_outputs;
13241322
13251323 // initialize scheduler with the specified graph
13261324 if (!ggml_backend_sched_reserve (sched.get (), gf)) {
@@ -1331,28 +1329,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13311329 return gf;
13321330}
13331331
1334- llm_graph_result_ptr llama_context::graph_build (
1335- ggml_context * ctx,
1336- ggml_cgraph * gf,
1337- const llama_ubatch & ubatch,
1338- llm_graph_type gtype,
1339- const llama_memory_context_i * mctx) {
1340- return model.build_graph (
1341- {
1342- /* .ctx =*/ ctx,
1343- /* .arch =*/ model.arch ,
1344- /* .hparams =*/ model.hparams ,
1345- /* .cparams =*/ cparams,
1346- /* .ubatch =*/ ubatch,
1347- /* .sched =*/ sched.get (),
1348- /* .backend_cpu =*/ backend_cpu,
1349- /* .cvec =*/ &cvec,
1350- /* .loras =*/ &loras,
1351- /* .mctx =*/ mctx,
1352- /* .cross =*/ &cross,
1353- /* .n_outputs =*/ n_outputs,
1354- /* .cb =*/ graph_get_cb (),
1355- }, gf, gtype);
1332+ llm_graph_params llama_context::graph_params (
1333+ llm_graph_result_i * res,
1334+ const llama_ubatch & ubatch,
1335+ const llama_memory_context_i * mctx,
1336+ llm_graph_type gtype) const {
1337+ return {
1338+ /* .arch =*/ model.arch ,
1339+ /* .hparams =*/ model.hparams ,
1340+ /* .cparams =*/ cparams,
1341+ /* .ubatch =*/ ubatch,
1342+ /* .gtype =*/ gtype,
1343+ /* .sched =*/ sched.get (),
1344+ /* .backend_cpu =*/ backend_cpu,
1345+ /* .cvec =*/ &cvec,
1346+ /* .loras =*/ &loras,
1347+ /* .mctx =*/ mctx,
1348+ /* .cross =*/ &cross,
1349+ /* .n_outputs =*/ n_outputs,
1350+ /* .cb =*/ graph_get_cb (),
1351+ /* .res =*/ res,
1352+ };
13561353}
13571354
13581355ggml_status llama_context::graph_compute (
@@ -1930,6 +1927,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
19301927 data.t_eval_ms = 1e-3 * t_eval_us;
19311928 data.n_p_eval = std::max (1 , n_p_eval);
19321929 data.n_eval = std::max (1 , n_eval);
1930+ data.n_reused = std::max (0 , n_reused);
19331931
19341932 return data;
19351933}
@@ -1938,6 +1936,7 @@ void llama_context::perf_reset() {
19381936 t_start_us = ggml_time_us ();
19391937 t_eval_us = n_eval = 0 ;
19401938 t_p_eval_us = n_p_eval = 0 ;
1939+ n_reused = 0 ;
19411940}
19421941
19431942//
@@ -2064,8 +2063,13 @@ void llama_context::opt_epoch_iter(
20642063 break ;
20652064 }
20662065
2067- auto * gf = graph_init ();
2068- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get ());
2066+ auto * res = gf_res_prev.get ();
2067+
2068+ const auto gparams = graph_params (res, ubatch, mctx.get (), LLM_GRAPH_TYPE_DEFAULT);
2069+
2070+ res->reset ();
2071+
2072+ auto * gf = model.build_graph (gparams);
20692073
20702074 struct ggml_context * ctx_compute_opt;
20712075 {
@@ -2187,6 +2191,7 @@ llama_context_params llama_context_default_params() {
21872191 /* .no_perf =*/ true ,
21882192 /* .op_offload =*/ true ,
21892193 /* .swa_full =*/ true ,
2194+ /* .graph_reuse =*/ false ,
21902195 };
21912196
21922197 return result;
@@ -2807,6 +2812,7 @@ void llama_perf_context_print(const llama_context * ctx) {
28072812 LLAMA_LOG_INFO (" %s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n " ,
28082813 __func__, data.t_eval_ms , data.n_eval , data.t_eval_ms / data.n_eval , 1e3 / data.t_eval_ms * data.n_eval );
28092814 LLAMA_LOG_INFO (" %s: total time = %10.2f ms / %5d tokens\n " , __func__, (t_end_ms - data.t_start_ms ), (data.n_p_eval + data.n_eval ));
2815+ LLAMA_LOG_INFO (" %s: graphs reused = %10d\n " , __func__, data.n_reused );
28102816}
28112817
28122818void llama_perf_context_reset (llama_context * ctx) {
0 commit comments