@@ -1651,12 +1651,13 @@ struct llama_mlock {
1651
1651
};
1652
1652
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
1653
1653
1654
- static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
1654
+ // NOTE: avoid ever using this except for building the token_to_piece caches
1655
+ static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
1655
1656
std::vector<char> result(8, 0);
1656
- const int n_tokens = llama_token_to_piece(llama_get_model(ctx) , token, result.data(), result.size(), special);
1657
+ const int n_tokens = llama_token_to_piece(model , token, result.data(), result.size(), special);
1657
1658
if (n_tokens < 0) {
1658
1659
result.resize(-n_tokens);
1659
- int check = llama_token_to_piece(llama_get_model(ctx) , token, result.data(), result.size(), special);
1660
+ int check = llama_token_to_piece(model , token, result.data(), result.size(), special);
1660
1661
GGML_ASSERT(check == -n_tokens);
1661
1662
}
1662
1663
else {
@@ -2086,7 +2087,11 @@ struct llama_vocab {
2086
2087
std::unordered_map<token, id> token_to_id;
2087
2088
std::vector<token_data> id_to_token;
2088
2089
2089
- std::unordered_map<token, id> special_tokens_cache;
2090
+ bool has_cache = false;
2091
+
2092
+ std::unordered_map<token, id> cache_special_tokens;
2093
+ std::unordered_map<id, token> cache_token_to_piece; // llama_token_to_piece(special = false);
2094
+ std::unordered_map<id, token> cache_token_to_piece_special; // llama_token_to_piece(special = true);
2090
2095
2091
2096
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
2092
2097
@@ -4789,7 +4794,7 @@ static void llm_load_vocab(
4789
4794
// And skip the ones which are one character
4790
4795
if (utf8_str_len > 1) {
4791
4796
// At this point what we have left are special tokens only
4792
- vocab.special_tokens_cache [token] = id;
4797
+ vocab.cache_special_tokens [token] = id;
4793
4798
4794
4799
// Count manually found special tokens
4795
4800
special_tokens_count_from_verification++;
@@ -4816,6 +4821,13 @@ static void llm_load_vocab(
4816
4821
);
4817
4822
}
4818
4823
}
4824
+
4825
+ for (llama_token id = 0; id < (llama_token) n_vocab; ++id) {
4826
+ vocab.cache_token_to_piece[id] = llama_token_to_piece(&model, id, false);
4827
+ vocab.cache_token_to_piece_special[id] = llama_token_to_piece(&model, id, true);
4828
+ }
4829
+
4830
+ vocab.has_cache = true;
4819
4831
}
4820
4832
4821
4833
static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@@ -12898,7 +12910,7 @@ struct fragment_buffer_variant {
12898
12910
12899
12911
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
12900
12912
// for each special token
12901
- for (const auto & st: vocab.special_tokens_cache ) {
12913
+ for (const auto & st: vocab.cache_special_tokens ) {
12902
12914
const auto & special_token = st.first;
12903
12915
const auto & special_id = st.second;
12904
12916
@@ -14058,7 +14070,7 @@ void llama_sample_repetition_penalties(
14058
14070
14059
14071
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
14060
14072
GGML_ASSERT(ctx);
14061
- const int64_t t_start_sample_us = ggml_time_us();
14073
+ int64_t t_start_sample_us = ggml_time_us();
14062
14074
14063
14075
bool allow_eog = false;
14064
14076
for (const auto & stack : grammar->stacks) {
@@ -14074,8 +14086,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
14074
14086
candidates_grammar.reserve(candidates->size);
14075
14087
14076
14088
for (size_t i = 0; i < candidates->size; ++i) {
14077
- const llama_token id = candidates->data[i].id;
14078
- const std::string piece = llama_token_to_piece( ctx, id, false );
14089
+ const llama_token id = candidates->data[i].id;
14090
+ const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id );
14079
14091
14080
14092
if (llama_token_is_eog(&ctx->model, id)) {
14081
14093
if (!allow_eog) {
@@ -14275,7 +14287,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
14275
14287
GGML_ASSERT(false);
14276
14288
}
14277
14289
14278
- const std::string piece = llama_token_to_piece( ctx, token, false );
14290
+ const std::string & piece = ctx->model.vocab.cache_token_to_piece.at( token);
14279
14291
14280
14292
// Note terminating 0 in decoded string
14281
14293
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
@@ -17948,69 +17960,79 @@ static std::string llama_decode_text(const std::string & text) {
17948
17960
17949
17961
// does not write null-terminator to buf
17950
17962
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
17963
+ if (model->vocab.has_cache) {
17964
+ const auto & cache = special ? model->vocab.cache_token_to_piece_special : model->vocab.cache_token_to_piece;
17965
+ const auto & res = cache.at(token);
17966
+ if (length < (int) res.size()) {
17967
+ return -(int) res.size();
17968
+ }
17969
+ memcpy(buf, res.c_str(), res.size());
17970
+ return res.size();
17971
+ }
17972
+
17951
17973
if (0 <= token && token < llama_n_vocab(model)) {
17952
17974
switch (llama_vocab_get_type(model->vocab)) {
17953
- case LLAMA_VOCAB_TYPE_WPM:
17954
- case LLAMA_VOCAB_TYPE_SPM: {
17955
- // NOTE: we accept all unsupported token types,
17956
- // suppressing them like CONTROL tokens.
17957
- if (llama_is_normal_token(model->vocab, token)) {
17958
- std::string result = model->vocab.id_to_token[token].text;
17959
- llama_unescape_whitespace(result);
17960
- if (length < (int) result.length()) {
17961
- return -(int) result.length();
17962
- }
17963
- memcpy(buf, result.c_str(), result.length());
17964
- return result.length();
17965
- } else if (
17966
- (llama_is_user_defined_token(model->vocab, token)) ||
17967
- (llama_is_control_token (model->vocab, token) && special)) {
17968
- std::string result = model->vocab.id_to_token[token].text;
17969
- if (length < (int) result.length()) {
17970
- return -(int) result.length();
17971
- }
17972
- memcpy(buf, result.c_str(), result.length());
17973
- return result.length();
17974
- } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
17975
- if (length < 3) {
17976
- return -3;
17977
- }
17978
- memcpy(buf, "\xe2\x96\x85", 3);
17979
- return 3;
17980
- } else if (llama_is_byte_token(model->vocab, token)) {
17981
- if (length < 1) {
17982
- return -1;
17975
+ case LLAMA_VOCAB_TYPE_WPM:
17976
+ case LLAMA_VOCAB_TYPE_SPM: {
17977
+ // NOTE: we accept all unsupported token types,
17978
+ // suppressing them like CONTROL tokens.
17979
+ if (llama_is_normal_token(model->vocab, token)) {
17980
+ std::string result = model->vocab.id_to_token[token].text;
17981
+ llama_unescape_whitespace(result);
17982
+ if (length < (int) result.length()) {
17983
+ return -(int) result.length();
17984
+ }
17985
+ memcpy(buf, result.c_str(), result.length());
17986
+ return result.length();
17987
+ } else if (
17988
+ (llama_is_user_defined_token(model->vocab, token)) ||
17989
+ (llama_is_control_token (model->vocab, token) && special)) {
17990
+ std::string result = model->vocab.id_to_token[token].text;
17991
+ if (length < (int) result.length()) {
17992
+ return -(int) result.length();
17993
+ }
17994
+ memcpy(buf, result.c_str(), result.length());
17995
+ return result.length();
17996
+ } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
17997
+ if (length < 3) {
17998
+ return -3;
17999
+ }
18000
+ memcpy(buf, "\xe2\x96\x85", 3);
18001
+ return 3;
18002
+ } else if (llama_is_byte_token(model->vocab, token)) {
18003
+ if (length < 1) {
18004
+ return -1;
18005
+ }
18006
+ buf[0] = llama_token_to_byte(model->vocab, token);
18007
+ return 1;
17983
18008
}
17984
- buf[0] = llama_token_to_byte(model->vocab, token);
17985
- return 1;
18009
+ break;
17986
18010
}
17987
- break;
17988
- }
17989
- case LLAMA_VOCAB_TYPE_BPE: {
17990
- // NOTE: we accept all unsupported token types,
17991
- // suppressing them like CONTROL tokens.
17992
- if (llama_is_normal_token(model->vocab, token)) {
17993
- std::string result = model->vocab.id_to_token[token].text;
17994
- result = llama_decode_text(result);
17995
- if (length < (int) result.length()) {
17996
- return -(int) result.length();
17997
- }
17998
- memcpy(buf, result.c_str(), result.length());
17999
- return result.length();
18000
- } else if (
18001
- (llama_is_user_defined_token(model->vocab, token)) ||
18002
- (llama_is_control_token (model->vocab, token) && special)) {
18003
- std::string result = model->vocab.id_to_token[token].text;
18004
- if (length < (int) result.length()) {
18005
- return -(int) result.length();
18011
+ case LLAMA_VOCAB_TYPE_BPE: {
18012
+ // NOTE: we accept all unsupported token types,
18013
+ // suppressing them like CONTROL tokens.
18014
+ if (llama_is_normal_token(model->vocab, token)) {
18015
+ std::string result = model->vocab.id_to_token[token].text;
18016
+ result = llama_decode_text(result);
18017
+ if (length < (int) result.length()) {
18018
+ return -(int) result.length();
18019
+ }
18020
+ memcpy(buf, result.c_str(), result.length());
18021
+ return result.length();
18022
+ } else if (
18023
+ (llama_is_user_defined_token(model->vocab, token)) ||
18024
+ (llama_is_control_token (model->vocab, token) && special)) {
18025
+ std::string result = model->vocab.id_to_token[token].text;
18026
+ if (length < (int) result.length()) {
18027
+ return -(int) result.length();
18028
+ }
18029
+ memcpy(buf, result.c_str(), result.length());
18030
+ return result.length();
18006
18031
}
18007
- memcpy(buf, result.c_str(), result.length());
18008
- return result.length();
18032
+ break;
18009
18033
}
18010
- break;
18011
- }
18012
- default:
18013
- GGML_ASSERT(false);
18034
+ default:
18035
+ GGML_ASSERT(false);
18014
18036
}
18015
18037
}
18016
18038
return 0;
0 commit comments