9
9
#include < tvm/runtime/logging.h>
10
10
#include < tvm/runtime/registry.h>
11
11
12
+ #include < array>
12
13
#include < filesystem>
13
14
#include < fstream>
14
15
#include < string>
15
16
17
+ #include " ./support/encoding.h"
16
18
#include " ./support/load_bytes_from_file.h"
17
19
18
20
namespace mlc {
@@ -91,13 +93,8 @@ Tokenizer Tokenizer::FromPath(const String& _path) {
91
93
LOG (FATAL) << " Cannot find any tokenizer under: " << _path;
92
94
}
93
95
94
- /* !
95
- * \brief Post-process a raw token (which may be a raw byte or contain lower
96
- * one eights block) to the actual token.
97
- * We do this in order to conform with the tokenizers' setup.
98
- */
99
- inline std::string PostProcessToken (std::string token) {
100
- // 1. The token represents a byte.
96
+ /* ! \brief ByteFallback decoder: transform tokens like <0x1B> to hex char byte 1B */
97
+ inline std::string ByteFallbackDecoder (const std::string& token) {
101
98
if (token.length () == 6 && token.substr (0 , 3 ) == " <0x" && token.back () == ' >' ) {
102
99
int byte = 0 ;
103
100
for (int i = 0 ; i < 2 ; ++i) {
@@ -108,15 +105,82 @@ inline std::string PostProcessToken(std::string token) {
108
105
ICHECK (byte >= 0 && byte < 256 );
109
106
return std::string (/* n=*/ 1 , static_cast <char >(byte));
110
107
}
108
+ return token;
109
+ }
111
110
112
- // 2. The token contains "\u2581" which means space.
113
- static const std::string& lower_one_eighth_block = " \u2581 " ;
114
- size_t pos = token.find (lower_one_eighth_block);
115
- while (pos != std::string::npos) {
116
- token.replace (pos, /* n=*/ lower_one_eighth_block.length (), /* str=*/ " " );
117
- pos = token.find (lower_one_eighth_block);
111
+ /* ! \brief SpaceReplacer decoder: transform "\u2581" back to space */
112
+ inline std::string SpaceReplacerDecoder (const std::string& token) {
113
+ // \u2581 is the unicode for "lower one eighth block"
114
+ // UTF8 encoding for \u2581 is 0xE2 0x96 0x81
115
+ std::string result;
116
+ for (size_t i = 0 ; i < token.size (); ++i) {
117
+ if (i + 2 < token.size () && token[i] == char (0xE2 ) && token[i + 1 ] == char (0x96 ) &&
118
+ token[i + 2 ] == char (0x81 )) {
119
+ result += ' ' ;
120
+ i += 2 ;
121
+ } else {
122
+ result += token[i];
123
+ }
124
+ }
125
+ return result;
126
+ }
127
+
128
+ /* ! \brief ByteLevel decoder: inverses the bytes-to-unicode transformation in the encoding
129
+ * process as in
130
+ * https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59
131
+ */
132
+ inline std::string ByteLevelDecoder (const std::string& token) {
133
+ // clang-format off
134
+ // The inverse map of bytes_to_unicode. -1 means there is no mapping to this unicode.
135
+ static const std::array<int , 324 > unicode_to_byte_map = {
136
+ -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
137
+ -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , 33 , 34 , 35 , 36 , 37 , 38 , 39 , 40 , 41 , 42 , 43 , 44 , 45 ,
138
+ 46 , 47 , 48 , 49 , 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 , 60 , 61 , 62 , 63 , 64 , 65 , 66 , 67 , 68 ,
139
+ 69 , 70 , 71 , 72 , 73 , 74 , 75 , 76 , 77 , 78 , 79 , 80 , 81 , 82 , 83 , 84 , 85 , 86 , 87 , 88 , 89 , 90 , 91 ,
140
+ 92 , 93 , 94 , 95 , 96 , 97 , 98 , 99 , 100 , 101 , 102 , 103 , 104 , 105 , 106 , 107 , 108 , 109 , 110 , 111 ,
141
+ 112 , 113 , 114 , 115 , 116 , 117 , 118 , 119 , 120 , 121 , 122 , 123 , 124 , 125 , 126 , -1 , -1 , -1 , -1 ,
142
+ -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 , -1 ,
143
+ -1 , -1 , -1 , -1 , -1 , -1 , -1 , 161 , 162 , 163 , 164 , 165 , 166 , 167 , 168 , 169 , 170 , 171 , 172 , -1 ,
144
+ 174 , 175 , 176 , 177 , 178 , 179 , 180 , 181 , 182 , 183 , 184 , 185 , 186 , 187 , 188 , 189 , 190 , 191 ,
145
+ 192 , 193 , 194 , 195 , 196 , 197 , 198 , 199 , 200 , 201 , 202 , 203 , 204 , 205 , 206 , 207 , 208 , 209 ,
146
+ 210 , 211 , 212 , 213 , 214 , 215 , 216 , 217 , 218 , 219 , 220 , 221 , 222 , 223 , 224 , 225 , 226 , 227 ,
147
+ 228 , 229 , 230 , 231 , 232 , 233 , 234 , 235 , 236 , 237 , 238 , 239 , 240 , 241 , 242 , 243 , 244 , 245 ,
148
+ 246 , 247 , 248 , 249 , 250 , 251 , 252 , 253 , 254 , 255 , 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 ,
149
+ 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 , 31 , 32 , 127 , 128 ,
150
+ 129 , 130 , 131 , 132 , 133 , 134 , 135 , 136 , 137 , 138 , 139 , 140 , 141 , 142 , 143 , 144 , 145 , 146 ,
151
+ 147 , 148 , 149 , 150 , 151 , 152 , 153 , 154 , 155 , 156 , 157 , 158 , 159 , 160 , 173
152
+ };
153
+ // clang-format on
154
+
155
+ auto unicode_codepoints = ParseUTF8 (token.c_str ());
156
+ std::string decoded;
157
+
158
+ for (auto unicode_codepoint : unicode_codepoints) {
159
+ ICHECK (unicode_codepoint >= 0 &&
160
+ unicode_codepoint < static_cast <int >(unicode_to_byte_map.size ()));
161
+ int byte = unicode_to_byte_map[unicode_codepoint];
162
+ if (byte == -1 ) {
163
+ // If there is no mapping, add the codepoint itself to the result string
164
+ // Some tokenizer like Phi-2 have raw tokens like \t\t
165
+ decoded += static_cast <char >(unicode_codepoint);
166
+ } else {
167
+ decoded += static_cast <char >(byte);
168
+ }
169
+ }
170
+ return decoded;
171
+ }
172
+
173
+ /* !
174
+ * \brief Post-process a raw token to the actual token with the given post-processing method.
175
+ */
176
+ inline std::string PostProcessToken (const std::string& token, const std::string& postproc_method) {
177
+ if (postproc_method == " byte_fallback" ) {
178
+ return SpaceReplacerDecoder (ByteFallbackDecoder (token));
179
+ } else if (postproc_method == " byte_level" ) {
180
+ return ByteLevelDecoder (token);
181
+ } else {
182
+ LOG (FATAL) << " Unknown post-processing method: " << postproc_method;
118
183
}
119
- return token;
120
184
}
121
185
122
186
const std::vector<std::string>& TokenizerObj::TokenTable () {
@@ -127,12 +191,21 @@ const std::vector<std::string>& TokenizerObj::TokenTable() {
127
191
int vocab_size = tokenizer->GetVocabSize ();
128
192
token_table_.reserve (vocab_size);
129
193
for (int32_t token_id = 0 ; token_id < vocab_size; ++token_id) {
130
- std::string token = tokenizer->IdToToken (token_id);
131
- token_table_.push_back (PostProcessToken (token));
194
+ token_table_.push_back (tokenizer->IdToToken (token_id));
132
195
}
133
196
return token_table_;
134
197
}
135
198
199
+ std::vector<std::string> Tokenizer::PostProcessTokenTable (
200
+ const std::vector<std::string>& token_table, const std::string& postproc_method) {
201
+ std::vector<std::string> postprocessed_token_table;
202
+ postprocessed_token_table.reserve (token_table.size ());
203
+ for (const std::string& token : token_table) {
204
+ postprocessed_token_table.push_back (PostProcessToken (token, postproc_method));
205
+ }
206
+ return postprocessed_token_table;
207
+ }
208
+
136
209
TVM_REGISTER_GLOBAL (" mlc.Tokenizer" ).set_body_typed([](const String& path) {
137
210
return Tokenizer::FromPath (path);
138
211
});
0 commit comments