Skip to content

Commit

Permalink
Implement structured engine for parsing json grammar by token.
Browse files Browse the repository at this point in the history
  • Loading branch information
pathorn committed Mar 14, 2024
1 parent 2efce05 commit 473f87b
Show file tree
Hide file tree
Showing 5 changed files with 1,680 additions and 0 deletions.
24 changes: 24 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import LoRA

from vllm.model_executor.structure_logits_processors import JSONStructureLogitsProcessor

TIMEOUT_KEEP_ALIVE = 5 # seconds

openai_serving_chat: OpenAIServingChat = None
Expand Down Expand Up @@ -141,6 +143,10 @@ def parse_args():
"If a function is provided, vLLM will add it to the server using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server using app.add_middleware(). "
)
parser.add_argument(
"--enable-json-mode",
action="store_true",
help="Enables JSON mode by passing response_format=\{\"type\":\"json_object\"\}")

parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args()
Expand Down Expand Up @@ -204,6 +210,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())


async def _post_init():
engine_model_config = await engine.get_model_config()
if args.enable_json_mode:
JSONStructureLogitsProcessor.init_static(engine_model_config, openai_serving_chat.tokenizer)


if __name__ == "__main__":
args = parse_args()

Expand Down Expand Up @@ -248,13 +260,25 @@ async def authentication(request: Request, call_next):

engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)

openai_serving_chat = OpenAIServingChat(engine, served_model,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
engine, served_model, args.lora_modules)

try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None

if event_loop is not None and event_loop.is_running(
): # If the current is instanced by Ray Serve, there is already a running event loop
event_loop.create_task(_post_init())
else: # When using single vLLM without engine_use_ray
asyncio.run(_post_init())

app.root_path = args.root_path
uvicorn.run(app,
host=args.host,
Expand Down
12 changes: 12 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import time
from typing import Dict, List, Literal, Optional, Union

from enum import Enum
from pydantic import BaseModel, Field, model_validator

from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
from vllm.model_executor.structure_logits_processors import JSONStructureLogitsProcessor

import torch

Expand Down Expand Up @@ -55,6 +57,10 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0


class ResponseFormat(BaseModel):
type: Literal["text", "json_object"] = "text"


class ChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
Expand Down Expand Up @@ -89,6 +95,7 @@ class ChatCompletionRequest(BaseModel):
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
response_format: Optional[ResponseFormat] = None

def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
Expand All @@ -107,6 +114,8 @@ def logit_bias_logits_processor(
return logits

logits_processors = [logit_bias_logits_processor]
if self.response_format and self.response_format.type == "json_object":
logits_processors = (logits_processors or []) + [JSONStructureLogitsProcessor()]

return SamplingParams(
n=self.n,
Expand Down Expand Up @@ -183,6 +192,7 @@ class CompletionRequest(BaseModel):
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
response_format: Optional[ResponseFormat] = None

def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
Expand All @@ -200,6 +210,8 @@ def logit_bias_logits_processor(
return logits

logits_processors = [logit_bias_logits_processor]
if self.response_format and self.response_format.type == "json_object":
logits_processors = (logits_processors or []) + [JSONStructureLogitsProcessor()]

return SamplingParams(
n=self.n,
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from vllm.logger import init_logger
from vllm.model_executor.structure_logits_processors import JSONStructureLogitsProcessor
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (CompletionRequest,
Expand Down
Loading

0 comments on commit 473f87b

Please sign in to comment.