-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Support OpenAI batch file format #4794
Changes from 27 commits
e9fbe7e
5d27883
219eb24
a62c606
2f83da0
6f8b35c
d59023a
30e63fe
e02c46f
530fd70
3890872
36a2a95
2d94844
ecb5c5f
d4919ff
b1b29d0
b3eaa49
e159be0
3c26e47
459ea17
1aa656a
bbdd51f
d928050
8dcefe9
2546fbe
d8f20a7
89d059e
d792956
36339ea
0e2bf89
0aee415
6c3a5e2
07f47d8
301d53b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import subprocess | ||
import sys | ||
import tempfile | ||
|
||
from vllm.entrypoints.openai.protocol import BatchRequestOutput | ||
|
||
# ruff: noqa: E501 | ||
INPUT_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} | ||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}""" | ||
|
||
|
||
def test_e2e(): | ||
with tempfile.NamedTemporaryFile( | ||
"w") as input_file, tempfile.NamedTemporaryFile( | ||
"r") as output_file: | ||
input_file.write(INPUT_BATCH) | ||
input_file.flush() | ||
proc = subprocess.Popen([ | ||
sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i", | ||
input_file.name, "-o", output_file.name, "--model", | ||
"NousResearch/Meta-Llama-3-8B-Instruct" | ||
], ) | ||
proc.communicate() | ||
proc.wait() | ||
assert proc.returncode == 0, f"{proc=}" | ||
|
||
contents = output_file.read() | ||
for line in contents.strip().split("\n"): | ||
BatchRequestOutput.model_validate_json(line) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,7 +101,7 @@ async def show_version(): | |
async def create_chat_completion(request: ChatCompletionRequest, | ||
raw_request: Request): | ||
generator = await openai_serving_chat.create_chat_completion( | ||
request, raw_request) | ||
request, raw_request.is_disconnected) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer if this PR did not touch this file |
||
if isinstance(generator, ErrorResponse): | ||
return JSONResponse(content=generator.model_dump(), | ||
status_code=generator.code) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -492,3 +492,44 @@ class ChatCompletionStreamResponse(OpenAIBaseModel): | |
model: str | ||
choices: List[ChatCompletionResponseStreamChoice] | ||
usage: Optional[UsageInfo] = Field(default=None) | ||
|
||
|
||
class BatchRequestInput(OpenAIBaseModel): | ||
""" | ||
The per-line object of the batch input file. | ||
|
||
NOTE: Currently only the `/v1/chat/completions` endpoint is supported. | ||
""" | ||
|
||
# A developer-provided per-request id that will be used to match outputs to | ||
# inputs. Must be unique for each request in a batch. | ||
custom_id: str | ||
|
||
# The HTTP method to be used for the request. Currently only POST is | ||
# supported. | ||
method: str | ||
|
||
# The OpenAI API relative URL to be used for the request. Currently | ||
# /v1/chat/completions is supported. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason we cannot support There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [ This could be done in a follow up PR ] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we can support the other 2 endpoints in follow ups. |
||
url: str | ||
|
||
# The parameteters of the request. | ||
body: Union[ChatCompletionRequest, ] | ||
|
||
|
||
class BatchRequestOutput(OpenAIBaseModel): | ||
""" | ||
The per-line object of the batch output and error files | ||
""" | ||
|
||
id: str | ||
|
||
# A developer-provided per-request id that will be used to match outputs to | ||
# inputs. | ||
custom_id: str | ||
|
||
response: Optional[ChatCompletionResponse] | ||
|
||
# For requests that failed with a non-HTTP error, this will contain more | ||
# information on the cause of the failure. | ||
error: Optional[Any] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import argparse | ||
import asyncio | ||
import sys | ||
from io import StringIO | ||
|
||
import aiohttp | ||
|
||
import vllm | ||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str | ||
from vllm.engine.async_llm_engine import AsyncLLMEngine | ||
from vllm.entrypoints.openai.protocol import (BatchRequestInput, | ||
BatchRequestOutput, | ||
ChatCompletionResponse) | ||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat | ||
from vllm.logger import init_logger | ||
from vllm.usage.usage_lib import UsageContext | ||
from vllm.utils import random_uuid | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description="vLLM OpenAI-Compatible batch runner.") | ||
parser.add_argument( | ||
"-i", | ||
"--input-file", | ||
required=True, | ||
type=str, | ||
help= | ||
"The path or url to a single input file. Currently supports local file " | ||
"paths, or the http protocol (http or https). If a URL is specified, " | ||
"the file should be available via HTTP GET.") | ||
parser.add_argument( | ||
"-o", | ||
"--output-file", | ||
required=True, | ||
type=str, | ||
help="The path or url to a single output file. Currently supports " | ||
"local file paths, or web (http or https) urls. If a URL is specified," | ||
" the file should be available via HTTP PUT.") | ||
parser.add_argument("--response-role", | ||
type=nullable_str, | ||
default="assistant", | ||
help="The role name to return if " | ||
"`request.add_generation_prompt=true`.") | ||
|
||
parser = AsyncEngineArgs.add_cli_args(parser) | ||
return parser.parse_args() | ||
|
||
|
||
async def read_file(path_or_url: str) -> str: | ||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"): | ||
async with aiohttp.ClientSession() as session, \ | ||
session.get(path_or_url) as resp: | ||
return await resp.text() | ||
else: | ||
with open(path_or_url, "r") as f: | ||
return f.read() | ||
|
||
|
||
async def write_file(path_or_url: str, data: str) -> None: | ||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"): | ||
async with aiohttp.ClientSession() as session, \ | ||
session.put(path_or_url, data=data.encode("utf-8")): | ||
pass | ||
else: | ||
# We should make this async, but as long as this is always run as a | ||
# standalone program, blocking the event loop won't effect performance | ||
# in this particular case. | ||
with open(path_or_url, "w") as f: | ||
f.write(data) | ||
|
||
|
||
async def run_request(chat_serving: OpenAIServingChat, | ||
request: BatchRequestInput) -> BatchRequestOutput: | ||
chat_request = request.body | ||
chat_response = await chat_serving.create_chat_completion(chat_request) | ||
if isinstance(chat_response, ChatCompletionResponse): | ||
batch_output = BatchRequestOutput( | ||
id=f"vllm-{random_uuid()}", | ||
custom_id=request.custom_id, | ||
response=chat_response, | ||
error=None, | ||
) | ||
else: | ||
batch_output = BatchRequestOutput( | ||
id=f"vllm-{random_uuid()}", | ||
custom_id=request.custom_id, | ||
response=None, | ||
error=chat_response, | ||
) | ||
return batch_output | ||
|
||
|
||
async def main(args): | ||
if args.served_model_name is not None: | ||
served_model_names = args.served_model_name | ||
else: | ||
served_model_names = [args.model] | ||
|
||
engine_args = AsyncEngineArgs.from_cli_args(args) | ||
engine = AsyncLLMEngine.from_engine_args( | ||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER) | ||
|
||
# When using single vLLM without engine_use_ray | ||
model_config = await engine.get_model_config() | ||
|
||
openai_serving_chat = OpenAIServingChat( | ||
engine, | ||
model_config, | ||
served_model_names, | ||
args.response_role, | ||
) | ||
|
||
response_futures = [] | ||
for request_json in (await read_file(args.input_file)).strip().split("\n"): | ||
request = BatchRequestInput.model_validate_json(request_json) | ||
response_futures.append(run_request(openai_serving_chat, request)) | ||
|
||
responses = await asyncio.gather(*response_futures) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume this lets each request run concurrently right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, that way the async engine can decide how to queue/batch things. |
||
|
||
output_buffer = StringIO() | ||
for response in responses: | ||
print(response.model_dump_json(), file=output_buffer) | ||
|
||
output_buffer.seek(0) | ||
await write_file(args.output_file, output_buffer.read().strip()) | ||
|
||
# Temporary workaround for https://github.com/vllm-project/vllm/issues/4789 | ||
sys.exit(0) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
|
||
logger.info("vLLM API server version %s", vllm.__version__) | ||
logger.info("args: %s", args) | ||
|
||
asyncio.run(main(args)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,8 @@ | ||
import codecs | ||
import time | ||
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, | ||
Optional, Tuple, TypedDict, Union, final) | ||
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Callable, | ||
Iterable, List, Optional, Tuple, TypedDict, Union, final) | ||
|
||
from fastapi import Request | ||
from openai.types.chat import (ChatCompletionContentPartParam, | ||
ChatCompletionRole) | ||
|
||
|
@@ -100,9 +99,17 @@ def _parse_chat_message_content( | |
return [ConversationMessage(role=role, content="\n".join(texts))], [] | ||
|
||
async def create_chat_completion( | ||
self, request: ChatCompletionRequest, raw_request: Request | ||
self, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not think we should change the signature of this function in the PR. Is there a way we can refactor this such that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, are you just trying to keep this method signature backwards compatible? 2 possibilities
I'm fine either way, or if you had something else in mind. Let me know what you prefer. |
||
request: ChatCompletionRequest, | ||
is_aborted: Optional[Callable[[], Awaitable[bool]]] = None, | ||
) -> Union[ErrorResponse, AsyncGenerator[str, None], | ||
ChatCompletionResponse]: | ||
if is_aborted is None: | ||
|
||
async def always_false(): | ||
wuisawesome marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return False | ||
|
||
is_aborted = always_false | ||
"""Completion API similar to OpenAI's API. | ||
|
||
See https://platform.openai.com/docs/api-reference/chat/create | ||
|
@@ -166,7 +173,7 @@ async def create_chat_completion( | |
else: | ||
try: | ||
return await self.chat_completion_full_generator( | ||
request, raw_request, result_generator, request_id, | ||
request, is_aborted, result_generator, request_id, | ||
conversation) | ||
except ValueError as e: | ||
# TODO: Use a vllm-specific Validation Error | ||
|
@@ -319,7 +326,8 @@ async def chat_completion_stream_generator( | |
yield "data: [DONE]\n\n" | ||
|
||
async def chat_completion_full_generator( | ||
self, request: ChatCompletionRequest, raw_request: Request, | ||
self, request: ChatCompletionRequest, | ||
is_aborted: Callable[[], Awaitable[bool]], | ||
result_generator: AsyncIterator[RequestOutput], request_id: str, | ||
conversation: List[ConversationMessage] | ||
) -> Union[ErrorResponse, ChatCompletionResponse]: | ||
|
@@ -329,7 +337,7 @@ async def chat_completion_full_generator( | |
final_res: Optional[RequestOutput] = None | ||
|
||
async for res in result_generator: | ||
if await raw_request.is_disconnected(): | ||
if await is_aborted(): | ||
# Abort the request if the client disconnects. | ||
await self.engine.abort(request_id) | ||
return self.create_error_response("Client disconnected") | ||
|
@@ -387,4 +395,4 @@ async def chat_completion_full_generator( | |
usage=usage, | ||
) | ||
|
||
return response | ||
return response |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It there any way we can add a test that the input / output formats conform to the openai api
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure