Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extending grammar integration tests #6644

Merged
merged 10 commits into from
Apr 29, 2024
Next Next commit
Cleaning up integration tests to share code between tests and make it…
… simpler to add new tests.
  • Loading branch information
HanClinto committed Apr 29, 2024
commit 9cd07c2f9d575096cd68d3cfee5bc84b25803163
113 changes: 33 additions & 80 deletions tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,8 @@
#include <cassert>
#include <string>

static void test_simple_grammar() {
// Test case for a simple grammar
const std::string grammar_str = R"""(root ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""";

grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
static llama_grammar* build_grammar(const std::string & grammar_str) {
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());

// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());
Expand All @@ -30,28 +24,45 @@ number ::= [0-9]+)""";
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));

std::string input = "123+456";
return grammar;
}

static bool match_string(const std::string & input, llama_grammar* grammar) {
auto decoded = decode_utf8(input, {});

const auto & code_points = decoded.first;

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
assert(!grammar->stacks.empty());
if (grammar->stacks.empty()) {
// no stacks means that the grammar failed to match at this point
return false;
}
}

bool completed_grammar = false;

for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
// An empty stack means that the grammar has been completed
return true;
}
}

assert(completed_grammar);
return false;
}

static void test_simple_grammar() {
// Test case for a simple grammar
const std::string grammar_str = R"""(root ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""";

auto grammar = build_grammar(grammar_str);

bool matched = match_string("123+456", grammar);

assert(matched);

// Clean up allocated memory
llama_grammar_free(grammar);
Expand All @@ -68,17 +79,7 @@ variable ::= [a-zA-Z_][a-zA-Z0-9_]*
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
ws ::= [ \t\n\r]?)""";

grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());

// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());

// Ensure we have a root node
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));

std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
auto grammar = build_grammar(grammar_str);

// Save the original grammar stacks so that we can reset after every new string we want to test
auto original_stacks = grammar->stacks;
Expand Down Expand Up @@ -130,68 +131,19 @@ ws ::= [ \t\n\r]?)""";

// Passing strings
for (const auto & test_string : test_strings_pass) {
auto decoded = decode_utf8(test_string, {});

const auto & code_points = decoded.first;

int pos = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
++pos;
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);

// Expect that each code point will not cause the grammar to fail
if (grammar->stacks.empty()) {
fprintf(stdout, "Error at position %d\n", pos);
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
fprintf(stderr, "Input string is %s:\n", test_string.c_str());
}
assert(!grammar->stacks.empty());
}

bool completed_grammar = false;

for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}
bool matched = match_string(test_string, grammar);

assert(completed_grammar);
assert(matched);

// Reset the grammar stacks
grammar->stacks = original_stacks;
}

// Failing strings
for (const auto & test_string : test_strings_fail) {
auto decoded = decode_utf8(test_string, {});

const auto & code_points = decoded.first;
bool parse_failed = false;

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
parse_failed = true;
break;
}
assert(!grammar->stacks.empty());
}

bool completed_grammar = false;

for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}
bool matched = match_string(test_string, grammar);

// Ensure that the grammar is not completed, or that each string failed to match as-expected
assert((!completed_grammar) || parse_failed);
assert(!matched);

// Reset the grammar stacks
grammar->stacks = original_stacks;
Expand Down Expand Up @@ -231,13 +183,14 @@ number ::= [0-9]+)""";
// Ensure we did NOT parsed correctly
assert(parsed_grammar.rules.empty());

fprintf(stderr, "End of expected error. Test successful.\n");
fprintf(stderr, "End of expected error.\n");
}

int main() {
test_simple_grammar();
test_complex_grammar();
test_failure_missing_root();
test_failure_missing_reference();
fprintf(stdout, "All tests passed.\n");
return 0;
}