diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index a5b2bf4c0f0c9..86d9a85af80b1 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -660,5 +660,55 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) +async def test_response_format_json_object(server, client: openai.AsyncOpenAI): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": + "user", + "content": ('what is 1+1? please respond with a JSON object, ' + 'the format is {"result": 2}') + }], + response_format={"type": "json_object"}) + + content = resp.choices[0].message.content + loaded = json.loads(content) + assert loaded == {"result": 2}, loaded + + +async def test_guided_grammar(server, client: openai.AsyncOpenAI): + simple_sql_grammar = """ +start: select_statement + +select_statement: "SELECT" column "from" table "where" condition + +column: "col_1" | "col_2" +table: "table_1" | "table_2" +condition: column "=" number + +number: "1" | "2" +""" + + completion = await client.completions.create( + model=MODEL_NAME, + prompt=("Generate a sql state that select col_1 from " + "table_1 where it is equals to 1"), + temperature=1.0, + max_tokens=500, + extra_body=dict(guided_grammar=simple_sql_grammar)) + + content = completion.choices[0].text + + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(simple_sql_grammar) + parser.parse(content) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") + + assert content.strip() == ground_truth + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 26499b8d7a66f..9421880411611 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -55,6 +55,11 @@ class UsageInfo(BaseModel): completion_tokens: Optional[int] = 0 +class ResponseFormat(BaseModel): + # type must be "json_object" or "text" + type: str = Literal["text", "json_object"] + + class ChatCompletionRequest(BaseModel): model: str messages: List[Dict[str, str]] @@ -89,6 +94,8 @@ class ChatCompletionRequest(BaseModel): guided_json: Optional[Union[str, dict, BaseModel]] = None guided_regex: Optional[str] = None guided_choice: Optional[List[str]] = None + guided_grammar: Optional[str] = None + response_format: Optional[ResponseFormat] = None def to_sampling_params(self) -> SamplingParams: if self.logprobs and not self.top_logprobs: @@ -183,6 +190,8 @@ class CompletionRequest(BaseModel): guided_json: Optional[Union[str, dict, BaseModel]] = None guided_regex: Optional[str] = None guided_choice: Optional[List[str]] = None + guided_grammar: Optional[str] = None + response_format: Optional[ResponseFormat] = None def to_sampling_params(self): echo_without_generation = self.echo and self.max_tokens == 0 diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 00984460d79a6..bd09cf9cb6ee3 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -6,19 +6,50 @@ from json import dumps as json_dumps from re import escape as regex_escape from typing import Union, Tuple + from pydantic import BaseModel +from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import (CompletionRequest, ChatCompletionRequest) from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor, - RegexLogitsProcessor) + RegexLogitsProcessor, + CFGLogitsProcessor) class GuidedDecodingMode(Enum): JSON = "json" REGEX = "regex" CHOICE = "choice" + GRAMMAR = "grammar" + + +# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark +# the main difference is that we changed the start: value to +# start: object | array, so we are denying scalar values as the root of the +# JSON. Starting with scalars as the root seems to cause llama to generate +# without stop. +JSON_GRAMMAR = r""" +?start: object | array + +?value: object +| array +| UNESCAPED_STRING +| SIGNED_NUMBER -> number +| "true" -> true +| "false" -> false +| "null" -> null + +array : "[" [value ("," value)*] "]" +object : "{" [pair ("," pair)*] "}" +pair : UNESCAPED_STRING ":" value + +%import common.UNESCAPED_STRING +%import common.SIGNED_NUMBER +%import common.WS +%ignore WS +""" global_thread_pool = None # used for generating logits processor fsm @@ -57,9 +88,6 @@ def _get_guide_and_mode( ) -> Tuple[str, GuidedDecodingMode]: if request.guided_json: - if not isinstance(request.guided_json, (str, dict, BaseModel)): - raise TypeError("JSON schema must be str, dict, or BaseModel") - json = request.guided_json if isinstance(json, dict): # turn dict into hashable string @@ -69,33 +97,33 @@ def _get_guide_and_mode( # with the same fields will get hashed the same json = str(json.__signature__) return json, GuidedDecodingMode.JSON - elif request.guided_regex: - if not isinstance(request.guided_regex, str): - raise TypeError("Regex must be string") return request.guided_regex, GuidedDecodingMode.REGEX - elif request.guided_choice: - if not isinstance(request.guided_choice, list): - raise TypeError("Choices must be a list") - # choice just uses regex choices = [ regex_escape(str(choice)) for choice in request.guided_choice ] choices_regex = "(" + "|".join(choices) + ")" return choices_regex, GuidedDecodingMode.CHOICE - + elif request.guided_grammar: + return request.guided_grammar, GuidedDecodingMode.GRAMMAR + elif (request.response_format is not None + and request.response_format.type == "json_object"): + return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR else: return None, None @lru_cache(maxsize=32) -def _get_cached_logits_processor(guide: str, tokenizer, +def _get_cached_logits_processor(guide: str, + tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode): if mode == GuidedDecodingMode.JSON: return JSONLogitsProcessor(guide, tokenizer) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: return RegexLogitsProcessor(guide, tokenizer) + elif mode == GuidedDecodingMode.GRAMMAR: + return CFGLogitsProcessor(guide, tokenizer) else: raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_logits_processors.py index 76d41aa37dd7b..2cd1ae1571065 100644 --- a/vllm/model_executor/guided_logits_processors.py +++ b/vllm/model_executor/guided_logits_processors.py @@ -16,30 +16,60 @@ import json import math from collections import defaultdict -from typing import Union, DefaultDict, Dict, List, Optional +from typing import Union, DefaultDict, Dict, List, Optional, Callable import torch from pydantic import BaseModel -from outlines.fsm.fsm import RegexFSM +from transformers import PreTrainedTokenizerBase +from outlines.fsm.fsm import RegexFSM, CFGFSM from outlines.fsm.json_schema import build_regex_from_schema -class RegexLogitsProcessor: +class BaseLogitsProcessor: - def __init__(self, regex_string: str, tokenizer): - """Compile the FSM that drives the regex-structured generation. + def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase): + """Adapt vLLM's tokenizer to use to compile the FSM. - Parameters - ---------- - regex_string - A string that represents a regular expression - tokenizer - The model's tokenizer + The API of Outlines tokenizers is slightly different to that of + `transformers`. The decoder of outlines, returns a list whereas + the decode of vLLM returns an str. To sync the vLLM decoder with + outlines internal api, the decoder should be adapted. In addition + we need to handle the missing spaces to Llama's tokenizer to be + able to compile FSMs for this model. """ - tokenizer = self.adapt_tokenizer(tokenizer) - fsm = RegexFSM(regex_string, tokenizer) - self.fsm = fsm + if getattr(tokenizer, "_outlines_adapted", False): + return tokenizer + + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + def change_decoder( + decoder: Callable[[List[int]], str] + ) -> Callable[[List[int]], List[str]]: + """Sync vLLM's decoder with the outlines by returning list.""" + + def new_decoder(inp_tokens: List[int]) -> List[str]: + return [decoder(inp_tokens)] + + return new_decoder + + tokenizer.convert_token_to_string = convert_token_to_string + tokenizer.decode = change_decoder(tokenizer.decode) + setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 + + return tokenizer def init_state(self): """Initialize the FSM states.""" @@ -69,38 +99,30 @@ def __call__(self, input_ids: List[int], return scores - def adapt_tokenizer(self, tokenizer): - """Adapt vLLM's tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. In addition we need to handle the missing spaces to - Llama's tokenizer to be able to compile FSMs for this model. - - """ - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - string = tokenizer.convert_tokens_to_string([token]) +class RegexLogitsProcessor(BaseLogitsProcessor): - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string + def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): + """Compile the FSM that drives the regex-structured generation. - tokenizer.convert_token_to_string = convert_token_to_string + Parameters + ---------- + regex_string + A string that represents a regular expression + tokenizer + The model's tokenizer - return tokenizer + """ + tokenizer = self.adapt_tokenizer(tokenizer) + fsm = RegexFSM(regex_string, tokenizer) + self.fsm = fsm class JSONLogitsProcessor(RegexLogitsProcessor): def __init__(self, schema: Union[str, Dict, BaseModel], - tokenizer, + tokenizer: PreTrainedTokenizerBase, whitespace_pattern: Optional[str] = None): """Compile the FSM that drives the JSON-guided generation. @@ -130,3 +152,21 @@ def __init__(self, f"the JSON Schema specification") regex_string = build_regex_from_schema(schema_str, whitespace_pattern) super().__init__(regex_string, tokenizer) + + +class CFGLogitsProcessor(BaseLogitsProcessor): + + def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): + """Compile the FSM that drives the context free grammar generation. + + Parameters + ---------- + cfg + A string that represents a context-free grammar + tokenizer + The model's tokenizer + + """ + tokenizer = self.adapt_tokenizer(tokenizer) + fsm = CFGFSM(cfg, tokenizer) + self.fsm = fsm