Skip to content

Commit 7cc84a5

Browse files
authored
chore: sync sdk code with DeepLearning repo (#119)
1 parent 8860c92 commit 7cc84a5

File tree

5 files changed

+540
-1
lines changed

5 files changed

+540
-1
lines changed

assemblyai/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.40.2"
1+
__version__ = "0.41.0b1"

assemblyai/streaming/v3/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from .client import StreamingClient
2+
from .models import (
3+
BeginEvent,
4+
EventMessage,
5+
StreamingClientOptions,
6+
StreamingError,
7+
StreamingEvents,
8+
StreamingParameters,
9+
StreamingSessionParameters,
10+
TerminationEvent,
11+
TurnEvent,
12+
Word,
13+
)
14+
15+
__all__ = [
16+
"BeginEvent",
17+
"EventMessage",
18+
"StreamingClient",
19+
"StreamingClientOptions",
20+
"StreamingError",
21+
"StreamingEvents",
22+
"StreamingParameters",
23+
"StreamingSessionParameters",
24+
"TerminationEvent",
25+
"TurnEvent",
26+
"Word",
27+
]

assemblyai/streaming/v3/client.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import json
2+
import logging
3+
import queue
4+
import sys
5+
import threading
6+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
7+
from urllib.parse import urlencode
8+
9+
import httpx
10+
import websockets
11+
from pydantic import BaseModel
12+
from websockets.sync.client import connect as websocket_connect
13+
14+
from assemblyai import __version__
15+
16+
from .models import (
17+
BeginEvent,
18+
ErrorEvent,
19+
EventMessage,
20+
OperationMessage,
21+
StreamingClientOptions,
22+
StreamingError,
23+
StreamingErrorCodes,
24+
StreamingEvents,
25+
StreamingParameters,
26+
StreamingSessionParameters,
27+
TerminateSession,
28+
TerminationEvent,
29+
TurnEvent,
30+
UpdateConfiguration,
31+
)
32+
33+
logger = logging.getLogger(__name__)
34+
35+
36+
def _user_agent() -> str:
37+
vi = sys.version_info
38+
python_version = f"{vi.major}.{vi.minor}.{vi.micro}"
39+
return (
40+
f"AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})"
41+
)
42+
43+
44+
class StreamingClient:
45+
def __init__(self, options: StreamingClientOptions):
46+
self._options = options
47+
48+
self._handlers: Dict[StreamingEvents, List[Callable]] = {}
49+
50+
for event in StreamingEvents.__members__.values():
51+
self._handlers[event] = []
52+
53+
self._write_queue: queue.Queue[OperationMessage] = queue.Queue()
54+
self._write_thread = threading.Thread(target=self._write_message)
55+
self._read_thread = threading.Thread(target=self._read_message)
56+
self._stop_event = threading.Event()
57+
58+
def connect(self, params: StreamingParameters) -> None:
59+
params_dict = params.model_dump(exclude_none=True)
60+
params_encoded = urlencode(params_dict)
61+
62+
uri = f"wss://{self._options.api_host}/v3/ws?{params_encoded}"
63+
headers = {
64+
"Authorization": self._options.api_key,
65+
"User-Agent": _user_agent(),
66+
"AssemblyAI-Version": "2025-05-12",
67+
}
68+
69+
try:
70+
self._websocket = websocket_connect(
71+
uri,
72+
additional_headers=headers,
73+
open_timeout=15,
74+
)
75+
except websockets.exceptions.ConnectionClosed as exc:
76+
self._handle_error(exc)
77+
return
78+
79+
self._write_thread.start()
80+
self._read_thread.start()
81+
82+
logger.debug("Connected to WebSocket server")
83+
84+
def disconnect(self, terminate: bool = False) -> None:
85+
if terminate and not self._stop_event.is_set():
86+
self._write_queue.put(TerminateSession())
87+
88+
try:
89+
self._read_thread.join()
90+
self._write_thread.join()
91+
92+
if self._websocket:
93+
self._websocket.close()
94+
except Exception:
95+
pass
96+
97+
def stream(
98+
self, data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]]
99+
) -> None:
100+
if isinstance(data, bytes):
101+
self._write_queue.put(data)
102+
return
103+
104+
for chunk in data:
105+
self._write_queue.put(chunk)
106+
107+
def set_params(self, params: StreamingSessionParameters):
108+
message = UpdateConfiguration(**params.model_dump())
109+
self._write_queue.put(message)
110+
111+
def on(self, event: StreamingEvents, handler: Callable) -> None:
112+
if event in StreamingEvents.__members__.values() and callable(handler):
113+
self._handlers[event].append(handler)
114+
115+
def _write_message(self) -> None:
116+
while not self._stop_event.is_set():
117+
if not self._websocket:
118+
raise ValueError("Not connected to the WebSocket server")
119+
120+
try:
121+
data = self._write_queue.get(timeout=1)
122+
except queue.Empty:
123+
continue
124+
125+
try:
126+
if isinstance(data, bytes):
127+
self._websocket.send(data)
128+
elif isinstance(data, BaseModel):
129+
message = data.model_dump_json(exclude_none=True)
130+
self._websocket.send(message)
131+
else:
132+
raise ValueError(f"Attempted to send invalid message: {type(data)}")
133+
except websockets.exceptions.ConnectionClosed as exc:
134+
self._handle_error(exc)
135+
return
136+
137+
def _read_message(self) -> None:
138+
while not self._stop_event.is_set():
139+
if not self._websocket:
140+
raise ValueError("Not connected to the WebSocket server")
141+
142+
try:
143+
message_data = self._websocket.recv(timeout=1)
144+
except TimeoutError:
145+
continue
146+
except websockets.exceptions.ConnectionClosed as exc:
147+
self._handle_error(exc)
148+
return
149+
150+
try:
151+
message_json = json.loads(message_data)
152+
except json.JSONDecodeError as exc:
153+
logger.warning(f"Failed to decode message: {exc}")
154+
continue
155+
156+
message = self._parse_message(message_json)
157+
158+
if isinstance(message, ErrorEvent):
159+
self._handle_error(message)
160+
elif message:
161+
self._handle_message(message)
162+
else:
163+
logger.warning(f"Unsupported event type: {message_json['type']}")
164+
165+
def _handle_message(self, message: EventMessage) -> None:
166+
if isinstance(message, TerminationEvent):
167+
self._stop_event.set()
168+
169+
event_type = StreamingEvents[message.type]
170+
171+
for handler in self._handlers[event_type]:
172+
handler(self, message)
173+
174+
def _parse_message(self, data: Dict[str, Any]) -> Optional[EventMessage]:
175+
if "type" in data:
176+
message_type = data.get("type")
177+
178+
event_type = self._parse_event_type(message_type)
179+
180+
if event_type == StreamingEvents.Begin:
181+
return BeginEvent.model_validate(data)
182+
elif event_type == StreamingEvents.Termination:
183+
return TerminationEvent.model_validate(data)
184+
elif event_type == StreamingEvents.Turn:
185+
return TurnEvent.model_validate(data)
186+
else:
187+
return None
188+
elif "error" in data:
189+
return ErrorEvent.model_validate(data)
190+
191+
return None
192+
193+
@staticmethod
194+
def _parse_event_type(message_type: Optional[Any]) -> Optional[StreamingEvents]:
195+
if not isinstance(message_type, str):
196+
return None
197+
198+
try:
199+
return StreamingEvents[message_type]
200+
except KeyError:
201+
return None
202+
203+
def _handle_error(
204+
self,
205+
error: Union[
206+
ErrorEvent,
207+
websockets.exceptions.ConnectionClosed,
208+
],
209+
):
210+
parsed_error = self._parse_error(error)
211+
212+
for handler in self._handlers[StreamingEvents.Error]:
213+
handler(self, parsed_error)
214+
215+
self.disconnect()
216+
217+
def _parse_error(
218+
self,
219+
error: Union[
220+
ErrorEvent,
221+
websockets.exceptions.ConnectionClosed,
222+
],
223+
) -> StreamingError:
224+
if isinstance(error, ErrorEvent):
225+
return StreamingError(
226+
message=error.error,
227+
)
228+
elif isinstance(error, websockets.exceptions.ConnectionClosed):
229+
if (
230+
error.code >= 4000
231+
and error.code <= 4999
232+
and error.code in StreamingErrorCodes
233+
):
234+
error_message = StreamingErrorCodes[error.code]
235+
else:
236+
error_message = error.reason
237+
238+
if error.code != 1000:
239+
return StreamingError(message=error_message, code=error.code)
240+
241+
return StreamingError(
242+
message=f"Unknown error: {error}",
243+
)
244+
245+
246+
class HTTPClient:
247+
def __init__(self, options: StreamingClientOptions):
248+
headers = {
249+
"Authorization": options.api_key,
250+
"User-Agent": _user_agent(),
251+
}
252+
253+
base_url = f"https://{options.api_host}"
254+
255+
self._http_client = httpx.Client(base_url=base_url, headers=headers, timeout=30)

0 commit comments

Comments
 (0)