Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arbitrary json_object in OpenAI and Context Free Grammar #3211

Merged
merged 12 commits into from
Mar 16, 2024
50 changes: 50 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,5 +650,55 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))


async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role":
"user",
"content":
'what is 1+1? please respond with a JSON object, the format is {"result": 2}'
}],
response_format={"type": "json_object"})

content = resp.choices[0].message.content
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded


async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """
start: select_statement

select_statement: "SELECT" column "from" table "where" condition

column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number

number: "1" | "2"
"""

completion = await client.completions.create(
model=MODEL_NAME,
prompt=
"Generate a sql state that select col_1 from table_1 where it is equals to 1",
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_cfg=simple_sql_grammar))

content = completion.choices[0].text

# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(simple_sql_grammar)
parser.parse(content)

# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")

assert content.strip() == ground_truth


if __name__ == "__main__":
pytest.main([__file__])
9 changes: 9 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0


class ResponseFormat(BaseModel):
# type must be "json_object" or "text"
type: str = Literal["text", "json_object"]


class ChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
Expand Down Expand Up @@ -89,6 +94,8 @@ class ChatCompletionRequest(BaseModel):
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
guided_cfg: Optional[str] = None
response_format: Optional[ResponseFormat] = None

def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
Expand Down Expand Up @@ -183,6 +190,8 @@ class CompletionRequest(BaseModel):
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
guided_cfg: Optional[str] = None
response_format: Optional[ResponseFormat] = None

def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
Expand Down
46 changes: 34 additions & 12 deletions vllm/model_executor/guided_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,42 @@
from pydantic import BaseModel

from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest
from vllm.model_executor.guided_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor
from vllm.model_executor.guided_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor


class GuidedDecodingMode(Enum):
JSON = "json"
REGEX = "regex"
CHOICE = "choice"
CFG = "cfg"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dumb question: does CFG mean config? only for grammar config?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context free grammar, good point that I should probably change it to grammar to disambiguate



# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
# the main difference is that we changed the start: value to start: object | array
# so we are denying scalar values as the root of the JSON. Starting with scalars as the root
# seems to cause llama to generate without stop.
JSON_CFG_GRAMMAR = r"""
?start: object | array

?value: object
| array
| UNESCAPED_STRING
| SIGNED_NUMBER -> number
| "true" -> true
| "false" -> false
| "null" -> null

array : "[" [value ("," value)*] "]"
object : "{" [pair ("," pair)*] "}"
pair : UNESCAPED_STRING ":" value

%import common.UNESCAPED_STRING
%import common.SIGNED_NUMBER
%import common.WS

%ignore WS
"""

global_thread_pool = None # used for generating logits processor fsm


Expand Down Expand Up @@ -55,9 +82,6 @@ def _get_guide_and_mode(
) -> Tuple[str, GuidedDecodingMode]:

if request.guided_json:
if not isinstance(request.guided_json, (str, dict, BaseModel)):
raise TypeError("JSON schema must be str, dict, or BaseModel")

json = request.guided_json
if isinstance(json, dict):
# turn dict into hashable string
Expand All @@ -67,23 +91,19 @@ def _get_guide_and_mode(
# with the same fields will get hashed the same
json = str(json.__signature__)
return json, GuidedDecodingMode.JSON

elif request.guided_regex:
if not isinstance(request.guided_regex, str):
raise TypeError("Regex must be string")
return request.guided_regex, GuidedDecodingMode.REGEX

elif request.guided_choice:
if not isinstance(request.guided_choice, list):
raise TypeError("Choices must be a list")

# choice just uses regex
choices = [
regex_escape(str(choice)) for choice in request.guided_choice
]
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE

elif request.guided_cfg:
return request.guided_cfg, GuidedDecodingMode.CFG
elif request.response_format is not None and request.response_format.type == "json_object":
return JSON_CFG_GRAMMAR, GuidedDecodingMode.CFG
else:
return None, None

Expand All @@ -95,5 +115,7 @@ def _get_cached_logits_processor(guide: str, tokenizer,
return JSONLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.CFG:
return CFGLogitsProcessor(guide, tokenizer)
else:
raise ValueError(f"Unknown guided decoding mode {mode}")
109 changes: 74 additions & 35 deletions vllm/model_executor/guided_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,59 @@
import json
import math
from collections import defaultdict
from typing import Union, DefaultDict, Dict, List, Optional
from typing import Union, DefaultDict, Dict, List, Optional, Callable

import torch
from pydantic import BaseModel
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.fsm import RegexFSM, CFGFSM
from outlines.fsm.json_schema import build_regex_from_schema


class RegexLogitsProcessor:
class BaseLogitsProcessor:

def __init__(self, regex_string: str, tokenizer):
"""Compile the FSM that drives the regex-structured generation.
def adapt_tokenizer(self, tokenizer):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Adapt vLLM's tokenizer to use to compile the FSM.

Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.

"""
tokenizer = self.adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm
if getattr(tokenizer, "_outlines_adapted", False):
return tokenizer

tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

def change_decoder(
decoder: Callable[[List[int]], str]
) -> Callable[[List[int]], List[str]]:
"""Sync vLLM's decoder with the outlines expectations by returning list"""

def new_decoder(inp_tokens: List[int]) -> List[str]:
return [decoder(inp_tokens)]

return new_decoder

tokenizer.convert_token_to_string = convert_token_to_string
tokenizer.decode = change_decoder(tokenizer.decode)
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010

return tokenizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need return this?


def init_state(self):
"""Initialize the FSM states."""
Expand Down Expand Up @@ -69,31 +98,23 @@ def __call__(self, input_ids: List[int],

return scores

def adapt_tokenizer(self, tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.

The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.

"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
class RegexLogitsProcessor(BaseLogitsProcessor):

return string
def __init__(self, regex_string: str, tokenizer):
"""Compile the FSM that drives the regex-structured generation.

tokenizer.convert_token_to_string = convert_token_to_string
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer

return tokenizer
"""
tokenizer = self.adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm


class JSONLogitsProcessor(RegexLogitsProcessor):
Expand Down Expand Up @@ -127,3 +148,21 @@ def __init__(self,
+ "Schema specification")
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer)


class CFGLogitsProcessor(BaseLogitsProcessor):

def __init__(self, cfg: str, tokenizer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding type definition for tokenizer in this whole file is better.

"""Compile the FSM that drives the CFG-structured generation.

Parameters
----------
cfg
A string that represents a context-free grammar
tokenizer
The model's tokenizer

"""
tokenizer = self.adapt_tokenizer(tokenizer)
fsm = CFGFSM(cfg, tokenizer)
self.fsm = fsm
Loading