Skip to content

Commit

Permalink
Unicode codepoint flags for custom regexs (ggerganov#7245)
Browse files Browse the repository at this point in the history
* Replace CODEPOINT_TYPE_* with codepoint_flags
* Update and bugfix brute force random test
* Deterministic brute force random test
* Unicode normalization NFD
* Get rid of BOM
  • Loading branch information
jaime-m-p authored May 17, 2024
1 parent 0fc1e82 commit b43272a
Show file tree
Hide file tree
Showing 7 changed files with 7,297 additions and 2,407 deletions.
8 changes: 4 additions & 4 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12576,16 +12576,16 @@ struct llm_tokenizer_wpm {
// to lowercase, pad chinese characters, pad punctuation
std::string new_str = "";
for (uint32_t code : cpts_nfd) {
int type = unicode_cpt_type(code);
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
const codepoint_flags flags = unicode_cpt_flags(code);
if (flags.is_accent_mark || flags.is_control) {
continue;
}
code = unicode_tolower(code);
if (type == CODEPOINT_TYPE_SEPARATOR) {
if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ?
code = ' ';
}
std::string s = unicode_cpt_to_utf8(code);
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) {
if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
new_str += " ";
new_str += s;
new_str += " ";
Expand Down
162 changes: 116 additions & 46 deletions scripts/gen-unicode-data.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,134 @@
import regex
import ctypes
import unicodedata


def get_matches(regex_expr):
regex_expr_compiled = regex.compile(regex_expr)
unicode_ranges = []
current_range = None
class CoodepointFlags (ctypes.Structure):
_fields_ = [ # see definition in unicode.h
("is_undefined", ctypes.c_uint16, 1),
("is_number", ctypes.c_uint16, 1), # regex: \p{N}
("is_letter", ctypes.c_uint16, 1), # regex: \p{L}
("is_separator", ctypes.c_uint16, 1), # regex: \p{Z}
("is_accent_mark", ctypes.c_uint16, 1), # regex: \p{M}
("is_punctuation", ctypes.c_uint16, 1), # regex: \p{P}
("is_symbol", ctypes.c_uint16, 1), # regex: \p{S}
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
]

for codepoint in range(0x110000):
char = chr(codepoint)
if regex_expr_compiled.match(char):
if current_range is None:
current_range = [codepoint, codepoint]
else:
current_range[1] = codepoint
elif current_range is not None:
unicode_ranges.append(tuple(current_range))
current_range = None

if current_range is not None:
unicode_ranges.append(tuple(current_range))
assert (ctypes.sizeof(CoodepointFlags) == 2)

return unicode_ranges

MAX_CODEPOINTS = 0x110000

def print_cat(mode, cat, ranges):
if mode == "range":
print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat)) # noqa: NP100
if mode == "map":
print("const std::map<uint32_t, uint32_t> unicode_map_{} = {{".format(cat)) # noqa: NP100
for i, values in enumerate(ranges):
end = ",\n" if (i % 4 == 3 or i + 1 == len(ranges)) else ", "
values = ["0x%08X" % value for value in values]
print("{" + ", ".join(values) + "}", end=end) # noqa: NP100
print("};") # noqa: NP100
print("") # noqa: NP100
regex_number = regex.compile(r'\p{N}')
regex_letter = regex.compile(r'\p{L}')
regex_separator = regex.compile(r'\p{Z}')
regex_accent_mark = regex.compile(r'\p{M}')
regex_punctuation = regex.compile(r'\p{P}')
regex_symbol = regex.compile(r'\p{S}')
regex_control = regex.compile(r'\p{C}')
regex_whitespace = regex.compile(r'\s')

codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
table_whitespace = []
table_lowercase = []
table_uppercase = []
table_nfd = []

print_cat("range", "number", get_matches(r'\p{N}'))
print_cat("range", "letter", get_matches(r'\p{L}'))
print_cat("range", "separator", get_matches(r'\p{Z}'))
print_cat("range", "accent_mark", get_matches(r'\p{M}'))
print_cat("range", "punctuation", get_matches(r'\p{P}'))
print_cat("range", "symbol", get_matches(r'\p{S}'))
print_cat("range", "control", get_matches(r'\p{C}'))
for codepoint in range(MAX_CODEPOINTS):
# convert codepoint to unicode character
char = chr(codepoint)

print_cat("range", "whitespace", get_matches(r'\s'))
# regex categories
flags = codepoint_flags[codepoint]
flags.is_number = bool(regex_number.match(char))
flags.is_letter = bool(regex_letter.match(char))
flags.is_separator = bool(regex_separator.match(char))
flags.is_accent_mark = bool(regex_accent_mark.match(char))
flags.is_punctuation = bool(regex_punctuation.match(char))
flags.is_symbol = bool(regex_symbol.match(char))
flags.is_control = bool(regex_control.match(char))
flags.is_undefined = bytes(flags)[0] == 0
assert (not flags.is_undefined)

# whitespaces
if bool(regex_whitespace.match(char)):
table_whitespace.append(codepoint)

map_lowercase = []
map_uppercase = []
for codepoint in range(0x110000):
char = chr(codepoint)
# lowercase conversion
lower = ord(char.lower()[0])
upper = ord(char.upper()[0])
if codepoint != lower:
map_lowercase.append((codepoint, lower))
table_lowercase.append((codepoint, lower))

# uppercase conversion
upper = ord(char.upper()[0])
if codepoint != upper:
map_uppercase.append((codepoint, upper))
print_cat("map", "lowercase", map_lowercase)
print_cat("map", "uppercase", map_uppercase)
table_uppercase.append((codepoint, upper))

# NFD normalization
norm = ord(unicodedata.normalize('NFD', char)[0])
if codepoint != norm:
table_nfd.append((codepoint, norm))


# group ranges with same flags
ranges_flags = [(0, codepoint_flags[0])] # start, flags
for codepoint, flags in enumerate(codepoint_flags):
if bytes(flags) != bytes(ranges_flags[-1][1]):
ranges_flags.append((codepoint, flags))
ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))


# group ranges with same nfd
ranges_nfd = [(0, 0, 0)] # start, last, nfd
for codepoint, norm in table_nfd:
start = ranges_nfd[-1][0]
if ranges_nfd[-1] != (start, codepoint - 1, norm):
ranges_nfd.append(None)
start = codepoint
ranges_nfd[-1] = (start, codepoint, norm)


# Generate 'unicode-data.cpp'


def out(line=""):
print(line, end='\n') # noqa


out("""\
// generated with scripts/gen-unicode-data.py
#include "unicode-data.h"
#include <cstdint>
#include <vector>
#include <unordered_map>
#include <unordered_set>
""")

out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
for codepoint, flags in ranges_flags:
flags = int.from_bytes(bytes(flags), "little")
out("{0x%06X, 0x%04X}," % (codepoint, flags))
out("};\n")

out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
out(", ".join("0x%06X" % cpt for cpt in table_whitespace))
out("};\n")

out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
for tuple in table_lowercase:
out("{0x%06X, 0x%06X}," % tuple)
out("};\n")

out("const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase = {")
for tuple in table_uppercase:
out("{0x%06X, 0x%06X}," % tuple)
out("};\n")

# TODO: generate unicode_map_nfd
out("const std::vector<range_nfd> unicode_ranges_nfd = { // start, last, nfd")
for triple in ranges_nfd:
out("{0x%06X, 0x%06X, 0x%06X}," % triple)
out("};\n")
Loading

0 comments on commit b43272a

Please sign in to comment.