Skip to content

Commit ca7cdcc

Browse files
authored
[Tokenizer] Support ByteLevel BPE in tokenizer token table (#2248)
1 parent 6a43570 commit ca7cdcc

File tree

4 files changed

+198
-22
lines changed

4 files changed

+198
-22
lines changed

cpp/serve/engine.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ class EngineImpl : public Engine {
5656
}
5757
this->request_stream_callback_ = std::move(request_stream_callback);
5858
this->trace_recorder_ = trace_recorder;
59-
this->tokenizer_ = Tokenizer::FromPath(engine_config->model);
60-
this->token_table_ = tokenizer_->TokenTable();
61-
this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_);
59+
6260
// Step 2. Initialize each model independently.
6361
// Create the logit processor and sampler.
6462
this->models_.clear();
@@ -100,6 +98,21 @@ class EngineImpl : public Engine {
10098
engine_config->additional_model_lib_paths[i], /*model_index=*/i + 1);
10199
}
102100

101+
// Step 3. Initialize tokenizer and grammar
102+
this->tokenizer_ = Tokenizer::FromPath(engine_config->model);
103+
std::string token_table_postproc_method;
104+
if (model_configs[0].count("token_table_postproc_method") == 0) {
105+
// Backward compatibility: use "byte-fallback" by default
106+
token_table_postproc_method = "byte-fallback";
107+
} else {
108+
token_table_postproc_method =
109+
model_configs[0].at("token_table_postproc_method").get<std::string>();
110+
}
111+
this->token_table_ =
112+
Tokenizer::PostProcessTokenTable(tokenizer_->TokenTable(), token_table_postproc_method);
113+
this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_);
114+
115+
// Step 4. Initialize engine actions that represent state transitions.
103116
int max_num_tokens = engine_config->max_num_sequence;
104117
DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr};
105118
if (engine_config->speculative_mode != SpeculativeMode::kDisable) {
@@ -113,7 +126,6 @@ class EngineImpl : public Engine {
113126
this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder);
114127
Sampler sampler = this->models_[0]->CreateSampler(
115128
max_num_tokens, static_cast<int>(this->models_.size()), trace_recorder);
116-
// Step 3. Initialize engine actions that represent state transitions.
117129
if (engine_config->speculative_mode != SpeculativeMode::kDisable) {
118130
// Speculative decoding is only possible for more than one model.
119131
ICHECK_GT(this->models_.size(), 1U);

cpp/tokenizers.cc

Lines changed: 89 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
#include <tvm/runtime/logging.h>
1010
#include <tvm/runtime/registry.h>
1111

12+
#include <array>
1213
#include <filesystem>
1314
#include <fstream>
1415
#include <string>
1516

17+
#include "./support/encoding.h"
1618
#include "./support/load_bytes_from_file.h"
1719

1820
namespace mlc {
@@ -91,13 +93,8 @@ Tokenizer Tokenizer::FromPath(const String& _path) {
9193
LOG(FATAL) << "Cannot find any tokenizer under: " << _path;
9294
}
9395

94-
/*!
95-
* \brief Post-process a raw token (which may be a raw byte or contain lower
96-
* one eights block) to the actual token.
97-
* We do this in order to conform with the tokenizers' setup.
98-
*/
99-
inline std::string PostProcessToken(std::string token) {
100-
// 1. The token represents a byte.
96+
/*! \brief ByteFallback decoder: transform tokens like <0x1B> to hex char byte 1B */
97+
inline std::string ByteFallbackDecoder(const std::string& token) {
10198
if (token.length() == 6 && token.substr(0, 3) == "<0x" && token.back() == '>') {
10299
int byte = 0;
103100
for (int i = 0; i < 2; ++i) {
@@ -108,15 +105,82 @@ inline std::string PostProcessToken(std::string token) {
108105
ICHECK(byte >= 0 && byte < 256);
109106
return std::string(/*n=*/1, static_cast<char>(byte));
110107
}
108+
return token;
109+
}
111110

112-
// 2. The token contains "\u2581" which means space.
113-
static const std::string& lower_one_eighth_block = "\u2581";
114-
size_t pos = token.find(lower_one_eighth_block);
115-
while (pos != std::string::npos) {
116-
token.replace(pos, /*n=*/lower_one_eighth_block.length(), /*str=*/" ");
117-
pos = token.find(lower_one_eighth_block);
111+
/*! \brief SpaceReplacer decoder: transform "\u2581" back to space */
112+
inline std::string SpaceReplacerDecoder(const std::string& token) {
113+
// \u2581 is the unicode for "lower one eighth block"
114+
// UTF8 encoding for \u2581 is 0xE2 0x96 0x81
115+
std::string result;
116+
for (size_t i = 0; i < token.size(); ++i) {
117+
if (i + 2 < token.size() && token[i] == char(0xE2) && token[i + 1] == char(0x96) &&
118+
token[i + 2] == char(0x81)) {
119+
result += ' ';
120+
i += 2;
121+
} else {
122+
result += token[i];
123+
}
124+
}
125+
return result;
126+
}
127+
128+
/*! \brief ByteLevel decoder: inverses the bytes-to-unicode transformation in the encoding
129+
* process as in
130+
* https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59
131+
*/
132+
inline std::string ByteLevelDecoder(const std::string& token) {
133+
// clang-format off
134+
// The inverse map of bytes_to_unicode. -1 means there is no mapping to this unicode.
135+
static const std::array<int, 324> unicode_to_byte_map = {
136+
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
137+
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
138+
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68,
139+
69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91,
140+
92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
141+
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, -1, -1, -1, -1,
142+
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
143+
-1, -1, -1, -1, -1, -1, -1, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, -1,
144+
174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
145+
192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209,
146+
210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227,
147+
228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245,
148+
246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
149+
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 127, 128,
150+
129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146,
151+
147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 173
152+
};
153+
// clang-format on
154+
155+
auto unicode_codepoints = ParseUTF8(token.c_str());
156+
std::string decoded;
157+
158+
for (auto unicode_codepoint : unicode_codepoints) {
159+
ICHECK(unicode_codepoint >= 0 &&
160+
unicode_codepoint < static_cast<int>(unicode_to_byte_map.size()));
161+
int byte = unicode_to_byte_map[unicode_codepoint];
162+
if (byte == -1) {
163+
// If there is no mapping, add the codepoint itself to the result string
164+
// Some tokenizer like Phi-2 have raw tokens like \t\t
165+
decoded += static_cast<char>(unicode_codepoint);
166+
} else {
167+
decoded += static_cast<char>(byte);
168+
}
169+
}
170+
return decoded;
171+
}
172+
173+
/*!
174+
* \brief Post-process a raw token to the actual token with the given post-processing method.
175+
*/
176+
inline std::string PostProcessToken(const std::string& token, const std::string& postproc_method) {
177+
if (postproc_method == "byte_fallback") {
178+
return SpaceReplacerDecoder(ByteFallbackDecoder(token));
179+
} else if (postproc_method == "byte_level") {
180+
return ByteLevelDecoder(token);
181+
} else {
182+
LOG(FATAL) << "Unknown post-processing method: " << postproc_method;
118183
}
119-
return token;
120184
}
121185

122186
const std::vector<std::string>& TokenizerObj::TokenTable() {
@@ -127,12 +191,21 @@ const std::vector<std::string>& TokenizerObj::TokenTable() {
127191
int vocab_size = tokenizer->GetVocabSize();
128192
token_table_.reserve(vocab_size);
129193
for (int32_t token_id = 0; token_id < vocab_size; ++token_id) {
130-
std::string token = tokenizer->IdToToken(token_id);
131-
token_table_.push_back(PostProcessToken(token));
194+
token_table_.push_back(tokenizer->IdToToken(token_id));
132195
}
133196
return token_table_;
134197
}
135198

199+
std::vector<std::string> Tokenizer::PostProcessTokenTable(
200+
const std::vector<std::string>& token_table, const std::string& postproc_method) {
201+
std::vector<std::string> postprocessed_token_table;
202+
postprocessed_token_table.reserve(token_table.size());
203+
for (const std::string& token : token_table) {
204+
postprocessed_token_table.push_back(PostProcessToken(token, postproc_method));
205+
}
206+
return postprocessed_token_table;
207+
}
208+
136209
TVM_REGISTER_GLOBAL("mlc.Tokenizer").set_body_typed([](const String& path) {
137210
return Tokenizer::FromPath(path);
138211
});

cpp/tokenizers.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class TokenizerObj : public Object {
3030
std::vector<int32_t> Encode(const std::string& text) const;
3131
/*! \brief Decode token ids into text. */
3232
std::string Decode(const std::vector<int32_t>& token_ids) const;
33-
/*! \brief Return the token table of the tokenizer. */
33+
/*! \brief Return the token table of the tokenizer. Special tokens are included. */
3434
const std::vector<std::string>& TokenTable();
3535

3636
/*!
@@ -64,6 +64,25 @@ class Tokenizer : public ObjectRef {
6464
/*! \brief Create a tokenizer from a directory path on disk. */
6565
MLC_LLM_DLL static Tokenizer FromPath(const String& path);
6666

67+
/*!
68+
* \brief Convert raw tokens provided by the tokenizer to their original string to simplify
69+
* later processing. E.g. For LLaMA-2, convert "▁of" to " of".
70+
*
71+
* \param token_table The raw token table.
72+
* \param postproc_method The postprocessing method to use. Now we only support "byte-fallback"
73+
* and "byte-level", which refers to the type of the decoder of the tokenizer.
74+
* - "byte-fallback": Use the decoding method in the byte-fallback BPE tokenizer. This is used
75+
* by LLaMA-2, Mixtral-7b, etc. This method: 1) transform tokens like <0x1B> to hex char
76+
* byte 1B. (known as the byte-fallback method); 2) transform \\u2581 to space.
77+
* - "byte-level": Use the decoding method in the byte-level BPE tokenizer. This is used by
78+
* LLaMA-3, GPT-2, Phi-2, etc. This method inverses the bytes-to-unicode transformation in
79+
* the encoding process as in
80+
* https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59
81+
* \returns The postprocessed token table containing the original strings.
82+
*/
83+
static std::vector<std::string> PostProcessTokenTable(const std::vector<std::string>& token_table,
84+
const std::string& postproc_method);
85+
6786
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Tokenizer, ObjectRef, TokenizerObj);
6887

6988
private:

python/mlc_llm/interface/gen_config.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
import shutil
77
from pathlib import Path
8-
from typing import Any, Dict, List, Optional, Union
8+
from typing import Any, Dict, List, Literal, Optional, Union
99

1010
from mlc_llm.conversation_template import ConvTemplateRegistry
1111
from mlc_llm.model import Model
@@ -51,7 +51,11 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes
5151
pad_token_id: int = None
5252
bos_token_id: int = None
5353
eos_token_id: int = None
54+
# Tokenizer configuration
5455
tokenizer_files: List[str] = dataclasses.field(default_factory=list)
56+
# The method to post-process the token table. See
57+
# cpp/tokenizers.h::Tokenizer::PostProcessTokenTable for details
58+
token_table_postproc_method: Literal["byte_fallback", "byte_level"] = None
5559
# Version control
5660
version: str = VERSION
5761

@@ -129,6 +133,70 @@ def json2rwkv_tokenizer(vocab: Path, out: Path) -> None:
129133
msgpack.pack(idx2token, f)
130134

131135

136+
def detect_token_table_postproc_method(output_path: Path) -> Literal["byte_fallback", "byte_level"]:
137+
"""Detect the token table postprocessing method from tokenizer.json that is found under
138+
output_path. If not detected, use ByteFallback as default.
139+
140+
Check the decoder field of the tokenizer. If it uses ByteFallback decoder, return
141+
"byte_fallback". If it uses ByteLevel decoder, return "byte_level". Otherwise, use
142+
ByteFallback as default.
143+
144+
See also cpp/tokenizers.h::Tokenizer::PostProcessTokenTable.
145+
"""
146+
output_tokenizer_path = output_path / "tokenizer.json"
147+
if not output_tokenizer_path.exists():
148+
logger.warning(
149+
"Tokenizer token table postprocessing method is not detected as tokenizer.json "
150+
"is not found, use ByteFallback (the same as LLaMA/LLaMA2) by default"
151+
)
152+
return "byte_fallback"
153+
154+
with output_tokenizer_path.open("r", encoding="utf-8") as in_file:
155+
tokenizer_json = json.load(in_file)
156+
157+
# Find all decoders in tokenizer.json
158+
decoders = []
159+
160+
if "decoder" not in tokenizer_json:
161+
logger.warning(
162+
"Decoder field is not found in tokenizer.json, use ByteFallback (the same as "
163+
"LLaMA/LLaMA2) as the token table postprocessing method by default"
164+
)
165+
return "byte_fallback"
166+
167+
decoders_json = tokenizer_json["decoder"]
168+
assert "type" in decoders_json, "Decoder type is not specified in tokenizer.json"
169+
if decoders_json["type"] == "Sequence":
170+
assert "decoders" in decoders_json
171+
decoders = decoders_json["decoders"]
172+
else:
173+
decoders = [decoders_json]
174+
175+
is_byte_level = False
176+
is_byte_fallback = False
177+
178+
for decoder in decoders:
179+
if decoder["type"] == "ByteLevel":
180+
is_byte_level = True
181+
if decoder["type"] == "ByteFallback":
182+
is_byte_fallback = True
183+
assert not (
184+
is_byte_level and is_byte_fallback
185+
), "Tokenizer decoder cannot have both type ByteLevel and type ByteFallback"
186+
187+
if is_byte_level:
188+
return "byte_level"
189+
if is_byte_fallback:
190+
return "byte_fallback"
191+
192+
logger.warning(
193+
"Neither ByteLevel nor ByteFallback decoder is detected in tokenizer.json, use "
194+
"ByteFallback (the same as LLaMA/LLaMA2) as the token table postprocessing method "
195+
"by default"
196+
)
197+
return "byte_fallback"
198+
199+
132200
def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements
133201
config: Path,
134202
model: Model,
@@ -255,6 +323,10 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
255323
except Exception: # pylint: disable=broad-exception-caught
256324
logger.exception("%s with the exception below. Skipping", FAILED)
257325

326+
# 3.4. Find the token table postprocessing method from tokenizer.json if it exists. If not
327+
# detected, use "byte_fallback" as default.
328+
mlc_chat_config.token_table_postproc_method = detect_token_table_postproc_method(output)
329+
258330
# Step 4. Load system default value
259331
mlc_chat_config.apply_defaults()
260332
# Step 5. Dump the configuration file to output directory

0 commit comments

Comments
 (0)