-
Notifications
You must be signed in to change notification settings - Fork 12.1k
Improve BERT tokenization for accented characters and non-latin scripts #5740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,10 +68,12 @@ | |
#include <cstdio> | ||
#include <cstring> | ||
#include <ctime> | ||
#include <cwctype> | ||
#include <forward_list> | ||
#include <fstream> | ||
#include <functional> | ||
#include <initializer_list> | ||
#include <locale> | ||
#include <map> | ||
#include <memory> | ||
#include <mutex> | ||
|
@@ -8897,37 +8899,46 @@ struct llm_tokenizer_wpm { | |
} | ||
|
||
std::vector<std::string> preprocess(const std::string & text) { | ||
std::string ori_str = normalize(text); | ||
uint64_t ori_size = ori_str.size(); | ||
// normalalization form D | ||
std::vector<uint32_t> codepoints = codepoints_from_utf8(text); | ||
std::vector<uint32_t> nfd_codepoints; | ||
for (uint32_t code : codepoints) { | ||
auto it = nfd_map.find(code); | ||
if (it != nfd_map.end()) { | ||
for (uint32_t c : it->second) { | ||
nfd_codepoints.push_back(c); | ||
} | ||
} else { | ||
nfd_codepoints.push_back(code); | ||
} | ||
} | ||
|
||
// single punct / single symbol / single digit | ||
// baseline: add whitespace on the left and right of punct and chinese characters | ||
std::vector<std::string> words; | ||
// strip accents, strip control, uniformize whitespace, | ||
// to lowercase, pad chinese characters, pad punctuation | ||
std::string new_str = ""; | ||
uint64_t i = 0; | ||
while (i < ori_size) { | ||
int utf_char_len = utf8_len(ori_str[i]); | ||
if ((utf_char_len == 1) && ispunct(ori_str[i])) { | ||
new_str += " "; | ||
new_str += ori_str[i]; | ||
new_str += " "; | ||
i += 1; | ||
for (uint32_t code : nfd_codepoints) { | ||
int type = codepoint_type(code); | ||
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) { | ||
continue; | ||
} | ||
else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) { | ||
code = to_lower(code); | ||
if (type == CODEPOINT_TYPE_WHITESPACE) { | ||
code = ' '; | ||
} | ||
std::string s = codepoint_to_utf8(code); | ||
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) { | ||
new_str += " "; | ||
new_str += ori_str.substr(i, 3); | ||
new_str += s; | ||
new_str += " "; | ||
i += 3; | ||
} | ||
else { | ||
new_str += ori_str[i]; | ||
i += 1; | ||
} else { | ||
new_str += s; | ||
} | ||
} | ||
|
||
// split by whitespace | ||
uint64_t l = 0; | ||
uint64_t r = 0; | ||
std::vector<std::string> words; | ||
while (r < new_str.size()) { | ||
// if is whitespace | ||
if (isspace(new_str[r])) { | ||
|
@@ -8945,47 +8956,20 @@ struct llm_tokenizer_wpm { | |
return words; | ||
} | ||
|
||
std::string normalize(const std::string & text) { | ||
// TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98 | ||
std::string text2 = strip_accents(text); | ||
for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) { | ||
char c = text2[i]; | ||
if (c >= 'A' && c <= 'Z') { | ||
text2[i] = c - 'A' + 'a'; | ||
} | ||
uint32_t to_lower(uint32_t code) { | ||
#if defined(_WIN32) | ||
if (code > 0xFFFF) { | ||
return code; | ||
} | ||
return text2; | ||
#endif | ||
return std::tolower(wchar_t(code), std::locale("en_US.UTF-8")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @iamlemec This depends on the system having the Even if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, it's not safe to assume that lowercasing one letter at a time is the same as lowercasing the whole string. See the notes at cppreference and the special-cased entry for capital sigma in UnicodeData. |
||
} | ||
|
||
bool is_chinese_char(const std::string & str) { | ||
int len = str.length(); | ||
unsigned int codepoint = 0; | ||
int num_bytes = 0; | ||
int i = 0; | ||
unsigned char ch = static_cast<unsigned char>(str[i]); | ||
if (ch <= 0x7f) { | ||
codepoint = ch; | ||
num_bytes = 1; | ||
} else if ((ch >> 5) == 0x06) { | ||
codepoint = ch & 0x1f; | ||
num_bytes = 2; | ||
} else if ((ch >> 4) == 0x0e) { | ||
codepoint = ch & 0x0f; | ||
num_bytes = 3; | ||
} else if ((ch >> 3) == 0x1e) { | ||
codepoint = ch & 0x07; | ||
num_bytes = 4; | ||
} | ||
for (int j = 1; j < num_bytes; ++j) { | ||
if (i + j >= len) { | ||
return false; // incomplete UTF-8 character | ||
} | ||
unsigned char next_ch = static_cast<unsigned char>(str[i + j]); | ||
if ((next_ch >> 6) != 0x02) { | ||
return false; // invalid trailing byte | ||
} | ||
codepoint = (codepoint << 6) | (next_ch & 0x3f); | ||
} | ||
bool is_ascii_punct(uint32_t code) { | ||
return code < 256 && ispunct(code); | ||
} | ||
|
||
bool is_chinese_char(uint32_t codepoint) { | ||
if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) || | ||
(codepoint >= 0x3400 && codepoint <= 0x4DBF) || | ||
(codepoint >= 0x20000 && codepoint <= 0x2A6DF) || | ||
|
@@ -9001,41 +8985,6 @@ struct llm_tokenizer_wpm { | |
return false; | ||
} | ||
|
||
std::string strip_accents(const std::string & input_string) { | ||
std::string resultString; | ||
std::map<std::string, char> accent_map = { | ||
{"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'}, | ||
{"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'}, | ||
{"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'}, | ||
{"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'}, | ||
{"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'}, | ||
{"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'}, | ||
{"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'}, | ||
{"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'}, | ||
{"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'}, | ||
}; | ||
|
||
for (size_t i = 0; i < input_string.length();) { | ||
int len = utf8_len(input_string[i]); | ||
std::string curChar = input_string.substr(i, len); | ||
auto iter = accent_map.find(curChar); | ||
if (iter != accent_map.end()) { | ||
resultString += iter->second; | ||
} else { | ||
resultString += curChar; | ||
} | ||
i += len; | ||
} | ||
|
||
return resultString; | ||
} | ||
|
||
static size_t utf8_len(char src) { | ||
const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4}; | ||
uint8_t highbits = static_cast<uint8_t>(src) >> 4; | ||
return lookup[highbits]; | ||
} | ||
|
||
const llama_vocab & vocab; | ||
}; | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.