Skip to content

[Frontend][Docs] Transcription API streaming #13301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,10 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.

:::{note}
To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`.
:::

<!-- TODO: api enforced limits + uploading audios -->

Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
Expand Down
59 changes: 51 additions & 8 deletions examples/online_serving/openai_transcription_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json

import httpx
from openai import OpenAI

from vllm.assets.audio import AudioAsset
Expand All @@ -13,11 +17,50 @@
api_key=openai_api_key,
base_url=openai_api_base,
)
with open(str(mary_had_lamb), "rb") as f:
transcription = client.audio.transcriptions.create(
file=f,
model="openai/whisper-large-v3",
language="en",
response_format="text",
temperature=0.0)
print("transcription result:", transcription)


def sync_openai():
with open(str(mary_had_lamb), "rb") as f:
transcription = client.audio.transcriptions.create(
file=f,
model="openai/whisper-small",
language="en",
response_format="json",
temperature=0.0)
print("transcription result:", transcription.text)


sync_openai()


# OpenAI Transcription API client does not support streaming.
async def stream_openai_response():
data = {
"language": "en",
'stream': True,
"model": "openai/whisper-large-v3",
}
url = openai_api_base + "/audio/transcriptions"
print("transcription result:", end=' ')
async with httpx.AsyncClient() as client:
with open(str(winning_call), "rb") as f:
async with client.stream('POST', url, files={'file': f},
data=data) as response:
async for line in response.aiter_lines():
# Each line is a JSON object prefixed with 'data: '
if line:
if line.startswith('data: '):
line = line[len('data: '):]
# Last chunk, stream ends
if line.strip() == '[DONE]':
break
# Parse the JSON response
chunk = json.loads(line)
# Extract and print the content
content = chunk['choices'][0].get('delta',
{}).get('content')
print(content, end='')


# Run the asynchronous function
asyncio.run(stream_openai_response())
72 changes: 72 additions & 0 deletions tests/entrypoints/openai/test_transcription_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
# imports for guided decoding tests
import io
import json
from unittest.mock import patch

import librosa
import numpy as np
import openai
import pytest
import soundfile as sf
from openai._base_client import AsyncAPIClient

from vllm.assets.audio import AudioAsset

Expand Down Expand Up @@ -120,3 +122,73 @@ async def test_completion_endpoints():
res = await client.completions.create(model=model_name, prompt="Hello")
assert res.code == 400
assert res.message == "The model does not support Completions API"


@pytest.mark.asyncio
async def test_streaming_response(winning_call):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
transcription = ""
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
res_no_stream = await client.audio.transcriptions.create(
model=model_name,
file=winning_call,
response_format="json",
language="en",
temperature=0.0)
# Unfortunately this only works when the openai client is patched
# to use streaming mode, not exposed in the transcription api.
original_post = AsyncAPIClient.post

async def post_with_stream(*args, **kwargs):
kwargs['stream'] = True
return await original_post(*args, **kwargs)

with patch.object(AsyncAPIClient, "post", new=post_with_stream):
client = remote_server.get_async_client()
res = await client.audio.transcriptions.create(
model=model_name,
file=winning_call,
language="en",
temperature=0.0,
extra_body=dict(stream=True))
# Reconstruct from chunks and validate
async for chunk in res:
# just a chunk
text = chunk.choices[0]['delta']['content']
transcription += text

assert transcription == res_no_stream.text


@pytest.mark.asyncio
async def test_stream_options(winning_call):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
original_post = AsyncAPIClient.post

async def post_with_stream(*args, **kwargs):
kwargs['stream'] = True
return await original_post(*args, **kwargs)

with patch.object(AsyncAPIClient, "post", new=post_with_stream):
client = remote_server.get_async_client()
res = await client.audio.transcriptions.create(
model=model_name,
file=winning_call,
language="en",
temperature=0.0,
extra_body=dict(stream=True,
stream_include_usage=True,
stream_continuous_usage_stats=True))
final = False
continuous = True
async for chunk in res:
if not len(chunk.choices):
# final usage sent
final = True
else:
continuous = continuous and hasattr(chunk, 'usage')
assert final and continuous
40 changes: 39 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,21 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None)


class TranscriptionResponseStreamChoice(OpenAIBaseModel):
delta: DeltaMessage
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None


class TranscriptionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
object: Literal["transcription.chunk"] = "transcription.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[TranscriptionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)


class BatchRequestInput(OpenAIBaseModel):
"""
The per-line object of the batch input file.
Expand Down Expand Up @@ -1514,6 +1529,15 @@ class TranscriptionRequest(OpenAIBaseModel):
timestamps incurs additional latency.
"""

stream: Optional[bool] = False
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
Completion endpoint.
"""
# Flattened stream option to simplify form data.
stream_include_usage: Optional[bool] = False
stream_continuous_usage_stats: Optional[bool] = False

# Default sampling parameters for transcription requests.
_DEFAULT_SAMPLING_PARAMS: dict = {
"temperature": 0,
Expand All @@ -1534,7 +1558,21 @@ def to_sampling_params(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])

return SamplingParams.from_optional(temperature=temperature,
max_tokens=max_tokens)
max_tokens=max_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream \
else RequestOutputKind.FINAL_ONLY)

@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
stream = data.get("stream", False)
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
raise ValueError(
"Stream options can only be defined when `stream=True`.")

return data


# Transcription response objects
Expand Down
Loading