Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d0f4998
Add StreamOptions Class
Etelis Jun 6, 2024
8fee154
- Modified the initial response generation to conditionally include `…
Etelis Jun 6, 2024
03c309e
Tests added for the following scenarios:
Etelis Jun 6, 2024
a592cd8
Fixed issues related to formatting.
Etelis Jun 6, 2024
a7319e1
FIxing testing file.
Etelis Jun 7, 2024
4d33e28
Ran formater.sh to fix formatting issues.
Etelis Jun 7, 2024
0fc7392
Merge branch 'main' into Add-support-for-stream_options-in-Completion…
Etelis Jun 7, 2024
d6ac891
Fixed formating.
Etelis Jun 7, 2024
1473a7f
Tests fixture:
Etelis Jun 8, 2024
3d987d0
Fixing testing.
Etelis Jun 8, 2024
d25d055
Tests Related:
Etelis Jun 9, 2024
2e96a59
Fixing Testing.
Etelis Jun 9, 2024
5c1c0c6
Fortmater fixure.
Etelis Jun 9, 2024
23aa903
Fixed testing.
Etelis Jun 9, 2024
776c73b
Reorder tests in server test file to resolve conflicts
Etelis Jun 9, 2024
4c828f4
Reorder tests in server test file to resolve conflicts
Etelis Jun 9, 2024
86d6d7a
feat(tests): Update test cases for chat and completion streaming options
Etelis Jun 9, 2024
fb6ae02
Formatting related issues resolved.
Etelis Jun 9, 2024
22cc139
fix(tests): Update streaming tests for correct handling of 'usage' at…
Etelis Jun 9, 2024
0d9b6b1
Fix: Incorrect indentation causing empty `choices` entries
Etelis Jun 10, 2024
184e7d9
Fix `serving_chat.py`:
Etelis Jun 10, 2024
e713488
Running format.sh.
Etelis Jun 10, 2024
fd35380
Merge branch 'main' into Add-support-for-stream_options-in-Completion…
Etelis Jun 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 132 additions & 104 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,6 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
temperature=0.0,
)
single_output = single_completion.choices[0].text
single_usage = single_completion.usage

stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
Expand All @@ -495,7 +493,6 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == "length"
assert chunk.choices[0].text
assert chunk.usage == single_usage
assert "".join(chunks) == single_output


Expand Down Expand Up @@ -550,6 +547,138 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI,
assert "".join(chunks) == output


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
)
async def test_chat_completion_stream_options(server,
client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "What is the capital of France?"
}]

# Test stream=True, stream_options={"include_usage": False}
stream = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
stream=True,
stream_options={"include_usage": False})
async for chunk in stream:
assert chunk.usage is None

# Test stream=True, stream_options={"include_usage": True}
stream = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
stream=True,
stream_options={"include_usage": True})

async for chunk in stream:
if chunk.choices[0].finish_reason is None:
assert chunk.usage is None
else:
assert chunk.usage is None
final_chunk = await stream.__anext__()
assert final_chunk.usage is not None
assert final_chunk.usage.prompt_tokens > 0
assert final_chunk.usage.completion_tokens > 0
assert final_chunk.usage.total_tokens == (
final_chunk.usage.prompt_tokens +
final_chunk.usage.completion_tokens)
assert final_chunk.choices == []

# Test stream=False, stream_options={"include_usage": None}
with pytest.raises(BadRequestError):
await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
stream=False,
stream_options={"include_usage": None})

# Test stream=False, stream_options={"include_usage": True}
with pytest.raises(BadRequestError):
await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
stream=False,
stream_options={"include_usage": True})


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
)
async def test_completion_stream_options(server, client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is the capital of France?"

# Test stream=True, stream_options={"include_usage": False}
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={"include_usage": False})
async for chunk in stream:
assert chunk.usage is None

# Test stream=True, stream_options={"include_usage": True}
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={"include_usage": True})
async for chunk in stream:
if chunk.choices[0].finish_reason is None:
assert chunk.usage is None
else:
assert chunk.usage is None
final_chunk = await stream.__anext__()
assert final_chunk.usage is not None
assert final_chunk.usage.prompt_tokens > 0
assert final_chunk.usage.completion_tokens > 0
assert final_chunk.usage.total_tokens == (
final_chunk.usage.prompt_tokens +
final_chunk.usage.completion_tokens)
assert final_chunk.choices == []

# Test stream=False, stream_options={"include_usage": None}
with pytest.raises(BadRequestError):
await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": None})

# Test stream=False, stream_options={"include_usage": True}
with pytest.raises(BadRequestError):
await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": True})


@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
Expand Down Expand Up @@ -1343,106 +1472,5 @@ async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
assert embeddings.usage.total_tokens == 17


@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_stream_options(server, client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is the capital of France?"

# Test stream=True, stream_options=None
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options=None,
)
chunks = []
async for chunk in stream:
chunks.append(chunk.choices[0].text)
assert len(chunks) > 0
assert "usage" not in chunk

# Test stream=True, stream_options={"include_usage": False}
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={"include_usage": False},
)
chunks = []
async for chunk in stream:
chunks.append(chunk.choices[0].text)
assert len(chunks) > 0
assert "usage" not in chunk

# Test stream=True, stream_options={"include_usage": True}
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={"include_usage": True},
)
chunks = []
finish_reason_count = 0
async for chunk in stream:
if chunk.choices[0].finish_reason is None:
assert chunk.usage is None
chunks.append(chunk.choices[0].text)
else:
assert chunk.usage is None
finish_reason_count += 1

# The last message should have usage and no choices
last_message = await stream.__anext__()
assert last_message.usage is not None
assert last_message.usage.prompt_tokens > 0
assert last_message.usage.completion_tokens > 0
assert last_message.usage.total_tokens == (
last_message.usage.prompt_tokens +
last_message.usage.completion_tokens)
assert last_message.choices == []

# Test stream=False, stream_options={"include_usage": None}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": None},
)

# Test stream=False, stream_options={"include_usage": False}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": False},
)

# Test stream=False, stream_options={"include_usage": True}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": True},
)


if __name__ == "__main__":
pytest.main([__file__])
9 changes: 9 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ class CompletionRequest(OpenAIBaseModel):
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
Expand Down Expand Up @@ -482,6 +483,14 @@ def check_logprobs(cls, data):
" in the interval [0, 5]."))
return data

@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"Stream options can only be defined when stream is True.")
return data


class EmbeddingRequest(BaseModel):
# Ordered by official OpenAI API documentation
Expand Down
35 changes: 17 additions & 18 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,25 +441,24 @@ async def chat_completion_stream_generator(
yield f"data: {data}\n\n"
finish_reason_sent[i] = True

if (request.stream_options
and request.stream_options.include_usage):
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens +
previous_num_tokens[i],
)
if (request.stream_options
and request.stream_options.include_usage):
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens + previous_num_tokens[i],
)

final_usage_chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"
final_usage_chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"

except ValueError as e:
# TODO: Use a vllm-specific Validation Error
Expand Down
26 changes: 22 additions & 4 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ async def completion_stream_generator(
)
else:
final_usage = None
response_json = CompletionStreamResponse(

chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
Expand All @@ -276,10 +277,27 @@ async def completion_stream_generator(
finish_reason=finish_reason,
stop_reason=stop_reason,
)
],
usage=final_usage,
).model_dump_json(exclude_unset=True)
])
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None

response_json = chunk.model_dump_json(exclude_unset=True)
yield f"data: {response_json}\n\n"

if (request.stream_options
and request.stream_options.include_usage):
final_usage_chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[],
usage=final_usage,
)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"

except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
Expand Down