Skip to content

Commit 11ea5c7

Browse files
authored
infill. : fix tokenization (#3508)
* infill tokens correction * serverinfill tokens correction * removing any leading whitespace from infill suffix and removing leeading space token from suffix when params.escape * removing any leading whitespace from infill suffix and removing leeading space token from suffix when params.escape * only rm when params.escape, rm space if possible which is added back or rm added space token * only rm when params.escape, rm space if possible which is added back or rm added space token * Revert "only rm when params.escape, rm space if possible which is added back or rm added space token" This reverts commit 63ba0b6. * fix interactive prompt escaping and fix server infill leading space handling * rm unnecessary bool check
1 parent 95bd60a commit 11ea5c7

File tree

2 files changed

+46
-6
lines changed

2 files changed

+46
-6
lines changed

examples/infill/infill.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,22 @@ int main(int argc, char ** argv) {
233233
const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM;
234234
LOG("add_bos: %d\n", add_bos);
235235

236+
bool suff_rm_leading_spc = params.escape;
237+
if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
238+
params.input_suffix.erase(0, 1);
239+
suff_rm_leading_spc = false;
240+
}
236241
std::vector<llama_token> embd_inp;
237-
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos);
238-
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos);
242+
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
243+
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
244+
const int space_token = 29871;
245+
if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
246+
inp_sfx.erase(inp_sfx.begin());
247+
}
239248
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
249+
if (add_bos) {
250+
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
251+
}
240252
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
241253
embd_inp = inp_pfx;
242254
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
@@ -627,10 +639,27 @@ int main(int argc, char ** argv) {
627639
buffer.clear();
628640
// done taking input, reset color
629641
console::set_display(console::reset);
642+
643+
if (params.escape) {
644+
//process escape sequences, for the initial prompt this is done in common.cpp when we load the params, but for the interactive mode we need to do it here
645+
process_escapes(params.input_prefix);
646+
process_escapes(params.input_suffix);
647+
}
648+
suff_rm_leading_spc = params.escape;
649+
if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
650+
params.input_suffix.erase(0, 1);
651+
suff_rm_leading_spc = false;
652+
}
630653
// tokenize new prefix and suffix
631-
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos);
632-
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos);
654+
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
655+
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
656+
if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
657+
inp_sfx.erase(inp_sfx.begin());
658+
}
633659
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
660+
if (add_bos) {
661+
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
662+
}
634663
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
635664
embd_inp = inp_pfx;
636665
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());

examples/server/server.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,20 @@ struct llama_server_context
344344

345345
void loadInfill()
346346
{
347-
auto prefix_tokens = tokenize(params.input_prefix, true); // always add BOS
348-
auto suffix_tokens = tokenize(params.input_suffix, true); // always add BOS
347+
bool suff_rm_leading_spc = true;
348+
if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
349+
params.input_suffix.erase(0, 1);
350+
suff_rm_leading_spc = false;
351+
}
352+
353+
auto prefix_tokens = tokenize(params.input_prefix, false);
354+
auto suffix_tokens = tokenize(params.input_suffix, false);
355+
const int space_token = 29871;
356+
if (suff_rm_leading_spc && suffix_tokens[0] == space_token) {
357+
suffix_tokens.erase(suffix_tokens.begin());
358+
}
349359
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
360+
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS
350361
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
351362
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
352363
prefix_tokens.push_back(llama_token_middle(ctx));

0 commit comments

Comments
 (0)