@@ -48,9 +48,16 @@ int main(int argc, char ** argv) {
4848 // tokenize prompt
4949 auto tokens = common_tokenize (ctx, params.prompt , true );
5050
51+ // prepare the batch
52+ llama_batch batch = llama_batch_init (tokens.size (), 0 , 1 );
53+ for (size_t i = 0 ; i < tokens.size (); i++) {
54+ common_batch_add (batch, tokens[i], i, {0 }, false );
55+ }
56+ batch.logits [batch.n_tokens - 1 ] = true ; // generate next token
57+
5158 // evaluate prompt
52- llama_decode (ctx, llama_batch_get_one (tokens. data (), tokens. size ()) );
53- n_past += tokens. size () ;
59+ llama_decode (ctx, batch );
60+ n_past += batch. n_tokens ;
5461
5562 // save state (rng, logits, embedding and kv_cache) to file
5663 {
@@ -77,8 +84,12 @@ int main(int argc, char ** argv) {
7784 printf (" %s" , next_token_str.c_str ());
7885 result0 += next_token_str;
7986
80- if (llama_decode (ctx, llama_batch_get_one (&next_token, 1 ))) {
87+ common_batch_clear (batch);
88+ common_batch_add (batch, next_token, n_past, {0 }, true );
89+
90+ if (llama_decode (ctx, batch)) {
8191 fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
92+ llama_batch_free (batch);
8293 llama_free (ctx);
8394 llama_free_model (model);
8495 return 1 ;
@@ -133,8 +144,12 @@ int main(int argc, char ** argv) {
133144 printf (" %s" , next_token_str.c_str ());
134145 result1 += next_token_str;
135146
136- if (llama_decode (ctx2, llama_batch_get_one (&next_token, 1 ))) {
147+ common_batch_clear (batch);
148+ common_batch_add (batch, next_token, n_past, {0 }, true );
149+
150+ if (llama_decode (ctx2, batch)) {
137151 fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
152+ llama_batch_free (batch);
138153 llama_free (ctx2);
139154 llama_free_model (model);
140155 return 1 ;
@@ -221,8 +236,12 @@ int main(int argc, char ** argv) {
221236 printf (" %s" , next_token_str.c_str ());
222237 result2 += next_token_str;
223238
224- if (llama_decode (ctx3, llama_batch_get_one (&next_token, 1 ))) {
239+ common_batch_clear (batch);
240+ common_batch_add (batch, next_token, n_past, {1 }, true );
241+
242+ if (llama_decode (ctx3, batch)) {
225243 fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
244+ llama_batch_free (batch);
226245 llama_free (ctx3);
227246 llama_free_model (model);
228247 return 1 ;
@@ -236,6 +255,7 @@ int main(int argc, char ** argv) {
236255 llama_sampler_free (smpl2);
237256 llama_sampler_free (smpl3);
238257
258+ llama_batch_free (batch);
239259 llama_free (ctx3);
240260 llama_free_model (model);
241261
0 commit comments