Skip to content

LM Format Enforcer Guided Decoding Support #3868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ff46b1d
WIP: Refactoring before lm-format-enforcer integration
noamgat Mar 31, 2024
8fa1d0d
Integrated LM Format enforcer decoding
noamgat Apr 5, 2024
d2ef17c
Merge branch 'main' into lmfe-generation
noamgat Apr 5, 2024
0c3ab87
format.sh
noamgat Apr 5, 2024
04c457f
Making pip install work
noamgat Apr 5, 2024
3c2f45b
LMFE integration fixes
noamgat Apr 5, 2024
657a5d5
Ruff error
noamgat Apr 5, 2024
2b797ce
isort fix
noamgat Apr 5, 2024
d637a3e
Fixing test import
noamgat Apr 5, 2024
bb078dc
Merge branch 'main' into lmfe-generation
noamgat Apr 5, 2024
f39ec42
Refactor : introduced DecodingConfig
noamgat Apr 5, 2024
04062b5
Added tests for get_guided_processors wrapper and LMFE implementation
noamgat Apr 5, 2024
11e11be
format.sh
noamgat Apr 5, 2024
ae21046
Fixing test_guided_logits_processor_black_box()
noamgat Apr 5, 2024
13d792d
Merge branch 'main' into lmfe-generation
noamgat Apr 5, 2024
0cfbb1b
Update vllm/config.py based on PR review
noamgat Apr 6, 2024
11cb2db
Update vllm/config.py based on PR review
noamgat Apr 6, 2024
8b9bfc1
format.sh
noamgat Apr 6, 2024
ddb1a8b
Added the option to override guided_decoding_backend on a per-request…
noamgat Apr 6, 2024
dc6c919
Modifying outlines logits processor to return the tokenizer to its no…
noamgat Apr 6, 2024
bee7147
Ruff fixes
noamgat Apr 6, 2024
87bb454
yapf fixes
noamgat Apr 6, 2024
88a200a
Test fixes
noamgat Apr 6, 2024
1baedd3
Ruff fixes
noamgat Apr 6, 2024
2d35e61
yapf
noamgat Apr 6, 2024
041a16b
isort
noamgat Apr 6, 2024
3e5eaa0
Merge branch 'main' into lmfe-generation
noamgat Apr 10, 2024
105c3d0
Doc change for CI retrigger
noamgat Apr 10, 2024
79f060a
Merge branch 'main' of github.com:vllm-project/vllm into lmfe-generation
simon-mo Apr 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.9.3
outlines == 0.0.34 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
42 changes: 39 additions & 3 deletions tests/entrypoints/test_guided_processors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# This unit test should be moved to a new
# tests/test_guided_decoding directory.

import pytest
import torch
from transformers import AutoTokenizer

from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
RegexLogitsProcessor)
from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)

TEST_SCHEMA = {
"type": "object",
Expand Down Expand Up @@ -73,3 +76,36 @@ def test_guided_logits_processors():
json_LP(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)


@pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
async def test_guided_logits_processor_black_box(backend: str):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
regex_request = CompletionRequest(model='test',
prompt=token_ids,
guided_regex=TEST_REGEX)
regex_lp = await get_guided_decoding_logits_processor(
backend, regex_request, tokenizer)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
json_request = CompletionRequest(model='test',
prompt=token_ids,
guided_json=TEST_SCHEMA)
json_lp = await get_guided_decoding_logits_processor(
backend, json_request, tokenizer)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
69 changes: 50 additions & 19 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,15 +506,19 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text


async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example JSON for an employee profile "
f"that fits this schema: {TEST_SCHEMA}",
n=3,
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
extra_body=dict(guided_json=TEST_SCHEMA,
guided_decoding_backend=guided_decoding_backend))

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
Expand All @@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)


async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA,
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message
assert message.content is not None
json1 = json.loads(message.content)
Expand All @@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA,
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message
assert message.content is not None
json2 = json.loads(message.content)
Expand All @@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
assert json1["age"] != json2["age"]


async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
n=3,
temperature=1.0,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
extra_body=dict(guided_regex=TEST_REGEX,
guided_decoding_backend=guided_decoding_backend))

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
Expand All @@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None


async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
extra_body=dict(guided_regex=TEST_REGEX,
guided_decoding_backend=guided_decoding_backend))
ip1 = chat_completion.choices[0].message.content
assert ip1 is not None
assert re.fullmatch(TEST_REGEX, ip1) is not None
Expand All @@ -606,29 +623,37 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
extra_body=dict(guided_regex=TEST_REGEX,
guided_decoding_backend=guided_decoding_backend))
ip2 = chat_completion.choices[0].message.content
assert ip2 is not None
assert re.fullmatch(TEST_REGEX, ip2) is not None
assert ip1 != ip2


async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
completion = await client.completions.create(
model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ",
n=2,
temperature=1.0,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 2
for i in range(2):
assert completion.choices[i].text in TEST_CHOICE


async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
choice1 = chat_completion.choices[0].message.content
assert choice1 in TEST_CHOICE

Expand All @@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
choice2 = chat_completion.choices[0].message.content
assert choice2 in TEST_CHOICE
assert choice1 != choice2


async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example JSON that fits this schema: 42",
extra_body=dict(guided_json=42))
extra_body=dict(guided_json=42,
guided_decoding_backend=guided_decoding_backend))

messages = [{
"role": "system",
Expand Down
26 changes: 21 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class ModelConfig:
weights. If None, we assume the model weights are not quantized.
quantization_param_path: Path to JSON file containing scaling factors.
Used to load KV cache scaling factors into the model when KV cache
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
be used to load activation and weight scaling factors when the
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
be used to load activation and weight scaling factors when the
model dtype is FP8_E4M3 on ROCm.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
Expand Down Expand Up @@ -422,7 +422,7 @@ def verify_with_parallel_config(
@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool.

Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
Expand All @@ -446,9 +446,9 @@ def create_config(
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.

If tokenizer_pool_size is 0, return None.

Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
Expand Down Expand Up @@ -1079,6 +1079,21 @@ def _get_and_verify_max_len(
return int(max_model_len)


@dataclass
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine"""

# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
guided_decoding_backend: str = 'outlines'

def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer']
backend = self.guided_decoding_backend
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
f"must be one of {valid_guided_backends}")


@dataclass(frozen=True)
class EngineConfig:
"""Dataclass which contains all engine-related configuration. This
Expand All @@ -1093,6 +1108,7 @@ class EngineConfig:
lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig]
speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig]
tensorizer_config: Optional[TensorizerConfig]

def __post_init__(self):
Expand Down
18 changes: 15 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from dataclasses import dataclass
from typing import BinaryIO, Optional, Union

from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, TensorizerConfig,
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig, TensorizerConfig,
TokenizerPoolConfig, VisionLanguageConfig)
from vllm.model_executor.tensorizer_loader import TensorizerArgs
from vllm.utils import str_to_int_tuple
Expand Down Expand Up @@ -80,6 +80,7 @@ class EngineArgs:
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False

guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
Expand Down Expand Up @@ -200,6 +201,13 @@ def add_cli_args(
default=EngineArgs.max_model_len,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
parser.add_argument(
'--guided-decoding-backend',
type=str,
default='outlines',
choices=['outlines', 'lm-format-enforcer'],
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc)')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
Expand Down Expand Up @@ -511,6 +519,9 @@ def create_engine_config(self, ) -> EngineConfig:
else:
vision_language_config = None

decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)

return EngineConfig(model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
Expand All @@ -519,6 +530,7 @@ def create_engine_config(self, ) -> EngineConfig:
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
decoding_config=decoding_config,
tensorizer_config=tensorizer_config)


Expand Down
10 changes: 7 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from transformers import PreTrainedTokenizer

import vllm
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TensorizerConfig, VisionLanguageConfig)
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, TensorizerConfig,
VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
tensorizer_config: Optional[TensorizerConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
Expand All @@ -100,6 +102,7 @@ def __init__(
f"kv_cache_dtype={cache_config.cache_dtype}, "
f"quantization_param_path={model_config.quantization_param_path}, "
f"device_config={device_config.device}, "
f"decoding_config={decoding_config!r}, "
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.

Expand All @@ -111,6 +114,7 @@ def __init__(
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.decoding_config = decoding_config or DecodingConfig()
self.tensorizer_config = tensorizer_config
self.log_stats = log_stats

Expand Down
Loading