Skip to content

Commit a9b113d

Browse files
Etelisjimpang
authored andcommitted
[Feature][Frontend]: Continued stream_options implementation also in CompletionRequest (vllm-project#5319)
1 parent 5616ee3 commit a9b113d

File tree

4 files changed

+180
-126
lines changed

4 files changed

+180
-126
lines changed

tests/entrypoints/test_openai_server.py

Lines changed: 132 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -478,8 +478,6 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
478478
temperature=0.0,
479479
)
480480
single_output = single_completion.choices[0].text
481-
single_usage = single_completion.usage
482-
483481
stream = await client.completions.create(model=model_name,
484482
prompt=prompt,
485483
max_tokens=5,
@@ -495,7 +493,6 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
495493
assert finish_reason_count == 1
496494
assert chunk.choices[0].finish_reason == "length"
497495
assert chunk.choices[0].text
498-
assert chunk.usage == single_usage
499496
assert "".join(chunks) == single_output
500497

501498

@@ -550,6 +547,138 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI,
550547
assert "".join(chunks) == output
551548

552549

550+
@pytest.mark.asyncio
551+
@pytest.mark.parametrize(
552+
"model_name",
553+
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
554+
)
555+
async def test_chat_completion_stream_options(server,
556+
client: openai.AsyncOpenAI,
557+
model_name: str):
558+
messages = [{
559+
"role": "system",
560+
"content": "You are a helpful assistant."
561+
}, {
562+
"role": "user",
563+
"content": "What is the capital of France?"
564+
}]
565+
566+
# Test stream=True, stream_options={"include_usage": False}
567+
stream = await client.chat.completions.create(
568+
model=model_name,
569+
messages=messages,
570+
max_tokens=10,
571+
temperature=0.0,
572+
stream=True,
573+
stream_options={"include_usage": False})
574+
async for chunk in stream:
575+
assert chunk.usage is None
576+
577+
# Test stream=True, stream_options={"include_usage": True}
578+
stream = await client.chat.completions.create(
579+
model=model_name,
580+
messages=messages,
581+
max_tokens=10,
582+
temperature=0.0,
583+
stream=True,
584+
stream_options={"include_usage": True})
585+
586+
async for chunk in stream:
587+
if chunk.choices[0].finish_reason is None:
588+
assert chunk.usage is None
589+
else:
590+
assert chunk.usage is None
591+
final_chunk = await stream.__anext__()
592+
assert final_chunk.usage is not None
593+
assert final_chunk.usage.prompt_tokens > 0
594+
assert final_chunk.usage.completion_tokens > 0
595+
assert final_chunk.usage.total_tokens == (
596+
final_chunk.usage.prompt_tokens +
597+
final_chunk.usage.completion_tokens)
598+
assert final_chunk.choices == []
599+
600+
# Test stream=False, stream_options={"include_usage": None}
601+
with pytest.raises(BadRequestError):
602+
await client.chat.completions.create(
603+
model=model_name,
604+
messages=messages,
605+
max_tokens=10,
606+
temperature=0.0,
607+
stream=False,
608+
stream_options={"include_usage": None})
609+
610+
# Test stream=False, stream_options={"include_usage": True}
611+
with pytest.raises(BadRequestError):
612+
await client.chat.completions.create(
613+
model=model_name,
614+
messages=messages,
615+
max_tokens=10,
616+
temperature=0.0,
617+
stream=False,
618+
stream_options={"include_usage": True})
619+
620+
621+
@pytest.mark.asyncio
622+
@pytest.mark.parametrize(
623+
"model_name",
624+
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
625+
)
626+
async def test_completion_stream_options(server, client: openai.AsyncOpenAI,
627+
model_name: str):
628+
prompt = "What is the capital of France?"
629+
630+
# Test stream=True, stream_options={"include_usage": False}
631+
stream = await client.completions.create(
632+
model=model_name,
633+
prompt=prompt,
634+
max_tokens=5,
635+
temperature=0.0,
636+
stream=True,
637+
stream_options={"include_usage": False})
638+
async for chunk in stream:
639+
assert chunk.usage is None
640+
641+
# Test stream=True, stream_options={"include_usage": True}
642+
stream = await client.completions.create(
643+
model=model_name,
644+
prompt=prompt,
645+
max_tokens=5,
646+
temperature=0.0,
647+
stream=True,
648+
stream_options={"include_usage": True})
649+
async for chunk in stream:
650+
if chunk.choices[0].finish_reason is None:
651+
assert chunk.usage is None
652+
else:
653+
assert chunk.usage is None
654+
final_chunk = await stream.__anext__()
655+
assert final_chunk.usage is not None
656+
assert final_chunk.usage.prompt_tokens > 0
657+
assert final_chunk.usage.completion_tokens > 0
658+
assert final_chunk.usage.total_tokens == (
659+
final_chunk.usage.prompt_tokens +
660+
final_chunk.usage.completion_tokens)
661+
assert final_chunk.choices == []
662+
663+
# Test stream=False, stream_options={"include_usage": None}
664+
with pytest.raises(BadRequestError):
665+
await client.completions.create(model=model_name,
666+
prompt=prompt,
667+
max_tokens=5,
668+
temperature=0.0,
669+
stream=False,
670+
stream_options={"include_usage": None})
671+
672+
# Test stream=False, stream_options={"include_usage": True}
673+
with pytest.raises(BadRequestError):
674+
await client.completions.create(model=model_name,
675+
prompt=prompt,
676+
max_tokens=5,
677+
temperature=0.0,
678+
stream=False,
679+
stream_options={"include_usage": True})
680+
681+
553682
@pytest.mark.asyncio
554683
@pytest.mark.parametrize(
555684
# just test 1 lora hereafter
@@ -1343,106 +1472,5 @@ async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
13431472
assert embeddings.usage.total_tokens == 17
13441473

13451474

1346-
@pytest.mark.parametrize(
1347-
"model_name",
1348-
[MODEL_NAME],
1349-
)
1350-
async def test_stream_options(server, client: openai.AsyncOpenAI,
1351-
model_name: str):
1352-
prompt = "What is the capital of France?"
1353-
1354-
# Test stream=True, stream_options=None
1355-
stream = await client.completions.create(
1356-
model=model_name,
1357-
prompt=prompt,
1358-
max_tokens=5,
1359-
temperature=0.0,
1360-
stream=True,
1361-
stream_options=None,
1362-
)
1363-
chunks = []
1364-
async for chunk in stream:
1365-
chunks.append(chunk.choices[0].text)
1366-
assert len(chunks) > 0
1367-
assert "usage" not in chunk
1368-
1369-
# Test stream=True, stream_options={"include_usage": False}
1370-
stream = await client.completions.create(
1371-
model=model_name,
1372-
prompt=prompt,
1373-
max_tokens=5,
1374-
temperature=0.0,
1375-
stream=True,
1376-
stream_options={"include_usage": False},
1377-
)
1378-
chunks = []
1379-
async for chunk in stream:
1380-
chunks.append(chunk.choices[0].text)
1381-
assert len(chunks) > 0
1382-
assert "usage" not in chunk
1383-
1384-
# Test stream=True, stream_options={"include_usage": True}
1385-
stream = await client.completions.create(
1386-
model=model_name,
1387-
prompt=prompt,
1388-
max_tokens=5,
1389-
temperature=0.0,
1390-
stream=True,
1391-
stream_options={"include_usage": True},
1392-
)
1393-
chunks = []
1394-
finish_reason_count = 0
1395-
async for chunk in stream:
1396-
if chunk.choices[0].finish_reason is None:
1397-
assert chunk.usage is None
1398-
chunks.append(chunk.choices[0].text)
1399-
else:
1400-
assert chunk.usage is None
1401-
finish_reason_count += 1
1402-
1403-
# The last message should have usage and no choices
1404-
last_message = await stream.__anext__()
1405-
assert last_message.usage is not None
1406-
assert last_message.usage.prompt_tokens > 0
1407-
assert last_message.usage.completion_tokens > 0
1408-
assert last_message.usage.total_tokens == (
1409-
last_message.usage.prompt_tokens +
1410-
last_message.usage.completion_tokens)
1411-
assert last_message.choices == []
1412-
1413-
# Test stream=False, stream_options={"include_usage": None}
1414-
with pytest.raises(BadRequestError):
1415-
await client.completions.create(
1416-
model=model_name,
1417-
prompt=prompt,
1418-
max_tokens=5,
1419-
temperature=0.0,
1420-
stream=False,
1421-
stream_options={"include_usage": None},
1422-
)
1423-
1424-
# Test stream=False, stream_options={"include_usage": False}
1425-
with pytest.raises(BadRequestError):
1426-
await client.completions.create(
1427-
model=model_name,
1428-
prompt=prompt,
1429-
max_tokens=5,
1430-
temperature=0.0,
1431-
stream=False,
1432-
stream_options={"include_usage": False},
1433-
)
1434-
1435-
# Test stream=False, stream_options={"include_usage": True}
1436-
with pytest.raises(BadRequestError):
1437-
await client.completions.create(
1438-
model=model_name,
1439-
prompt=prompt,
1440-
max_tokens=5,
1441-
temperature=0.0,
1442-
stream=False,
1443-
stream_options={"include_usage": True},
1444-
)
1445-
1446-
14471475
if __name__ == "__main__":
14481476
pytest.main([__file__])

vllm/entrypoints/openai/protocol.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ class CompletionRequest(OpenAIBaseModel):
346346
le=torch.iinfo(torch.long).max)
347347
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
348348
stream: Optional[bool] = False
349+
stream_options: Optional[StreamOptions] = None
349350
suffix: Optional[str] = None
350351
temperature: Optional[float] = 1.0
351352
top_p: Optional[float] = 1.0
@@ -482,6 +483,14 @@ def check_logprobs(cls, data):
482483
" in the interval [0, 5]."))
483484
return data
484485

486+
@model_validator(mode="before")
487+
@classmethod
488+
def validate_stream_options(cls, data):
489+
if data.get("stream_options") and not data.get("stream"):
490+
raise ValueError(
491+
"Stream options can only be defined when stream is True.")
492+
return data
493+
485494

486495
class EmbeddingRequest(BaseModel):
487496
# Ordered by official OpenAI API documentation

vllm/entrypoints/openai/serving_chat.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -441,25 +441,24 @@ async def chat_completion_stream_generator(
441441
yield f"data: {data}\n\n"
442442
finish_reason_sent[i] = True
443443

444-
if (request.stream_options
445-
and request.stream_options.include_usage):
446-
final_usage = UsageInfo(
447-
prompt_tokens=prompt_tokens,
448-
completion_tokens=previous_num_tokens[i],
449-
total_tokens=prompt_tokens +
450-
previous_num_tokens[i],
451-
)
444+
if (request.stream_options
445+
and request.stream_options.include_usage):
446+
final_usage = UsageInfo(
447+
prompt_tokens=prompt_tokens,
448+
completion_tokens=previous_num_tokens[i],
449+
total_tokens=prompt_tokens + previous_num_tokens[i],
450+
)
452451

453-
final_usage_chunk = ChatCompletionStreamResponse(
454-
id=request_id,
455-
object=chunk_object_type,
456-
created=created_time,
457-
choices=[],
458-
model=model_name,
459-
usage=final_usage)
460-
final_usage_data = (final_usage_chunk.model_dump_json(
461-
exclude_unset=True, exclude_none=True))
462-
yield f"data: {final_usage_data}\n\n"
452+
final_usage_chunk = ChatCompletionStreamResponse(
453+
id=request_id,
454+
object=chunk_object_type,
455+
created=created_time,
456+
choices=[],
457+
model=model_name,
458+
usage=final_usage)
459+
final_usage_data = (final_usage_chunk.model_dump_json(
460+
exclude_unset=True, exclude_none=True))
461+
yield f"data: {final_usage_data}\n\n"
463462

464463
except ValueError as e:
465464
# TODO: Use a vllm-specific Validation Error

vllm/entrypoints/openai/serving_completion.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ async def completion_stream_generator(
264264
)
265265
else:
266266
final_usage = None
267-
response_json = CompletionStreamResponse(
267+
268+
chunk = CompletionStreamResponse(
268269
id=request_id,
269270
created=created_time,
270271
model=model_name,
@@ -276,10 +277,27 @@ async def completion_stream_generator(
276277
finish_reason=finish_reason,
277278
stop_reason=stop_reason,
278279
)
279-
],
280-
usage=final_usage,
281-
).model_dump_json(exclude_unset=True)
280+
])
281+
if (request.stream_options
282+
and request.stream_options.include_usage):
283+
chunk.usage = None
284+
285+
response_json = chunk.model_dump_json(exclude_unset=True)
282286
yield f"data: {response_json}\n\n"
287+
288+
if (request.stream_options
289+
and request.stream_options.include_usage):
290+
final_usage_chunk = CompletionStreamResponse(
291+
id=request_id,
292+
created=created_time,
293+
model=model_name,
294+
choices=[],
295+
usage=final_usage,
296+
)
297+
final_usage_data = (final_usage_chunk.model_dump_json(
298+
exclude_unset=True, exclude_none=True))
299+
yield f"data: {final_usage_data}\n\n"
300+
283301
except ValueError as e:
284302
# TODO: Use a vllm-specific Validation Error
285303
data = self.create_streaming_error_response(str(e))

0 commit comments

Comments
 (0)