-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[Frontend] Add /v1/audio/transcriptions
OpenAI API endpoint
#12909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e4aee90
025084b
e37ea61
2edbd0e
150a32a
c3511a1
806a07e
8a7d4d8
4ac9f43
7ef7a91
49165ad
1ecb9b4
d33647f
5e27adb
e418f26
133e783
7ef6932
70f83aa
3e6307d
4290464
9ceb623
3fea4d6
19c6ccb
cad9b30
b2323cc
7ffa0e0
74027ce
aeeb141
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from openai import OpenAI | ||
|
||
from vllm.assets.audio import AudioAsset | ||
|
||
mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path() | ||
winning_call = AudioAsset('winning_call').get_local_path() | ||
|
||
# Modify OpenAI's API key and API base to use vLLM's API server. | ||
openai_api_key = "EMPTY" | ||
openai_api_base = "http://localhost:8000/v1" | ||
client = OpenAI( | ||
api_key=openai_api_key, | ||
base_url=openai_api_base, | ||
) | ||
with open(str(mary_had_lamb), "rb") as f: | ||
transcription = client.audio.transcriptions.create( | ||
file=f, | ||
model="openai/whisper-large-v3", | ||
language="en", | ||
response_format="text", | ||
temperature=0.0) | ||
print("transcription result:", transcription) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,12 +8,11 @@ py-cpuinfo | |
transformers >= 4.48.2 # Required for Bamba model and Transformers backend. | ||
tokenizers >= 0.19.1 # Required for Llama 3. | ||
protobuf # Required by LlamaTokenizer. | ||
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' | ||
fastapi >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' | ||
fastapi[standard] >= 0.107.0, < 0.113.0; python_version < '3.9' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By including the |
||
fastapi[standard] >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' | ||
aiohttp | ||
openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) | ||
uvicorn[standard] | ||
pydantic >= 2.9 # Required for fastapi >= 0.113.0 | ||
pydantic >= 2.9 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also required by |
||
prometheus_client >= 0.18.0 | ||
pillow # Required for image processing | ||
prometheus-fastapi-instrumentator >= 7.0.0 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
""" | ||
Evaluate Transcription API correctness by computing Word Error Rate (WER) | ||
on a given ASR dataset. When provided, it will also compare the WER against | ||
a baseline. | ||
This simulates real work usage of the API and makes sure that the frontend and | ||
AsyncLLMEngine are working correctly. | ||
""" | ||
import asyncio | ||
import io | ||
import time | ||
from statistics import mean, median | ||
from typing import List | ||
|
||
import librosa | ||
import pytest | ||
import soundfile | ||
import torch | ||
from datasets import load_dataset | ||
from evaluate import load | ||
from transformers import AutoTokenizer | ||
|
||
from ....utils import RemoteOpenAIServer | ||
|
||
|
||
def to_bytes(y, sr): | ||
buffer = io.BytesIO() | ||
soundfile.write(buffer, y, sr, format="WAV") | ||
buffer.seek(0) | ||
return buffer | ||
|
||
|
||
async def transcribe_audio(client, tokenizer, y, sr): | ||
# Send loaded audio directly instead of loading from disk, | ||
# dont account for that time though | ||
with to_bytes(y, sr) as f: | ||
start_time = time.perf_counter() | ||
transcription = await client.audio.transcriptions.create( | ||
file=f, | ||
model=tokenizer.name_or_path, | ||
language="en", | ||
temperature=0.0, | ||
) | ||
end_time = time.perf_counter() | ||
# NOTE there's no streaming in transcriptions, can't measure ttft | ||
latency = end_time - start_time | ||
num_output_tokens = len( | ||
tokenizer(transcription.text, add_special_tokens=False).input_ids) | ||
return latency, num_output_tokens, transcription.text | ||
|
||
|
||
async def bound_transcribe(model_name, sem, client, audio, reference): | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
# Use semaphore to limit concurrent requests. | ||
async with sem: | ||
result = await transcribe_audio(client, tokenizer, *audio) | ||
# Normalize *english* output/reference for evaluation. | ||
out = tokenizer.normalize(result[2]) | ||
ref = tokenizer.normalize(reference) | ||
return result[:2] + (out, ref) | ||
|
||
|
||
async def process_dataset(model, client, data, concurrent_request): | ||
sem = asyncio.Semaphore(concurrent_request) | ||
|
||
# Warmup call as the first `librosa.load` server-side is quite slow. | ||
audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"] | ||
_ = await bound_transcribe(model, sem, client, (audio, sr), "") | ||
|
||
tasks: List[asyncio.Task] = [] | ||
for sample in data: | ||
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] | ||
task = asyncio.create_task( | ||
bound_transcribe(model, sem, client, (audio, sr), sample["text"])) | ||
tasks.append(task) | ||
return await asyncio.gather(*tasks) | ||
|
||
|
||
def print_performance_metrics(results, total_time): | ||
latencies = [res[0] for res in results] | ||
total_tokens = sum([res[1] for res in results]) | ||
|
||
total = len(results) | ||
print(f"Total Requests: {total}") | ||
print(f"Successful Requests: {len(latencies)}") | ||
print(f"Average Latency: {mean(latencies):.4f} seconds") | ||
print(f"Median Latency: {median(latencies):.4f} seconds") | ||
perc = sorted(latencies)[int(len(latencies) * 0.95) - 1] | ||
print(f"95th Percentile Latency: {perc:.4f} seconds") | ||
# Throughput | ||
req_throughput = len(latencies) / total_time | ||
print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s") | ||
throughput = total_tokens / total_time | ||
print(f"Estimated Throughput: {throughput:.2f} tok/s") | ||
|
||
|
||
def add_duration(sample): | ||
y, sr = sample['audio']["array"], sample['audio']["sampling_rate"] | ||
sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000 | ||
return sample | ||
|
||
|
||
def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs): | ||
## Load and filter the dataset | ||
dataset = load_dataset(dataset_repo, split=split, **hf_kwargs) | ||
if 'duration_ms' not in dataset[0]: | ||
# compute duration to filter | ||
dataset = dataset.map(add_duration) | ||
|
||
# Whisper max supported duration | ||
dataset = dataset.filter(lambda example: example['duration_ms'] < 30000) | ||
return dataset | ||
|
||
|
||
def run_evaluation(model: str, | ||
client, | ||
dataset, | ||
max_concurrent_reqs: int, | ||
n_examples: int = -1, | ||
print_metrics: bool = True): | ||
if n_examples > 0: | ||
dataset = dataset.select(range(n_examples)) | ||
start = time.perf_counter() | ||
results = asyncio.run( | ||
process_dataset(model, client, dataset, max_concurrent_reqs)) | ||
end = time.perf_counter() | ||
total_time = end - start | ||
print(f"Total Test Time: {total_time:.4f} seconds") | ||
if print_metrics: | ||
print_performance_metrics(results, total_time) | ||
# Compute WER | ||
predictions = [res[2] for res in results] | ||
references = [res[3] for res in results] | ||
wer = load("wer") | ||
wer_score = 100 * wer.compute(references=references, | ||
predictions=predictions) | ||
print("WER:", wer_score) | ||
return wer_score | ||
|
||
|
||
# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo".. | ||
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"]) | ||
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice. | ||
@pytest.mark.parametrize( | ||
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]) | ||
# NOTE: Expected WER measured with equivalent hf.transformers args: | ||
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered. | ||
@pytest.mark.parametrize("expected_wer", [12.744980]) | ||
def test_wer_correctness(model_name, | ||
dataset_repo, | ||
expected_wer, | ||
n_examples=-1, | ||
max_concurrent_request=None): | ||
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server: | ||
dataset = load_hf_dataset(dataset_repo) | ||
|
||
if not max_concurrent_request: | ||
# No max concurrency | ||
max_concurrent_request = n_examples if n_examples > 0\ | ||
else len(dataset) | ||
|
||
client = remote_server.get_async_client() | ||
wer = run_evaluation(model_name, client, dataset, | ||
max_concurrent_request, n_examples) | ||
if expected_wer: | ||
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we should put the instructions to host the model with the server so this client code will work "out of the box"