Skip to content

Commit

Permalink
[Tokenizer] Support ByteLevel BPE in tokenizer token table (#2248)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubospica authored Apr 30, 2024
1 parent 6a43570 commit ca7cdcc
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 22 deletions.
20 changes: 16 additions & 4 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ class EngineImpl : public Engine {
}
this->request_stream_callback_ = std::move(request_stream_callback);
this->trace_recorder_ = trace_recorder;
this->tokenizer_ = Tokenizer::FromPath(engine_config->model);
this->token_table_ = tokenizer_->TokenTable();
this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_);

// Step 2. Initialize each model independently.
// Create the logit processor and sampler.
this->models_.clear();
Expand Down Expand Up @@ -100,6 +98,21 @@ class EngineImpl : public Engine {
engine_config->additional_model_lib_paths[i], /*model_index=*/i + 1);
}

// Step 3. Initialize tokenizer and grammar
this->tokenizer_ = Tokenizer::FromPath(engine_config->model);
std::string token_table_postproc_method;
if (model_configs[0].count("token_table_postproc_method") == 0) {
// Backward compatibility: use "byte-fallback" by default
token_table_postproc_method = "byte-fallback";

This comment has been minimized.

Copy link
@joshmackwilliams

joshmackwilliams May 2, 2024

Is this a typo? Maybe it should be "byte_fallback"?

} else {
token_table_postproc_method =
model_configs[0].at("token_table_postproc_method").get<std::string>();
}
this->token_table_ =
Tokenizer::PostProcessTokenTable(tokenizer_->TokenTable(), token_table_postproc_method);
this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_);

// Step 4. Initialize engine actions that represent state transitions.
int max_num_tokens = engine_config->max_num_sequence;
DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr};
if (engine_config->speculative_mode != SpeculativeMode::kDisable) {
Expand All @@ -113,7 +126,6 @@ class EngineImpl : public Engine {
this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder);
Sampler sampler = this->models_[0]->CreateSampler(
max_num_tokens, static_cast<int>(this->models_.size()), trace_recorder);
// Step 3. Initialize engine actions that represent state transitions.
if (engine_config->speculative_mode != SpeculativeMode::kDisable) {
// Speculative decoding is only possible for more than one model.
ICHECK_GT(this->models_.size(), 1U);
Expand Down
105 changes: 89 additions & 16 deletions cpp/tokenizers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>

#include <array>
#include <filesystem>
#include <fstream>
#include <string>

#include "./support/encoding.h"
#include "./support/load_bytes_from_file.h"

namespace mlc {
Expand Down Expand Up @@ -91,13 +93,8 @@ Tokenizer Tokenizer::FromPath(const String& _path) {
LOG(FATAL) << "Cannot find any tokenizer under: " << _path;
}

/*!
* \brief Post-process a raw token (which may be a raw byte or contain lower
* one eights block) to the actual token.
* We do this in order to conform with the tokenizers' setup.
*/
inline std::string PostProcessToken(std::string token) {
// 1. The token represents a byte.
/*! \brief ByteFallback decoder: transform tokens like <0x1B> to hex char byte 1B */
inline std::string ByteFallbackDecoder(const std::string& token) {
if (token.length() == 6 && token.substr(0, 3) == "<0x" && token.back() == '>') {
int byte = 0;
for (int i = 0; i < 2; ++i) {
Expand All @@ -108,15 +105,82 @@ inline std::string PostProcessToken(std::string token) {
ICHECK(byte >= 0 && byte < 256);
return std::string(/*n=*/1, static_cast<char>(byte));
}
return token;
}

// 2. The token contains "\u2581" which means space.
static const std::string& lower_one_eighth_block = "\u2581";
size_t pos = token.find(lower_one_eighth_block);
while (pos != std::string::npos) {
token.replace(pos, /*n=*/lower_one_eighth_block.length(), /*str=*/" ");
pos = token.find(lower_one_eighth_block);
/*! \brief SpaceReplacer decoder: transform "\u2581" back to space */
inline std::string SpaceReplacerDecoder(const std::string& token) {
// \u2581 is the unicode for "lower one eighth block"
// UTF8 encoding for \u2581 is 0xE2 0x96 0x81
std::string result;
for (size_t i = 0; i < token.size(); ++i) {
if (i + 2 < token.size() && token[i] == char(0xE2) && token[i + 1] == char(0x96) &&
token[i + 2] == char(0x81)) {
result += ' ';
i += 2;
} else {
result += token[i];
}
}
return result;
}

/*! \brief ByteLevel decoder: inverses the bytes-to-unicode transformation in the encoding
* process as in
* https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59
*/
inline std::string ByteLevelDecoder(const std::string& token) {
// clang-format off
// The inverse map of bytes_to_unicode. -1 means there is no mapping to this unicode.
static const std::array<int, 324> unicode_to_byte_map = {
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68,
69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91,
92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, -1,
174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209,
210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227,
228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245,
246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 127, 128,
129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146,
147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 173
};
// clang-format on

auto unicode_codepoints = ParseUTF8(token.c_str());
std::string decoded;

for (auto unicode_codepoint : unicode_codepoints) {
ICHECK(unicode_codepoint >= 0 &&
unicode_codepoint < static_cast<int>(unicode_to_byte_map.size()));
int byte = unicode_to_byte_map[unicode_codepoint];
if (byte == -1) {
// If there is no mapping, add the codepoint itself to the result string
// Some tokenizer like Phi-2 have raw tokens like \t\t
decoded += static_cast<char>(unicode_codepoint);
} else {
decoded += static_cast<char>(byte);
}
}
return decoded;
}

/*!
* \brief Post-process a raw token to the actual token with the given post-processing method.
*/
inline std::string PostProcessToken(const std::string& token, const std::string& postproc_method) {
if (postproc_method == "byte_fallback") {
return SpaceReplacerDecoder(ByteFallbackDecoder(token));
} else if (postproc_method == "byte_level") {
return ByteLevelDecoder(token);
} else {
LOG(FATAL) << "Unknown post-processing method: " << postproc_method;
}
return token;
}

const std::vector<std::string>& TokenizerObj::TokenTable() {
Expand All @@ -127,12 +191,21 @@ const std::vector<std::string>& TokenizerObj::TokenTable() {
int vocab_size = tokenizer->GetVocabSize();
token_table_.reserve(vocab_size);
for (int32_t token_id = 0; token_id < vocab_size; ++token_id) {
std::string token = tokenizer->IdToToken(token_id);
token_table_.push_back(PostProcessToken(token));
token_table_.push_back(tokenizer->IdToToken(token_id));
}
return token_table_;
}

std::vector<std::string> Tokenizer::PostProcessTokenTable(
const std::vector<std::string>& token_table, const std::string& postproc_method) {
std::vector<std::string> postprocessed_token_table;
postprocessed_token_table.reserve(token_table.size());
for (const std::string& token : token_table) {
postprocessed_token_table.push_back(PostProcessToken(token, postproc_method));
}
return postprocessed_token_table;
}

TVM_REGISTER_GLOBAL("mlc.Tokenizer").set_body_typed([](const String& path) {
return Tokenizer::FromPath(path);
});
Expand Down
21 changes: 20 additions & 1 deletion cpp/tokenizers.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TokenizerObj : public Object {
std::vector<int32_t> Encode(const std::string& text) const;
/*! \brief Decode token ids into text. */
std::string Decode(const std::vector<int32_t>& token_ids) const;
/*! \brief Return the token table of the tokenizer. */
/*! \brief Return the token table of the tokenizer. Special tokens are included. */
const std::vector<std::string>& TokenTable();

/*!
Expand Down Expand Up @@ -64,6 +64,25 @@ class Tokenizer : public ObjectRef {
/*! \brief Create a tokenizer from a directory path on disk. */
MLC_LLM_DLL static Tokenizer FromPath(const String& path);

/*!
* \brief Convert raw tokens provided by the tokenizer to their original string to simplify
* later processing. E.g. For LLaMA-2, convert "▁of" to " of".
*
* \param token_table The raw token table.
* \param postproc_method The postprocessing method to use. Now we only support "byte-fallback"
* and "byte-level", which refers to the type of the decoder of the tokenizer.
* - "byte-fallback": Use the decoding method in the byte-fallback BPE tokenizer. This is used
* by LLaMA-2, Mixtral-7b, etc. This method: 1) transform tokens like <0x1B> to hex char
* byte 1B. (known as the byte-fallback method); 2) transform \\u2581 to space.
* - "byte-level": Use the decoding method in the byte-level BPE tokenizer. This is used by
* LLaMA-3, GPT-2, Phi-2, etc. This method inverses the bytes-to-unicode transformation in
* the encoding process as in
* https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59
* \returns The postprocessed token table containing the original strings.
*/
static std::vector<std::string> PostProcessTokenTable(const std::vector<std::string>& token_table,
const std::string& postproc_method);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Tokenizer, ObjectRef, TokenizerObj);

private:
Expand Down
74 changes: 73 additions & 1 deletion python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
import shutil
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from mlc_llm.conversation_template import ConvTemplateRegistry
from mlc_llm.model import Model
Expand Down Expand Up @@ -51,7 +51,11 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes
pad_token_id: int = None
bos_token_id: int = None
eos_token_id: int = None
# Tokenizer configuration
tokenizer_files: List[str] = dataclasses.field(default_factory=list)
# The method to post-process the token table. See
# cpp/tokenizers.h::Tokenizer::PostProcessTokenTable for details
token_table_postproc_method: Literal["byte_fallback", "byte_level"] = None
# Version control
version: str = VERSION

Expand Down Expand Up @@ -129,6 +133,70 @@ def json2rwkv_tokenizer(vocab: Path, out: Path) -> None:
msgpack.pack(idx2token, f)


def detect_token_table_postproc_method(output_path: Path) -> Literal["byte_fallback", "byte_level"]:
"""Detect the token table postprocessing method from tokenizer.json that is found under
output_path. If not detected, use ByteFallback as default.
Check the decoder field of the tokenizer. If it uses ByteFallback decoder, return
"byte_fallback". If it uses ByteLevel decoder, return "byte_level". Otherwise, use
ByteFallback as default.
See also cpp/tokenizers.h::Tokenizer::PostProcessTokenTable.
"""
output_tokenizer_path = output_path / "tokenizer.json"
if not output_tokenizer_path.exists():
logger.warning(
"Tokenizer token table postprocessing method is not detected as tokenizer.json "
"is not found, use ByteFallback (the same as LLaMA/LLaMA2) by default"
)
return "byte_fallback"

with output_tokenizer_path.open("r", encoding="utf-8") as in_file:
tokenizer_json = json.load(in_file)

# Find all decoders in tokenizer.json
decoders = []

if "decoder" not in tokenizer_json:
logger.warning(
"Decoder field is not found in tokenizer.json, use ByteFallback (the same as "
"LLaMA/LLaMA2) as the token table postprocessing method by default"
)
return "byte_fallback"

decoders_json = tokenizer_json["decoder"]
assert "type" in decoders_json, "Decoder type is not specified in tokenizer.json"
if decoders_json["type"] == "Sequence":
assert "decoders" in decoders_json
decoders = decoders_json["decoders"]
else:
decoders = [decoders_json]

is_byte_level = False
is_byte_fallback = False

for decoder in decoders:
if decoder["type"] == "ByteLevel":
is_byte_level = True
if decoder["type"] == "ByteFallback":
is_byte_fallback = True
assert not (
is_byte_level and is_byte_fallback
), "Tokenizer decoder cannot have both type ByteLevel and type ByteFallback"

if is_byte_level:
return "byte_level"
if is_byte_fallback:
return "byte_fallback"

logger.warning(
"Neither ByteLevel nor ByteFallback decoder is detected in tokenizer.json, use "
"ByteFallback (the same as LLaMA/LLaMA2) as the token table postprocessing method "
"by default"
)
return "byte_fallback"


def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements
config: Path,
model: Model,
Expand Down Expand Up @@ -255,6 +323,10 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
except Exception: # pylint: disable=broad-exception-caught
logger.exception("%s with the exception below. Skipping", FAILED)

# 3.4. Find the token table postprocessing method from tokenizer.json if it exists. If not
# detected, use "byte_fallback" as default.
mlc_chat_config.token_table_postproc_method = detect_token_table_postproc_method(output)

# Step 4. Load system default value
mlc_chat_config.apply_defaults()
# Step 5. Dump the configuration file to output directory
Expand Down

0 comments on commit ca7cdcc

Please sign in to comment.