Skip to content

Commit ea61266

Browse files
simon-moNikolaBorisov
authored andcommitted
Support Batch Completion in Server (vllm-project#2529)
1 parent fe30a77 commit ea61266

File tree

2 files changed

+215
-105
lines changed

2 files changed

+215
-105
lines changed

tests/entrypoints/test_openai_server.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import time
1+
import os
22
import subprocess
3+
import time
34

45
import sys
56
import pytest
@@ -17,8 +18,11 @@
1718
class ServerRunner:
1819

1920
def __init__(self, args):
21+
env = os.environ.copy()
22+
env["PYTHONUNBUFFERED"] = "1"
2023
self.proc = subprocess.Popen(
2124
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
25+
env=env,
2226
stdout=sys.stdout,
2327
stderr=sys.stderr,
2428
)
@@ -58,7 +62,8 @@ def server():
5862
"--dtype",
5963
"bfloat16", # use half precision for speed and memory savings in CI environment
6064
"--max-model-len",
61-
"8192"
65+
"8192",
66+
"--enforce-eager",
6267
])
6368
ray.get(server_runner.ready.remote())
6469
yield server_runner
@@ -199,5 +204,51 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
199204
assert "".join(chunks) == output
200205

201206

207+
async def test_batch_completions(server, client: openai.AsyncOpenAI):
208+
# test simple list
209+
batch = await client.completions.create(
210+
model=MODEL_NAME,
211+
prompt=["Hello, my name is", "Hello, my name is"],
212+
max_tokens=5,
213+
temperature=0.0,
214+
)
215+
assert len(batch.choices) == 2
216+
assert batch.choices[0].text == batch.choices[1].text
217+
218+
# test n = 2
219+
batch = await client.completions.create(
220+
model=MODEL_NAME,
221+
prompt=["Hello, my name is", "Hello, my name is"],
222+
n=2,
223+
max_tokens=5,
224+
temperature=0.0,
225+
extra_body=dict(
226+
# NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client.
227+
use_beam_search=True),
228+
)
229+
assert len(batch.choices) == 4
230+
assert batch.choices[0].text != batch.choices[
231+
1].text, "beam search should be different"
232+
assert batch.choices[0].text == batch.choices[
233+
2].text, "two copies of the same prompt should be the same"
234+
assert batch.choices[1].text == batch.choices[
235+
3].text, "two copies of the same prompt should be the same"
236+
237+
# test streaming
238+
batch = await client.completions.create(
239+
model=MODEL_NAME,
240+
prompt=["Hello, my name is", "Hello, my name is"],
241+
max_tokens=5,
242+
temperature=0.0,
243+
stream=True,
244+
)
245+
texts = [""] * 2
246+
async for chunk in batch:
247+
assert len(chunk.choices) == 1
248+
choice = chunk.choices[0]
249+
texts[choice.index] += choice.text
250+
assert texts[0] == texts[1]
251+
252+
202253
if __name__ == "__main__":
203254
pytest.main([__file__])

0 commit comments

Comments
 (0)