Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Expand test_guided_generate to test all backends #11313

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 69 additions & 43 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]


@pytest.fixture(scope="module")
Expand All @@ -26,11 +27,13 @@ def llm():


@pytest.mark.skip_global_cleanup
def test_guided_regex(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(regex=sample_regex))
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend=guided_decoding_backend))
outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
Expand All @@ -50,11 +53,14 @@ def test_guided_regex(sample_regex, llm):


@pytest.mark.skip_global_cleanup
def test_guided_json_completion(sample_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_completion(sample_json_schema, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
Expand All @@ -77,11 +83,14 @@ def test_guided_json_completion(sample_json_schema, llm):


@pytest.mark.skip_global_cleanup
def test_guided_complex_json_completion(sample_complex_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=sample_complex_json_schema))
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_complex_json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(prompts=[
f"Give an example JSON for an assignment grade "
f"that fits this schema: {sample_complex_json_schema}"
Expand All @@ -105,11 +114,14 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm):


@pytest.mark.skip_global_cleanup
def test_guided_definition_json_completion(sample_definition_json_schema, llm):
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_definition_json_schema))
json=sample_definition_json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(prompts=[
f"Give an example JSON for solving 8x + 7 = -23 "
f"that fits this schema: {sample_definition_json_schema}"
Expand All @@ -133,11 +145,14 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm):


@pytest.mark.skip_global_cleanup
def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_choice_completion(sample_guided_choice, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
choice=sample_guided_choice,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
Expand All @@ -156,13 +171,20 @@ def test_guided_choice_completion(sample_guided_choice, llm):


@pytest.mark.skip_global_cleanup
def test_guided_grammar(sample_sql_statements, llm):

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_statements))
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_grammar(sample_sql_statements, llm,
guided_decoding_backend: str):
if guided_decoding_backend == "outlines":
pytest.skip("Outlines backend fails in this test case with:\n"
"AttributeError: Error in model execution: 'ParserConf' "
"object has no attribute 'deterministic'")

sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_statements,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
Expand Down Expand Up @@ -218,15 +240,18 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):


@pytest.mark.skip_global_cleanup
def test_guided_json_object(llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=100,
guided_decoding=GuidedDecodingParams(json_object=True))
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=100,
n=2,
guided_decoding=GuidedDecodingParams(
json_object=True,
backend=guided_decoding_backend))

outputs = llm.generate(
prompts=("Generate a JSON object describing a person with name "
"and age for John Smith who is 31 years old."),
prompts=("Generate a JSON object with curly braces for a person with "
"name and age fields for John Smith who is 31 years old."),
sampling_params=sampling_params,
use_tqdm=True)

Expand All @@ -235,10 +260,11 @@ def test_guided_json_object(llm):
assert output is not None
assert isinstance(output, RequestOutput)

generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
for i in range(2):
generated_text = output.outputs[i].text
print(generated_text)
assert generated_text is not None

# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
4 changes: 2 additions & 2 deletions tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.sampling_params import GuidedDecodingParams

MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]


def test_guided_logits_processors(sample_regex, sample_json_schema):
Expand Down Expand Up @@ -42,8 +43,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):


@pytest.mark.asyncio
@pytest.mark.parametrize("backend",
["outlines", "lm-format-enforcer", "xgrammar"])
@pytest.mark.parametrize("backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("is_local", [True, False])
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
sample_regex,
Expand Down
64 changes: 58 additions & 6 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,60 @@ def check_object(obj: dict) -> bool:
return check_object(schema)


def has_lmf_unsupported_json_features(schema: dict) -> bool:
"""
Check if JSON schema contains features unsupported
by lm_format_enforcer.

Known issues:
- Regex patterns:
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"""

def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False

# Check for pattern restrictions
if "pattern" in obj:
return True

# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True

return False

return check_object(schema)


def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if (guided_params.backend == "lm-format-enforcer"
and guided_params.grammar is not None):
logger.warning(
"lm-format-enforcer does not support grammar guided decoding. "
"Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar"
if guided_params.backend == "lm-format-enforcer":
if guided_params.grammar is not None:
logger.warning(
"lm-format-enforcer does not support grammar guided decoding. "
"Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar"

# lm-format-enforcer doesn't support some JSON schema features
elif (guided_params.json is not None
and has_lmf_unsupported_json_features(guided_params.json)):
logger.warning(
"lm-format-enforcer does not support advanced JSON schema "
"features like patterns or numeric ranges. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

if guided_params.backend == "xgrammar":
# xgrammar only has x86 wheels for linux, fallback to outlines
Expand All @@ -82,6 +127,13 @@ def maybe_backend_fallback(
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

if (guided_params.backend == "outlines"
and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to xgrammar
logger.warning("outlines does not support json_object. "
"Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar"

return guided_params


Expand Down
Loading