1
+ #ifdef NDEBUG
2
+ #undef NDEBUG
3
+ #endif
4
+
5
+ #include " llama.cpp" // TODO: not great
6
+ #include " grammar-parser.h"
7
+ #include < cassert>
8
+ #include < string>
9
+
10
+ static void test_failure_missing_root () {
11
+ // Test case for a grammar that is missing a root rule
12
+ const std::string grammar_str = R"""( rot ::= expr
13
+ expr ::= term ("+" term)*
14
+ term ::= number
15
+ number ::= [0-9]+)""" ;
16
+
17
+ grammar_parser::parse_state parsed_grammar = grammar_parser::parse (grammar_str.c_str ());
18
+
19
+ // Ensure we parsed correctly
20
+ assert (!parsed_grammar.rules .empty ());
21
+
22
+ // Ensure we do NOT have a root node
23
+ assert (parsed_grammar.symbol_ids .find (" root" ) == parsed_grammar.symbol_ids .end ());
24
+ }
25
+
26
+ static void test_failure_missing_reference () {
27
+ // Test case for a grammar that is missing a referenced rule
28
+ const std::string grammar_str = R"""( root ::= expr
29
+ expr ::= term ("+" term)*
30
+ term ::= numero
31
+ number ::= [0-9]+)""" ;
32
+
33
+ grammar_parser::parse_state parsed_grammar = grammar_parser::parse (grammar_str.c_str ());
34
+
35
+ // Ensure we did NOT parsed correctly
36
+ assert (parsed_grammar.rules .empty ());
37
+
38
+ fprintf (stderr, " ^ If previous line displays an error, then this test passed.\n " );
39
+ }
40
+
41
+ static void test_simple_grammar () {
42
+ // Test case for a simple grammar
43
+ const std::string grammar_str = R"""( root ::= expr
44
+ expr ::= term ("+" term)*
45
+ term ::= number
46
+ number ::= [0-9]+)""" ;
47
+
48
+ grammar_parser::parse_state parsed_grammar = grammar_parser::parse (grammar_str.c_str ());
49
+
50
+ // Ensure we parsed correctly
51
+ assert (!parsed_grammar.rules .empty ());
52
+
53
+ // Ensure we have a root node
54
+ assert (!(parsed_grammar.symbol_ids .find (" root" ) == parsed_grammar.symbol_ids .end ()));
55
+
56
+ std::vector<const llama_grammar_element*> grammar_rules (parsed_grammar.c_rules ());
57
+ llama_grammar* grammar = llama_grammar_init (
58
+ grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
59
+
60
+ std::string input = " 123+456" ;
61
+
62
+ auto decoded = decode_utf8 (input, {});
63
+
64
+ const auto & code_points = decoded.first ;
65
+
66
+ for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
67
+ auto prev_stacks = grammar->stacks ;
68
+ grammar->stacks = llama_grammar_accept (grammar->rules , grammar->stacks , *it);
69
+ assert (!grammar->stacks .empty ());
70
+ }
71
+
72
+ bool completed_grammar = false ;
73
+
74
+ for (const auto & stack : grammar->stacks ) {
75
+ if (stack.empty ()) {
76
+ completed_grammar = true ;
77
+ break ;
78
+ }
79
+ }
80
+
81
+ assert (completed_grammar);
82
+
83
+ // Clean up allocated memory
84
+ llama_grammar_free (grammar);
85
+ }
86
+
87
+ static void test_complex_grammar () {
88
+ // Test case for a more complex grammar
89
+ const std::string grammar_str = R"""( root ::= expression
90
+ expression ::= term ws (("+"|"-") ws term)*
91
+ term ::= factor ws (("*"|"/") ws factor)*
92
+ factor ::= number | variable | "(" expression ")" | function-call
93
+ number ::= [0-9]+
94
+ variable ::= [a-zA-Z_][a-zA-Z0-9_]*
95
+ function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
96
+ ws ::= [ \t\n\r]?)""" ;
97
+
98
+ grammar_parser::parse_state parsed_grammar = grammar_parser::parse (grammar_str.c_str ());
99
+
100
+ // Ensure we parsed correctly
101
+ assert (!parsed_grammar.rules .empty ());
102
+
103
+ // Ensure we have a root node
104
+ assert (!(parsed_grammar.symbol_ids .find (" root" ) == parsed_grammar.symbol_ids .end ()));
105
+
106
+ std::vector<const llama_grammar_element*> grammar_rules (parsed_grammar.c_rules ());
107
+ llama_grammar* grammar = llama_grammar_init (
108
+ grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
109
+
110
+ // Save the original grammar stacks so that we can reset after every new string we want to test
111
+ auto original_stacks = grammar->stacks ;
112
+
113
+ // Test a few strings
114
+ std::vector<std::string> test_strings_pass = {
115
+ " 42" ,
116
+ " 1*2*3*4*5" ,
117
+ " x" ,
118
+ " x+10" ,
119
+ " x1+y2" ,
120
+ " (a+b)*(c-d)" ,
121
+ " func()" ,
122
+ " func(x,y+2)" ,
123
+ " a*(b+c)-d/e" ,
124
+ " f(g(x),h(y,z))" ,
125
+ " x + 10" ,
126
+ " x1 + y2" ,
127
+ " (a + b) * (c - d)" ,
128
+ " func()" ,
129
+ " func(x, y + 2)" ,
130
+ " a * (b + c) - d / e" ,
131
+ " f(g(x), h(y, z))" ,
132
+ " 123+456" ,
133
+ " 123*456*789-123/456+789*123" ,
134
+ " 123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
135
+ };
136
+
137
+ std::vector<std::string> test_strings_fail = {
138
+ " +" ,
139
+ " / 3x" ,
140
+ " x + + y" ,
141
+ " a * / b" ,
142
+ " func(,)" ,
143
+ " func(x y)" ,
144
+ " (a + b" ,
145
+ " x + y)" ,
146
+ " a + b * (c - d" ,
147
+ " 42 +" ,
148
+ " x +" ,
149
+ " x + 10 +" ,
150
+ " (a + b) * (c - d" ,
151
+ " func(" ,
152
+ " func(x, y + 2" ,
153
+ " a * (b + c) - d /" ,
154
+ " f(g(x), h(y, z)" ,
155
+ " 123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/" ,
156
+ };
157
+
158
+ for (const auto & test_string : test_strings_pass) {
159
+ auto decoded = decode_utf8 (test_string, {});
160
+
161
+ const auto & code_points = decoded.first ;
162
+
163
+ int pos = 0 ;
164
+ for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
165
+ ++pos;
166
+ auto prev_stacks = grammar->stacks ;
167
+ grammar->stacks = llama_grammar_accept (grammar->rules , grammar->stacks , *it);
168
+
169
+ // Expect that each code point will not cause the grammar to fail
170
+ if (grammar->stacks .empty ()) {
171
+ fprintf (stdout, " Error at position %d\n " , pos);
172
+ fprintf (stderr, " Unexpected character '%s'\n " , unicode_cpt_to_utf8 (*it).c_str ());
173
+ fprintf (stderr, " Input string is %s:\n " , test_string.c_str ());
174
+ }
175
+ assert (!grammar->stacks .empty ());
176
+ }
177
+
178
+ bool completed_grammar = false ;
179
+
180
+ for (const auto & stack : grammar->stacks ) {
181
+ if (stack.empty ()) {
182
+ completed_grammar = true ;
183
+ break ;
184
+ }
185
+ }
186
+
187
+ assert (completed_grammar);
188
+
189
+ // Reset the grammar stacks
190
+ grammar->stacks = original_stacks;
191
+ }
192
+
193
+ for (const auto & test_string : test_strings_fail) {
194
+ auto decoded = decode_utf8 (test_string, {});
195
+
196
+ const auto & code_points = decoded.first ;
197
+ bool parse_failed = false ;
198
+
199
+ for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
200
+ auto prev_stacks = grammar->stacks ;
201
+ grammar->stacks = llama_grammar_accept (grammar->rules , grammar->stacks , *it);
202
+ if (grammar->stacks .empty ()) {
203
+ parse_failed = true ;
204
+ break ;
205
+ }
206
+ assert (!grammar->stacks .empty ());
207
+ }
208
+
209
+ bool completed_grammar = false ;
210
+
211
+ for (const auto & stack : grammar->stacks ) {
212
+ if (stack.empty ()) {
213
+ completed_grammar = true ;
214
+ break ;
215
+ }
216
+ }
217
+
218
+ // Ensure that the grammar is not completed, or that each string failed to match as-expected
219
+ assert ((!completed_grammar) || parse_failed);
220
+
221
+ // Reset the grammar stacks
222
+ grammar->stacks = original_stacks;
223
+ }
224
+
225
+ // Clean up allocated memory
226
+ llama_grammar_free (grammar);
227
+ }
228
+
229
+ int main () {
230
+ test_simple_grammar ();
231
+ test_complex_grammar ();
232
+ test_failure_missing_root ();
233
+ test_failure_missing_reference ();
234
+ // Add more test cases as needed
235
+ return 0 ;
236
+ }
0 commit comments