|
2 | 2 | #include "sampling.h"
|
3 | 3 | #include <random>
|
4 | 4 |
|
| 5 | +// |
| 6 | +// Token healing (internal) |
| 7 | +// |
| 8 | + |
| 9 | +static bool startswith(const std::string & str, const std::string & prefix) { |
| 10 | + return str.rfind(prefix, 0) != std::string::npos; |
| 11 | +} |
| 12 | + |
| 13 | +static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) { |
| 14 | + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); |
| 15 | + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { |
| 16 | + if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) { |
| 17 | + return true; |
| 18 | + } |
| 19 | + } |
| 20 | + return false; |
| 21 | +} |
| 22 | + |
| 23 | +static std::vector<llama_token> token_healing_find_prefix( |
| 24 | + const llama_context * ctx_main, |
| 25 | + const std::string & prefix, |
| 26 | + const bool include_partial_prefix) { |
| 27 | + // Example: prefix=" world" -> " world", " worldwide", ... |
| 28 | + // If `include_partial_prefix`, include also: " w", " wo", ... |
| 29 | + std::vector<llama_token> candidates; |
| 30 | + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); |
| 31 | + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { |
| 32 | + std::string token = llama_token_to_piece(ctx_main, token_id); |
| 33 | + if (startswith(token, prefix) || |
| 34 | + (include_partial_prefix && startswith(prefix, token))) { |
| 35 | + candidates.push_back(token_id); |
| 36 | + } |
| 37 | + } |
| 38 | + return candidates; |
| 39 | +} |
| 40 | + |
| 41 | +// |
| 42 | +// Token healing (external) |
| 43 | +// |
| 44 | + |
| 45 | +std::string llama_token_healing_rollback( |
| 46 | + const llama_context * ctx_main, |
| 47 | + llama_token_healing_type th_type, |
| 48 | + std::vector<llama_token> & tokens, |
| 49 | + int max_to_remove, |
| 50 | + int * n_removed) { |
| 51 | + if (n_removed != nullptr) { |
| 52 | + *n_removed = 0; |
| 53 | + } |
| 54 | + if (tokens.empty()) { |
| 55 | + return ""; |
| 56 | + } |
| 57 | + |
| 58 | + const llama_model * model = llama_get_model(ctx_main); |
| 59 | + const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; |
| 60 | + const int n_ctx = tokens.size(); |
| 61 | + max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove; |
| 62 | + max_to_remove = max_to_remove < 0 ? n_ctx : std::min(max_to_remove, n_ctx); |
| 63 | + int removed = 0; |
| 64 | + std::string prefix; |
| 65 | + // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt |
| 66 | + // and stop early if a special token is encountered |
| 67 | + while (removed < max_to_remove) { |
| 68 | + const llama_token next_token_id = tokens[n_ctx - removed - 1]; |
| 69 | + if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) { |
| 70 | + // Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize) |
| 71 | + break; |
| 72 | + } |
| 73 | + std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix; |
| 74 | + if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) { |
| 75 | + break; |
| 76 | + } |
| 77 | + removed += 1; |
| 78 | + prefix = new_prefix; |
| 79 | + } |
| 80 | + if (removed == 0) { // E.g. if the last token is a special token |
| 81 | + return ""; |
| 82 | + } |
| 83 | + // If constrained decoding would give back the original prompt, there is no need to modify the context |
| 84 | + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || |
| 85 | + th_type == llama_token_healing_type::DYNAMIC_MULTI; |
| 86 | + const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step); |
| 87 | + LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed); |
| 88 | + if (removed == 1 && candidates.size() == 1) { |
| 89 | + LOG("token_healing: nothing to heal\n"); |
| 90 | + return ""; |
| 91 | + } |
| 92 | + // Finalize outputs |
| 93 | + if (n_removed != nullptr) { |
| 94 | + *n_removed = removed; |
| 95 | + } |
| 96 | + tokens.resize(n_ctx - removed); |
| 97 | + return prefix; |
| 98 | +} |
| 99 | + |
| 100 | +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) { |
| 101 | + ctx_sampling->token_healing_prefix = prefix; |
| 102 | +} |
| 103 | + |
| 104 | +// |
| 105 | +// Sampling |
| 106 | +// |
| 107 | + |
5 | 108 | struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
6 | 109 | struct llama_sampling_context * result = new llama_sampling_context();
|
7 | 110 |
|
@@ -64,6 +167,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
64 | 167 | grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
|
65 | 168 | }
|
66 | 169 |
|
| 170 | + ctx->token_healing_prefix.clear(); |
| 171 | + |
67 | 172 | std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
68 | 173 | ctx->cur.clear();
|
69 | 174 | ctx->n_valid = 0;
|
@@ -122,7 +227,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
122 | 227 | }
|
123 | 228 |
|
124 | 229 | std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
125 |
| - std::string result = "CFG -> Penalties "; |
| 230 | + std::string result = "(Token healing) -> CFG -> Penalties "; |
126 | 231 | if (params.mirostat == 0) {
|
127 | 232 | for (auto sampler_type : params.samplers_sequence) {
|
128 | 233 | const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
|
@@ -303,8 +408,27 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
303 | 408 |
|
304 | 409 | cur.clear();
|
305 | 410 |
|
306 |
| - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { |
307 |
| - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); |
| 411 | + // Constrain tokens based on the remaining token healing prefix (if any) |
| 412 | + const auto & th_type = params.token_healing_type; |
| 413 | + const auto & th_prefix = ctx_sampling->token_healing_prefix; |
| 414 | + if (params.token_healing_enabled && !th_prefix.empty()) { |
| 415 | + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || |
| 416 | + th_type == llama_token_healing_type::DYNAMIC_MULTI; |
| 417 | + std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step); |
| 418 | + |
| 419 | + LOG("token_healing: prefix = '%s'\n", th_prefix.c_str()); |
| 420 | + for (const llama_token token_id : th_candidates) { |
| 421 | + LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str()); |
| 422 | + } |
| 423 | + |
| 424 | + // N.B. We could also set token constraints by setting rejected tokens' logits to -inf |
| 425 | + for (const llama_token token_id : th_candidates) { |
| 426 | + cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); |
| 427 | + } |
| 428 | + } else { |
| 429 | + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { |
| 430 | + cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); |
| 431 | + } |
308 | 432 | }
|
309 | 433 |
|
310 | 434 | llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
@@ -367,4 +491,19 @@ void llama_sampling_accept(
|
367 | 491 | if (ctx_sampling->grammar != NULL && apply_grammar) {
|
368 | 492 | llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
|
369 | 493 | }
|
| 494 | + |
| 495 | + if (ctx_sampling->params.token_healing_enabled && apply_grammar) { |
| 496 | + std::string & th_prefix = ctx_sampling->token_healing_prefix; |
| 497 | + if (!th_prefix.empty()) { |
| 498 | + const std::string new_token_piece = llama_token_to_piece(ctx_main, id); |
| 499 | + if (new_token_piece.size() < th_prefix.size()) { |
| 500 | + // Shift prefix constraint (for multi step token healing) |
| 501 | + th_prefix = th_prefix.substr(new_token_piece.size()); |
| 502 | + } else { |
| 503 | + // Prefix has been generated => no more constrained generation |
| 504 | + th_prefix.clear(); |
| 505 | + LOG("token_healing: done\n"); |
| 506 | + } |
| 507 | + } |
| 508 | + } |
370 | 509 | }
|
0 commit comments