-
Notifications
You must be signed in to change notification settings - Fork 12.1k
llama : cache llama_token_to_piece #7587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1702,12 +1702,13 @@ struct llama_mlock { | |
}; | ||
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>; | ||
|
||
static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { | ||
// NOTE: avoid ever using this except for building the token_to_piece caches | ||
static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { | ||
std::vector<char> result(8, 0); | ||
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); | ||
const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), special); | ||
if (n_tokens < 0) { | ||
result.resize(-n_tokens); | ||
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); | ||
int check = llama_token_to_piece(model, token, result.data(), result.size(), special); | ||
GGML_ASSERT(check == -n_tokens); | ||
} | ||
else { | ||
|
@@ -2162,7 +2163,9 @@ struct llama_vocab { | |
std::unordered_map<token, id> token_to_id; | ||
std::vector<token_data> id_to_token; | ||
|
||
std::vector<id> special_tokens_cache; | ||
std::vector<id> cache_special_tokens; | ||
std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = false); | ||
std::vector<token> cache_token_to_piece_special; // llama_token_to_piece(special = true); | ||
|
||
std::map<std::pair<std::string, std::string>, int> bpe_ranks; | ||
|
||
|
@@ -4592,20 +4595,14 @@ static void llm_load_vocab( | |
vocab.special_cls_id = 101; | ||
vocab.special_mask_id = 103; | ||
vocab.add_space_prefix = false; | ||
} else { | ||
if (tokenizer_model == "gpt2") { | ||
vocab.type = LLAMA_VOCAB_TYPE_BPE; | ||
} else if (tokenizer_model == "gpt2") { | ||
vocab.type = LLAMA_VOCAB_TYPE_BPE; | ||
|
||
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); | ||
if (add_space_prefix_keyidx != -1) { | ||
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx); | ||
} | ||
} else { | ||
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str()); | ||
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); | ||
vocab.type = LLAMA_VOCAB_TYPE_SPM; | ||
return; | ||
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); | ||
if (add_space_prefix_keyidx != -1) { | ||
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx); | ||
} | ||
|
||
// read bpe merges and populate bpe ranks | ||
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); | ||
if (merges_keyidx == -1) { | ||
|
@@ -4639,6 +4636,8 @@ static void llm_load_vocab( | |
vocab.special_pad_id = -1; | ||
vocab.special_cls_id = -1; | ||
vocab.special_mask_id = -1; | ||
} else { | ||
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); | ||
} | ||
|
||
// for now, only BPE models have pre-tokenizers | ||
|
@@ -4833,17 +4832,38 @@ static void llm_load_vocab( | |
{ | ||
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { | ||
if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) { | ||
vocab.special_tokens_cache.push_back(id); | ||
vocab.cache_special_tokens.push_back(id); | ||
} | ||
} | ||
|
||
std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(), | ||
std::sort( vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(), | ||
[&] (const llama_vocab::id a, const llama_vocab::id b) { | ||
return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size(); | ||
} | ||
); | ||
|
||
LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size()); | ||
LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size()); | ||
} | ||
|
||
// build token to piece caches | ||
{ | ||
size_t size_cache = 0; | ||
|
||
std::vector<llama_vocab::token> cache_token_to_piece (n_vocab); | ||
std::vector<llama_vocab::token> cache_token_to_piece_special(n_vocab); | ||
|
||
for (uint32_t id = 0; id < n_vocab; ++id) { | ||
cache_token_to_piece[id] = llama_token_to_piece(&model, id, false); | ||
cache_token_to_piece_special[id] = llama_token_to_piece(&model, id, true); | ||
|
||
size_cache += cache_token_to_piece[id].size(); | ||
size_cache += cache_token_to_piece_special[id].size(); | ||
} | ||
|
||
std::swap(vocab.cache_token_to_piece, cache_token_to_piece); | ||
std::swap(vocab.cache_token_to_piece_special, cache_token_to_piece_special); | ||
|
||
LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0); | ||
} | ||
} | ||
|
||
|
@@ -13233,7 +13253,7 @@ struct fragment_buffer_variant { | |
|
||
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) { | ||
// for each special token | ||
for (const llama_vocab::id special_id : vocab.special_tokens_cache) { | ||
for (const llama_vocab::id special_id : vocab.cache_special_tokens) { | ||
const auto & special_token = vocab.id_to_token[special_id].text; | ||
|
||
// for each text fragment | ||
|
@@ -14392,7 +14412,7 @@ void llama_sample_repetition_penalties( | |
|
||
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { | ||
GGML_ASSERT(ctx); | ||
const int64_t t_start_sample_us = ggml_time_us(); | ||
int64_t t_start_sample_us = ggml_time_us(); | ||
|
||
bool allow_eog = false; | ||
for (const auto & stack : grammar->stacks) { | ||
|
@@ -14404,12 +14424,13 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c | |
|
||
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded; | ||
candidates_decoded.reserve(candidates->size); | ||
std::vector<llama_grammar_candidate> candidates_grammar; | ||
|
||
std::vector<llama_grammar_candidate> candidates_grammar; | ||
candidates_grammar.reserve(candidates->size); | ||
|
||
for (size_t i = 0; i < candidates->size; ++i) { | ||
const llama_token id = candidates->data[i].id; | ||
const std::string piece = llama_token_to_piece(ctx, id, false); | ||
const llama_token id = candidates->data[i].id; | ||
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id); | ||
HanClinto marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if (llama_token_is_eog(&ctx->model, id)) { | ||
if (!allow_eog) { | ||
|
@@ -14609,7 +14630,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar | |
GGML_ASSERT(false); | ||
} | ||
|
||
const std::string piece = llama_token_to_piece(ctx, token, false); | ||
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token); | ||
|
||
// Note terminating 0 in decoded string | ||
const auto decoded = decode_utf8(piece, grammar->partial_utf8); | ||
|
@@ -18292,69 +18313,83 @@ static std::string llama_decode_text(const std::string & text) { | |
|
||
// does not write null-terminator to buf | ||
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) { | ||
// if we have a cache - use it | ||
{ | ||
const auto & cache = special ? model->vocab.cache_token_to_piece_special : model->vocab.cache_token_to_piece; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Maybe we could get away w/ a single cache (built w/ special=true) and early-exit in special case at the top of the function? int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
if (!special && llama_is_control_token(model->vocab, token)) {
return 0;
}
// if we have a cache - use it
if (!model->vocab.cache_token_to_piece.empty()) {
....
}
... |
||
|
||
if (!cache.empty()) { | ||
HanClinto marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const auto & res = cache.at(token); | ||
if (length < (int) res.size()) { | ||
return -(int) res.size(); | ||
} | ||
memcpy(buf, res.c_str(), res.size()); | ||
return res.size(); | ||
} | ||
} | ||
|
||
if (0 <= token && token < llama_n_vocab(model)) { | ||
switch (llama_vocab_get_type(model->vocab)) { | ||
case LLAMA_VOCAB_TYPE_WPM: | ||
case LLAMA_VOCAB_TYPE_SPM: { | ||
// NOTE: we accept all unsupported token types, | ||
// suppressing them like CONTROL tokens. | ||
if (llama_is_normal_token(model->vocab, token)) { | ||
std::string result = model->vocab.id_to_token[token].text; | ||
llama_unescape_whitespace(result); | ||
if (length < (int) result.length()) { | ||
return -(int) result.length(); | ||
} | ||
memcpy(buf, result.c_str(), result.length()); | ||
return result.length(); | ||
} else if ( | ||
(llama_is_user_defined_token(model->vocab, token)) || | ||
(llama_is_control_token (model->vocab, token) && special)) { | ||
std::string result = model->vocab.id_to_token[token].text; | ||
if (length < (int) result.length()) { | ||
return -(int) result.length(); | ||
} | ||
memcpy(buf, result.c_str(), result.length()); | ||
return result.length(); | ||
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT | ||
if (length < 3) { | ||
return -3; | ||
} | ||
memcpy(buf, "\xe2\x96\x85", 3); | ||
return 3; | ||
} else if (llama_is_byte_token(model->vocab, token)) { | ||
if (length < 1) { | ||
return -1; | ||
case LLAMA_VOCAB_TYPE_WPM: | ||
case LLAMA_VOCAB_TYPE_SPM: { | ||
// NOTE: we accept all unsupported token types, | ||
// suppressing them like CONTROL tokens. | ||
if (llama_is_normal_token(model->vocab, token)) { | ||
std::string result = model->vocab.id_to_token[token].text; | ||
llama_unescape_whitespace(result); | ||
if (length < (int) result.length()) { | ||
return -(int) result.length(); | ||
} | ||
memcpy(buf, result.c_str(), result.length()); | ||
return result.length(); | ||
} else if ( | ||
(llama_is_user_defined_token(model->vocab, token)) || | ||
(llama_is_control_token (model->vocab, token) && special)) { | ||
std::string result = model->vocab.id_to_token[token].text; | ||
if (length < (int) result.length()) { | ||
return -(int) result.length(); | ||
} | ||
memcpy(buf, result.c_str(), result.length()); | ||
return result.length(); | ||
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT | ||
if (length < 3) { | ||
return -3; | ||
} | ||
memcpy(buf, "\xe2\x96\x85", 3); | ||
return 3; | ||
} else if (llama_is_byte_token(model->vocab, token)) { | ||
if (length < 1) { | ||
return -1; | ||
} | ||
buf[0] = llama_token_to_byte(model->vocab, token); | ||
return 1; | ||
} | ||
buf[0] = llama_token_to_byte(model->vocab, token); | ||
return 1; | ||
break; | ||
} | ||
break; | ||
} | ||
case LLAMA_VOCAB_TYPE_BPE: { | ||
// NOTE: we accept all unsupported token types, | ||
// suppressing them like CONTROL tokens. | ||
if (llama_is_normal_token(model->vocab, token)) { | ||
std::string result = model->vocab.id_to_token[token].text; | ||
result = llama_decode_text(result); | ||
if (length < (int) result.length()) { | ||
return -(int) result.length(); | ||
} | ||
memcpy(buf, result.c_str(), result.length()); | ||
return result.length(); | ||
} else if ( | ||
(llama_is_user_defined_token(model->vocab, token)) || | ||
(llama_is_control_token (model->vocab, token) && special)) { | ||
std::string result = model->vocab.id_to_token[token].text; | ||
if (length < (int) result.length()) { | ||
return -(int) result.length(); | ||
case LLAMA_VOCAB_TYPE_BPE: { | ||
// NOTE: we accept all unsupported token types, | ||
// suppressing them like CONTROL tokens. | ||
if (llama_is_normal_token(model->vocab, token)) { | ||
std::string result = model->vocab.id_to_token[token].text; | ||
result = llama_decode_text(result); | ||
if (length < (int) result.length()) { | ||
return -(int) result.length(); | ||
} | ||
memcpy(buf, result.c_str(), result.length()); | ||
return result.length(); | ||
} else if ( | ||
(llama_is_user_defined_token(model->vocab, token)) || | ||
(llama_is_control_token (model->vocab, token) && special)) { | ||
std::string result = model->vocab.id_to_token[token].text; | ||
if (length < (int) result.length()) { | ||
return -(int) result.length(); | ||
} | ||
memcpy(buf, result.c_str(), result.length()); | ||
return result.length(); | ||
} | ||
memcpy(buf, result.c_str(), result.length()); | ||
return result.length(); | ||
break; | ||
} | ||
break; | ||
} | ||
default: | ||
GGML_ASSERT(false); | ||
default: | ||
GGML_ASSERT(false); | ||
} | ||
} | ||
return 0; | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.