Skip to content

[Feat] Add audio benchmarking support /v1/audio/transcriptions #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@ dependencies = [
"sentencepiece",
"aiohttp",
"pydantic",
"matplotlib"
"matplotlib",
"librosa",
"soundfile",
"datasets",
]

classifiers = [
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]

[tool.setuptools]
Expand Down
7 changes: 6 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,9 @@ pytest-timeout
aiohttp
openai
httpx
vllm
vllm

# audio
librosa
soundfile
datasets
106 changes: 105 additions & 1 deletion src/flexible_inference_benchmark/engine/backend_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class bcolors:

class RequestFuncInput(BaseModel):
prompt: str
media: List[str]
media: List[str] = Field(default_factory=list)
audio_file_path: Optional[str] = None
language: Optional[str] = None
api_url: str
prompt_len: int
output_len: int
Expand Down Expand Up @@ -633,6 +635,107 @@ def remove_prefix(text: str, prefix: str) -> str:
return text


async def async_request_openai_audio_transcriptions(
idx: int, request_func_input: RequestFuncInput, pbar: Optional[tqdm], verbose: bool, wait_time: float
) -> RequestFuncOutput:
"""
Handle API calls to an OpenAI-compatible audio transcription endpoint.

This function manages the interaction with audio transcription APIs that follow
the OpenAI API format for audio endpoints (/v1/audio/transcriptions).
"""
api_url = request_func_input.api_url

output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len

if not request_func_input.audio_file_path or not os.path.exists(request_func_input.audio_file_path):
output.success = False
output.error = f"Audio file not provided or not found: {request_func_input.audio_file_path}"
if pbar:
pbar.update(1)
return output

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT, cookies=request_func_input.cookies) as session:
form = aiohttp.FormData()
form.add_field('model', request_func_input.model)
if request_func_input.language:
form.add_field('language', request_func_input.language)

try:
audio_file_handle = open(request_func_input.audio_file_path, "rb")
form.add_field('file',
audio_file_handle,
filename=os.path.basename(request_func_input.audio_file_path),
content_type='application/octet-stream')
except IOError as e:
output.success = False
output.error = f"Could not open audio file {request_func_input.audio_file_path}: {e}"
if pbar:
pbar.update(1)
if 'audio_file_handle' in locals() and audio_file_handle:
audio_file_handle.close()
return output

generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
latency = 0.0

try:
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
async with session.post(url=api_url, data=form, headers=headers, verify_ssl=request_func_input.ssl) as response:
if response.status == 200:
if request_func_input.stream:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue

chunk_text = chunk_bytes.decode("utf-8")
timestamp = time.perf_counter()
if ttft == 0.0 and chunk_text:
ttft = timestamp - st
output.ttft = ttft
elif chunk_text:
output.itl.append(timestamp - most_recent_timestamp)

generated_text += chunk_text
most_recent_timestamp = timestamp
latency = time.perf_counter() - st
else:
resp_json = await response.json()
generated_text = resp_json.get("text", "")
latency = time.perf_counter() - st
output.ttft = latency # For non-streaming, TTFT is the full latency.

output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = len(generated_text.split())

else:
output.success = False
error_detail = await response.text()
output.error = f"API error: {response.status} {response.reason}. Detail: {error_detail}"

except aiohttp.ClientConnectorError:
output.success = False
output.error = "Connection error, please verify the server is running and endpoint is correct."
except Exception: # pylint: disable=broad-except
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
finally:
if 'audio_file_handle' in locals() and audio_file_handle:
audio_file_handle.close()

if pbar:
pbar.update(1)
return output


def print_verbose(
idx: int, request_func_input: RequestFuncInput, send_time: float, rcv_time: float, latency: float, sending: bool
) -> None:
Expand Down Expand Up @@ -663,5 +766,6 @@ def print_verbose(
"deepspeed-mii": async_request_deepspeed_mii,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
"openai-audio": async_request_openai_audio_transcriptions, # New backend
"tensorrt-llm": async_request_trt_llm,
}
57 changes: 19 additions & 38 deletions src/flexible_inference_benchmark/engine/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,29 +127,16 @@ async def send_wave_request(
return request_result

async def benchmark(
self, data: List[Tuple[str, int, int]], request_times: List[Union[float, int]], requests_media: List[List[str]]
self,
prepared_requests_data: List[Dict[str, Any]],
request_times: List[Union[float, int]]
) -> list[Union[RequestFuncOutput, Any, None]]:
assert len(data) == len(request_times), "Data and request times must have the same length"
assert len(data) == len(requests_media), "Data and request media must have the same length"
pbar = None if self.disable_tqdm else tqdm(total=len(data))
assert len(prepared_requests_data) == len(request_times), "Prepared requests data and request times must have the same length"
pbar = None if self.disable_tqdm else tqdm(total=len(prepared_requests_data), desc="Sending Requests")

request_func_inputs = [
RequestFuncInput(
prompt=data_sample[0],
media=media_sample,
api_url=self.api_url,
prompt_len=data_sample[1],
output_len=data_sample[2],
model=self.model_id,
best_of=self.best_of,
use_beam_search=self.use_beam_search,
ssl=self.ssl,
ignore_eos=self.ignore_eos,
stream=self.stream,
cookies=self.cookies,
logprobs=self.logprobs,
)
for (data_sample, media_sample) in zip(data, requests_media)
RequestFuncInput(**req_data)
for req_data in prepared_requests_data
]

if self.wave:
Expand All @@ -171,22 +158,16 @@ async def benchmark(
]
)

async def validate_url_endpoint(
self, request: Tuple[str, int, int], media_item: List[str]
async def validate_request_func_input(
self, request_input: RequestFuncInput
) -> Union[RequestFuncOutput, Any]:
data = RequestFuncInput(
prompt=request[0],
media=media_item,
api_url=self.api_url,
prompt_len=request[1],
output_len=request[2],
model=self.model_id,
best_of=self.best_of,
use_beam_search=self.use_beam_search,
ssl=self.ssl,
ignore_eos=self.ignore_eos,
stream=self.stream,
cookies=self.cookies,
logprobs=self.logprobs,
)
return await self.send_request(0, data, 0, None, None)
"""
Validate a request by sending it to the endpoint.

Args:
request_input: A fully prepared RequestFuncInput instance.

Returns:
The response from the server.
"""
return await self.send_request(0, request_input, 0, None, None)
Loading
Loading