Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import base64
import json
import os
import time
import weakref
from dataclasses import dataclass
from typing import Any, Literal, TypedDict
from typing import Any, Callable, Generic, Literal, TypedDict, TypeVar

import aiohttp

Expand All @@ -42,6 +43,33 @@
from .log import logger
from .models import STTRealtimeSampleRates

T = TypeVar("T")


class _PeriodicCollector(Generic[T]):
"""Accumulates values and calls a callback after a specified duration."""

def __init__(self, callback: Callable[[T], None], *, duration: float) -> None:
self._duration = duration
self._callback = callback
self._last_flush_time = time.monotonic()
self._total: T | None = None

def push(self, value: T) -> None:
if self._total is None:
self._total = value
else:
self._total += value # type: ignore[operator]
if time.monotonic() - self._last_flush_time >= self._duration:
self.flush()

def flush(self) -> None:
if self._total is not None:
self._callback(self._total)
self._total = None
self._last_flush_time = time.monotonic()


API_BASE_URL_V1 = "https://api.elevenlabs.io/v1"
AUTHORIZATION_HEADER = "xi-api-key"

Expand Down Expand Up @@ -311,6 +339,10 @@ def __init__(
self._session = http_session
self._reconnect_event = asyncio.Event()
self._speaking = False # Track if we're currently in a speech segment
self._audio_duration_collector = _PeriodicCollector(
callback=self._on_audio_duration_report,
duration=5.0,
)

def update_options(
self,
Expand Down Expand Up @@ -354,6 +386,7 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
frames.extend(audio_bstream.flush())

for frame in frames:
self._audio_duration_collector.push(frame.duration)
audio_b64 = base64.b64encode(frame.data.tobytes()).decode("utf-8")
await ws.send_str(
json.dumps(
Expand All @@ -366,6 +399,7 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
)
)

self._audio_duration_collector.flush()
closing_ws = True

@utils.log_exceptions(logger=logger)
Expand Down Expand Up @@ -571,3 +605,11 @@ def _process_stream_event(self, data: dict) -> None:
raise APIConnectionError(f"{message_type}: {error_msg}{details_suffix}")
else:
logger.warning("ElevenLabs STT unknown message type: %s, data: %s", message_type, data)

def _on_audio_duration_report(self, duration: float) -> None:
usage_event = stt.SpeechEvent(
type=stt.SpeechEventType.RECOGNITION_USAGE,
alternatives=[],
recognition_usage=stt.RecognitionUsage(audio_duration=duration),
)
self._event_ch.send_nowait(usage_event)