File tree Expand file tree Collapse file tree 4 files changed +40
-19
lines changed
include/pytorch/tokenizers Expand file tree Collapse file tree 4 files changed +40
-19
lines changed Original file line number Diff line number Diff line change 25
25
#include < pytorch/tokenizers/string_integer_map.h>
26
26
#include < pytorch/tokenizers/tokenizer.h>
27
27
28
+ #include " re2/re2.h"
29
+
28
30
namespace tokenizers {
29
31
namespace detail {
30
32
@@ -104,6 +106,25 @@ static Result<TokenMap> buildTokenMap(
104
106
return buildTokenMap (std::move (pairs));
105
107
}
106
108
109
+ inline Result<std::unique_ptr<IRegex>> build_special_token_regex (
110
+ const TokenMap& special_token_map) {
111
+ std::string special_pattern;
112
+ const std::size_t count = special_token_map.size ();
113
+
114
+ for (std::size_t i = 0 ; i < count; ++i) {
115
+ const auto & [token, _] = special_token_map.getElement (i);
116
+ if (!special_pattern.empty ()) {
117
+ special_pattern += " |" ;
118
+ }
119
+ special_pattern += re2::RE2::QuoteMeta (std::string (token));
120
+ }
121
+
122
+ if (special_pattern.empty ()) {
123
+ return static_cast <std::unique_ptr<IRegex>>(nullptr );
124
+ }
125
+ return create_regex (special_pattern);
126
+ }
127
+
107
128
class BPETokenizerBase : public Tokenizer {
108
129
public:
109
130
Result<std::vector<uint64_t >>
Original file line number Diff line number Diff line change @@ -69,6 +69,12 @@ Error HFTokenizer::load(const std::string& path) {
69
69
special_tokens,
70
70
[](const auto & it) -> std::string { return it.at (" content" ); },
71
71
[](const auto & it) -> std::uint64_t { return it.at (" id" ); }));
72
+
73
+ // Create special token regex to help later with encoding.
74
+ special_token_regex_ =
75
+ TK_UNWRAP (detail::build_special_token_regex (special_token_map));
76
+
77
+ // Store for future use.
72
78
special_token_map_.emplace (std::move (special_token_map));
73
79
} catch (const json::out_of_range& e) {
74
80
fprintf (stderr, " Could not parse special tokens: %s\n " , e.what ());
@@ -142,8 +148,15 @@ Error HFTokenizer::load(const std::string& path) {
142
148
143
149
// Pull out the token strings
144
150
try {
145
- const std::string bos_token = parsed_config_json.at (" bos_token" );
146
- const std::string eos_token = parsed_config_json.at (" eos_token" );
151
+ const std::string bos_token = parsed_config_json.contains (" bos_token" ) &&
152
+ !parsed_config_json[" bos_token" ].is_null ()
153
+ ? parsed_config_json[" bos_token" ].get <std::string>()
154
+ : " " ;
155
+
156
+ const std::string eos_token = parsed_config_json.contains (" eos_token" ) &&
157
+ !parsed_config_json[" eos_token" ].is_null ()
158
+ ? parsed_config_json[" eos_token" ].get <std::string>()
159
+ : " " ;
147
160
const auto bos_res = special_token_map_->tryGetInteger (bos_token);
148
161
const auto eos_res = special_token_map_->tryGetInteger (eos_token);
149
162
if (!bos_res) {
Original file line number Diff line number Diff line change 32
32
#include < fstream>
33
33
#include < limits>
34
34
#include < unordered_set>
35
- #include " re2/re2.h"
36
35
37
36
namespace tokenizers {
38
37
@@ -47,21 +46,6 @@ static Result<std::unique_ptr<IRegex>> _create_regex(
47
46
return create_regex (pattern);
48
47
}
49
48
50
- static Result<std::unique_ptr<IRegex>> _build_special_token_regex (
51
- const std::vector<std::pair<std::string, std::uint64_t >>& special_encoder) {
52
- std::string special_pattern;
53
- for (const auto & ele : special_encoder) {
54
- if (!special_pattern.empty ()) {
55
- special_pattern += " |" ;
56
- }
57
- special_pattern += re2::RE2::QuoteMeta (ele.first );
58
- }
59
- if (special_pattern.empty ()) {
60
- return static_cast <std::unique_ptr<IRegex>>(nullptr );
61
- }
62
- return _create_regex (special_pattern);
63
- }
64
-
65
49
static Result<std::pair<std::string, uint64_t >> _parse (
66
50
const std::string& line) {
67
51
// Tiktoken format
@@ -153,7 +137,7 @@ Error Tiktoken::load(const std::string& path) {
153
137
154
138
_regex = TK_UNWRAP (_create_regex (_pattern));
155
139
special_token_regex_ =
156
- TK_UNWRAP (_build_special_token_regex ( special_token_map));
140
+ TK_UNWRAP (detail::build_special_token_regex ( TokenMap ( special_token_map) ));
157
141
158
142
// initialize vocab_size, bos_tok, eos_tok
159
143
vocab_size_ = token_map_->size () + special_token_map_->size ();
Original file line number Diff line number Diff line change @@ -77,6 +77,9 @@ def define_common_targets():
77
77
exported_deps = [
78
78
":headers" ,
79
79
],
80
+ exported_external_deps = [
81
+ "re2" ,
82
+ ],
80
83
visibility = [
81
84
"//pytorch/tokenizers/..." ,
82
85
],
You can’t perform that action at this time.
0 commit comments