Skip to content

Commit c9b2297

Browse files
committed
main : add token healing
1 parent 38c0347 commit c9b2297

File tree

5 files changed

+249
-6
lines changed

5 files changed

+249
-6
lines changed

common/common.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
13061306
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
13071307
return true;
13081308
}
1309+
if (arg == "-th" || arg == "--token-healing") {
1310+
if (++i >= argc) {
1311+
invalid_param = true;
1312+
return true;
1313+
}
1314+
sparams.token_healing_enabled = true;
1315+
auto & th_type = sparams.token_healing_type;
1316+
auto & th_n_rollback = sparams.token_healing_n_rollback;
1317+
std::string value(argv[i]);
1318+
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
1319+
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
1320+
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
1321+
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
1322+
else if (value[0] == 'r' ) {
1323+
th_type = llama_token_healing_type::ROLLBACK_MULTI;
1324+
th_n_rollback = std::stoi(value.substr(1));
1325+
if (th_n_rollback <= 0) {
1326+
sparams.token_healing_enabled = false;
1327+
}
1328+
} else { invalid_param = true; }
1329+
return true;
1330+
}
13091331
if (arg == "--override-kv") {
13101332
if (++i >= argc) {
13111333
invalid_param = true;
@@ -1503,6 +1525,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
15031525
printf(" -j SCHEMA, --json-schema SCHEMA\n");
15041526
printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n");
15051527
printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n");
1528+
printf(" -th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}\n");
1529+
printf(" Token healing type. (default: 0, disabled)\n");
1530+
printf(" 1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens\n");
15061531
printf(" --cfg-negative-prompt PROMPT\n");
15071532
printf(" negative prompt to use for guidance. (default: empty)\n");
15081533
printf(" --cfg-negative-prompt-file FNAME\n");

common/sampling.cpp

Lines changed: 142 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,109 @@
22
#include "sampling.h"
33
#include <random>
44

5+
//
6+
// Token healing (internal)
7+
//
8+
9+
static bool startswith(const std::string & str, const std::string & prefix) {
10+
return str.rfind(prefix, 0) != std::string::npos;
11+
}
12+
13+
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
14+
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
15+
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
16+
if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) {
17+
return true;
18+
}
19+
}
20+
return false;
21+
}
22+
23+
static std::vector<llama_token> token_healing_find_prefix(
24+
const llama_context * ctx_main,
25+
const std::string & prefix,
26+
const bool include_partial_prefix) {
27+
// Example: prefix=" world" -> " world", " worldwide", ...
28+
// If `include_partial_prefix`, include also: " w", " wo", ...
29+
std::vector<llama_token> candidates;
30+
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
31+
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
32+
std::string token = llama_token_to_piece(ctx_main, token_id);
33+
if (startswith(token, prefix) ||
34+
(include_partial_prefix && startswith(prefix, token))) {
35+
candidates.push_back(token_id);
36+
}
37+
}
38+
return candidates;
39+
}
40+
41+
//
42+
// Token healing (external)
43+
//
44+
45+
std::string llama_token_healing_rollback(
46+
const llama_context * ctx_main,
47+
llama_token_healing_type th_type,
48+
std::vector<llama_token> & tokens,
49+
int max_to_remove,
50+
int * n_removed) {
51+
if (n_removed != nullptr) {
52+
*n_removed = 0;
53+
}
54+
if (tokens.empty()) {
55+
return "";
56+
}
57+
58+
const llama_model * model = llama_get_model(ctx_main);
59+
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
60+
const int n_ctx = tokens.size();
61+
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
62+
max_to_remove = max_to_remove < 0 ? n_ctx : std::min(max_to_remove, n_ctx);
63+
int removed = 0;
64+
std::string prefix;
65+
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
66+
// and stop early if a special token is encountered
67+
while (removed < max_to_remove) {
68+
const llama_token next_token_id = tokens[n_ctx - removed - 1];
69+
if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) {
70+
// Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize)
71+
break;
72+
}
73+
std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix;
74+
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
75+
break;
76+
}
77+
removed += 1;
78+
prefix = new_prefix;
79+
}
80+
if (removed == 0) { // E.g. if the last token is a special token
81+
return "";
82+
}
83+
// If constrained decoding would give back the original prompt, there is no need to modify the context
84+
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
85+
th_type == llama_token_healing_type::DYNAMIC_MULTI;
86+
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
87+
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed);
88+
if (removed == 1 && candidates.size() == 1) {
89+
LOG("token_healing: nothing to heal\n");
90+
return "";
91+
}
92+
// Finalize outputs
93+
if (n_removed != nullptr) {
94+
*n_removed = removed;
95+
}
96+
tokens.resize(n_ctx - removed);
97+
return prefix;
98+
}
99+
100+
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
101+
ctx_sampling->token_healing_prefix = prefix;
102+
}
103+
104+
//
105+
// Sampling
106+
//
107+
5108
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
6109
struct llama_sampling_context * result = new llama_sampling_context();
7110

@@ -64,6 +167,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
64167
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
65168
}
66169

170+
ctx->token_healing_prefix.clear();
171+
67172
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
68173
ctx->cur.clear();
69174
ctx->n_valid = 0;
@@ -122,7 +227,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
122227
}
123228

124229
std::string llama_sampling_order_print(const llama_sampling_params & params) {
125-
std::string result = "CFG -> Penalties ";
230+
std::string result = "(Token healing) -> CFG -> Penalties ";
126231
if (params.mirostat == 0) {
127232
for (auto sampler_type : params.samplers_sequence) {
128233
const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
@@ -303,8 +408,27 @@ static llama_token_data_array llama_sampling_prepare_impl(
303408

304409
cur.clear();
305410

306-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
307-
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
411+
// Constrain tokens based on the remaining token healing prefix (if any)
412+
const auto & th_type = params.token_healing_type;
413+
const auto & th_prefix = ctx_sampling->token_healing_prefix;
414+
if (params.token_healing_enabled && !th_prefix.empty()) {
415+
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
416+
th_type == llama_token_healing_type::DYNAMIC_MULTI;
417+
std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step);
418+
419+
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
420+
for (const llama_token token_id : th_candidates) {
421+
LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str());
422+
}
423+
424+
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
425+
for (const llama_token token_id : th_candidates) {
426+
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
427+
}
428+
} else {
429+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
430+
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
431+
}
308432
}
309433

310434
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
@@ -367,4 +491,19 @@ void llama_sampling_accept(
367491
if (ctx_sampling->grammar != NULL && apply_grammar) {
368492
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
369493
}
494+
495+
if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
496+
std::string & th_prefix = ctx_sampling->token_healing_prefix;
497+
if (!th_prefix.empty()) {
498+
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
499+
if (new_token_piece.size() < th_prefix.size()) {
500+
// Shift prefix constraint (for multi step token healing)
501+
th_prefix = th_prefix.substr(new_token_piece.size());
502+
} else {
503+
// Prefix has been generated => no more constrained generation
504+
th_prefix.clear();
505+
LOG("token_healing: done\n");
506+
}
507+
}
508+
}
370509
}

common/sampling.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ enum class llama_sampler_type : char {
1919
TEMPERATURE = 't'
2020
};
2121

22+
enum class llama_token_healing_type : uint8_t {
23+
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
24+
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
25+
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
26+
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
27+
};
28+
2229
// sampling parameters
2330
typedef struct llama_sampling_params {
2431
int32_t n_prev = 64; // number of previous tokens to remember
@@ -62,6 +69,10 @@ typedef struct llama_sampling_params {
6269

6370
std::vector<llama_token> penalty_prompt_tokens;
6471
bool use_penalty_prompt_tokens = false;
72+
73+
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
74+
bool token_healing_enabled = false;
75+
int token_healing_n_rollback = -1; // number of tokens to roll back
6576
} llama_sampling_params;
6677

6778
// general sampler context
@@ -78,6 +89,8 @@ struct llama_sampling_context {
7889
// internal
7990
grammar_parser::parse_state parsed_grammar;
8091

92+
std::string token_healing_prefix;
93+
8194
// TODO: replace with ring-buffer
8295
std::vector<llama_token> prev;
8396
std::vector<llama_token_data> cur;
@@ -153,3 +166,18 @@ void llama_sampling_accept(
153166
struct llama_context * ctx_main,
154167
llama_token id,
155168
bool apply_grammar);
169+
170+
//
171+
// Token healing
172+
//
173+
174+
// Roll back `tokens` for constrained generation according to the token healing
175+
// strategy. Returns the prefix for constrained generation.
176+
std::string llama_token_healing_rollback(
177+
const llama_context * ctx_main,
178+
llama_token_healing_type th_type,
179+
std::vector<llama_token> & tokens,
180+
int max_to_remove = -1,
181+
int * n_removed = nullptr);
182+
183+
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);

examples/main/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,19 @@ A more practical use case might be to prevent the generation of `\code{begin}` a
259259

260260
Example usage: `--logit-bias 29905-inf`
261261

262+
### Token healing
263+
264+
- `-th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}`: Set the token healing strategy (default: 0, 0 = disabled).
265+
266+
Token healing (a.k.a. token alignment) alleviates tokenization artifacts for text completion.
267+
268+
- `-th 1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2].
269+
- `-th d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2].
270+
- `-th d`: Like `d1` but allow multiple decoding steps until the removed suffix is generated.
271+
- `-th r{N}`: Like `d` but roll back `N` tokens, where `-th r3` is recommended [1].
272+
273+
Sources: [0](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb), [1](https://arxiv.org/abs/2403.08688), [2](https://arxiv.org/abs/2402.01035).
274+
262275
### RNG Seed
263276

264277
- `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).

examples/main/main.cpp

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,17 @@ int main(int argc, char ** argv) {
264264
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
265265
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
266266

267+
if (sparams.token_healing_enabled && (params.instruct || params.chatml || params.conversation || !params.input_suffix.empty())) {
268+
sparams.token_healing_enabled = false;
269+
LOG("token_healing: disabled due to custom suffix/conversation mode");
270+
}
271+
std::string token_healing_prefix;
272+
int token_healing_n_removed = 0;
273+
if (!params.interactive_first && sparams.token_healing_enabled) {
274+
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
275+
sparams.token_healing_n_rollback, &token_healing_n_removed);
276+
}
277+
267278
// Should not run without any tokens
268279
if (embd_inp.empty()) {
269280
embd_inp.push_back(llama_token_bos(model));
@@ -283,7 +294,7 @@ int main(int argc, char ** argv) {
283294
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
284295
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
285296

286-
original_prompt_len = original_inp.size();
297+
original_prompt_len = original_inp.size() - token_healing_n_removed;
287298
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
288299
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
289300
LOG("guidance_offset: %s", log_tostr(guidance_offset));
@@ -502,6 +513,7 @@ int main(int argc, char ** argv) {
502513
int n_consumed = 0;
503514
int n_session_consumed = 0;
504515
int n_past_guidance = 0;
516+
int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix
505517

506518
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
507519
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
@@ -527,6 +539,7 @@ int main(int argc, char ** argv) {
527539
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
528540
exit(1);
529541
}
542+
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
530543

531544
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
532545
// predict
@@ -741,7 +754,16 @@ int main(int argc, char ** argv) {
741754
if (input_echo && display) {
742755
for (auto id : embd) {
743756
const std::string token_str = llama_token_to_piece(ctx, id, !params.conversation);
744-
printf("%s", token_str.c_str());
757+
758+
// Suppress printing while generating token healing prefix
759+
if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) {
760+
printf("%s", token_str.substr(n_bytes_to_skip).c_str());
761+
n_bytes_to_skip = 0;
762+
} else if (n_bytes_to_skip > 0) {
763+
n_bytes_to_skip -= token_str.size();
764+
} else {
765+
printf("%s", token_str.c_str());
766+
}
745767

746768
if (embd.size() > 1) {
747769
input_tokens.push_back(id);
@@ -820,6 +842,7 @@ int main(int argc, char ** argv) {
820842
}
821843
}
822844

845+
token_healing_n_removed = 0;
823846
if (n_past > 0 && is_interacting) {
824847
LOG("waiting for user input\n");
825848

@@ -903,13 +926,24 @@ int main(int argc, char ** argv) {
903926
embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end());
904927
}
905928

929+
if (sparams.token_healing_enabled) {
930+
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
931+
const int n_new_tokens = embd_inp.size() - original_size;
932+
const int max_to_remove = sparams.token_healing_n_rollback < 0
933+
? n_new_tokens
934+
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
935+
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
936+
max_to_remove, &token_healing_n_removed);
937+
n_bytes_to_skip = token_healing_prefix.size();
938+
}
939+
906940
for (size_t i = original_size; i < embd_inp.size(); ++i) {
907941
const llama_token token = embd_inp[i];
908942
output_tokens.push_back(token);
909943
output_ss << llama_token_to_piece(ctx, token);
910944
}
911945

912-
n_remain -= line_inp.size();
946+
n_remain -= line_inp.size() + token_healing_n_removed;
913947
LOG("n_remain: %d\n", n_remain);
914948
} else {
915949
LOG("empty line, passing control back\n");
@@ -921,6 +955,10 @@ int main(int argc, char ** argv) {
921955
if (n_past > 0) {
922956
if (is_interacting) {
923957
llama_sampling_reset(ctx_sampling);
958+
if (token_healing_n_removed > 0) {
959+
// Set new prefix after an interaction
960+
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
961+
}
924962
}
925963
is_interacting = false;
926964
}

0 commit comments

Comments
 (0)