Skip to content

Commit 24a7940

Browse files
authored
Merge pull request #5 from namtranase/add-feature
add completion function
2 parents ede8eba + df10530 commit 24a7940

File tree

2 files changed

+109
-18
lines changed

2 files changed

+109
-18
lines changed

src/gemma_binding.cpp

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ namespace gcpp
135135

136136
while (abs_pos < args.max_tokens)
137137
{
138-
std::string prompt_string;
138+
std::string prompt_string;
139139
std::vector<int> prompt;
140140
current_pos = 0;
141141
{
@@ -255,6 +255,58 @@ namespace gcpp
255255
{ return true; });
256256
}
257257

258+
std::string decode(gcpp::Gemma &model, hwy::ThreadPool &pool,
259+
hwy::ThreadPool &inner_pool, const InferenceArgs &args,
260+
int verbosity, const gcpp::AcceptFunc &accept_token, std::string &prompt_string)
261+
{
262+
std::string generated_text;
263+
// Seed the random number generator
264+
std::random_device rd;
265+
std::mt19937 gen(rd());
266+
int prompt_size{};
267+
if (model.model_training == ModelTraining::GEMMA_IT)
268+
{
269+
// For instruction-tuned models: add control tokens.
270+
prompt_string = "<start_of_turn>user\n" + prompt_string +
271+
"<end_of_turn>\n<start_of_turn>model\n";
272+
}
273+
// Encode the prompt string into tokens
274+
std::vector<int> prompt;
275+
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok());
276+
// Placeholder for generated token IDs
277+
std::vector<int> generated_tokens;
278+
// Define lambda for token decoding
279+
StreamFunc stream_token = [&generated_tokens](int token, float /* probability */) -> bool {
280+
generated_tokens.push_back(token);
281+
return true; // Continue generating
282+
};
283+
// Decode tokens
284+
prompt_size = prompt.size();
285+
GenerateGemma(model, args, prompt, /*start_pos=*/0, pool, inner_pool, stream_token, accept_token, gen, verbosity);
286+
HWY_ASSERT(model.Tokenizer().Decode(generated_tokens, &generated_text).ok());
287+
generated_text = generated_text.substr(prompt_string.size());
288+
289+
return generated_text;
290+
}
291+
292+
std::string completion(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app, std::string &prompt_string)
293+
{
294+
hwy::ThreadPool inner_pool(0);
295+
hwy::ThreadPool pool(app.num_threads);
296+
if (app.num_threads > 10)
297+
{
298+
PinThreadToCore(app.num_threads - 1); // Main thread
299+
300+
pool.Run(0, pool.NumThreads(),
301+
[](uint64_t /*task*/, size_t thread)
302+
{ PinThreadToCore(thread); });
303+
}
304+
gcpp::Gemma model(loader, pool);
305+
return decode(model, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int)
306+
{ return true; }, prompt_string);
307+
308+
}
309+
258310
} // namespace gcpp
259311

260312
void chat_base(int argc, char **argv)
@@ -283,7 +335,30 @@ void chat_base(int argc, char **argv)
283335
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
284336
// return 1;
285337
}
338+
std::string completion_base(int argc, char **argv)
339+
{
340+
gcpp::LoaderArgs loader(argc, argv);
341+
gcpp::InferenceArgs inference(argc, argv);
342+
gcpp::AppArgs app(argc, argv);
343+
std::string prompt_string = argv[argc-1];
344+
return gcpp::completion(loader, inference, app, prompt_string);
345+
}
346+
std::string completion_base_wrapper(const std::vector<std::string> &args,std::string &prompt_string)
347+
{
348+
int argc = args.size() + 2; // +1 for the program name
349+
std::vector<char *> argv_vec;
350+
argv_vec.reserve(argc);
286351

352+
argv_vec.push_back(const_cast<char *>("pygemma"));
353+
354+
for (const auto &arg : args)
355+
{
356+
argv_vec.push_back(const_cast<char *>(arg.c_str()));
357+
}
358+
argv_vec.push_back(const_cast<char *>(prompt_string.c_str()));
359+
char **argv = argv_vec.data();
360+
return completion_base(argc, argv);
361+
}
287362
void show_help_wrapper()
288363
{
289364
// Assuming ShowHelp does not critically depend on argv content
@@ -294,12 +369,11 @@ void show_help_wrapper()
294369
ShowHelp(loader, inference, app);
295370
}
296371

297-
void chat_base_wrapper(const std::vector<std::string> &args)
372+
std::string chat_base_wrapper(const std::vector<std::string> &args)
298373
{
299374
int argc = args.size() + 1; // +1 for the program name
300375
std::vector<char *> argv_vec;
301376
argv_vec.reserve(argc);
302-
303377
argv_vec.push_back(const_cast<char *>("pygemma"));
304378

305379
for (const auto &arg : args)
@@ -308,12 +382,15 @@ void chat_base_wrapper(const std::vector<std::string> &args)
308382
}
309383

310384
char **argv = argv_vec.data();
385+
311386
chat_base(argc, argv);
312387
}
313388

389+
314390
PYBIND11_MODULE(pygemma, m)
315391
{
316392
m.doc() = "Pybind11 integration for chat_base function";
317393
m.def("chat_base", &chat_base_wrapper, "A wrapper for the chat_base function accepting Python list of strings as arguments");
318394
m.def("show_help", &show_help_wrapper, "A wrapper for show_help function");
395+
m.def("completion", &completion_base_wrapper, "A wrapper for inference function");
319396
}

tests/test_chat.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,39 @@ def main():
1818
parser.add_argument(
1919
"--model", type=str, required=True, help="Model type identifier."
2020
)
21-
22-
args = parser.parse_args()
23-
24-
# Now using the parsed arguments
25-
pygemma.chat_base(
26-
[
27-
"--tokenizer",
28-
args.tokenizer,
29-
"--compressed_weights",
30-
args.compressed_weights,
31-
"--model",
32-
args.model,
33-
]
21+
parser.add_argument(
22+
"--input", type=str, required=False, help="Input text to chat with the model. If None, Switch to Chat mode.",
23+
default="Hello."
3424
)
35-
25+
# Now using the parsed arguments
26+
args = parser.parse_args()
27+
if args.input is not None:
28+
string = pygemma.completion(
29+
[
30+
"--tokenizer",
31+
args.tokenizer,
32+
"--compressed_weights",
33+
args.compressed_weights,
34+
"--model",
35+
args.model,
36+
], args.input
37+
)
38+
print(string)
39+
else:
40+
return pygemma.chat_base(
41+
[
42+
"--tokenizer",
43+
args.tokenizer,
44+
"--compressed_weights",
45+
args.compressed_weights,
46+
"--model",
47+
args.model,
48+
]
49+
)
3650
# Optionally, show help if needed
3751
# pygemma.show_help()
3852

3953

4054
if __name__ == "__main__":
4155
main()
42-
# python tests/test_chat.py --tokenizer /path/to/tokenizer.spm --compressed_weights /path/to/weights.sbs --model model_identifier
56+
# python tests/test_chat.py --tokenizer ../Model_Weight/tokenizer.spm --compressed_weights ../Model_Weight/2b-it-sfp.sbs --model 2b-it

0 commit comments

Comments
 (0)