Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unicode in grammars (fixes #2501) #2553

Merged
merged 4 commits into from
Aug 17, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 132 additions & 25 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2082,37 +2082,81 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
// grammar - internal
//

struct llama_partial_utf8 {
uint32_t value; // bit value so far (unshifted)
int n_remain; // num bytes remaining; -1 indicates invalid sequence
};

struct llama_grammar {
const std::vector<std::vector<llama_grammar_element>> rules;
std::vector<std::vector<const llama_grammar_element *>> stacks;

// buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8;
};

struct llama_grammar_candidate {
size_t index;
const uint32_t * code_points;
size_t index;
const uint32_t * code_points;
llama_partial_utf8 partial_utf8;
};

// NOTE: assumes valid utf8 (but checks for overrun)
// adds a terminating 0 for use as pointer
std::vector<uint32_t> decode_utf8(const char * src) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const char * src,
llama_partial_utf8 partial_start) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
const char * pos = src;
std::vector<uint32_t> code_points;
uint32_t value = partial_start.value;
int n_remain = partial_start.n_remain;

// continue previous decode, if applicable
while (*pos != 0 && n_remain > 0) {
uint8_t next_byte = static_cast<uint8_t>(*pos);
if ((next_byte >> 6) != 2) {
// invalid sequence, abort
code_points.push_back(0);
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
}
value = (value << 6) + (next_byte & 0x3F);
++pos;
--n_remain;
}

if (partial_start.n_remain > 0 && n_remain == 0) {
code_points.push_back(value);
}

// decode any subsequent utf-8 sequences, which may end in an incomplete one
while (*pos != 0) {
uint8_t first_byte = static_cast<uint8_t>(*pos);
uint8_t highbits = first_byte >> 4;
int len = lookup[highbits];
uint8_t mask = (1 << (8 - len)) - 1;
uint32_t value = first_byte & mask;
const char * end = pos + len; // may overrun!
n_remain = lookup[highbits] - 1;

if (n_remain < 0) {
// invalid sequence, abort
code_points.clear();
code_points.push_back(0);
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
}

uint8_t mask = (1 << (7 - n_remain)) - 1;
value = first_byte & mask;
++pos;
for ( ; pos < end && *pos != 0; ++pos) {
while (*pos != 0 && n_remain > 0) {
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
++pos;
--n_remain;
}
if (n_remain == 0) {
code_points.push_back(value);
}
code_points.push_back(value);
}
code_points.push_back(0);
return code_points;

return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
}

// returns true iff pos points to the end of one of the definitions of a rule
Expand Down Expand Up @@ -2149,6 +2193,56 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
return std::make_pair(found == is_positive_char, pos);
}

// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
// range at pos (regular or inverse range)
// asserts that pos is pointing to a char range element
static bool llama_grammar_match_partial_char(
const llama_grammar_element * pos,
const llama_partial_utf8 partial_utf8) {

bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);

uint32_t partial_value = partial_utf8.value;
int n_remain = partial_utf8.n_remain;

// invalid sequence or 7-bit char split across 2 bytes (overlong)
if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
return false;
}

// range of possible code points this partial UTF-8 sequence could complete to
uint32_t low = partial_value << (n_remain * 6);
uint32_t high = low | ((1 << (n_remain * 6)) - 1);

if (low == 0) {
if (n_remain == 2) {
low = 1 << 11;
} else if (n_remain == 3) {
low = 1 << 16;
}
}

do {
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
// inclusive range, e.g. [a-z]
if (pos->value <= high && low <= pos[1].value) {
return is_positive_char;
}
pos += 2;
} else {
// exact char match, e.g. [a] or "a"
if (low <= pos->value && pos->value <= high) {
return is_positive_char;
}
pos += 1;
}
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);

return !is_positive_char;
}


// transforms a grammar pushdown stack into N possible stacks, all ending
// at a character range (terminal element)
static void llama_grammar_advance_stack(
Expand Down Expand Up @@ -2249,19 +2343,27 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
std::vector<llama_grammar_candidate> rejects;

if (stack.empty()) {
// accept nothing; EOS is handled elsewhere
rejects.insert(rejects.end(), candidates.begin(), candidates.end());
for (auto tok : candidates) {
if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
rejects.push_back(tok);
}
}
return rejects;
}

const llama_grammar_element * stack_pos = stack.back();

std::vector<llama_grammar_candidate> next_candidates;
for (auto tok : candidates) {
if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) {
if (tok.code_points[1] != 0) {
next_candidates.push_back({ tok.index, tok.code_points + 1 });
if (*tok.code_points == 0) {
// reached end of full codepoints in token, reject iff it ended in a partial sequence
// that cannot satisfy this position in grammar
if (tok.partial_utf8.n_remain != 0 &&
!llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
rejects.push_back(tok);
}
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
} else {
rejects.push_back(tok);
}
Expand All @@ -2279,7 +2381,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_

auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
for (auto tok : next_rejects) {
rejects.push_back({ tok.index, tok.code_points - 1 });
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
}

return rejects;
Expand Down Expand Up @@ -2344,7 +2446,7 @@ struct llama_grammar * llama_grammar_init(
}
} while (true);

return new llama_grammar{ std::move(vec_rules), std::move(stacks) };
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
}

void llama_grammar_free(struct llama_grammar * grammar) {
Expand Down Expand Up @@ -2650,8 +2752,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c

const llama_token eos = llama_token_eos();

std::vector<std::vector<uint32_t>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar;
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar;

for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
Expand All @@ -2663,8 +2765,10 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
} else if (*str == 0) {
candidates->data[i].logit = -INFINITY;
} else {
candidates_decoded.push_back(decode_utf8(str));
candidates_grammar.push_back({ i, candidates_decoded.back().data() });
candidates_decoded.push_back(decode_utf8(str, grammar->partial_utf8));
candidates_grammar.push_back({
i, candidates_decoded.back().first.data(), candidates_decoded.back().second
});
}
}

Expand Down Expand Up @@ -2865,11 +2969,14 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
}

const char * str = llama_token_to_str(ctx, token);

// Note terminating 0 in decoded string
auto code_points = decode_utf8(str);
const auto decoded = decode_utf8(str, grammar->partial_utf8);
const auto & code_points = decoded.first;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
}
grammar->partial_utf8 = decoded.second;
LLAMA_ASSERT(!grammar->stacks.empty());

ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
Expand Down