Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Support OpenAI batch file format #4794

Merged
merged 34 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ py-cpuinfo
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
aiohttp
openai
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
Expand Down
29 changes: 29 additions & 0 deletions tests/entrypoints/test_openai_run_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import subprocess
import sys
import tempfile

from vllm.entrypoints.openai.protocol import BatchRequestOutput

# ruff: noqa: E501
INPUT_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""


def test_e2e():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It there any way we can add a test that the input / output formats conform to the openai api

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

with tempfile.NamedTemporaryFile(
"w") as input_file, tempfile.NamedTemporaryFile(
"r") as output_file:
input_file.write(INPUT_BATCH)
input_file.flush()
proc = subprocess.Popen([
sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i",
input_file.name, "-o", output_file.name, "--model",
"NousResearch/Meta-Llama-3-8B-Instruct"
], )
proc.communicate()
proc.wait()
assert proc.returncode == 0, f"{proc=}"

contents = output_file.read()
for line in contents.strip().split("\n"):
BatchRequestOutput.model_validate_json(line)
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def show_version():
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
request, raw_request)
request, raw_request.is_disconnected)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer if this PR did not touch this file

if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
Expand Down
41 changes: 41 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,3 +492,44 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)


class BatchRequestInput(OpenAIBaseModel):
"""
The per-line object of the batch input file.

NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
"""

# A developer-provided per-request id that will be used to match outputs to
# inputs. Must be unique for each request in a batch.
custom_id: str

# The HTTP method to be used for the request. Currently only POST is
# supported.
method: str

# The OpenAI API relative URL to be used for the request. Currently
# /v1/chat/completions is supported.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we cannot support /v1/embeddings?

https://platform.openai.com/docs/guides/batch

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[ This could be done in a follow up PR ]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can support the other 2 endpoints in follow ups.

url: str

# The parameteters of the request.
body: Union[ChatCompletionRequest, ]


class BatchRequestOutput(OpenAIBaseModel):
"""
The per-line object of the batch output and error files
"""

id: str

# A developer-provided per-request id that will be used to match outputs to
# inputs.
custom_id: str

response: Optional[ChatCompletionResponse]

# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error: Optional[Any]
140 changes: 140 additions & 0 deletions vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import argparse
import asyncio
import sys
from io import StringIO

import aiohttp

import vllm
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
ChatCompletionResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid

logger = init_logger(__name__)


def parse_args():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible batch runner.")
parser.add_argument(
"-i",
"--input-file",
required=True,
type=str,
help=
"The path or url to a single input file. Currently supports local file "
"paths, or the http protocol (http or https). If a URL is specified, "
"the file should be available via HTTP GET.")
parser.add_argument(
"-o",
"--output-file",
required=True,
type=str,
help="The path or url to a single output file. Currently supports "
"local file paths, or web (http or https) urls. If a URL is specified,"
" the file should be available via HTTP PUT.")
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")

parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args()


async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \
session.get(path_or_url) as resp:
return await resp.text()
else:
with open(path_or_url, "r") as f:
return f.read()


async def write_file(path_or_url: str, data: str) -> None:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \
session.put(path_or_url, data=data.encode("utf-8")):
pass
else:
# We should make this async, but as long as this is always run as a
# standalone program, blocking the event loop won't effect performance
# in this particular case.
with open(path_or_url, "w") as f:
f.write(data)


async def run_request(chat_serving: OpenAIServingChat,
request: BatchRequestInput) -> BatchRequestOutput:
chat_request = request.body
chat_response = await chat_serving.create_chat_completion(chat_request)
if isinstance(chat_response, ChatCompletionResponse):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=chat_response,
error=None,
)
else:
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=None,
error=chat_response,
)
return batch_output


async def main(args):
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]

engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)

# When using single vLLM without engine_use_ray
model_config = await engine.get_model_config()

openai_serving_chat = OpenAIServingChat(
engine,
model_config,
served_model_names,
args.response_role,
)

response_futures = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
request = BatchRequestInput.model_validate_json(request_json)
response_futures.append(run_request(openai_serving_chat, request))

responses = await asyncio.gather(*response_futures)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this lets each request run concurrently right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, that way the async engine can decide how to queue/batch things.


output_buffer = StringIO()
for response in responses:
print(response.model_dump_json(), file=output_buffer)

output_buffer.seek(0)
await write_file(args.output_file, output_buffer.read().strip())

# Temporary workaround for https://github.com/vllm-project/vllm/issues/4789
sys.exit(0)


if __name__ == "__main__":
args = parse_args()

logger.info("vLLM API server version %s", vllm.__version__)
logger.info("args: %s", args)

asyncio.run(main(args))
24 changes: 16 additions & 8 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import codecs
import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
Optional, Tuple, TypedDict, Union, final)
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Callable,
Iterable, List, Optional, Tuple, TypedDict, Union, final)

from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartParam,
ChatCompletionRole)

Expand Down Expand Up @@ -100,9 +99,17 @@ def _parse_chat_message_content(
return [ConversationMessage(role=role, content="\n".join(texts))], []

async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request
self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we should change the signature of this function in the PR. Is there a way we can refactor this such that the run_batch.py can use RawRequest in the same way as the api_server.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, are you just trying to keep this method signature backwards compatible?

2 possibilities

  1. Make raw_request: Optional[Request] and default to never aborting if raw_request is None
  2. I can make a fake RawRequest in the run_batch.py to pass in.

I'm fine either way, or if you had something else in mind. Let me know what you prefer.

request: ChatCompletionRequest,
is_aborted: Optional[Callable[[], Awaitable[bool]]] = None,
) -> Union[ErrorResponse, AsyncGenerator[str, None],
ChatCompletionResponse]:
if is_aborted is None:

async def always_false():
wuisawesome marked this conversation as resolved.
Show resolved Hide resolved
return False

is_aborted = always_false
"""Completion API similar to OpenAI's API.

See https://platform.openai.com/docs/api-reference/chat/create
Expand Down Expand Up @@ -166,7 +173,7 @@ async def create_chat_completion(
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id,
request, is_aborted, result_generator, request_id,
conversation)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
Expand Down Expand Up @@ -319,7 +326,8 @@ async def chat_completion_stream_generator(
yield "data: [DONE]\n\n"

async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request,
self, request: ChatCompletionRequest,
is_aborted: Callable[[], Awaitable[bool]],
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
) -> Union[ErrorResponse, ChatCompletionResponse]:
Expand All @@ -329,7 +337,7 @@ async def chat_completion_full_generator(
final_res: Optional[RequestOutput] = None

async for res in result_generator:
if await raw_request.is_disconnected():
if await is_aborted():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
return self.create_error_response("Client disconnected")
Expand Down Expand Up @@ -387,4 +395,4 @@ async def chat_completion_full_generator(
usage=usage,
)

return response
return response
Loading