Skip to content

Commit 9b3c89f

Browse files
committed
Added integration tests for GBNF parser to validate correctness of parsing, as well as correctness of string matching. Intended for use to pin behavior while working on performance improvements.
1 parent 5fb1574 commit 9b3c89f

File tree

3 files changed

+242
-1
lines changed

3 files changed

+242
-1
lines changed

Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ TEST_TARGETS = \
1010
tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \
1111
tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope \
1212
tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease \
13-
tests/test-json-schema-to-grammar
13+
tests/test-json-schema-to-grammar tests/test-grammar-integration
1414

1515
# Code coverage output files
1616
COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report
@@ -914,6 +914,10 @@ tests/test-grammar-parser: tests/test-grammar-parser.cpp ggml.o llama.o grammar-
914914
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
915915
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
916916

917+
tests/test-grammar-integration: tests/test-grammar-integration.cpp ggml.o grammar-parser.o $(OBJS)
918+
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
919+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
920+
917921
tests/test-double-float: tests/test-double-float.cpp ggml.o $(OBJS)
918922
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
919923
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-gpt2 AR
5959

6060
llama_test(test-grammar-parser.cpp)
6161
llama_test(test-llama-grammar.cpp)
62+
llama_test(test-grammar-integration.cpp)
6263
llama_test(test-grad0.cpp)
6364
# llama_test(test-opt.cpp) # SLOW
6465
llama_test(test-backend-ops.cpp)

tests/test-grammar-integration.cpp

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
#ifdef NDEBUG
2+
#undef NDEBUG
3+
#endif
4+
5+
#include "llama.cpp" // TODO: not great
6+
#include "grammar-parser.h"
7+
#include <cassert>
8+
#include <string>
9+
10+
static void test_failure_missing_root() {
11+
// Test case for a grammar that is missing a root rule
12+
const std::string grammar_str = R"""(rot ::= expr
13+
expr ::= term ("+" term)*
14+
term ::= number
15+
number ::= [0-9]+)""";
16+
17+
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
18+
19+
// Ensure we parsed correctly
20+
assert(!parsed_grammar.rules.empty());
21+
22+
// Ensure we do NOT have a root node
23+
assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
24+
}
25+
26+
static void test_failure_missing_reference() {
27+
// Test case for a grammar that is missing a referenced rule
28+
const std::string grammar_str = R"""(root ::= expr
29+
expr ::= term ("+" term)*
30+
term ::= numero
31+
number ::= [0-9]+)""";
32+
33+
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
34+
35+
// Ensure we did NOT parsed correctly
36+
assert(parsed_grammar.rules.empty());
37+
38+
fprintf(stderr, "^ If previous line displays an error, then this test passed.\n");
39+
}
40+
41+
static void test_simple_grammar() {
42+
// Test case for a simple grammar
43+
const std::string grammar_str = R"""(root ::= expr
44+
expr ::= term ("+" term)*
45+
term ::= number
46+
number ::= [0-9]+)""";
47+
48+
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
49+
50+
// Ensure we parsed correctly
51+
assert(!parsed_grammar.rules.empty());
52+
53+
// Ensure we have a root node
54+
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
55+
56+
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
57+
llama_grammar* grammar = llama_grammar_init(
58+
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
59+
60+
std::string input = "123+456";
61+
62+
auto decoded = decode_utf8(input, {});
63+
64+
const auto & code_points = decoded.first;
65+
66+
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
67+
auto prev_stacks = grammar->stacks;
68+
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
69+
assert(!grammar->stacks.empty());
70+
}
71+
72+
bool completed_grammar = false;
73+
74+
for (const auto & stack : grammar->stacks) {
75+
if (stack.empty()) {
76+
completed_grammar = true;
77+
break;
78+
}
79+
}
80+
81+
assert(completed_grammar);
82+
83+
// Clean up allocated memory
84+
llama_grammar_free(grammar);
85+
}
86+
87+
static void test_complex_grammar() {
88+
// Test case for a more complex grammar
89+
const std::string grammar_str = R"""(root ::= expression
90+
expression ::= term ws (("+"|"-") ws term)*
91+
term ::= factor ws (("*"|"/") ws factor)*
92+
factor ::= number | variable | "(" expression ")" | function-call
93+
number ::= [0-9]+
94+
variable ::= [a-zA-Z_][a-zA-Z0-9_]*
95+
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
96+
ws ::= [ \t\n\r]?)""";
97+
98+
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
99+
100+
// Ensure we parsed correctly
101+
assert(!parsed_grammar.rules.empty());
102+
103+
// Ensure we have a root node
104+
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
105+
106+
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
107+
llama_grammar* grammar = llama_grammar_init(
108+
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
109+
110+
// Save the original grammar stacks so that we can reset after every new string we want to test
111+
auto original_stacks = grammar->stacks;
112+
113+
// Test a few strings
114+
std::vector<std::string> test_strings_pass = {
115+
"42",
116+
"1*2*3*4*5",
117+
"x",
118+
"x+10",
119+
"x1+y2",
120+
"(a+b)*(c-d)",
121+
"func()",
122+
"func(x,y+2)",
123+
"a*(b+c)-d/e",
124+
"f(g(x),h(y,z))",
125+
"x + 10",
126+
"x1 + y2",
127+
"(a + b) * (c - d)",
128+
"func()",
129+
"func(x, y + 2)",
130+
"a * (b + c) - d / e",
131+
"f(g(x), h(y, z))",
132+
"123+456",
133+
"123*456*789-123/456+789*123",
134+
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
135+
};
136+
137+
std::vector<std::string> test_strings_fail = {
138+
"+",
139+
"/ 3x",
140+
"x + + y",
141+
"a * / b",
142+
"func(,)",
143+
"func(x y)",
144+
"(a + b",
145+
"x + y)",
146+
"a + b * (c - d",
147+
"42 +",
148+
"x +",
149+
"x + 10 +",
150+
"(a + b) * (c - d",
151+
"func(",
152+
"func(x, y + 2",
153+
"a * (b + c) - d /",
154+
"f(g(x), h(y, z)",
155+
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
156+
};
157+
158+
for (const auto & test_string : test_strings_pass) {
159+
auto decoded = decode_utf8(test_string, {});
160+
161+
const auto & code_points = decoded.first;
162+
163+
int pos = 0;
164+
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
165+
++pos;
166+
auto prev_stacks = grammar->stacks;
167+
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
168+
169+
// Expect that each code point will not cause the grammar to fail
170+
if (grammar->stacks.empty()) {
171+
fprintf(stdout, "Error at position %d\n", pos);
172+
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
173+
fprintf(stderr, "Input string is %s:\n", test_string.c_str());
174+
}
175+
assert(!grammar->stacks.empty());
176+
}
177+
178+
bool completed_grammar = false;
179+
180+
for (const auto & stack : grammar->stacks) {
181+
if (stack.empty()) {
182+
completed_grammar = true;
183+
break;
184+
}
185+
}
186+
187+
assert(completed_grammar);
188+
189+
// Reset the grammar stacks
190+
grammar->stacks = original_stacks;
191+
}
192+
193+
for (const auto & test_string : test_strings_fail) {
194+
auto decoded = decode_utf8(test_string, {});
195+
196+
const auto & code_points = decoded.first;
197+
bool parse_failed = false;
198+
199+
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
200+
auto prev_stacks = grammar->stacks;
201+
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
202+
if (grammar->stacks.empty()) {
203+
parse_failed = true;
204+
break;
205+
}
206+
assert(!grammar->stacks.empty());
207+
}
208+
209+
bool completed_grammar = false;
210+
211+
for (const auto & stack : grammar->stacks) {
212+
if (stack.empty()) {
213+
completed_grammar = true;
214+
break;
215+
}
216+
}
217+
218+
// Ensure that the grammar is not completed, or that each string failed to match as-expected
219+
assert((!completed_grammar) || parse_failed);
220+
221+
// Reset the grammar stacks
222+
grammar->stacks = original_stacks;
223+
}
224+
225+
// Clean up allocated memory
226+
llama_grammar_free(grammar);
227+
}
228+
229+
int main() {
230+
test_simple_grammar();
231+
test_complex_grammar();
232+
test_failure_missing_root();
233+
test_failure_missing_reference();
234+
// Add more test cases as needed
235+
return 0;
236+
}

0 commit comments

Comments
 (0)