From 05434764cd99990035779cf9a4ed86623b528825 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Tue, 16 Apr 2024 08:54:57 +0300 Subject: [PATCH] LM Format Enforcer Guided Decoding Support (#3868) Co-authored-by: Simon Mo --- requirements-common.txt | 1 + tests/entrypoints/test_guided_processors.py | 42 +++++++- tests/entrypoints/test_openai_server.py | 69 ++++++++---- vllm/config.py | 26 ++++- vllm/engine/arg_utils.py | 18 +++- vllm/engine/llm_engine.py | 10 +- vllm/entrypoints/openai/protocol.py | 12 +++ vllm/entrypoints/openai/serving_chat.py | 6 +- vllm/entrypoints/openai/serving_completion.py | 6 +- .../guided_decoding/__init__.py | 25 +++++ .../lm_format_enforcer_decoding.py | 69 ++++++++++++ .../outlines_decoding.py} | 7 +- .../outlines_logits_processors.py} | 100 +++++++++--------- 13 files changed, 304 insertions(+), 87 deletions(-) create mode 100644 vllm/model_executor/guided_decoding/__init__.py create mode 100644 vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py rename vllm/model_executor/{guided_decoding.py => guided_decoding/outlines_decoding.py} (93%) rename vllm/model_executor/{guided_logits_processors.py => guided_decoding/outlines_logits_processors.py} (70%) diff --git a/requirements-common.txt b/requirements-common.txt index 90a3bc8abc1db..c1614d2537b25 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -11,6 +11,7 @@ uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 tiktoken == 0.6.0 # Required for DBRX tokenizer +lm-format-enforcer == 0.9.3 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 5622744566bcc..30f0ad5d8272f 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -1,11 +1,14 @@ # This unit test should be moved to a new # tests/test_guided_decoding directory. - +import pytest import torch from transformers import AutoTokenizer -from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor, - RegexLogitsProcessor) +from vllm.entrypoints.openai.protocol import CompletionRequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + JSONLogitsProcessor, RegexLogitsProcessor) TEST_SCHEMA = { "type": "object", @@ -73,3 +76,36 @@ def test_guided_logits_processors(): json_LP(token_ids, tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"]) +async def test_guided_logits_processor_black_box(backend: str): + tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + token_ids = tokenizer.encode( + f"Give an example IPv4 address with this regex: {TEST_REGEX}") + regex_request = CompletionRequest(model='test', + prompt=token_ids, + guided_regex=TEST_REGEX) + regex_lp = await get_guided_decoding_logits_processor( + backend, regex_request, tokenizer) + assert regex_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = regex_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) + + token_ids = tokenizer.encode( + f"Give an employee profile that fits this schema: {TEST_SCHEMA}") + json_request = CompletionRequest(model='test', + prompt=token_ids, + guided_json=TEST_SCHEMA) + json_lp = await get_guided_decoding_logits_processor( + backend, json_request, tokenizer) + assert json_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = json_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 7940430b8b654..14e6ee0ffe9d9 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): assert first_response != completion.choices[0].text -async def test_guided_json_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt=f"Give an example JSON for an employee profile " @@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI): n=3, temperature=1.0, max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 3 @@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI): jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) -async def test_guided_json_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, - max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + max_tokens=1000, + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) message = chat_completion.choices[0].message assert message.content is not None json1 = json.loads(message.content) @@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, - max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + max_tokens=1000, + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) message = chat_completion.choices[0].message assert message.content is not None json2 = json.loads(message.content) @@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): assert json1["age"] != json2["age"] -async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}", n=3, temperature=1.0, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 3 @@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None -async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) ip1 = chat_completion.choices[0].message.content assert ip1 is not None assert re.fullmatch(TEST_REGEX, ip1) is not None @@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) ip2 = chat_completion.choices[0].message.content assert ip2 is not None assert re.fullmatch(TEST_REGEX, ip2) is not None assert ip1 != ip2 -async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt="The best language for type-safe systems programming is ", n=2, temperature=1.0, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 2 @@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): assert completion.choices[i].text in TEST_CHOICE -async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) choice1 = chat_completion.choices[0].message.content assert choice1 in TEST_CHOICE @@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) choice2 = chat_completion.choices[0].message.content assert choice2 in TEST_CHOICE assert choice1 != choice2 -async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): with pytest.raises(openai.BadRequestError): _ = await client.completions.create( model=MODEL_NAME, prompt="Give an example JSON that fits this schema: 42", - extra_body=dict(guided_json=42)) + extra_body=dict(guided_json=42, + guided_decoding_backend=guided_decoding_backend)) messages = [{ "role": "system", diff --git a/vllm/config.py b/vllm/config.py index dce2944b2ee8a..bf31b03b7c6c4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -66,8 +66,8 @@ class ModelConfig: weights. If None, we assume the model weights are not quantized. quantization_param_path: Path to JSON file containing scaling factors. Used to load KV cache scaling factors into the model when KV cache - type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also - be used to load activation and weight scaling factors when the + type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also + be used to load activation and weight scaling factors when the model dtype is FP8_E4M3 on ROCm. enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. @@ -422,7 +422,7 @@ def verify_with_parallel_config( @dataclass class TokenizerPoolConfig: """Configuration for the tokenizer pool. - + Args: pool_size: Number of tokenizer workers in the pool. pool_type: Type of the pool. @@ -446,9 +446,9 @@ def create_config( tokenizer_pool_extra_config: Optional[Union[str, dict]] ) -> Optional["TokenizerPoolConfig"]: """Create a TokenizerPoolConfig from the given parameters. - + If tokenizer_pool_size is 0, return None. - + Args: tokenizer_pool_size: Number of tokenizer workers in the pool. tokenizer_pool_type: Type of the pool. @@ -1079,6 +1079,21 @@ def _get_and_verify_max_len( return int(max_model_len) +@dataclass +class DecodingConfig: + """Dataclass which contains the decoding strategy of the engine""" + + # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer' + guided_decoding_backend: str = 'outlines' + + def __post_init__(self): + valid_guided_backends = ['outlines', 'lm-format-enforcer'] + backend = self.guided_decoding_backend + if backend not in valid_guided_backends: + raise ValueError(f"Invalid guided_decoding_backend '{backend}," + f"must be one of {valid_guided_backends}") + + @dataclass(frozen=True) class EngineConfig: """Dataclass which contains all engine-related configuration. This @@ -1093,6 +1108,7 @@ class EngineConfig: lora_config: Optional[LoRAConfig] vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] + decoding_config: Optional[DecodingConfig] tensorizer_config: Optional[TensorizerConfig] def __post_init__(self): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 831a03be65f61..3de74b0ac28b9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -5,9 +5,9 @@ from dataclasses import dataclass from typing import BinaryIO, Optional, Union -from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, TensorizerConfig, +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, + EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, SpeculativeConfig, TensorizerConfig, TokenizerPoolConfig, VisionLanguageConfig) from vllm.model_executor.tensorizer_loader import TensorizerArgs from vllm.utils import str_to_int_tuple @@ -80,6 +80,7 @@ class EngineArgs: scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False + guided_decoding_backend: str = 'outlines' # Speculative decoding configuration. speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None @@ -200,6 +201,13 @@ def add_cli_args( default=EngineArgs.max_model_len, help='model context length. If unspecified, ' 'will be automatically derived from the model.') + parser.add_argument( + '--guided-decoding-backend', + type=str, + default='outlines', + choices=['outlines', 'lm-format-enforcer'], + help='Which engine will be used for guided decoding' + ' (JSON schema / regex etc)') # Parallel arguments parser.add_argument('--worker-use-ray', action='store_true', @@ -511,6 +519,9 @@ def create_engine_config(self, ) -> EngineConfig: else: vision_language_config = None + decoding_config = DecodingConfig( + guided_decoding_backend=self.guided_decoding_backend) + return EngineConfig(model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, @@ -519,6 +530,7 @@ def create_engine_config(self, ) -> EngineConfig: lora_config=lora_config, vision_language_config=vision_language_config, speculative_config=speculative_config, + decoding_config=decoding_config, tensorizer_config=tensorizer_config) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8c37c5a9d6ee9..f06c1d18ace4b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -4,9 +4,10 @@ from transformers import PreTrainedTokenizer import vllm -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, TensorizerConfig, + VisionLanguageConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats @@ -74,6 +75,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + decoding_config: Optional[DecodingConfig], tensorizer_config: Optional[TensorizerConfig], executor_class: Type[ExecutorBase], log_stats: bool, @@ -100,6 +102,7 @@ def __init__( f"kv_cache_dtype={cache_config.cache_dtype}, " f"quantization_param_path={model_config.quantization_param_path}, " f"device_config={device_config.device}, " + f"decoding_config={decoding_config!r}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. @@ -111,6 +114,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.speculative_config = speculative_config + self.decoding_config = decoding_config or DecodingConfig() self.tensorizer_config = tensorizer_config self.log_stats = log_stats diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f94d22d279cc4..cf779d44c816b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -133,6 +133,12 @@ class ChatCompletionRequest(BaseModel): description=( "If specified, the output will follow the context free grammar."), ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'")) # doc: end-chat-completion-extra-params @@ -265,6 +271,12 @@ class CompletionRequest(BaseModel): description=( "If specified, the output will follow the context free grammar."), ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be one of " + "'outlines' / 'lm-format-enforcer'")) # doc: end-completion-extra-params diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a03c5dc88108f..c9ed4a9de20f4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -68,9 +68,13 @@ async def create_chat_completion( request, prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + decoding_config = self.engine.engine.decoding_config + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( await get_guided_decoding_logits_processor( - request, await self.engine.get_tokenizer())) + guided_decoding_backend, request, await + self.engine.get_tokenizer())) if guided_decode_logits_processor: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index e24aa2489a80f..a71f2d6a4426a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -88,9 +88,13 @@ async def create_completion(self, request: CompletionRequest, try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + decoding_config = self.engine.engine.decoding_config + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend guided_decode_logit_processor = ( await get_guided_decoding_logits_processor( - request, await self.engine.get_tokenizer())) + guided_decoding_backend, request, await + self.engine.get_tokenizer())) if guided_decode_logit_processor is not None: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py new file mode 100644 index 0000000000000..0558d6c95d97b --- /dev/null +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -0,0 +1,25 @@ +from typing import Optional, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest) +from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( + get_lm_format_enforcer_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_outlines_guided_decoding_logits_processor) +from vllm.sampling_params import LogitsProcessor + + +async def get_guided_decoding_logits_processor( + guided_decoding_backend: str, request: Union[CompletionRequest, + ChatCompletionRequest], + tokenizer) -> Optional[LogitsProcessor]: + if guided_decoding_backend == 'outlines': + return await get_outlines_guided_decoding_logits_processor( + request, tokenizer) + if guided_decoding_backend == 'lm-format-enforcer': + return await get_lm_format_enforcer_guided_decoding_logits_processor( + request, tokenizer) + + raise ValueError( + f"Unknown guided decoding backend '{guided_decoding_backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer'") diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py new file mode 100644 index 0000000000000..0d74a5f8e81ff --- /dev/null +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -0,0 +1,69 @@ +from functools import lru_cache +from json import loads as json_loads +from typing import Optional, Union + +from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser, + RegexParser, StringParser, + TokenEnforcerTokenizerData, UnionParser) +from lmformatenforcer.integrations.vllm import ( + build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) +from pydantic import BaseModel +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest) +from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_outlines_guided_decoding_logits_processor) +from vllm.sampling_params import LogitsProcessor + + +async def get_lm_format_enforcer_guided_decoding_logits_processor( + request: Union[CompletionRequest, ChatCompletionRequest], + tokenizer) -> Optional[LogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + + tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer) + character_level_parser: CharacterLevelParser + if request.guided_json: + schema = _normalize_json_schema_object(request.guided_json) + character_level_parser = JsonSchemaParser(schema) + elif request.guided_choice: + character_level_parser = UnionParser( + [StringParser(choice) for choice in request.guided_choice]) + elif request.guided_regex: + character_level_parser = RegexParser(request.guided_regex) + elif request.guided_grammar: + # CFG grammar not supported by LMFE, revert to outlines + return await get_outlines_guided_decoding_logits_processor( + request, tokenizer) + elif (request.response_format is not None + and request.response_format.type == "json_object"): + character_level_parser = JsonSchemaParser( + None) # None means any json object + else: + return None + + logits_processor = build_vllm_logits_processor(tokenizer_data, + character_level_parser) + return logits_processor + + +def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: + if isinstance(schema, str): + return json_loads(schema) + if isinstance(schema, dict): + return schema + if isinstance(schema, BaseModel): + return schema.model_json_schema() + + +@lru_cache +def _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData: + return build_vllm_token_enforcer_tokenizer_data(tokenizer) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py similarity index 93% rename from vllm/model_executor/guided_decoding.py rename to vllm/model_executor/guided_decoding/outlines_decoding.py index 8e710f1ac2b53..bd4564a36e1ed 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -12,9 +12,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) -from vllm.model_executor.guided_logits_processors import (CFGLogitsProcessor, - JSONLogitsProcessor, - RegexLogitsProcessor) +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) class GuidedDecodingMode(Enum): @@ -54,7 +53,7 @@ class GuidedDecodingMode(Enum): global_thread_pool = None # used for generating logits processor fsm -async def get_guided_decoding_logits_processor( +async def get_outlines_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: """ diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py similarity index 70% rename from vllm/model_executor/guided_logits_processors.py rename to vllm/model_executor/guided_decoding/outlines_logits_processors.py index 035fe00037328..28041695546dc 100644 --- a/vllm/model_executor/guided_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -13,9 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import json import math from collections import defaultdict +from functools import lru_cache from typing import Callable, DefaultDict, Dict, List, Optional, Union import torch @@ -27,50 +29,6 @@ class BaseLogitsProcessor: - def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase): - """Adapt vLLM's tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. The decoder of outlines, returns a list whereas - the decode of vLLM returns an str. To sync the vLLM decoder with - outlines internal api, the decoder should be adapted. In addition - we need to handle the missing spaces to Llama's tokenizer to be - able to compile FSMs for this model. - - """ - if getattr(tokenizer, "_outlines_adapted", False): - return tokenizer - - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = tokenizer.convert_tokens_to_string([token]) - - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - def change_decoder( - decoder: Callable[[List[int]], str] - ) -> Callable[[List[int]], List[str]]: - """Sync vLLM's decoder with the outlines by returning list.""" - - def new_decoder(inp_tokens: List[int]) -> List[str]: - return [decoder(inp_tokens)] - - return new_decoder - - tokenizer.convert_token_to_string = convert_token_to_string - tokenizer.decode = change_decoder(tokenizer.decode) - setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 - - return tokenizer - def init_state(self): """Initialize the FSM states.""" self.fsm_state: DefaultDict[int, int] = defaultdict(int) @@ -78,7 +36,6 @@ def init_state(self): def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" - seq_id = hash(tuple(input_ids)) if len(input_ids) == 0: @@ -96,7 +53,6 @@ def __call__(self, input_ids: List[int], device=scores.device) mask[allowed_tokens] = 0 scores.add_(mask) - return scores @@ -113,7 +69,7 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - tokenizer = self.adapt_tokenizer(tokenizer) + tokenizer = _adapt_tokenizer(tokenizer) fsm = RegexFSM(regex_string, tokenizer) self.fsm = fsm @@ -167,6 +123,54 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - tokenizer = self.adapt_tokenizer(tokenizer) + tokenizer = _adapt_tokenizer(tokenizer) fsm = CFGFSM(cfg, tokenizer) self.fsm = fsm + + +@lru_cache +def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. The decoder of outlines, returns a list whereas + the decode of vLLM returns an str. To sync the vLLM decoder with + outlines internal api, the decoder should be adapted. In addition + we need to handle the missing spaces to Llama's tokenizer to be + able to compile FSMs for this model. + + """ + if getattr(tokenizer, "_outlines_adapted", False): + return tokenizer + + tokenizer = copy.deepcopy(tokenizer) + + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + def change_decoder( + decoder: Callable[[List[int]], + str]) -> Callable[[List[int]], List[str]]: + """Sync vLLM's decoder with the outlines by returning list.""" + + def new_decoder(inp_tokens: List[int]) -> List[str]: + return [decoder(inp_tokens)] + + return new_decoder + + tokenizer.convert_token_to_string = convert_token_to_string + tokenizer.decode = change_decoder(tokenizer.decode) + setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 + + return tokenizer