@@ -429,7 +429,7 @@ bool gpt2_eval(
429429 };
430430
431431 struct ggml_context * ctx0 = ggml_init (params);
432- struct ggml_cgraph gf = {} ;
432+ struct ggml_cgraph * gf = ggml_new_graph (ctx0) ;
433433
434434 struct ggml_tensor * embd = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, N);
435435 memcpy (embd->data , embd_inp.data (), N*ggml_element_size (embd));
@@ -491,8 +491,8 @@ bool gpt2_eval(
491491 struct ggml_tensor * k = ggml_view_1d (ctx0, model.memory_k , N*n_embd, (ggml_element_size (model.memory_k )*n_embd)*(il*n_ctx + n_past));
492492 struct ggml_tensor * v = ggml_view_1d (ctx0, model.memory_v , N*n_embd, (ggml_element_size (model.memory_v )*n_embd)*(il*n_ctx + n_past));
493493
494- ggml_build_forward_expand (& gf, ggml_cpy (ctx0, Kcur, k));
495- ggml_build_forward_expand (& gf, ggml_cpy (ctx0, Vcur, v));
494+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, Kcur, k));
495+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, Vcur, v));
496496 }
497497
498498 // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
@@ -673,8 +673,8 @@ bool gpt2_eval(
673673 // inpL = ggml_soft_max_inplace(ctx0, inpL);
674674
675675 // run the computation
676- ggml_build_forward_expand (& gf, inpL);
677- ggml_graph_compute_with_ctx (ctx0, & gf, n_threads);
676+ ggml_build_forward_expand (gf, inpL);
677+ ggml_graph_compute_with_ctx (ctx0, gf, n_threads);
678678
679679 // if (n_past%100 == 0) {
680680 // ggml_graph_print (&gf);
@@ -767,7 +767,7 @@ int main(int argc, char ** argv) {
767767 size_t mem_per_token = 0 ;
768768 gpt2_eval (model, params.n_threads , 0 , { 0 , 1 , 2 , 3 }, logits, mem_per_token);
769769
770- for (int i = embd.size (); i < embd_inp.size () + params.n_predict ; i++) {
770+ for (size_t i = embd.size (); i < embd_inp.size () + params.n_predict ; i++) {
771771 // predict
772772 if (embd.size () > 0 ) {
773773 const int64_t t_start_us = ggml_time_us ();
@@ -805,9 +805,9 @@ int main(int argc, char ** argv) {
805805 embd.push_back (id);
806806 } else {
807807 // if here, it means we are still processing the input prompt
808- for (int k = i; k < embd_inp.size (); k++) {
808+ for (size_t k = i; k < embd_inp.size (); k++) {
809809 embd.push_back (embd_inp[k]);
810- if (embd.size () >= params.n_batch ) {
810+ if (int32_t ( embd.size () ) >= params.n_batch ) {
811811 break ;
812812 }
813813 }
0 commit comments