@@ -943,16 +943,29 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama
943943static int tokenize_prompt (const llama_vocab * vocab, const std::string & prompt,
944944 std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
945945 const bool is_first = llama_memory_seq_pos_max (llama_get_memory (llama_data.context .get ()), 0 ) == -1 ;
946-
947- const int n_prompt_tokens = -llama_tokenize (vocab, prompt.c_str (), prompt.size (), NULL , 0 , is_first, true );
948- prompt_tokens.resize (n_prompt_tokens);
949- if (llama_tokenize (vocab, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), is_first,
950- true ) < 0 ) {
951- printe (" failed to tokenize the prompt\n " );
946+ int n_tokens = prompt.size () + 2 * is_first;
947+ prompt_tokens.resize (n_tokens);
948+ n_tokens = llama_tokenize (vocab, prompt.c_str (), prompt.size (),
949+ prompt_tokens.data (), prompt_tokens.size (),
950+ is_first, /* parse_special =*/ true );
951+ if (n_tokens == std::numeric_limits<int32_t >::min ()) {
952+ printe (" tokenization failed: input too large\n " );
952953 return -1 ;
953954 }
954-
955- return n_prompt_tokens;
955+ if (n_tokens < 0 ) {
956+ prompt_tokens.resize (-n_tokens);
957+ int check = llama_tokenize (vocab, prompt.c_str (), prompt.size (),
958+ prompt_tokens.data (), prompt_tokens.size (),
959+ is_first, /* parse_special =*/ true );
960+ if (check != -n_tokens) {
961+ printe (" failed to tokenize the prompt (size mismatch)\n " );
962+ return -1 ;
963+ }
964+ n_tokens = check;
965+ } else {
966+ prompt_tokens.resize (n_tokens);
967+ }
968+ return n_tokens;
956969}
957970
958971// Check if we have enough space in the context to evaluate this batch
0 commit comments