Skip to content

[Bugfix] Update run_batch.py to handle larger numbers of batches #5774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
from io import StringIO
from typing import Awaitable, List

from asyncio import Semaphore
from tqdm.asyncio import tqdm
import aiohttp

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
Expand Down Expand Up @@ -100,6 +101,28 @@ async def run_request(chat_serving: OpenAIServingChat,

return batch_output

MAX_CONCURRENT_REQUESTS = 10000

async def process_requests(input_file: str, openai_serving_chat):
sem = Semaphore(MAX_CONCURRENT_REQUESTS)

async def run_request_with_semaphore(request):
async with sem:

return await run_request(openai_serving_chat, request)

response_futures: List[Awaitable[BatchRequestOutput]] = []

# Read file contents asynchronously
file_contents = (await read_file(input_file)).strip().split("\n")
total_lines = len(file_contents)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's figure out whether the engine queuing should be fixed instead, but if we decide to limit the requests we submit, I wonder if we should just not create the asyncio requests instead.

Something like

finished_requests = []
response_futures: List[Awaitable[BatchRequestOutput]] = []

for request_json in tqdm(file_contents, total=total_lines, desc="Running requests"):
        request = BatchRequestInput.model_validate_json(request_json.strip())
        response_futures.append(run_request_with_semaphore(request))
        if len(response_futures) >= MAX_CONCURRENT_REQUESTS:
            recently_finished, response_futures = await asyncio.wait(response_futures)
            finished_requests.extend(recently_finished)

finished_requests.extend(await asyncio.gather(*response_futures))

This way we can more accurately track the progress of the job.

for request_json in tqdm(file_contents, total=total_lines, desc="Processing requests"):
request = BatchRequestInput.model_validate_json(request_json.strip())
response_futures.append(run_request_with_semaphore(request))

responses = await tqdm.gather(*response_futures, desc="Gathering responses")
return responses

async def main(args):
if args.served_model_name is not None:
Expand All @@ -122,12 +145,7 @@ async def main(args):
)

# Submit all requests in the file to the engine "concurrently".
response_futures: List[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
request = BatchRequestInput.model_validate_json(request_json)
response_futures.append(run_request(openai_serving_chat, request))

responses = await asyncio.gather(*response_futures)
responses = await process_requests(args.input_file, openai_serving_chat)

output_buffer = StringIO()
for response in responses:
Expand Down
Loading