Skip to content

Consolidate tokenizer interface #2954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,16 @@ Error Runner::load() {
if (tokenizer_->bos_tok() != bos_id_) {
ET_LOG(
Error,
"Tokenizer's BOS id %d does not match model's BOS id %d, will override tokenizer's BOS.",
"Tokenizer's BOS id %" PRIu64
" does not match model's BOS id %d, will override tokenizer's BOS.",
tokenizer_->bos_tok(),
bos_id_);
}
if (tokenizer_->eos_tok() != eos_id_) {
ET_LOG(
Error,
"Tokenizer's EOS id %d does not match model's EOS id %d, will override tokenizer's EOS.",
"Tokenizer's EOS id %" PRIu64
" does not match model's EOS id %d, will override tokenizer's EOS.",
tokenizer_->eos_tok(),
eos_id_);
}
Expand Down Expand Up @@ -227,20 +229,18 @@ Error Runner::generate(
stats_.inference_start_ms = util::time_in_ms();
shouldStop_ = false;

// encode the (string) prompt into tokens sequence
int num_prompt_tokens = 0;
// max # of prompt tokens: len(prompt) + '\0', ?BOS, ?EOS
int* prompt_tokens = new int[prompt.size() + 1 + n_bos_ + n_eos_];

// Set the sequence length to the max seq length if not provided
seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;

tokenizer_->encode(
prompt.c_str(),
n_bos_,
append_eos_ ? n_eos_ : 0,
prompt_tokens,
&num_prompt_tokens);
Result<std::vector<uint64_t>> encode_res =
tokenizer_->encode(prompt, n_bos_, append_eos_ ? n_eos_ : 0);

ET_CHECK_OK_OR_RETURN_ERROR(
encode_res.error(), "Failed to encode prompt %s", prompt.c_str());

// encode the (string) prompt into tokens sequence
std::vector<uint64_t> prompt_tokens = encode_res.get();
int num_prompt_tokens = prompt_tokens.size();

ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
ET_CHECK_MSG(
Expand Down Expand Up @@ -303,13 +303,13 @@ Error Runner::generate(

// Print the prompt for consistent output between single token prefill and
// batch prefill.
int prev = prompt_tokens[0];
int cur;
uint64_t prev = prompt_tokens[0];
uint64_t cur;
for (int i = 1; i < num_prompt_tokens; i++) {
cur = prompt_tokens[i];
auto piece_res = tokenizer_->decode(prev, cur);
ET_CHECK_OK_OR_RETURN_ERROR(piece_res.error());
util::safe_printf(piece_res.get());
util::safe_printf(piece_res.get().c_str());
fflush(stdout);
prev = cur;
}
Expand Down Expand Up @@ -361,7 +361,7 @@ Error Runner::generate(
// print the token as string, decode it with the Tokenizer object
auto piece_res = tokenizer_->decode(prev_token, cur_token);
ET_CHECK(piece_res.ok());
const char* piece = piece_res.get();
const char* piece = piece_res.get().c_str();

// same as printf("%s", piece), but skips "unsafe" bytes
util::safe_printf(piece);
Expand Down Expand Up @@ -396,7 +396,6 @@ Error Runner::generate(
stats_callback(stats_);
}

delete[] prompt_tokens;
return Error::Ok;
}

Expand Down
5 changes: 3 additions & 2 deletions examples/models/llama2/tokenizer/test/test_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <executorch/examples/models/llama2/tokenizer/tokenizer.h>
#include <executorch/runtime/platform/runtime.h>
#include <gtest/gtest.h>
#include <vector>

using namespace ::testing;

Expand All @@ -28,8 +29,8 @@ class TokenizerExtensionTest : public Test {
};

TEST_F(TokenizerExtensionTest, EncodeWithoutLoadFails) {
Error error = tokenizer_->encode("hello world", 0, 0, nullptr, nullptr);
EXPECT_EQ(error, Error::NotSupported);
Result<std::vector<uint64_t>> res = tokenizer_->encode("hello world", 0, 0);
EXPECT_EQ(res.error(), Error::NotSupported);
}

TEST_F(TokenizerExtensionTest, DecodeWithoutLoadFails) {
Expand Down
45 changes: 21 additions & 24 deletions examples/models/llama2/tokenizer/tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ static int compare_tokens(const void* a, const void* b) {
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
}

Tokenizer::Tokenizer(int32_t vocab_size, int32_t bos_tok, int32_t eos_tok)
Tokenizer::Tokenizer(int32_t vocab_size, uint64_t bos_tok, uint64_t eos_tok)
: initialized_(false),
vocab_size_(vocab_size),
bos_tok_(bos_tok),
Expand Down Expand Up @@ -142,10 +142,10 @@ Tokenizer::~Tokenizer() {
*
* @param prev_token The previous token.
* @param token The current token.
* @return Result<const char*> A pointer to the string representation of the
* @return Result<std::string> A pointer to the string representation of the
* token.
*/
Result<const char*> Tokenizer::decode(int32_t prev_token, int32_t token) {
Result<std::string> Tokenizer::decode(uint64_t prev_token, uint64_t token) {
if (!initialized_) {
ET_LOG(Error, "Tokenizer not initialized");
return Error::NotSupported;
Expand All @@ -162,7 +162,8 @@ Result<const char*> Tokenizer::decode(int32_t prev_token, int32_t token) {
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
piece = (char*)byte_pieces_ + byte_val * 2;
}
return piece;
std::string res(piece);
return res;
}

static int32_t
Expand All @@ -183,23 +184,19 @@ str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) {
* @param eos The number of EOS to append to the token list.
* @param tokens The output tokens.
* @param n_tokens The number of tokens.
* @return Error
* @return Result<std::vector<uint64_t>>
*/
Error Tokenizer::encode(
const char* text,
int8_t bos,
int8_t eos,
int32_t* tokens,
int32_t* n_tokens) {
Result<std::vector<uint64_t>>
Tokenizer::encode(const std::string& text, int8_t bos, int8_t eos) {
if (!initialized_) {
ET_LOG(Error, "Tokenizer not initialized");
return Error::NotSupported;
}
// encode the string text (input) into an upper-bound preallocated tokens[]
// array bos != 0 means prepend the BOS token (=1), eos != 0 means append the
// EOS token (=2)
if (text == nullptr) {
ET_LOG(Error, "cannot encode null text");
if (text.empty()) {
ET_LOG(Error, "cannot encode empty text");
return Error::InvalidArgument;
}

Expand All @@ -210,12 +207,12 @@ Error Tokenizer::encode(
size_t str_len = 0;

// start at 0 tokens
*n_tokens = 0;
std::vector<uint64_t> tokens;

// add optional BOS token, if desired
if (bos > 0) {
while (bos--) {
tokens[(*n_tokens)++] = bos_tok_;
tokens.push_back(bos_tok_);
}
} else {
ET_LOG(Error, "bos %d should be >= 0", bos);
Expand All @@ -230,7 +227,7 @@ Error Tokenizer::encode(
const char* space = " ";
if (text[0] != '\0') {
int dummy_prefix = str_lookup(space, sorted_vocab_.get(), vocab_size_);
tokens[(*n_tokens)++] = dummy_prefix;
tokens.push_back(dummy_prefix);
}

// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
Expand All @@ -242,7 +239,7 @@ Error Tokenizer::encode(
// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx

// process the raw (UTF-8) byte sequence of the input string
for (const char* c = text; *c != '\0'; c++) {
for (const char* c = text.c_str(); *c != '\0'; c++) {
// reset buffer if the current byte is ASCII or a leading byte
// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the
// rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in
Expand Down Expand Up @@ -271,13 +268,13 @@ Error Tokenizer::encode(
int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_);
if (id != -1) {
// we found this codepoint in vocab, add it as a token
tokens[(*n_tokens)++] = id;
tokens.push_back(id);
} else {
// byte_fallback encoding: just encode each byte as a token
// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
// so the individual bytes only start at index 3
for (int i = 0; i < str_len; i++) {
tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
tokens.push_back((unsigned char)str_buffer[i] + 3);
}
}
str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
Expand All @@ -290,7 +287,7 @@ Error Tokenizer::encode(
int best_id = -1;
int best_idx = -1;

for (int i = 0; i < (*n_tokens - 1); i++) {
for (int i = 0; i < tokens.size() - 1; i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
snprintf(
str_buffer,
Expand All @@ -314,24 +311,24 @@ Error Tokenizer::encode(
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
tokens[best_idx] = best_id;
// delete token at position best_idx+1, shift the entire sequence back 1
for (int i = best_idx + 1; i < (*n_tokens - 1); i++) {
for (int i = best_idx + 1; i < tokens.size() - 1; i++) {
tokens[i] = tokens[i + 1];
}
(*n_tokens)--; // token length decreased
tokens.pop_back(); // token length decreased
}

// add optional EOS (=2) token, if desired
if (eos >= 0) {
while (eos--) {
tokens[(*n_tokens)++] = eos_tok_;
tokens.push_back(eos_tok_);
}
} else {
ET_LOG(Error, "eos %d should be >= 0", eos);
return Error::InvalidArgument;
}

delete[] str_buffer;
return Error::Ok;
return Result(tokens);
}

} // namespace executor
Expand Down
19 changes: 8 additions & 11 deletions examples/models/llama2/tokenizer/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cstring>
#include <memory>
#include <string>
#include <vector>

#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
Expand All @@ -32,37 +33,33 @@ struct TokenIndex {

class Tokenizer {
public:
explicit Tokenizer(int32_t vocab_size, int32_t bos_tok, int32_t eos_tok);
explicit Tokenizer(int32_t vocab_size, uint64_t bos_tok, uint64_t eos_tok);
~Tokenizer();

Error load(const std::string& tokenizer_path);

Error encode(
const char* text,
int8_t bos,
int8_t eos,
int32_t* tokens,
int32_t* n_tokens);
Result<std::vector<uint64_t>>
encode(const std::string& input, int8_t bos, int8_t eos);

Result<const char*> decode(int prev_token, int token);
Result<std::string> decode(uint64_t prev_token, uint64_t token);

// getters
int32_t vocab_size() const {
return vocab_size_;
}

int32_t bos_tok() const {
uint64_t bos_tok() const {
return bos_tok_;
}

int32_t eos_tok() const {
uint64_t eos_tok() const {
return eos_tok_;
}

private:
bool initialized_;
const int32_t vocab_size_;
int32_t bos_tok_, eos_tok_;
uint64_t bos_tok_, eos_tok_;
std::unique_ptr<char*[]> vocab_;
std::unique_ptr<float[]> vocab_scores_;
std::unique_ptr<TokenIndex[]> sorted_vocab_;
Expand Down