Skip to content

Commit 58a6a04

Browse files
russellbJC1DAmmoskal
committed
[V1] guidance backend and auto mode for structured output
This is the V1 integration for [guidance](https://github.com/guidance-ai/llguidance) as a backend for structured output. There is a V0 integration in #14589. This backend provides some key benefits to V1: * Broader jsonschema support * Quick startup performance for large schemas Instead of precomputing the masks for all states, this is done on the fly. We see very fast request startup times, even for large schemas. This should make V1 roughly feature equivalent to V0 in terms of the types of schemas it can support. An `auto` mode is also included, which includes opinionated fallback behavior based on our current understanding for varying feature support and performance characteristics for different scenarios. More technical details are available in the llguidance git repo. Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Loc Huynh <jc1da.3011@gmail.com> Co-authored-by: Michal Moskal <michal@moskal.me>
1 parent 038de04 commit 58a6a04

File tree

9 files changed

+337
-110
lines changed

9 files changed

+337
-110
lines changed

requirements/common.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ pillow # Required for image processing
1818
prometheus-fastapi-instrumentator >= 7.0.0
1919
tiktoken >= 0.6.0 # Required for DBRX tokenizer
2020
lm-format-enforcer >= 0.10.11, < 0.11
21-
llguidance >= 0.7.2, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
21+
llguidance >= 0.7.9, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
2222
outlines == 0.1.11
2323
lark == 1.2.2
2424
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"
25+
llguidance==0.7.5
2526
typing_extensions >= 4.10
2627
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
2728
partial-json-parser # used for parsing partial JSON outputs

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 93 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.outputs import RequestOutput
1414
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1515

16-
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
16+
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"]
1717
MODELS_TO_TEST = [
1818
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
1919
]
@@ -30,12 +30,13 @@ def test_guided_json_completion(
3030
model_name: str,
3131
):
3232
monkeypatch.setenv("VLLM_USE_V1", "1")
33-
llm = LLM(model=model_name, max_model_len=1024)
34-
sampling_params = SamplingParams(temperature=1.0,
35-
max_tokens=1000,
36-
guided_decoding=GuidedDecodingParams(
37-
json=sample_json_schema,
38-
backend=guided_decoding_backend))
33+
llm = LLM(model=model_name,
34+
max_model_len=1024,
35+
guided_decoding_backend=guided_decoding_backend)
36+
sampling_params = SamplingParams(
37+
temperature=1.0,
38+
max_tokens=1000,
39+
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
3940
outputs = llm.generate(prompts=[
4041
f"Give an example JSON for an employee profile "
4142
f"that fits this schema: {sample_json_schema}"
@@ -111,13 +112,14 @@ def test_guided_json_object(
111112
model_name: str,
112113
):
113114
monkeypatch.setenv("VLLM_USE_V1", "1")
114-
llm = LLM(model=model_name, max_model_len=1024)
115-
sampling_params = SamplingParams(temperature=1.0,
116-
max_tokens=100,
117-
n=2,
118-
guided_decoding=GuidedDecodingParams(
119-
json_object=True,
120-
backend=guided_decoding_backend))
115+
llm = LLM(model=model_name,
116+
max_model_len=1024,
117+
guided_decoding_backend=guided_decoding_backend)
118+
sampling_params = SamplingParams(
119+
temperature=1.0,
120+
max_tokens=100,
121+
n=2,
122+
guided_decoding=GuidedDecodingParams(json_object=True))
121123

122124
outputs = llm.generate(
123125
prompts=("Generate a JSON object with curly braces for a person with "
@@ -142,7 +144,7 @@ def test_guided_json_object(
142144

143145
@pytest.mark.skip_global_cleanup
144146
@pytest.mark.parametrize("guided_decoding_backend",
145-
GUIDED_DECODING_BACKENDS_V1)
147+
GUIDED_DECODING_BACKENDS_V1 + ["auto"])
146148
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
147149
def test_guided_json_unsupported_schema(
148150
monkeypatch: pytest.MonkeyPatch,
@@ -151,21 +153,43 @@ def test_guided_json_unsupported_schema(
151153
model_name: str,
152154
):
153155
monkeypatch.setenv("VLLM_USE_V1", "1")
154-
llm = LLM(model=model_name, max_model_len=1024)
155-
sampling_params = SamplingParams(temperature=1.0,
156-
max_tokens=1000,
157-
guided_decoding=GuidedDecodingParams(
158-
json=unsupported_json_schema,
159-
backend=guided_decoding_backend))
160-
with pytest.raises(ValueError,
161-
match="The provided JSON schema contains features "
162-
"not supported by xgrammar."):
163-
llm.generate(prompts=[
164-
f"Give an example JSON for an employee profile "
165-
f"that fits this schema: {unsupported_json_schema}"
166-
] * 2,
167-
sampling_params=sampling_params,
168-
use_tqdm=True)
156+
llm = LLM(model=model_name,
157+
max_model_len=1024,
158+
guided_decoding_backend=guided_decoding_backend)
159+
sampling_params = SamplingParams(
160+
temperature=1.0,
161+
max_tokens=1000,
162+
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
163+
if guided_decoding_backend == "xgrammar":
164+
with pytest.raises(ValueError,
165+
match="The provided JSON schema contains features "
166+
"not supported by xgrammar."):
167+
llm.generate(prompts=[
168+
f"Give an example JSON for an employee profile "
169+
f"that fits this schema: {unsupported_json_schema}"
170+
] * 2,
171+
sampling_params=sampling_params,
172+
use_tqdm=True)
173+
else:
174+
# This should work for both "guidance" and "auto".
175+
176+
outputs = llm.generate(
177+
prompts=("Give an example JSON object for a grade "
178+
"that fits this schema: "
179+
f"{unsupported_json_schema}"),
180+
sampling_params=sampling_params,
181+
use_tqdm=True)
182+
assert outputs is not None
183+
for output in outputs:
184+
assert output is not None
185+
assert isinstance(output, RequestOutput)
186+
generated_text = output.outputs[0].text
187+
assert generated_text is not None
188+
print(generated_text)
189+
190+
# Parse to verify it is valid JSON
191+
parsed_json = json.loads(generated_text)
192+
assert isinstance(parsed_json, dict)
169193

170194

171195
@pytest.mark.skip_global_cleanup
@@ -179,13 +203,14 @@ def test_guided_grammar_ebnf(
179203
model_name: str,
180204
):
181205
monkeypatch.setenv("VLLM_USE_V1", "1")
182-
llm = LLM(model=model_name, max_model_len=1024)
183-
sampling_params = SamplingParams(temperature=0.8,
184-
top_p=0.95,
185-
max_tokens=1000,
186-
guided_decoding=GuidedDecodingParams(
187-
grammar=sample_sql_ebnf,
188-
backend=guided_decoding_backend))
206+
llm = LLM(model=model_name,
207+
max_model_len=1024,
208+
guided_decoding_backend=guided_decoding_backend)
209+
sampling_params = SamplingParams(
210+
temperature=0.8,
211+
top_p=0.95,
212+
max_tokens=1000,
213+
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
189214
outputs = llm.generate(
190215
prompts=("Generate a sql statement that selects col_1 from "
191216
"table_1 where it is equal to 1"),
@@ -222,13 +247,14 @@ def test_guided_grammar_lark(
222247
model_name: str,
223248
):
224249
monkeypatch.setenv("VLLM_USE_V1", "1")
225-
llm = LLM(model=model_name, max_model_len=1024)
226-
sampling_params = SamplingParams(temperature=0.8,
227-
top_p=0.95,
228-
max_tokens=1000,
229-
guided_decoding=GuidedDecodingParams(
230-
grammar=sample_sql_lark,
231-
backend=guided_decoding_backend))
250+
llm = LLM(model=model_name,
251+
max_model_len=1024,
252+
guided_decoding_backend=guided_decoding_backend)
253+
sampling_params = SamplingParams(
254+
temperature=0.8,
255+
top_p=0.95,
256+
max_tokens=1000,
257+
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
232258
outputs = llm.generate(
233259
prompts=("Generate a sql statement that selects col_1 from "
234260
"table_1 where it is equal to 1"),
@@ -269,16 +295,15 @@ def test_guided_grammar_ebnf_invalid(
269295
model_name: str,
270296
):
271297
monkeypatch.setenv("VLLM_USE_V1", "1")
272-
llm = LLM(model=model_name, max_model_len=1024)
273-
sampling_params = SamplingParams(temperature=0.8,
274-
top_p=0.95,
275-
max_tokens=1000,
276-
guided_decoding=GuidedDecodingParams(
277-
grammar="not a grammar",
278-
backend=guided_decoding_backend))
279-
with pytest.raises(ValueError,
280-
match="Failed to convert the grammar "
281-
"from Lark to EBNF."):
298+
llm = LLM(model=model_name,
299+
max_model_len=1024,
300+
guided_decoding_backend=guided_decoding_backend)
301+
sampling_params = SamplingParams(
302+
temperature=0.8,
303+
top_p=0.95,
304+
max_tokens=1000,
305+
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
306+
with pytest.raises(ValueError, match="Failed to convert the grammar "):
282307
llm.generate(
283308
prompts=("Generate a sql statement that selects col_1 from "
284309
"table_1 where it is equal to 1"),
@@ -298,12 +323,13 @@ def test_guided_regex(
298323
model_name: str,
299324
):
300325
monkeypatch.setenv("VLLM_USE_V1", "1")
301-
llm = LLM(model=model_name, max_model_len=1024)
302-
sampling_params = SamplingParams(temperature=0.8,
303-
top_p=0.95,
304-
guided_decoding=GuidedDecodingParams(
305-
regex=sample_regex,
306-
backend=guided_decoding_backend))
326+
llm = LLM(model=model_name,
327+
max_model_len=1024,
328+
guided_decoding_backend=guided_decoding_backend)
329+
sampling_params = SamplingParams(
330+
temperature=0.8,
331+
top_p=0.95,
332+
guided_decoding=GuidedDecodingParams(regex=sample_regex))
307333
outputs = llm.generate(
308334
prompts=[
309335
f"Give an example IPv4 address with this regex: {sample_regex}"
@@ -335,12 +361,13 @@ def test_guided_choice_completion(
335361
model_name: str,
336362
):
337363
monkeypatch.setenv("VLLM_USE_V1", "1")
338-
llm = LLM(model=model_name, max_model_len=1024)
339-
sampling_params = SamplingParams(temperature=0.8,
340-
top_p=0.95,
341-
guided_decoding=GuidedDecodingParams(
342-
choice=sample_guided_choice,
343-
backend=guided_decoding_backend))
364+
llm = LLM(model=model_name,
365+
max_model_len=1024,
366+
guided_decoding_backend=guided_decoding_backend)
367+
sampling_params = SamplingParams(
368+
temperature=0.8,
369+
top_p=0.95,
370+
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
344371
outputs = llm.generate(
345372
prompts="The best language for type-safe systems programming is ",
346373
sampling_params=sampling_params,

vllm/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2805,12 +2805,17 @@ def compute_hash(self) -> str:
28052805
return hash_str
28062806

28072807
def __post_init__(self):
2808-
valid_guided_backends = [
2809-
'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance'
2808+
v0_valid_guided_backends = [
2809+
'outlines', 'lm-format-enforcer', 'xgrammar'
28102810
]
2811+
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']
28112812

28122813
backend = GuidedDecodingParams(
28132814
backend=self.guided_decoding_backend).backend_name
2815+
if envs.VLLM_USE_V1:
2816+
valid_guided_backends = v1_valid_guided_backends
2817+
else:
2818+
valid_guided_backends = v0_valid_guided_backends
28142819
if backend not in valid_guided_backends:
28152820
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
28162821
f" must be one of {valid_guided_backends}")

vllm/engine/arg_utils.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -391,16 +391,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
391391
default='xgrammar',
392392
help='Which engine will be used for guided decoding'
393393
' (JSON schema / regex etc) by default. Currently support '
394-
'https://github.com/outlines-dev/outlines, '
395-
'https://github.com/mlc-ai/xgrammar, and '
396-
'https://github.com/noamgat/lm-format-enforcer.'
397-
' Can be overridden per request via guided_decoding_backend'
398-
' parameter.\n'
399-
'Backend-specific options can be supplied in a comma-separated '
400-
'list following a colon after the backend name. Valid backends and '
401-
'all available options are: [xgrammar:no-fallback, '
402-
'xgrammar:disable-any-whitespace, '
403-
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
394+
'https://github.com/mlc-ai/xgrammar and '
395+
'https://github.com/guidance-ai/llguidance.'
396+
'Valid backend values are "xgrammar", "guidance", and "auto". '
397+
'With "auto", we will make opinionated choices based on request'
398+
'contents and what the backend libraries currently support, so '
399+
'the behavior is subject to change in each release. '
400+
'The default is xgrammar.')
404401
parser.add_argument(
405402
'--logits-processor-pattern',
406403
type=nullable_str,
@@ -1539,9 +1536,9 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
15391536
recommend_to_remove=False)
15401537
return False
15411538

1542-
# Only support Xgrammar for guided decoding so far.
1539+
# Xgrammar and Guidance are supported.
15431540
SUPPORTED_GUIDED_DECODING = [
1544-
"xgrammar", "xgrammar:disable-any-whitespace"
1541+
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
15451542
]
15461543
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
15471544
_raise_or_fallback(feature_name="--guided-decoding-backend",

vllm/v1/engine/processor.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from collections.abc import Mapping
55
from typing import Optional, Union
66

7-
import vllm.platforms
87
from vllm.config import VllmConfig
98
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
109
PromptType, SingletonInputsAdapter)
@@ -20,7 +19,10 @@
2019
from vllm.sampling_params import SamplingParams
2120
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
2221
from vllm.v1.engine import EngineCoreRequest
23-
from vllm.v1.structured_output.utils import validate_structured_output_request
22+
from vllm.v1.structured_output.backend_guidance import (
23+
validate_guidance_grammar)
24+
from vllm.v1.structured_output.utils import (
25+
validate_structured_output_request_xgrammar)
2426

2527

2628
class Processor:
@@ -120,7 +122,9 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
120122
if not params.guided_decoding or not self.decoding_config:
121123
return
122124

123-
supported_backends = ["xgrammar", "xgrammar:disable-any-whitespace"]
125+
supported_backends = [
126+
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
127+
]
124128
engine_level_backend = self.decoding_config.guided_decoding_backend
125129
if engine_level_backend not in supported_backends:
126130
raise ValueError(f"Only {supported_backends} structured output is "
@@ -134,10 +138,31 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
134138
else:
135139
params.guided_decoding.backend = engine_level_backend
136140

137-
if vllm.platforms.current_platform.is_tpu():
138-
raise ValueError("Structured output is not supported on TPU.")
139-
140-
validate_structured_output_request(params)
141+
# Request content validation
142+
143+
if engine_level_backend == "xgrammar":
144+
# xgrammar with no fallback
145+
validate_structured_output_request_xgrammar(params)
146+
params.guided_decoding.backend = "xgrammar"
147+
elif engine_level_backend == "auto":
148+
# "auto" is an opt-in to opinionated behavior where we try to
149+
# choose a backend based on request contents. This is not the
150+
# default as it is less predictable and subject to change
151+
# between releases as feature support changes.
152+
try:
153+
validate_structured_output_request_xgrammar(params)
154+
params.guided_decoding.backend = "xgrammar"
155+
except ValueError:
156+
# The request includes some jsonschema feature(s) that
157+
# are not supported in xgrammar. Fall back to guidance.
158+
params.guided_decoding.backend = "guidance"
159+
160+
if params.guided_decoding.backend == "guidance":
161+
# TODO ideally we would have the LLTokenizer here as Lark syntax
162+
# allows <|special_token|> and similar, see
163+
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
164+
# Without tokenizer these are disallowed in grammars.
165+
validate_guidance_grammar(params, tokenizer=None)
141166

142167
def process_inputs(
143168
self,

vllm/v1/structured_output/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm.config import VllmConfig
99
from vllm.logger import init_logger
10+
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
1011
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
1112
StructuredOutputGrammar)
1213

@@ -50,6 +51,8 @@ def grammar_init(self, request: Request) -> None:
5051
XgrammarBackend)
5152

5253
self.backend = XgrammarBackend(self.vllm_config)
54+
elif backend_name == "guidance":
55+
self.backend = GuidanceBackend(self.vllm_config)
5356
else:
5457
raise ValueError(
5558
f"Unsupported structured output backend: {backend_name}")

0 commit comments

Comments
 (0)