Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions sdk/rt/speechmatics/rt/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ class StaticKeyAuth(AuthBase):
def __init__(self, api_key: Optional[str] = None):
self._api_key = api_key or os.environ.get("SPEECHMATICS_API_KEY")

if not self._api_key:
raise ValueError("API key required: provide api_key or set SPEECHMATICS_API_KEY")

async def get_auth_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self._api_key}"}

Expand Down
29 changes: 25 additions & 4 deletions sdk/rt/speechmatics/rt/_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import ssl
from dataclasses import asdict
from dataclasses import dataclass
from dataclasses import field
Expand Down Expand Up @@ -462,9 +463,9 @@ class ConnectionConfig:
close_timeout: Timeout for closing WebSocket connection.
max_size: Maximum message size in bytes.
max_queue: Maximum number of messages in receive queue.
read_limit: Maximum number of bytes to read from WebSocket.
write_limit: Maximum number of bytes to write to WebSocket.

read_limit: Maximum number of bytes to read from WebSocket (legacy websockets only).
write_limit: Maximum number of bytes to write to WebSocket (legacy websockets only).
ssl_context: SSL context for the WebSocket connection.
Returns:
Websocket connection configuration as a dict while excluding None values.
"""
Expand All @@ -477,9 +478,29 @@ class ConnectionConfig:
max_queue: Optional[int] = None
read_limit: Optional[int] = None
write_limit: Optional[int] = None
ssl_context: ssl.SSLContext = field(default_factory=ssl.create_default_context)

def to_dict(self) -> dict[str, Any]:
return asdict(self, dict_factory=lambda x: {k: v for (k, v) in x if v is not None})
"""Convert to dict, excluding ssl field to avoid pickle errors."""
result = {}
if self.open_timeout is not None:
result["open_timeout"] = self.open_timeout
if self.ping_interval is not None:
result["ping_interval"] = self.ping_interval
if self.ping_timeout is not None:
result["ping_timeout"] = self.ping_timeout
if self.close_timeout is not None:
result["close_timeout"] = self.close_timeout
if self.max_size is not None:
result["max_size"] = self.max_size
if self.max_queue is not None:
result["max_queue"] = self.max_queue
if self.read_limit is not None:
result["read_limit"] = self.read_limit
if self.write_limit is not None:
result["write_limit"] = self.write_limit

return result


@dataclass
Expand Down
8 changes: 8 additions & 0 deletions sdk/rt/speechmatics/rt/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
from websockets.asyncio.client import connect

WS_HEADERS_KEY = "additional_headers"
IS_LEGACY_WEBSOCKETS = False
except ImportError:
# Fall back to legacy websockets
from websockets.legacy.client import WebSocketClientProtocol
from websockets.legacy.client import connect # type: ignore

WS_HEADERS_KEY = "extra_headers"
IS_LEGACY_WEBSOCKETS = True


class Transport:
Expand Down Expand Up @@ -116,8 +118,14 @@ async def connect(self, ws_headers: Optional[dict] = None) -> None:
ws_kwargs: dict = {
WS_HEADERS_KEY: ws_headers,
**self._conn_config.to_dict(),
"ssl": self._conn_config.ssl_context,
}

# Filter out parameters not supported by new websockets >=13.0
if not IS_LEGACY_WEBSOCKETS:
ws_kwargs.pop("read_limit", None)
ws_kwargs.pop("write_limit", None)

self._websocket = await connect(
url_with_params,
**ws_kwargs,
Expand Down