@@ -135,7 +135,7 @@ namespace gcpp
135
135
136
136
while (abs_pos < args.max_tokens )
137
137
{
138
- std::string prompt_string;
138
+ std::string prompt_string;
139
139
std::vector<int > prompt;
140
140
current_pos = 0 ;
141
141
{
@@ -255,6 +255,58 @@ namespace gcpp
255
255
{ return true ; });
256
256
}
257
257
258
+ std::string decode (gcpp::Gemma &model, hwy::ThreadPool &pool,
259
+ hwy::ThreadPool &inner_pool, const InferenceArgs &args,
260
+ int verbosity, const gcpp::AcceptFunc &accept_token, std::string &prompt_string)
261
+ {
262
+ std::string generated_text;
263
+ // Seed the random number generator
264
+ std::random_device rd;
265
+ std::mt19937 gen (rd ());
266
+ int prompt_size{};
267
+ if (model.model_training == ModelTraining::GEMMA_IT)
268
+ {
269
+ // For instruction-tuned models: add control tokens.
270
+ prompt_string = " <start_of_turn>user\n " + prompt_string +
271
+ " <end_of_turn>\n <start_of_turn>model\n " ;
272
+ }
273
+ // Encode the prompt string into tokens
274
+ std::vector<int > prompt;
275
+ HWY_ASSERT (model.Tokenizer ().Encode (prompt_string, &prompt).ok ());
276
+ // Placeholder for generated token IDs
277
+ std::vector<int > generated_tokens;
278
+ // Define lambda for token decoding
279
+ StreamFunc stream_token = [&generated_tokens](int token, float /* probability */ ) -> bool {
280
+ generated_tokens.push_back (token);
281
+ return true ; // Continue generating
282
+ };
283
+ // Decode tokens
284
+ prompt_size = prompt.size ();
285
+ GenerateGemma (model, args, prompt, /* start_pos=*/ 0 , pool, inner_pool, stream_token, accept_token, gen, verbosity);
286
+ HWY_ASSERT (model.Tokenizer ().Decode (generated_tokens, &generated_text).ok ());
287
+ generated_text = generated_text.substr (prompt_string.size ());
288
+
289
+ return generated_text;
290
+ }
291
+
292
+ std::string completion (LoaderArgs &loader, InferenceArgs &inference, AppArgs &app, std::string &prompt_string)
293
+ {
294
+ hwy::ThreadPool inner_pool (0 );
295
+ hwy::ThreadPool pool (app.num_threads );
296
+ if (app.num_threads > 10 )
297
+ {
298
+ PinThreadToCore (app.num_threads - 1 ); // Main thread
299
+
300
+ pool.Run (0 , pool.NumThreads (),
301
+ [](uint64_t /* task*/ , size_t thread)
302
+ { PinThreadToCore (thread); });
303
+ }
304
+ gcpp::Gemma model (loader, pool);
305
+ return decode (model, pool, inner_pool, inference, app.verbosity , /* accept_token=*/ [](int )
306
+ { return true ; }, prompt_string);
307
+
308
+ }
309
+
258
310
} // namespace gcpp
259
311
260
312
void chat_base (int argc, char **argv)
@@ -283,7 +335,30 @@ void chat_base(int argc, char **argv)
283
335
PROFILER_PRINT_RESULTS (); // Must call outside the zone above.
284
336
// return 1;
285
337
}
338
+ std::string completion_base (int argc, char **argv)
339
+ {
340
+ gcpp::LoaderArgs loader (argc, argv);
341
+ gcpp::InferenceArgs inference (argc, argv);
342
+ gcpp::AppArgs app (argc, argv);
343
+ std::string prompt_string = argv[argc-1 ];
344
+ return gcpp::completion (loader, inference, app, prompt_string);
345
+ }
346
+ std::string completion_base_wrapper (const std::vector<std::string> &args,std::string &prompt_string)
347
+ {
348
+ int argc = args.size () + 2 ; // +1 for the program name
349
+ std::vector<char *> argv_vec;
350
+ argv_vec.reserve (argc);
286
351
352
+ argv_vec.push_back (const_cast <char *>(" pygemma" ));
353
+
354
+ for (const auto &arg : args)
355
+ {
356
+ argv_vec.push_back (const_cast <char *>(arg.c_str ()));
357
+ }
358
+ argv_vec.push_back (const_cast <char *>(prompt_string.c_str ()));
359
+ char **argv = argv_vec.data ();
360
+ return completion_base (argc, argv);
361
+ }
287
362
void show_help_wrapper ()
288
363
{
289
364
// Assuming ShowHelp does not critically depend on argv content
@@ -294,12 +369,11 @@ void show_help_wrapper()
294
369
ShowHelp (loader, inference, app);
295
370
}
296
371
297
- void chat_base_wrapper (const std::vector<std::string> &args)
372
+ std::string chat_base_wrapper (const std::vector<std::string> &args)
298
373
{
299
374
int argc = args.size () + 1 ; // +1 for the program name
300
375
std::vector<char *> argv_vec;
301
376
argv_vec.reserve (argc);
302
-
303
377
argv_vec.push_back (const_cast <char *>(" pygemma" ));
304
378
305
379
for (const auto &arg : args)
@@ -308,12 +382,15 @@ void chat_base_wrapper(const std::vector<std::string> &args)
308
382
}
309
383
310
384
char **argv = argv_vec.data ();
385
+
311
386
chat_base (argc, argv);
312
387
}
313
388
389
+
314
390
PYBIND11_MODULE (pygemma, m)
315
391
{
316
392
m.doc () = " Pybind11 integration for chat_base function" ;
317
393
m.def (" chat_base" , &chat_base_wrapper, " A wrapper for the chat_base function accepting Python list of strings as arguments" );
318
394
m.def (" show_help" , &show_help_wrapper, " A wrapper for show_help function" );
395
+ m.def (" completion" , &completion_base_wrapper, " A wrapper for inference function" );
319
396
}
0 commit comments