Skip to content

[V1] guidance backend for structured output + auto fallback mode #14779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pillow # Required for image processing
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.11, < 0.11
llguidance >= 0.7.2, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
llguidance >= 0.7.9, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
outlines == 0.1.11
lark == 1.2.2
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"
Expand Down
169 changes: 102 additions & 67 deletions tests/v1/entrypoints/llm/test_struct_output_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"]
MODELS_TO_TEST = [
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
]
Expand All @@ -30,12 +30,13 @@ def test_guided_json_completion(
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_json_schema,
backend=guided_decoding_backend))
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
Expand Down Expand Up @@ -111,13 +112,14 @@ def test_guided_json_object(
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0,
max_tokens=100,
n=2,
guided_decoding=GuidedDecodingParams(
json_object=True,
backend=guided_decoding_backend))
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=100,
n=2,
guided_decoding=GuidedDecodingParams(json_object=True))

outputs = llm.generate(
prompts=("Generate a JSON object with curly braces for a person with "
Expand All @@ -137,12 +139,20 @@ def test_guided_json_object(

# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
allowed_types: tuple[type, ...] = (dict, )
if guided_decoding_backend == "xgrammar":
# TODO - we are currently too permissive with xgrammar and
# allow # any valid json (typically comes back as a list or
# object). We can fix this by specifying a jsonschema of
# {"type": "object"}, # but we need this fix in a release
# first: https://github.com/mlc-ai/xgrammar/pull/264
allowed_types = (dict, list)
assert isinstance(parsed_json, allowed_types)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
GUIDED_DECODING_BACKENDS_V1 + ["auto"])
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_unsupported_schema(
monkeypatch: pytest.MonkeyPatch,
Expand All @@ -151,21 +161,43 @@ def test_guided_json_unsupported_schema(
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=unsupported_json_schema,
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="The provided JSON schema contains features "
"not supported by xgrammar."):
llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {unsupported_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
if guided_decoding_backend == "xgrammar":
with pytest.raises(ValueError,
match="The provided JSON schema contains features "
"not supported by xgrammar."):
llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {unsupported_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
else:
# This should work for both "guidance" and "auto".

outputs = llm.generate(
prompts=("Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}"),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None
print(generated_text)

# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)


@pytest.mark.skip_global_cleanup
Expand All @@ -179,13 +211,14 @@ def test_guided_grammar_ebnf(
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_ebnf,
backend=guided_decoding_backend))
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
Expand Down Expand Up @@ -222,13 +255,14 @@ def test_guided_grammar_lark(
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_lark,
backend=guided_decoding_backend))
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
Expand Down Expand Up @@ -269,16 +303,15 @@ def test_guided_grammar_ebnf_invalid(
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar="not a grammar",
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="Failed to convert the grammar "
"from Lark to EBNF."):
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
with pytest.raises(ValueError, match="Failed to convert the grammar "):
llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
Expand All @@ -298,12 +331,13 @@ def test_guided_regex(
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend=guided_decoding_backend))
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(regex=sample_regex))
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
Expand Down Expand Up @@ -335,12 +369,13 @@ def test_guided_choice_completion(
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
choice=sample_guided_choice,
backend=guided_decoding_backend))
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
Expand Down
9 changes: 7 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2805,12 +2805,17 @@ def compute_hash(self) -> str:
return hash_str

def __post_init__(self):
valid_guided_backends = [
'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance'
v0_valid_guided_backends = [
'outlines', 'lm-format-enforcer', 'xgrammar'
]
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']

backend = GuidedDecodingParams(
backend=self.guided_decoding_backend).backend_name
if envs.VLLM_USE_V1:
valid_guided_backends = v1_valid_guided_backends
else:
valid_guided_backends = v0_valid_guided_backends
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
f" must be one of {valid_guided_backends}")
Expand Down
21 changes: 9 additions & 12 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,16 +391,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default='xgrammar',
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines, '
'https://github.com/mlc-ai/xgrammar, and '
'https://github.com/noamgat/lm-format-enforcer.'
' Can be overridden per request via guided_decoding_backend'
' parameter.\n'
'Backend-specific options can be supplied in a comma-separated '
'list following a colon after the backend name. Valid backends and '
'all available options are: [xgrammar:no-fallback, '
'xgrammar:disable-any-whitespace, '
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
'https://github.com/mlc-ai/xgrammar and '
'https://github.com/guidance-ai/llguidance.'
'Valid backend values are "xgrammar", "guidance", and "auto". '
'With "auto", we will make opinionated choices based on request'
'contents and what the backend libraries currently support, so '
'the behavior is subject to change in each release. '
'The default is xgrammar.')
parser.add_argument(
'--logits-processor-pattern',
type=nullable_str,
Expand Down Expand Up @@ -1539,9 +1536,9 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

# Only support Xgrammar for guided decoding so far.
# Xgrammar and Guidance are supported.
SUPPORTED_GUIDED_DECODING = [
"xgrammar", "xgrammar:disable-any-whitespace"
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
]
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
_raise_or_fallback(feature_name="--guided-decoding-backend",
Expand Down
39 changes: 32 additions & 7 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections.abc import Mapping
from typing import Optional, Union

import vllm.platforms
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
Expand All @@ -20,7 +19,10 @@
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.structured_output.utils import validate_structured_output_request
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.utils import (
validate_structured_output_request_xgrammar)


class Processor:
Expand Down Expand Up @@ -120,7 +122,9 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return

supported_backends = ["xgrammar", "xgrammar:disable-any-whitespace"]
supported_backends = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
]
engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends:
raise ValueError(f"Only {supported_backends} structured output is "
Expand All @@ -134,10 +138,31 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
else:
params.guided_decoding.backend = engine_level_backend

if vllm.platforms.current_platform.is_tpu():
raise ValueError("Structured output is not supported on TPU.")

validate_structured_output_request(params)
# Request content validation

if engine_level_backend == "xgrammar":
# xgrammar with no fallback
validate_structured_output_request_xgrammar(params)
params.guided_decoding.backend = "xgrammar"
elif engine_level_backend == "auto":
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
# default as it is less predictable and subject to change
# between releases as feature support changes.
try:
validate_structured_output_request_xgrammar(params)
params.guided_decoding.backend = "xgrammar"
except ValueError:
# The request includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
params.guided_decoding.backend = "guidance"

if params.guided_decoding.backend == "guidance":
# TODO ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)

def process_inputs(
self,
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/structured_output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar)

Expand Down Expand Up @@ -50,6 +51,8 @@ def grammar_init(self, request: Request) -> None:
XgrammarBackend)

self.backend = XgrammarBackend(self.vllm_config)
elif backend_name == "guidance":
self.backend = GuidanceBackend(self.vllm_config)
else:
raise ValueError(
f"Unsupported structured output backend: {backend_name}")
Expand Down
Loading