|
1 |
| -import time |
| 1 | +import os |
2 | 2 | import subprocess
|
| 3 | +import time |
3 | 4 |
|
4 | 5 | import sys
|
5 | 6 | import pytest
|
|
17 | 18 | class ServerRunner:
|
18 | 19 |
|
19 | 20 | def __init__(self, args):
|
| 21 | + env = os.environ.copy() |
| 22 | + env["PYTHONUNBUFFERED"] = "1" |
20 | 23 | self.proc = subprocess.Popen(
|
21 | 24 | ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
|
| 25 | + env=env, |
22 | 26 | stdout=sys.stdout,
|
23 | 27 | stderr=sys.stderr,
|
24 | 28 | )
|
@@ -58,7 +62,8 @@ def server():
|
58 | 62 | "--dtype",
|
59 | 63 | "bfloat16", # use half precision for speed and memory savings in CI environment
|
60 | 64 | "--max-model-len",
|
61 |
| - "8192" |
| 65 | + "8192", |
| 66 | + "--enforce-eager", |
62 | 67 | ])
|
63 | 68 | ray.get(server_runner.ready.remote())
|
64 | 69 | yield server_runner
|
@@ -199,5 +204,51 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
|
199 | 204 | assert "".join(chunks) == output
|
200 | 205 |
|
201 | 206 |
|
| 207 | +async def test_batch_completions(server, client: openai.AsyncOpenAI): |
| 208 | + # test simple list |
| 209 | + batch = await client.completions.create( |
| 210 | + model=MODEL_NAME, |
| 211 | + prompt=["Hello, my name is", "Hello, my name is"], |
| 212 | + max_tokens=5, |
| 213 | + temperature=0.0, |
| 214 | + ) |
| 215 | + assert len(batch.choices) == 2 |
| 216 | + assert batch.choices[0].text == batch.choices[1].text |
| 217 | + |
| 218 | + # test n = 2 |
| 219 | + batch = await client.completions.create( |
| 220 | + model=MODEL_NAME, |
| 221 | + prompt=["Hello, my name is", "Hello, my name is"], |
| 222 | + n=2, |
| 223 | + max_tokens=5, |
| 224 | + temperature=0.0, |
| 225 | + extra_body=dict( |
| 226 | + # NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client. |
| 227 | + use_beam_search=True), |
| 228 | + ) |
| 229 | + assert len(batch.choices) == 4 |
| 230 | + assert batch.choices[0].text != batch.choices[ |
| 231 | + 1].text, "beam search should be different" |
| 232 | + assert batch.choices[0].text == batch.choices[ |
| 233 | + 2].text, "two copies of the same prompt should be the same" |
| 234 | + assert batch.choices[1].text == batch.choices[ |
| 235 | + 3].text, "two copies of the same prompt should be the same" |
| 236 | + |
| 237 | + # test streaming |
| 238 | + batch = await client.completions.create( |
| 239 | + model=MODEL_NAME, |
| 240 | + prompt=["Hello, my name is", "Hello, my name is"], |
| 241 | + max_tokens=5, |
| 242 | + temperature=0.0, |
| 243 | + stream=True, |
| 244 | + ) |
| 245 | + texts = [""] * 2 |
| 246 | + async for chunk in batch: |
| 247 | + assert len(chunk.choices) == 1 |
| 248 | + choice = chunk.choices[0] |
| 249 | + texts[choice.index] += choice.text |
| 250 | + assert texts[0] == texts[1] |
| 251 | + |
| 252 | + |
202 | 253 | if __name__ == "__main__":
|
203 | 254 | pytest.main([__file__])
|
0 commit comments