Skip to content

Commit 57dd02c

Browse files
authored
Tests: Added integration tests for GBNF parser (#6472)
* Added integration tests for GBNF parser to validate correctness of parsing, as well as correctness of string matching. Intended for use to pin behavior while working on performance improvements. * Fixing whitespace errors and cleaning error message alert to be clearer. * Removing hacky include to llama.cpp from grammar integration test now that needed functions are available via internal API. * Comment cleanup. * Reorganizing tests for readability. * Cleaning up debug message to make a bit more sense.
1 parent 75cd4c7 commit 57dd02c

File tree

3 files changed

+249
-1
lines changed

3 files changed

+249
-1
lines changed

Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ TEST_TARGETS = \
1010
tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \
1111
tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope \
1212
tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease \
13-
tests/test-json-schema-to-grammar
13+
tests/test-json-schema-to-grammar tests/test-grammar-integration
1414

1515
# Code coverage output files
1616
COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report
@@ -918,6 +918,10 @@ tests/test-grammar-parser: tests/test-grammar-parser.cpp ggml.o llama.o grammar-
918918
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
919919
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
920920

921+
tests/test-grammar-integration: tests/test-grammar-integration.cpp ggml.o llama.o grammar-parser.o $(OBJS)
922+
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
923+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
924+
921925
tests/test-double-float: tests/test-double-float.cpp ggml.o $(OBJS)
922926
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
923927
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-gpt2 AR
5959

6060
llama_test(test-grammar-parser.cpp)
6161
llama_test(test-llama-grammar.cpp)
62+
llama_test(test-grammar-integration.cpp)
6263
llama_test(test-grad0.cpp)
6364
# llama_test(test-opt.cpp) # SLOW
6465
llama_test(test-backend-ops.cpp)

tests/test-grammar-integration.cpp

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
#ifdef NDEBUG
2+
#undef NDEBUG
3+
#endif
4+
5+
#define LLAMA_API_INTERNAL
6+
7+
#include "ggml.h"
8+
#include "llama.h"
9+
#include "grammar-parser.h"
10+
#include "unicode.h"
11+
#include <cassert>
12+
#include <string>
13+
14+
static void test_simple_grammar() {
15+
// Test case for a simple grammar
16+
const std::string grammar_str = R"""(root ::= expr
17+
expr ::= term ("+" term)*
18+
term ::= number
19+
number ::= [0-9]+)""";
20+
21+
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
22+
23+
// Ensure we parsed correctly
24+
assert(!parsed_grammar.rules.empty());
25+
26+
// Ensure we have a root node
27+
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
28+
29+
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
30+
llama_grammar* grammar = llama_grammar_init(
31+
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
32+
33+
std::string input = "123+456";
34+
35+
auto decoded = decode_utf8(input, {});
36+
37+
const auto & code_points = decoded.first;
38+
39+
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
40+
auto prev_stacks = grammar->stacks;
41+
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
42+
assert(!grammar->stacks.empty());
43+
}
44+
45+
bool completed_grammar = false;
46+
47+
for (const auto & stack : grammar->stacks) {
48+
if (stack.empty()) {
49+
completed_grammar = true;
50+
break;
51+
}
52+
}
53+
54+
assert(completed_grammar);
55+
56+
// Clean up allocated memory
57+
llama_grammar_free(grammar);
58+
}
59+
60+
static void test_complex_grammar() {
61+
// Test case for a more complex grammar, with both failure strings and success strings
62+
const std::string grammar_str = R"""(root ::= expression
63+
expression ::= term ws (("+"|"-") ws term)*
64+
term ::= factor ws (("*"|"/") ws factor)*
65+
factor ::= number | variable | "(" expression ")" | function-call
66+
number ::= [0-9]+
67+
variable ::= [a-zA-Z_][a-zA-Z0-9_]*
68+
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
69+
ws ::= [ \t\n\r]?)""";
70+
71+
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
72+
73+
// Ensure we parsed correctly
74+
assert(!parsed_grammar.rules.empty());
75+
76+
// Ensure we have a root node
77+
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
78+
79+
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
80+
llama_grammar* grammar = llama_grammar_init(
81+
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
82+
83+
// Save the original grammar stacks so that we can reset after every new string we want to test
84+
auto original_stacks = grammar->stacks;
85+
86+
// Test a few strings
87+
std::vector<std::string> test_strings_pass = {
88+
"42",
89+
"1*2*3*4*5",
90+
"x",
91+
"x+10",
92+
"x1+y2",
93+
"(a+b)*(c-d)",
94+
"func()",
95+
"func(x,y+2)",
96+
"a*(b+c)-d/e",
97+
"f(g(x),h(y,z))",
98+
"x + 10",
99+
"x1 + y2",
100+
"(a + b) * (c - d)",
101+
"func()",
102+
"func(x, y + 2)",
103+
"a * (b + c) - d / e",
104+
"f(g(x), h(y, z))",
105+
"123+456",
106+
"123*456*789-123/456+789*123",
107+
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
108+
};
109+
110+
std::vector<std::string> test_strings_fail = {
111+
"+",
112+
"/ 3x",
113+
"x + + y",
114+
"a * / b",
115+
"func(,)",
116+
"func(x y)",
117+
"(a + b",
118+
"x + y)",
119+
"a + b * (c - d",
120+
"42 +",
121+
"x +",
122+
"x + 10 +",
123+
"(a + b) * (c - d",
124+
"func(",
125+
"func(x, y + 2",
126+
"a * (b + c) - d /",
127+
"f(g(x), h(y, z)",
128+
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
129+
};
130+
131+
// Passing strings
132+
for (const auto & test_string : test_strings_pass) {
133+
auto decoded = decode_utf8(test_string, {});
134+
135+
const auto & code_points = decoded.first;
136+
137+
int pos = 0;
138+
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
139+
++pos;
140+
auto prev_stacks = grammar->stacks;
141+
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
142+
143+
// Expect that each code point will not cause the grammar to fail
144+
if (grammar->stacks.empty()) {
145+
fprintf(stdout, "Error at position %d\n", pos);
146+
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
147+
fprintf(stderr, "Input string is %s:\n", test_string.c_str());
148+
}
149+
assert(!grammar->stacks.empty());
150+
}
151+
152+
bool completed_grammar = false;
153+
154+
for (const auto & stack : grammar->stacks) {
155+
if (stack.empty()) {
156+
completed_grammar = true;
157+
break;
158+
}
159+
}
160+
161+
assert(completed_grammar);
162+
163+
// Reset the grammar stacks
164+
grammar->stacks = original_stacks;
165+
}
166+
167+
// Failing strings
168+
for (const auto & test_string : test_strings_fail) {
169+
auto decoded = decode_utf8(test_string, {});
170+
171+
const auto & code_points = decoded.first;
172+
bool parse_failed = false;
173+
174+
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
175+
auto prev_stacks = grammar->stacks;
176+
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
177+
if (grammar->stacks.empty()) {
178+
parse_failed = true;
179+
break;
180+
}
181+
assert(!grammar->stacks.empty());
182+
}
183+
184+
bool completed_grammar = false;
185+
186+
for (const auto & stack : grammar->stacks) {
187+
if (stack.empty()) {
188+
completed_grammar = true;
189+
break;
190+
}
191+
}
192+
193+
// Ensure that the grammar is not completed, or that each string failed to match as-expected
194+
assert((!completed_grammar) || parse_failed);
195+
196+
// Reset the grammar stacks
197+
grammar->stacks = original_stacks;
198+
}
199+
200+
// Clean up allocated memory
201+
llama_grammar_free(grammar);
202+
}
203+
204+
static void test_failure_missing_root() {
205+
// Test case for a grammar that is missing a root rule
206+
const std::string grammar_str = R"""(rot ::= expr
207+
expr ::= term ("+" term)*
208+
term ::= number
209+
number ::= [0-9]+)""";
210+
211+
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
212+
213+
// Ensure we parsed correctly
214+
assert(!parsed_grammar.rules.empty());
215+
216+
// Ensure we do NOT have a root node
217+
assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
218+
}
219+
220+
static void test_failure_missing_reference() {
221+
// Test case for a grammar that is missing a referenced rule
222+
const std::string grammar_str = R"""(root ::= expr
223+
expr ::= term ("+" term)*
224+
term ::= numero
225+
number ::= [0-9]+)""";
226+
227+
fprintf(stderr, "Expected error: ");
228+
229+
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
230+
231+
// Ensure we did NOT parsed correctly
232+
assert(parsed_grammar.rules.empty());
233+
234+
fprintf(stderr, "End of expected error. Test successful.\n");
235+
}
236+
237+
int main() {
238+
test_simple_grammar();
239+
test_complex_grammar();
240+
test_failure_missing_root();
241+
test_failure_missing_reference();
242+
return 0;
243+
}

0 commit comments

Comments
 (0)