Skip to content

Commit 92b88a0

Browse files
committed
llama : cache llama_token_to_piece
ggml-ci
1 parent 0548a41 commit 92b88a0

File tree

2 files changed

+91
-69
lines changed

2 files changed

+91
-69
lines changed

llama.cpp

Lines changed: 89 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,12 +1651,13 @@ struct llama_mlock {
16511651
};
16521652
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
16531653

1654-
static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
1654+
// NOTE: avoid ever using this except for building the token_to_piece caches
1655+
static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
16551656
std::vector<char> result(8, 0);
1656-
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
1657+
const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), special);
16571658
if (n_tokens < 0) {
16581659
result.resize(-n_tokens);
1659-
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
1660+
int check = llama_token_to_piece(model, token, result.data(), result.size(), special);
16601661
GGML_ASSERT(check == -n_tokens);
16611662
}
16621663
else {
@@ -2086,7 +2087,11 @@ struct llama_vocab {
20862087
std::unordered_map<token, id> token_to_id;
20872088
std::vector<token_data> id_to_token;
20882089

2089-
std::unordered_map<token, id> special_tokens_cache;
2090+
bool has_cache = false;
2091+
2092+
std::unordered_map<token, id> cache_special_tokens;
2093+
std::unordered_map<id, token> cache_token_to_piece; // llama_token_to_piece(special = false);
2094+
std::unordered_map<id, token> cache_token_to_piece_special; // llama_token_to_piece(special = true);
20902095

20912096
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
20922097

@@ -4789,7 +4794,7 @@ static void llm_load_vocab(
47894794
// And skip the ones which are one character
47904795
if (utf8_str_len > 1) {
47914796
// At this point what we have left are special tokens only
4792-
vocab.special_tokens_cache[token] = id;
4797+
vocab.cache_special_tokens[token] = id;
47934798

47944799
// Count manually found special tokens
47954800
special_tokens_count_from_verification++;
@@ -4816,6 +4821,13 @@ static void llm_load_vocab(
48164821
);
48174822
}
48184823
}
4824+
4825+
for (llama_token id = 0; id < (llama_token) n_vocab; ++id) {
4826+
vocab.cache_token_to_piece[id] = llama_token_to_piece(&model, id, false);
4827+
vocab.cache_token_to_piece_special[id] = llama_token_to_piece(&model, id, true);
4828+
}
4829+
4830+
vocab.has_cache = true;
48194831
}
48204832

48214833
static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@@ -12898,7 +12910,7 @@ struct fragment_buffer_variant {
1289812910

1289912911
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
1290012912
// for each special token
12901-
for (const auto & st: vocab.special_tokens_cache) {
12913+
for (const auto & st: vocab.cache_special_tokens) {
1290212914
const auto & special_token = st.first;
1290312915
const auto & special_id = st.second;
1290412916

@@ -14058,7 +14070,7 @@ void llama_sample_repetition_penalties(
1405814070

1405914071
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
1406014072
GGML_ASSERT(ctx);
14061-
const int64_t t_start_sample_us = ggml_time_us();
14073+
int64_t t_start_sample_us = ggml_time_us();
1406214074

1406314075
bool allow_eog = false;
1406414076
for (const auto & stack : grammar->stacks) {
@@ -14074,8 +14086,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
1407414086
candidates_grammar.reserve(candidates->size);
1407514087

1407614088
for (size_t i = 0; i < candidates->size; ++i) {
14077-
const llama_token id = candidates->data[i].id;
14078-
const std::string piece = llama_token_to_piece(ctx, id, false);
14089+
const llama_token id = candidates->data[i].id;
14090+
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id);
1407914091

1408014092
if (llama_token_is_eog(&ctx->model, id)) {
1408114093
if (!allow_eog) {
@@ -14275,7 +14287,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
1427514287
GGML_ASSERT(false);
1427614288
}
1427714289

14278-
const std::string piece = llama_token_to_piece(ctx, token, false);
14290+
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token);
1427914291

1428014292
// Note terminating 0 in decoded string
1428114293
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
@@ -17948,69 +17960,79 @@ static std::string llama_decode_text(const std::string & text) {
1794817960

1794917961
// does not write null-terminator to buf
1795017962
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
17963+
if (model->vocab.has_cache) {
17964+
const auto & cache = special ? model->vocab.cache_token_to_piece_special : model->vocab.cache_token_to_piece;
17965+
const auto & res = cache.at(token);
17966+
if (length < (int) res.size()) {
17967+
return -(int) res.size();
17968+
}
17969+
memcpy(buf, res.c_str(), res.size());
17970+
return res.size();
17971+
}
17972+
1795117973
if (0 <= token && token < llama_n_vocab(model)) {
1795217974
switch (llama_vocab_get_type(model->vocab)) {
17953-
case LLAMA_VOCAB_TYPE_WPM:
17954-
case LLAMA_VOCAB_TYPE_SPM: {
17955-
// NOTE: we accept all unsupported token types,
17956-
// suppressing them like CONTROL tokens.
17957-
if (llama_is_normal_token(model->vocab, token)) {
17958-
std::string result = model->vocab.id_to_token[token].text;
17959-
llama_unescape_whitespace(result);
17960-
if (length < (int) result.length()) {
17961-
return -(int) result.length();
17962-
}
17963-
memcpy(buf, result.c_str(), result.length());
17964-
return result.length();
17965-
} else if (
17966-
(llama_is_user_defined_token(model->vocab, token)) ||
17967-
(llama_is_control_token (model->vocab, token) && special)) {
17968-
std::string result = model->vocab.id_to_token[token].text;
17969-
if (length < (int) result.length()) {
17970-
return -(int) result.length();
17971-
}
17972-
memcpy(buf, result.c_str(), result.length());
17973-
return result.length();
17974-
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
17975-
if (length < 3) {
17976-
return -3;
17977-
}
17978-
memcpy(buf, "\xe2\x96\x85", 3);
17979-
return 3;
17980-
} else if (llama_is_byte_token(model->vocab, token)) {
17981-
if (length < 1) {
17982-
return -1;
17975+
case LLAMA_VOCAB_TYPE_WPM:
17976+
case LLAMA_VOCAB_TYPE_SPM: {
17977+
// NOTE: we accept all unsupported token types,
17978+
// suppressing them like CONTROL tokens.
17979+
if (llama_is_normal_token(model->vocab, token)) {
17980+
std::string result = model->vocab.id_to_token[token].text;
17981+
llama_unescape_whitespace(result);
17982+
if (length < (int) result.length()) {
17983+
return -(int) result.length();
17984+
}
17985+
memcpy(buf, result.c_str(), result.length());
17986+
return result.length();
17987+
} else if (
17988+
(llama_is_user_defined_token(model->vocab, token)) ||
17989+
(llama_is_control_token (model->vocab, token) && special)) {
17990+
std::string result = model->vocab.id_to_token[token].text;
17991+
if (length < (int) result.length()) {
17992+
return -(int) result.length();
17993+
}
17994+
memcpy(buf, result.c_str(), result.length());
17995+
return result.length();
17996+
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
17997+
if (length < 3) {
17998+
return -3;
17999+
}
18000+
memcpy(buf, "\xe2\x96\x85", 3);
18001+
return 3;
18002+
} else if (llama_is_byte_token(model->vocab, token)) {
18003+
if (length < 1) {
18004+
return -1;
18005+
}
18006+
buf[0] = llama_token_to_byte(model->vocab, token);
18007+
return 1;
1798318008
}
17984-
buf[0] = llama_token_to_byte(model->vocab, token);
17985-
return 1;
18009+
break;
1798618010
}
17987-
break;
17988-
}
17989-
case LLAMA_VOCAB_TYPE_BPE: {
17990-
// NOTE: we accept all unsupported token types,
17991-
// suppressing them like CONTROL tokens.
17992-
if (llama_is_normal_token(model->vocab, token)) {
17993-
std::string result = model->vocab.id_to_token[token].text;
17994-
result = llama_decode_text(result);
17995-
if (length < (int) result.length()) {
17996-
return -(int) result.length();
17997-
}
17998-
memcpy(buf, result.c_str(), result.length());
17999-
return result.length();
18000-
} else if (
18001-
(llama_is_user_defined_token(model->vocab, token)) ||
18002-
(llama_is_control_token (model->vocab, token) && special)) {
18003-
std::string result = model->vocab.id_to_token[token].text;
18004-
if (length < (int) result.length()) {
18005-
return -(int) result.length();
18011+
case LLAMA_VOCAB_TYPE_BPE: {
18012+
// NOTE: we accept all unsupported token types,
18013+
// suppressing them like CONTROL tokens.
18014+
if (llama_is_normal_token(model->vocab, token)) {
18015+
std::string result = model->vocab.id_to_token[token].text;
18016+
result = llama_decode_text(result);
18017+
if (length < (int) result.length()) {
18018+
return -(int) result.length();
18019+
}
18020+
memcpy(buf, result.c_str(), result.length());
18021+
return result.length();
18022+
} else if (
18023+
(llama_is_user_defined_token(model->vocab, token)) ||
18024+
(llama_is_control_token (model->vocab, token) && special)) {
18025+
std::string result = model->vocab.id_to_token[token].text;
18026+
if (length < (int) result.length()) {
18027+
return -(int) result.length();
18028+
}
18029+
memcpy(buf, result.c_str(), result.length());
18030+
return result.length();
1800618031
}
18007-
memcpy(buf, result.c_str(), result.length());
18008-
return result.length();
18032+
break;
1800918033
}
18010-
break;
18011-
}
18012-
default:
18013-
GGML_ASSERT(false);
18034+
default:
18035+
GGML_ASSERT(false);
1801418036
}
1801518037
}
1801618038
return 0;

llama.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,8 @@ extern "C" {
424424

425425
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
426426

427-
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
428-
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
427+
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
428+
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
429429

430430
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
431431
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);

0 commit comments

Comments
 (0)