13
13
from vllm .outputs import RequestOutput
14
14
from vllm .sampling_params import GuidedDecodingParams , SamplingParams
15
15
16
- GUIDED_DECODING_BACKENDS_V1 = ["xgrammar" ]
16
+ GUIDED_DECODING_BACKENDS_V1 = ["xgrammar" , "guidance" ]
17
17
MODELS_TO_TEST = [
18
18
"Qwen/Qwen2.5-1.5B-Instruct" , "mistralai/Ministral-8B-Instruct-2410"
19
19
]
@@ -30,12 +30,13 @@ def test_guided_json_completion(
30
30
model_name : str ,
31
31
):
32
32
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
33
- llm = LLM (model = model_name , max_model_len = 1024 )
34
- sampling_params = SamplingParams (temperature = 1.0 ,
35
- max_tokens = 1000 ,
36
- guided_decoding = GuidedDecodingParams (
37
- json = sample_json_schema ,
38
- backend = guided_decoding_backend ))
33
+ llm = LLM (model = model_name ,
34
+ max_model_len = 1024 ,
35
+ guided_decoding_backend = guided_decoding_backend )
36
+ sampling_params = SamplingParams (
37
+ temperature = 1.0 ,
38
+ max_tokens = 1000 ,
39
+ guided_decoding = GuidedDecodingParams (json = sample_json_schema ))
39
40
outputs = llm .generate (prompts = [
40
41
f"Give an example JSON for an employee profile "
41
42
f"that fits this schema: { sample_json_schema } "
@@ -111,13 +112,14 @@ def test_guided_json_object(
111
112
model_name : str ,
112
113
):
113
114
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
114
- llm = LLM (model = model_name , max_model_len = 1024 )
115
- sampling_params = SamplingParams (temperature = 1.0 ,
116
- max_tokens = 100 ,
117
- n = 2 ,
118
- guided_decoding = GuidedDecodingParams (
119
- json_object = True ,
120
- backend = guided_decoding_backend ))
115
+ llm = LLM (model = model_name ,
116
+ max_model_len = 1024 ,
117
+ guided_decoding_backend = guided_decoding_backend )
118
+ sampling_params = SamplingParams (
119
+ temperature = 1.0 ,
120
+ max_tokens = 100 ,
121
+ n = 2 ,
122
+ guided_decoding = GuidedDecodingParams (json_object = True ))
121
123
122
124
outputs = llm .generate (
123
125
prompts = ("Generate a JSON object with curly braces for a person with "
@@ -142,7 +144,7 @@ def test_guided_json_object(
142
144
143
145
@pytest .mark .skip_global_cleanup
144
146
@pytest .mark .parametrize ("guided_decoding_backend" ,
145
- GUIDED_DECODING_BACKENDS_V1 )
147
+ GUIDED_DECODING_BACKENDS_V1 + [ "auto" ] )
146
148
@pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
147
149
def test_guided_json_unsupported_schema (
148
150
monkeypatch : pytest .MonkeyPatch ,
@@ -151,21 +153,43 @@ def test_guided_json_unsupported_schema(
151
153
model_name : str ,
152
154
):
153
155
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
154
- llm = LLM (model = model_name , max_model_len = 1024 )
155
- sampling_params = SamplingParams (temperature = 1.0 ,
156
- max_tokens = 1000 ,
157
- guided_decoding = GuidedDecodingParams (
158
- json = unsupported_json_schema ,
159
- backend = guided_decoding_backend ))
160
- with pytest .raises (ValueError ,
161
- match = "The provided JSON schema contains features "
162
- "not supported by xgrammar." ):
163
- llm .generate (prompts = [
164
- f"Give an example JSON for an employee profile "
165
- f"that fits this schema: { unsupported_json_schema } "
166
- ] * 2 ,
167
- sampling_params = sampling_params ,
168
- use_tqdm = True )
156
+ llm = LLM (model = model_name ,
157
+ max_model_len = 1024 ,
158
+ guided_decoding_backend = guided_decoding_backend )
159
+ sampling_params = SamplingParams (
160
+ temperature = 1.0 ,
161
+ max_tokens = 1000 ,
162
+ guided_decoding = GuidedDecodingParams (json = unsupported_json_schema ))
163
+ if guided_decoding_backend == "xgrammar" :
164
+ with pytest .raises (ValueError ,
165
+ match = "The provided JSON schema contains features "
166
+ "not supported by xgrammar." ):
167
+ llm .generate (prompts = [
168
+ f"Give an example JSON for an employee profile "
169
+ f"that fits this schema: { unsupported_json_schema } "
170
+ ] * 2 ,
171
+ sampling_params = sampling_params ,
172
+ use_tqdm = True )
173
+ else :
174
+ # This should work for both "guidance" and "auto".
175
+
176
+ outputs = llm .generate (
177
+ prompts = ("Give an example JSON object for a grade "
178
+ "that fits this schema: "
179
+ f"{ unsupported_json_schema } " ),
180
+ sampling_params = sampling_params ,
181
+ use_tqdm = True )
182
+ assert outputs is not None
183
+ for output in outputs :
184
+ assert output is not None
185
+ assert isinstance (output , RequestOutput )
186
+ generated_text = output .outputs [0 ].text
187
+ assert generated_text is not None
188
+ print (generated_text )
189
+
190
+ # Parse to verify it is valid JSON
191
+ parsed_json = json .loads (generated_text )
192
+ assert isinstance (parsed_json , dict )
169
193
170
194
171
195
@pytest .mark .skip_global_cleanup
@@ -179,13 +203,14 @@ def test_guided_grammar_ebnf(
179
203
model_name : str ,
180
204
):
181
205
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
182
- llm = LLM (model = model_name , max_model_len = 1024 )
183
- sampling_params = SamplingParams (temperature = 0.8 ,
184
- top_p = 0.95 ,
185
- max_tokens = 1000 ,
186
- guided_decoding = GuidedDecodingParams (
187
- grammar = sample_sql_ebnf ,
188
- backend = guided_decoding_backend ))
206
+ llm = LLM (model = model_name ,
207
+ max_model_len = 1024 ,
208
+ guided_decoding_backend = guided_decoding_backend )
209
+ sampling_params = SamplingParams (
210
+ temperature = 0.8 ,
211
+ top_p = 0.95 ,
212
+ max_tokens = 1000 ,
213
+ guided_decoding = GuidedDecodingParams (grammar = sample_sql_ebnf ))
189
214
outputs = llm .generate (
190
215
prompts = ("Generate a sql statement that selects col_1 from "
191
216
"table_1 where it is equal to 1" ),
@@ -222,13 +247,14 @@ def test_guided_grammar_lark(
222
247
model_name : str ,
223
248
):
224
249
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
225
- llm = LLM (model = model_name , max_model_len = 1024 )
226
- sampling_params = SamplingParams (temperature = 0.8 ,
227
- top_p = 0.95 ,
228
- max_tokens = 1000 ,
229
- guided_decoding = GuidedDecodingParams (
230
- grammar = sample_sql_lark ,
231
- backend = guided_decoding_backend ))
250
+ llm = LLM (model = model_name ,
251
+ max_model_len = 1024 ,
252
+ guided_decoding_backend = guided_decoding_backend )
253
+ sampling_params = SamplingParams (
254
+ temperature = 0.8 ,
255
+ top_p = 0.95 ,
256
+ max_tokens = 1000 ,
257
+ guided_decoding = GuidedDecodingParams (grammar = sample_sql_lark ))
232
258
outputs = llm .generate (
233
259
prompts = ("Generate a sql statement that selects col_1 from "
234
260
"table_1 where it is equal to 1" ),
@@ -269,16 +295,15 @@ def test_guided_grammar_ebnf_invalid(
269
295
model_name : str ,
270
296
):
271
297
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
272
- llm = LLM (model = model_name , max_model_len = 1024 )
273
- sampling_params = SamplingParams (temperature = 0.8 ,
274
- top_p = 0.95 ,
275
- max_tokens = 1000 ,
276
- guided_decoding = GuidedDecodingParams (
277
- grammar = "not a grammar" ,
278
- backend = guided_decoding_backend ))
279
- with pytest .raises (ValueError ,
280
- match = "Failed to convert the grammar "
281
- "from Lark to EBNF." ):
298
+ llm = LLM (model = model_name ,
299
+ max_model_len = 1024 ,
300
+ guided_decoding_backend = guided_decoding_backend )
301
+ sampling_params = SamplingParams (
302
+ temperature = 0.8 ,
303
+ top_p = 0.95 ,
304
+ max_tokens = 1000 ,
305
+ guided_decoding = GuidedDecodingParams (grammar = "not a grammar" ))
306
+ with pytest .raises (ValueError , match = "Failed to convert the grammar " ):
282
307
llm .generate (
283
308
prompts = ("Generate a sql statement that selects col_1 from "
284
309
"table_1 where it is equal to 1" ),
@@ -298,12 +323,13 @@ def test_guided_regex(
298
323
model_name : str ,
299
324
):
300
325
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
301
- llm = LLM (model = model_name , max_model_len = 1024 )
302
- sampling_params = SamplingParams (temperature = 0.8 ,
303
- top_p = 0.95 ,
304
- guided_decoding = GuidedDecodingParams (
305
- regex = sample_regex ,
306
- backend = guided_decoding_backend ))
326
+ llm = LLM (model = model_name ,
327
+ max_model_len = 1024 ,
328
+ guided_decoding_backend = guided_decoding_backend )
329
+ sampling_params = SamplingParams (
330
+ temperature = 0.8 ,
331
+ top_p = 0.95 ,
332
+ guided_decoding = GuidedDecodingParams (regex = sample_regex ))
307
333
outputs = llm .generate (
308
334
prompts = [
309
335
f"Give an example IPv4 address with this regex: { sample_regex } "
@@ -335,12 +361,13 @@ def test_guided_choice_completion(
335
361
model_name : str ,
336
362
):
337
363
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
338
- llm = LLM (model = model_name , max_model_len = 1024 )
339
- sampling_params = SamplingParams (temperature = 0.8 ,
340
- top_p = 0.95 ,
341
- guided_decoding = GuidedDecodingParams (
342
- choice = sample_guided_choice ,
343
- backend = guided_decoding_backend ))
364
+ llm = LLM (model = model_name ,
365
+ max_model_len = 1024 ,
366
+ guided_decoding_backend = guided_decoding_backend )
367
+ sampling_params = SamplingParams (
368
+ temperature = 0.8 ,
369
+ top_p = 0.95 ,
370
+ guided_decoding = GuidedDecodingParams (choice = sample_guided_choice ))
344
371
outputs = llm .generate (
345
372
prompts = "The best language for type-safe systems programming is " ,
346
373
sampling_params = sampling_params ,
0 commit comments