@@ -371,31 +371,11 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
371371// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
372372// these are used by the llama_context to extact the relevant data, based on the compute parameters
373373
374- // TODO: this interface seems redundant - remove it
375- class llm_graph_result_i {
376- public:
377- virtual ~llm_graph_result_i () = default ;
378-
379- virtual ggml_tensor * get_tokens () const = 0;
380- virtual ggml_tensor * get_logits () const = 0;
381- virtual ggml_tensor * get_embd () const = 0;
382- virtual ggml_tensor * get_embd_pooled () const = 0;
383-
384- virtual ggml_cgraph * get_gf () = 0;
385- virtual ggml_context * get_ctx () = 0;
386-
387- virtual void reset () = 0;
388-
389- virtual void set_inputs (const llama_ubatch * ubatch) = 0;
390-
391- virtual bool can_reuse (const llm_graph_params & params) = 0;
392- };
393-
394- using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
395-
396374// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
397375using llm_graph_cb = std::function<void (const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
398376
377+ class llm_graph_result ;
378+
399379struct llm_graph_params {
400380 llm_arch arch = LLM_ARCH_UNKNOWN;
401381
@@ -418,8 +398,7 @@ struct llm_graph_params {
418398
419399 llm_graph_cb cb;
420400
421- // TODO: temporary
422- llm_graph_result_i * res;
401+ llm_graph_result * res;
423402
424403 // return true if the "other" params would result in a graph with the same topology as with the current params
425404 // having the same topology allows us to reuse the graph in some cases
@@ -462,27 +441,27 @@ struct llm_graph_params {
462441 }
463442};
464443
465- class llm_graph_result : public llm_graph_result_i {
444+ class llm_graph_result {
466445public:
467446 llm_graph_result (int64_t max_nodes) : max_nodes(max_nodes) {
468447 reset ();
469448 }
470449
471450 virtual ~llm_graph_result () = default ;
472451
473- ggml_tensor * get_tokens () const override { return t_tokens; }
474- ggml_tensor * get_logits () const override { return t_logits; }
475- ggml_tensor * get_embd () const override { return t_embd; }
476- ggml_tensor * get_embd_pooled () const override { return t_embd_pooled; }
452+ ggml_tensor * get_tokens () const { return t_tokens; }
453+ ggml_tensor * get_logits () const { return t_logits; }
454+ ggml_tensor * get_embd () const { return t_embd; }
455+ ggml_tensor * get_embd_pooled () const { return t_embd_pooled; }
477456
478- ggml_cgraph * get_gf () override { return gf; }
479- ggml_context * get_ctx () override { return ctx_compute.get (); }
457+ ggml_cgraph * get_gf () { return gf; }
458+ ggml_context * get_ctx () { return ctx_compute.get (); }
480459
481460 void set_max_nodes (int64_t max_nodes) {
482461 this ->max_nodes = max_nodes;
483462 }
484463
485- void reset () override {
464+ void reset () {
486465 t_tokens = nullptr ;
487466 t_logits = nullptr ;
488467 t_embd = nullptr ;
@@ -503,7 +482,7 @@ class llm_graph_result : public llm_graph_result_i {
503482 gf = ggml_new_graph_custom (ctx_compute.get (), max_nodes, false );
504483 }
505484
506- void set_inputs (const llama_ubatch * ubatch) override {
485+ void set_inputs (const llama_ubatch * ubatch) {
507486 for (auto & input : inputs) {
508487 input->set_input (ubatch);
509488 }
@@ -514,7 +493,7 @@ class llm_graph_result : public llm_graph_result_i {
514493 // would be identical to the existing graph. in that case, we simply have to update the memory
515494 // contexts of the input tensors of the graph and we can reuse it for another computation
516495 // return true if the graph was updated and can be reused
517- bool can_reuse (const llm_graph_params & params) override {
496+ bool can_reuse (const llm_graph_params & params) {
518497 if (!this ->params .allow_reuse (params)) {
519498 return false ;
520499 }
@@ -533,6 +512,10 @@ class llm_graph_result : public llm_graph_result_i {
533512 return inputs.back ().get ();
534513 }
535514
515+ void set_params (const llm_graph_params & params) {
516+ this ->params = params;
517+ }
518+
536519 // important graph nodes
537520 ggml_tensor * t_tokens = nullptr ;
538521 ggml_tensor * t_logits = nullptr ;
@@ -550,12 +533,15 @@ class llm_graph_result : public llm_graph_result_i {
550533
551534 int64_t max_nodes;
552535
536+ private:
553537 // keep a copy of the previous graph parameters
554538 // we will use this to determine whether the graph can be reused by comparing them with the new parameters
555539 // note: these are updated after constructing the new graph
556540 llm_graph_params params;
557541};
558542
543+ using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
544+
559545//
560546// llm_graph_context
561547//
@@ -613,6 +599,7 @@ struct llm_graph_context {
613599 llm_graph_result * res;
614600
615601 ggml_context * ctx0 = nullptr ;
602+ ggml_cgraph * gf = nullptr ;
616603
617604 llm_graph_context (const llm_graph_params & params);
618605 virtual ~llm_graph_context () = default ;
@@ -698,7 +685,6 @@ struct llm_graph_context {
698685 //
699686
700687 ggml_tensor * build_attn_mha (
701- ggml_cgraph * gf,
702688 ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
703689 ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
704690 ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
@@ -711,7 +697,6 @@ struct llm_graph_context {
711697
712698 ggml_tensor * build_attn (
713699 llm_graph_input_attn_no_cache * inp,
714- ggml_cgraph * gf,
715700 ggml_tensor * wo,
716701 ggml_tensor * wo_b,
717702 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -726,7 +711,6 @@ struct llm_graph_context {
726711
727712 ggml_tensor * build_attn (
728713 llm_graph_input_attn_kv_unified * inp,
729- ggml_cgraph * gf,
730714 ggml_tensor * wo,
731715 ggml_tensor * wo_b,
732716 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -742,7 +726,6 @@ struct llm_graph_context {
742726 // note: if k_cur or v_cur are not provided, they will not be stored in the memory
743727 ggml_tensor * build_attn (
744728 llm_graph_input_attn_kv_unified_iswa * inp,
745- ggml_cgraph * gf,
746729 ggml_tensor * wo,
747730 ggml_tensor * wo_b,
748731 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -757,7 +740,6 @@ struct llm_graph_context {
757740
758741 ggml_tensor * build_attn (
759742 llm_graph_input_attn_cross * inp,
760- ggml_cgraph * gf,
761743 ggml_tensor * wo,
762744 ggml_tensor * wo_b,
763745 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -779,7 +761,6 @@ struct llm_graph_context {
779761 // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
780762 // `llama_memory_recurrent`
781763 ggml_tensor * build_rs (
782- ggml_cgraph * gf,
783764 ggml_tensor * s,
784765 ggml_tensor * state_copy,
785766 int32_t state_size,
@@ -794,17 +775,15 @@ struct llm_graph_context {
794775
795776 ggml_tensor * build_rs (
796777 llm_graph_input_rs * inp,
797- ggml_cgraph * gf,
798778 ggml_tensor * s,
799779 int32_t state_size,
800780 int32_t n_seqs,
801781 const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const ;
802782
803783 ggml_tensor * build_rwkv_token_shift_load (
804784 llm_graph_input_rs * inp,
805- ggml_cgraph * gf,
806785 const llama_ubatch & ubatch,
807- int il) const ;
786+ int il) const ;
808787
809788 ggml_tensor * build_rwkv_token_shift_store (
810789 ggml_tensor * token_shift,
@@ -821,7 +800,6 @@ struct llm_graph_context {
821800 //
822801
823802 void build_pooling (
824- ggml_cgraph * gf,
825803 ggml_tensor * cls,
826804 ggml_tensor * cls_b,
827805 ggml_tensor * cls_out,
0 commit comments