Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arbitrary json_object in OpenAI and Context Free Grammar #3211

Merged
merged 12 commits into from
Mar 16, 2024
50 changes: 50 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,5 +660,55 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))


async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role":
"user",
"content": ('what is 1+1? please respond with a JSON object, '
'the format is {"result": 2}')
}],
response_format={"type": "json_object"})

content = resp.choices[0].message.content
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded


async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """
start: select_statement

select_statement: "SELECT" column "from" table "where" condition

column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number

number: "1" | "2"
"""

completion = await client.completions.create(
model=MODEL_NAME,
prompt=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_grammar=simple_sql_grammar))

content = completion.choices[0].text

# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(simple_sql_grammar)
parser.parse(content)

# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")

assert content.strip() == ground_truth


if __name__ == "__main__":
pytest.main([__file__])
4 changes: 2 additions & 2 deletions tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def test_contexted_kv_attention(
torch.cuda.manual_seed(0)
torch.set_default_device(device)

# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
# Need this, otherwise when we capture the graph the process for GPU 1 would
# run on both GPU0 and GPU1 and things would hang
#
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch.cuda.set_device(device)
Expand Down
9 changes: 9 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0


class ResponseFormat(BaseModel):
# type must be "json_object" or "text"
type: str = Literal["text", "json_object"]


class ChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
Expand Down Expand Up @@ -89,6 +94,8 @@ class ChatCompletionRequest(BaseModel):
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
guided_grammar: Optional[str] = None
response_format: Optional[ResponseFormat] = None

def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
Expand Down Expand Up @@ -183,6 +190,8 @@ class CompletionRequest(BaseModel):
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
guided_grammar: Optional[str] = None
response_format: Optional[ResponseFormat] = None

def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
Expand Down
54 changes: 41 additions & 13 deletions vllm/model_executor/guided_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,50 @@
from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Union, Tuple

from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.openai.protocol import (CompletionRequest,
ChatCompletionRequest)
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
RegexLogitsProcessor)
RegexLogitsProcessor,
CFGLogitsProcessor)


class GuidedDecodingMode(Enum):
JSON = "json"
REGEX = "regex"
CHOICE = "choice"
GRAMMAR = "grammar"


# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
# the main difference is that we changed the start: value to
# start: object | array, so we are denying scalar values as the root of the
# JSON. Starting with scalars as the root seems to cause llama to generate
# without stop.
JSON_GRAMMAR = r"""
?start: object | array

?value: object
| array
| UNESCAPED_STRING
| SIGNED_NUMBER -> number
| "true" -> true
| "false" -> false
| "null" -> null

array : "[" [value ("," value)*] "]"
object : "{" [pair ("," pair)*] "}"
pair : UNESCAPED_STRING ":" value

%import common.UNESCAPED_STRING
%import common.SIGNED_NUMBER
%import common.WS

%ignore WS
"""

global_thread_pool = None # used for generating logits processor fsm

Expand Down Expand Up @@ -57,9 +88,6 @@ def _get_guide_and_mode(
) -> Tuple[str, GuidedDecodingMode]:

if request.guided_json:
if not isinstance(request.guided_json, (str, dict, BaseModel)):
raise TypeError("JSON schema must be str, dict, or BaseModel")

json = request.guided_json
if isinstance(json, dict):
# turn dict into hashable string
Expand All @@ -69,33 +97,33 @@ def _get_guide_and_mode(
# with the same fields will get hashed the same
json = str(json.__signature__)
return json, GuidedDecodingMode.JSON

elif request.guided_regex:
if not isinstance(request.guided_regex, str):
raise TypeError("Regex must be string")
return request.guided_regex, GuidedDecodingMode.REGEX

elif request.guided_choice:
if not isinstance(request.guided_choice, list):
raise TypeError("Choices must be a list")

# choice just uses regex
choices = [
regex_escape(str(choice)) for choice in request.guided_choice
]
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE

elif request.guided_grammar:
return request.guided_grammar, GuidedDecodingMode.GRAMMAR
elif (request.response_format is not None
and request.response_format.type == "json_object"):
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
else:
return None, None


@lru_cache(maxsize=32)
def _get_cached_logits_processor(guide: str, tokenizer,
def _get_cached_logits_processor(guide: str,
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode):
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.GRAMMAR:
return CFGLogitsProcessor(guide, tokenizer)
else:
raise ValueError(f"Unknown guided decoding mode {mode}")
112 changes: 76 additions & 36 deletions vllm/model_executor/guided_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,60 @@
import json
import math
from collections import defaultdict
from typing import Union, DefaultDict, Dict, List, Optional
from typing import Union, DefaultDict, Dict, List, Optional, Callable

import torch
from pydantic import BaseModel
from outlines.fsm.fsm import RegexFSM
from transformers import PreTrainedTokenizerBase
from outlines.fsm.fsm import RegexFSM, CFGFSM
from outlines.fsm.json_schema import build_regex_from_schema


class RegexLogitsProcessor:
class BaseLogitsProcessor:

def __init__(self, regex_string: str, tokenizer):
"""Compile the FSM that drives the regex-structured generation.
def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase):
"""Adapt vLLM's tokenizer to use to compile the FSM.

Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.

"""
tokenizer = self.adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm
if getattr(tokenizer, "_outlines_adapted", False):
return tokenizer

tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

def change_decoder(
decoder: Callable[[List[int]], str]
) -> Callable[[List[int]], List[str]]:
"""Sync vLLM's decoder with the outlines by returning list."""

def new_decoder(inp_tokens: List[int]) -> List[str]:
return [decoder(inp_tokens)]

return new_decoder

tokenizer.convert_token_to_string = convert_token_to_string
tokenizer.decode = change_decoder(tokenizer.decode)
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010

return tokenizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need return this?


def init_state(self):
"""Initialize the FSM states."""
Expand Down Expand Up @@ -69,38 +99,30 @@ def __call__(self, input_ids: List[int],

return scores

def adapt_tokenizer(self, tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.

The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.

"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])
class RegexLogitsProcessor(BaseLogitsProcessor):

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the regex-structured generation.

tokenizer.convert_token_to_string = convert_token_to_string
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer

return tokenizer
"""
tokenizer = self.adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm


class JSONLogitsProcessor(RegexLogitsProcessor):

def __init__(self,
schema: Union[str, Dict, BaseModel],
tokenizer,
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None):
"""Compile the FSM that drives the JSON-guided generation.

Expand Down Expand Up @@ -130,3 +152,21 @@ def __init__(self,
f"the JSON Schema specification")
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer)


class CFGLogitsProcessor(BaseLogitsProcessor):

def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the context free grammar generation.

Parameters
----------
cfg
A string that represents a context-free grammar
tokenizer
The model's tokenizer

"""
tokenizer = self.adapt_tokenizer(tokenizer)
fsm = CFGFSM(cfg, tokenizer)
self.fsm = fsm
Loading