Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests: Added integration tests for GBNF parser #6472

Merged
merged 6 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ TEST_TARGETS = \
tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \
tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope \
tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease \
tests/test-json-schema-to-grammar
tests/test-json-schema-to-grammar tests/test-grammar-integration

# Code coverage output files
COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report
Expand Down Expand Up @@ -918,6 +918,10 @@ tests/test-grammar-parser: tests/test-grammar-parser.cpp ggml.o llama.o grammar-
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

tests/test-grammar-integration: tests/test-grammar-integration.cpp ggml.o llama.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

tests/test-double-float: tests/test-double-float.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-gpt2 AR

llama_test(test-grammar-parser.cpp)
llama_test(test-llama-grammar.cpp)
llama_test(test-grammar-integration.cpp)
llama_test(test-grad0.cpp)
# llama_test(test-opt.cpp) # SLOW
llama_test(test-backend-ops.cpp)
Expand Down
243 changes: 243 additions & 0 deletions tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
#ifdef NDEBUG
#undef NDEBUG
#endif

#define LLAMA_API_INTERNAL

#include "ggml.h"
#include "llama.h"
#include "grammar-parser.h"
#include "unicode.h"
#include <cassert>
#include <string>

static void test_simple_grammar() {
// Test case for a simple grammar
const std::string grammar_str = R"""(root ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""";

grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());

// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());

// Ensure we have a root node
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));

std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));

std::string input = "123+456";

auto decoded = decode_utf8(input, {});

const auto & code_points = decoded.first;

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
assert(!grammar->stacks.empty());
}

bool completed_grammar = false;

for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}

assert(completed_grammar);

// Clean up allocated memory
llama_grammar_free(grammar);
}

static void test_complex_grammar() {
// Test case for a more complex grammar, with both failure strings and success strings
const std::string grammar_str = R"""(root ::= expression
expression ::= term ws (("+"|"-") ws term)*
term ::= factor ws (("*"|"/") ws factor)*
factor ::= number | variable | "(" expression ")" | function-call
number ::= [0-9]+
variable ::= [a-zA-Z_][a-zA-Z0-9_]*
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
ws ::= [ \t\n\r]?)""";

grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());

// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());

// Ensure we have a root node
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));

std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));

// Save the original grammar stacks so that we can reset after every new string we want to test
auto original_stacks = grammar->stacks;

// Test a few strings
std::vector<std::string> test_strings_pass = {
"42",
"1*2*3*4*5",
"x",
"x+10",
"x1+y2",
"(a+b)*(c-d)",
"func()",
"func(x,y+2)",
"a*(b+c)-d/e",
"f(g(x),h(y,z))",
"x + 10",
"x1 + y2",
"(a + b) * (c - d)",
"func()",
"func(x, y + 2)",
"a * (b + c) - d / e",
"f(g(x), h(y, z))",
"123+456",
"123*456*789-123/456+789*123",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
};

std::vector<std::string> test_strings_fail = {
"+",
"/ 3x",
"x + + y",
"a * / b",
"func(,)",
"func(x y)",
"(a + b",
"x + y)",
"a + b * (c - d",
"42 +",
"x +",
"x + 10 +",
"(a + b) * (c - d",
"func(",
"func(x, y + 2",
"a * (b + c) - d /",
"f(g(x), h(y, z)",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
};

// Passing strings
for (const auto & test_string : test_strings_pass) {
auto decoded = decode_utf8(test_string, {});

const auto & code_points = decoded.first;

int pos = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
++pos;
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);

// Expect that each code point will not cause the grammar to fail
if (grammar->stacks.empty()) {
fprintf(stdout, "Error at position %d\n", pos);
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
fprintf(stderr, "Input string is %s:\n", test_string.c_str());
}
assert(!grammar->stacks.empty());
}

bool completed_grammar = false;

for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}

assert(completed_grammar);

// Reset the grammar stacks
grammar->stacks = original_stacks;
}

// Failing strings
for (const auto & test_string : test_strings_fail) {
auto decoded = decode_utf8(test_string, {});

const auto & code_points = decoded.first;
bool parse_failed = false;

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
if (grammar->stacks.empty()) {
parse_failed = true;
break;
}
assert(!grammar->stacks.empty());
}

bool completed_grammar = false;

for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}

// Ensure that the grammar is not completed, or that each string failed to match as-expected
assert((!completed_grammar) || parse_failed);

// Reset the grammar stacks
grammar->stacks = original_stacks;
}

// Clean up allocated memory
llama_grammar_free(grammar);
}

static void test_failure_missing_root() {
// Test case for a grammar that is missing a root rule
const std::string grammar_str = R"""(rot ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""";

grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());

// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());

// Ensure we do NOT have a root node
assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
}

static void test_failure_missing_reference() {
// Test case for a grammar that is missing a referenced rule
const std::string grammar_str = R"""(root ::= expr
expr ::= term ("+" term)*
term ::= numero
number ::= [0-9]+)""";

fprintf(stderr, "Expected error: ");

grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());

// Ensure we did NOT parsed correctly
assert(parsed_grammar.rules.empty());

fprintf(stderr, "End of expected error. Test successful.\n");
}

int main() {
test_simple_grammar();
test_complex_grammar();
test_failure_missing_root();
test_failure_missing_reference();
return 0;
}
Loading