Skip to content

Commit 949b0ad

Browse files
committed
Reland #66 and #67 (#74)
Summary: Reland #66 and #67 with unbypassable arc lint fixes Reviewed By: kirklandsign Differential Revision: D74693197 Pulled By: jackzhxng
1 parent ea21dc7 commit 949b0ad

File tree

4 files changed

+40
-19
lines changed

4 files changed

+40
-19
lines changed

include/pytorch/tokenizers/bpe_tokenizer_base.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include <pytorch/tokenizers/string_integer_map.h>
2626
#include <pytorch/tokenizers/tokenizer.h>
2727

28+
#include "re2/re2.h"
29+
2830
namespace tokenizers {
2931
namespace detail {
3032

@@ -104,6 +106,25 @@ static Result<TokenMap> buildTokenMap(
104106
return buildTokenMap(std::move(pairs));
105107
}
106108

109+
inline Result<std::unique_ptr<IRegex>> build_special_token_regex(
110+
const TokenMap& special_token_map) {
111+
std::string special_pattern;
112+
const std::size_t count = special_token_map.size();
113+
114+
for (std::size_t i = 0; i < count; ++i) {
115+
const auto& [token, _] = special_token_map.getElement(i);
116+
if (!special_pattern.empty()) {
117+
special_pattern += "|";
118+
}
119+
special_pattern += re2::RE2::QuoteMeta(std::string(token));
120+
}
121+
122+
if (special_pattern.empty()) {
123+
return static_cast<std::unique_ptr<IRegex>>(nullptr);
124+
}
125+
return create_regex(special_pattern);
126+
}
127+
107128
class BPETokenizerBase : public Tokenizer {
108129
public:
109130
Result<std::vector<uint64_t>>

src/hf_tokenizer.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ Error HFTokenizer::load(const std::string& path) {
6969
special_tokens,
7070
[](const auto& it) -> std::string { return it.at("content"); },
7171
[](const auto& it) -> std::uint64_t { return it.at("id"); }));
72+
73+
// Create special token regex to help later with encoding.
74+
special_token_regex_ =
75+
TK_UNWRAP(detail::build_special_token_regex(special_token_map));
76+
77+
// Store for future use.
7278
special_token_map_.emplace(std::move(special_token_map));
7379
} catch (const json::out_of_range& e) {
7480
fprintf(stderr, "Could not parse special tokens: %s\n", e.what());
@@ -142,8 +148,15 @@ Error HFTokenizer::load(const std::string& path) {
142148

143149
// Pull out the token strings
144150
try {
145-
const std::string bos_token = parsed_config_json.at("bos_token");
146-
const std::string eos_token = parsed_config_json.at("eos_token");
151+
const std::string bos_token = parsed_config_json.contains("bos_token") &&
152+
!parsed_config_json["bos_token"].is_null()
153+
? parsed_config_json["bos_token"].get<std::string>()
154+
: "";
155+
156+
const std::string eos_token = parsed_config_json.contains("eos_token") &&
157+
!parsed_config_json["eos_token"].is_null()
158+
? parsed_config_json["eos_token"].get<std::string>()
159+
: "";
147160
const auto bos_res = special_token_map_->tryGetInteger(bos_token);
148161
const auto eos_res = special_token_map_->tryGetInteger(eos_token);
149162
if (!bos_res) {

src/tiktoken.cpp

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
#include <fstream>
3333
#include <limits>
3434
#include <unordered_set>
35-
#include "re2/re2.h"
3635

3736
namespace tokenizers {
3837

@@ -47,21 +46,6 @@ static Result<std::unique_ptr<IRegex>> _create_regex(
4746
return create_regex(pattern);
4847
}
4948

50-
static Result<std::unique_ptr<IRegex>> _build_special_token_regex(
51-
const std::vector<std::pair<std::string, std::uint64_t>>& special_encoder) {
52-
std::string special_pattern;
53-
for (const auto& ele : special_encoder) {
54-
if (!special_pattern.empty()) {
55-
special_pattern += "|";
56-
}
57-
special_pattern += re2::RE2::QuoteMeta(ele.first);
58-
}
59-
if (special_pattern.empty()) {
60-
return static_cast<std::unique_ptr<IRegex>>(nullptr);
61-
}
62-
return _create_regex(special_pattern);
63-
}
64-
6549
static Result<std::pair<std::string, uint64_t>> _parse(
6650
const std::string& line) {
6751
// Tiktoken format
@@ -153,7 +137,7 @@ Error Tiktoken::load(const std::string& path) {
153137

154138
_regex = TK_UNWRAP(_create_regex(_pattern));
155139
special_token_regex_ =
156-
TK_UNWRAP(_build_special_token_regex(special_token_map));
140+
TK_UNWRAP(detail::build_special_token_regex(TokenMap(special_token_map)));
157141

158142
// initialize vocab_size, bos_tok, eos_tok
159143
vocab_size_ = token_map_->size() + special_token_map_->size();

targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def define_common_targets():
7777
exported_deps = [
7878
":headers",
7979
],
80+
exported_external_deps = [
81+
"re2",
82+
],
8083
visibility = [
8184
"//pytorch/tokenizers/...",
8285
],

0 commit comments

Comments
 (0)