Skip to content

Commit 929da81

Browse files
NickLuccheshreyankg
authored andcommitted
[Frontend][Docs] Transcription API streaming (vllm-project#13301)
Signed-off-by: NickLucche <nlucches@redhat.com>
1 parent b1c3871 commit 929da81

File tree

5 files changed

+297
-26
lines changed

5 files changed

+297
-26
lines changed

docs/source/serving/openai_compatible_server.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,10 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
379379
Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
380380
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
381381

382+
:::{note}
383+
To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`.
384+
:::
385+
382386
<!-- TODO: api enforced limits + uploading audios -->
383387

384388
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
import asyncio
3+
import json
4+
5+
import httpx
26
from openai import OpenAI
37

48
from vllm.assets.audio import AudioAsset
@@ -13,11 +17,50 @@
1317
api_key=openai_api_key,
1418
base_url=openai_api_base,
1519
)
16-
with open(str(mary_had_lamb), "rb") as f:
17-
transcription = client.audio.transcriptions.create(
18-
file=f,
19-
model="openai/whisper-large-v3",
20-
language="en",
21-
response_format="text",
22-
temperature=0.0)
23-
print("transcription result:", transcription)
20+
21+
22+
def sync_openai():
23+
with open(str(mary_had_lamb), "rb") as f:
24+
transcription = client.audio.transcriptions.create(
25+
file=f,
26+
model="openai/whisper-small",
27+
language="en",
28+
response_format="json",
29+
temperature=0.0)
30+
print("transcription result:", transcription.text)
31+
32+
33+
sync_openai()
34+
35+
36+
# OpenAI Transcription API client does not support streaming.
37+
async def stream_openai_response():
38+
data = {
39+
"language": "en",
40+
'stream': True,
41+
"model": "openai/whisper-large-v3",
42+
}
43+
url = openai_api_base + "/audio/transcriptions"
44+
print("transcription result:", end=' ')
45+
async with httpx.AsyncClient() as client:
46+
with open(str(winning_call), "rb") as f:
47+
async with client.stream('POST', url, files={'file': f},
48+
data=data) as response:
49+
async for line in response.aiter_lines():
50+
# Each line is a JSON object prefixed with 'data: '
51+
if line:
52+
if line.startswith('data: '):
53+
line = line[len('data: '):]
54+
# Last chunk, stream ends
55+
if line.strip() == '[DONE]':
56+
break
57+
# Parse the JSON response
58+
chunk = json.loads(line)
59+
# Extract and print the content
60+
content = chunk['choices'][0].get('delta',
61+
{}).get('content')
62+
print(content, end='')
63+
64+
65+
# Run the asynchronous function
66+
asyncio.run(stream_openai_response())

tests/entrypoints/openai/test_transcription_validation.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
# imports for guided decoding tests
44
import io
55
import json
6+
from unittest.mock import patch
67

78
import librosa
89
import numpy as np
910
import openai
1011
import pytest
1112
import soundfile as sf
13+
from openai._base_client import AsyncAPIClient
1214

1315
from vllm.assets.audio import AudioAsset
1416

@@ -120,3 +122,73 @@ async def test_completion_endpoints():
120122
res = await client.completions.create(model=model_name, prompt="Hello")
121123
assert res.code == 400
122124
assert res.message == "The model does not support Completions API"
125+
126+
127+
@pytest.mark.asyncio
128+
async def test_streaming_response(winning_call):
129+
model_name = "openai/whisper-small"
130+
server_args = ["--enforce-eager"]
131+
transcription = ""
132+
with RemoteOpenAIServer(model_name, server_args) as remote_server:
133+
client = remote_server.get_async_client()
134+
res_no_stream = await client.audio.transcriptions.create(
135+
model=model_name,
136+
file=winning_call,
137+
response_format="json",
138+
language="en",
139+
temperature=0.0)
140+
# Unfortunately this only works when the openai client is patched
141+
# to use streaming mode, not exposed in the transcription api.
142+
original_post = AsyncAPIClient.post
143+
144+
async def post_with_stream(*args, **kwargs):
145+
kwargs['stream'] = True
146+
return await original_post(*args, **kwargs)
147+
148+
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
149+
client = remote_server.get_async_client()
150+
res = await client.audio.transcriptions.create(
151+
model=model_name,
152+
file=winning_call,
153+
language="en",
154+
temperature=0.0,
155+
extra_body=dict(stream=True))
156+
# Reconstruct from chunks and validate
157+
async for chunk in res:
158+
# just a chunk
159+
text = chunk.choices[0]['delta']['content']
160+
transcription += text
161+
162+
assert transcription == res_no_stream.text
163+
164+
165+
@pytest.mark.asyncio
166+
async def test_stream_options(winning_call):
167+
model_name = "openai/whisper-small"
168+
server_args = ["--enforce-eager"]
169+
with RemoteOpenAIServer(model_name, server_args) as remote_server:
170+
original_post = AsyncAPIClient.post
171+
172+
async def post_with_stream(*args, **kwargs):
173+
kwargs['stream'] = True
174+
return await original_post(*args, **kwargs)
175+
176+
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
177+
client = remote_server.get_async_client()
178+
res = await client.audio.transcriptions.create(
179+
model=model_name,
180+
file=winning_call,
181+
language="en",
182+
temperature=0.0,
183+
extra_body=dict(stream=True,
184+
stream_include_usage=True,
185+
stream_continuous_usage_stats=True))
186+
final = False
187+
continuous = True
188+
async for chunk in res:
189+
if not len(chunk.choices):
190+
# final usage sent
191+
final = True
192+
else:
193+
continuous = continuous and hasattr(chunk, 'usage')
194+
assert final and continuous

vllm/entrypoints/openai/protocol.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1285,6 +1285,21 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
12851285
usage: Optional[UsageInfo] = Field(default=None)
12861286

12871287

1288+
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
1289+
delta: DeltaMessage
1290+
finish_reason: Optional[str] = None
1291+
stop_reason: Optional[Union[int, str]] = None
1292+
1293+
1294+
class TranscriptionStreamResponse(OpenAIBaseModel):
1295+
id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
1296+
object: Literal["transcription.chunk"] = "transcription.chunk"
1297+
created: int = Field(default_factory=lambda: int(time.time()))
1298+
model: str
1299+
choices: list[TranscriptionResponseStreamChoice]
1300+
usage: Optional[UsageInfo] = Field(default=None)
1301+
1302+
12881303
class BatchRequestInput(OpenAIBaseModel):
12891304
"""
12901305
The per-line object of the batch input file.
@@ -1510,6 +1525,15 @@ class TranscriptionRequest(OpenAIBaseModel):
15101525
timestamps incurs additional latency.
15111526
"""
15121527

1528+
stream: Optional[bool] = False
1529+
"""Custom field not present in the original OpenAI definition. When set,
1530+
it will enable output to be streamed in a similar fashion as the Chat
1531+
Completion endpoint.
1532+
"""
1533+
# Flattened stream option to simplify form data.
1534+
stream_include_usage: Optional[bool] = False
1535+
stream_continuous_usage_stats: Optional[bool] = False
1536+
15131537
# Default sampling parameters for transcription requests.
15141538
_DEFAULT_SAMPLING_PARAMS: dict = {
15151539
"temperature": 0,
@@ -1530,7 +1554,21 @@ def to_sampling_params(
15301554
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
15311555

15321556
return SamplingParams.from_optional(temperature=temperature,
1533-
max_tokens=max_tokens)
1557+
max_tokens=max_tokens,
1558+
output_kind=RequestOutputKind.DELTA
1559+
if self.stream \
1560+
else RequestOutputKind.FINAL_ONLY)
1561+
1562+
@model_validator(mode="before")
1563+
@classmethod
1564+
def validate_stream_options(cls, data):
1565+
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
1566+
stream = data.get("stream", False)
1567+
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
1568+
raise ValueError(
1569+
"Stream options can only be defined when `stream=True`.")
1570+
1571+
return data
15341572

15351573

15361574
# Transcription response objects

0 commit comments

Comments
 (0)