Skip to content

Commit

Permalink
Support arbitrary json_object in OpenAI and Context Free Grammar (vll…
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-mo authored Mar 16, 2024
1 parent 8e67598 commit 120157f
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 49 deletions.
50 changes: 50 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,5 +660,55 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))


async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role":
"user",
"content": ('what is 1+1? please respond with a JSON object, '
'the format is {"result": 2}')
}],
response_format={"type": "json_object"})

content = resp.choices[0].message.content
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded


async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
"""

completion = await client.completions.create(
model=MODEL_NAME,
prompt=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_grammar=simple_sql_grammar))

content = completion.choices[0].text

# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(simple_sql_grammar)
parser.parse(content)

# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")

assert content.strip() == ground_truth


if __name__ == "__main__":
pytest.main([__file__])
9 changes: 9 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0


class ResponseFormat(BaseModel):
# type must be "json_object" or "text"
type: str = Literal["text", "json_object"]


class ChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
Expand Down Expand Up @@ -89,6 +94,8 @@ class ChatCompletionRequest(BaseModel):
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
guided_grammar: Optional[str] = None
response_format: Optional[ResponseFormat] = None

def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
Expand Down Expand Up @@ -183,6 +190,8 @@ class CompletionRequest(BaseModel):
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
guided_grammar: Optional[str] = None
response_format: Optional[ResponseFormat] = None

def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
Expand Down
54 changes: 41 additions & 13 deletions vllm/model_executor/guided_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,50 @@
from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Union, Tuple

from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.openai.protocol import (CompletionRequest,
ChatCompletionRequest)
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
RegexLogitsProcessor)
RegexLogitsProcessor,
CFGLogitsProcessor)


class GuidedDecodingMode(Enum):
JSON = "json"
REGEX = "regex"
CHOICE = "choice"
GRAMMAR = "grammar"


# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
# the main difference is that we changed the start: value to
# start: object | array, so we are denying scalar values as the root of the
# JSON. Starting with scalars as the root seems to cause llama to generate
# without stop.
JSON_GRAMMAR = r"""
?start: object | array
?value: object
| array
| UNESCAPED_STRING
| SIGNED_NUMBER -> number
| "true" -> true
| "false" -> false
| "null" -> null
array : "[" [value ("," value)*] "]"
object : "{" [pair ("," pair)*] "}"
pair : UNESCAPED_STRING ":" value
%import common.UNESCAPED_STRING
%import common.SIGNED_NUMBER
%import common.WS
%ignore WS
"""

global_thread_pool = None # used for generating logits processor fsm

Expand Down Expand Up @@ -57,9 +88,6 @@ def _get_guide_and_mode(
) -> Tuple[str, GuidedDecodingMode]:

if request.guided_json:
if not isinstance(request.guided_json, (str, dict, BaseModel)):
raise TypeError("JSON schema must be str, dict, or BaseModel")

json = request.guided_json
if isinstance(json, dict):
# turn dict into hashable string
Expand All @@ -69,33 +97,33 @@ def _get_guide_and_mode(
# with the same fields will get hashed the same
json = str(json.__signature__)
return json, GuidedDecodingMode.JSON

elif request.guided_regex:
if not isinstance(request.guided_regex, str):
raise TypeError("Regex must be string")
return request.guided_regex, GuidedDecodingMode.REGEX

elif request.guided_choice:
if not isinstance(request.guided_choice, list):
raise TypeError("Choices must be a list")

# choice just uses regex
choices = [
regex_escape(str(choice)) for choice in request.guided_choice
]
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE

elif request.guided_grammar:
return request.guided_grammar, GuidedDecodingMode.GRAMMAR
elif (request.response_format is not None
and request.response_format.type == "json_object"):
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
else:
return None, None


@lru_cache(maxsize=32)
def _get_cached_logits_processor(guide: str, tokenizer,
def _get_cached_logits_processor(guide: str,
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode):
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.GRAMMAR:
return CFGLogitsProcessor(guide, tokenizer)
else:
raise ValueError(f"Unknown guided decoding mode {mode}")
112 changes: 76 additions & 36 deletions vllm/model_executor/guided_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,60 @@
import json
import math
from collections import defaultdict
from typing import Union, DefaultDict, Dict, List, Optional
from typing import Union, DefaultDict, Dict, List, Optional, Callable

import torch
from pydantic import BaseModel
from outlines.fsm.fsm import RegexFSM
from transformers import PreTrainedTokenizerBase
from outlines.fsm.fsm import RegexFSM, CFGFSM
from outlines.fsm.json_schema import build_regex_from_schema


class RegexLogitsProcessor:
class BaseLogitsProcessor:

def __init__(self, regex_string: str, tokenizer):
"""Compile the FSM that drives the regex-structured generation.
def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase):
"""Adapt vLLM's tokenizer to use to compile the FSM.
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.
"""
tokenizer = self.adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm
if getattr(tokenizer, "_outlines_adapted", False):
return tokenizer

tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

def change_decoder(
decoder: Callable[[List[int]], str]
) -> Callable[[List[int]], List[str]]:
"""Sync vLLM's decoder with the outlines by returning list."""

def new_decoder(inp_tokens: List[int]) -> List[str]:
return [decoder(inp_tokens)]

return new_decoder

tokenizer.convert_token_to_string = convert_token_to_string
tokenizer.decode = change_decoder(tokenizer.decode)
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010

return tokenizer

def init_state(self):
"""Initialize the FSM states."""
Expand Down Expand Up @@ -69,38 +99,30 @@ def __call__(self, input_ids: List[int],

return scores

def adapt_tokenizer(self, tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])
class RegexLogitsProcessor(BaseLogitsProcessor):

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the regex-structured generation.
tokenizer.convert_token_to_string = convert_token_to_string
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer
return tokenizer
"""
tokenizer = self.adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm


class JSONLogitsProcessor(RegexLogitsProcessor):

def __init__(self,
schema: Union[str, Dict, BaseModel],
tokenizer,
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None):
"""Compile the FSM that drives the JSON-guided generation.
Expand Down Expand Up @@ -130,3 +152,21 @@ def __init__(self,
f"the JSON Schema specification")
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer)


class CFGLogitsProcessor(BaseLogitsProcessor):

def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the context free grammar generation.
Parameters
----------
cfg
A string that represents a context-free grammar
tokenizer
The model's tokenizer
"""
tokenizer = self.adapt_tokenizer(tokenizer)
fsm = CFGFSM(cfg, tokenizer)
self.fsm = fsm

0 comments on commit 120157f

Please sign in to comment.