Skip to content

Commit 14dbfd8

Browse files
authored
Rework data stream API (#352)
1 parent a412b3d commit 14dbfd8

File tree

4 files changed

+91
-74
lines changed

4 files changed

+91
-74
lines changed

examples/data-streams/data_streams.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ async def greetParticipant(identity: str):
2424
await room.local_participant.send_file(
2525
"./green_tree_python.jpg",
2626
destination_identities=[identity],
27-
topic="welcome",
27+
topic="files",
2828
)
2929

3030
async def on_chat_message_received(
@@ -55,17 +55,17 @@ def on_participant_connected(participant: rtc.RemoteParticipant):
5555
asyncio.create_task(greetParticipant(participant.identity))
5656

5757
room.set_text_stream_handler(
58+
"chat",
5859
lambda reader, participant_identity: asyncio.create_task(
5960
on_chat_message_received(reader, participant_identity)
6061
),
61-
"chat",
6262
)
6363

6464
room.set_byte_stream_handler(
65+
"files",
6566
lambda reader, participant_identity: asyncio.create_task(
6667
on_welcome_image_received(reader, participant_identity)
6768
),
68-
"welcome",
6969
)
7070

7171
# By default, autosubscribe is enabled. The participant will be subscribed to

livekit-rtc/livekit/rtc/data_stream.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import datetime
2020
from collections.abc import Callable
2121
from dataclasses import dataclass
22-
from typing import AsyncIterator, Optional, TypedDict, Dict, List
22+
from typing import AsyncIterator, Optional, Dict, List
2323
from ._proto.room_pb2 import DataStream as proto_DataStream
2424
from ._proto import ffi_pb2 as proto_ffi
2525
from ._proto import room_pb2 as proto_room
@@ -35,13 +35,13 @@
3535

3636

3737
@dataclass
38-
class BaseStreamInfo(TypedDict):
38+
class BaseStreamInfo:
3939
stream_id: str
4040
mime_type: str
4141
topic: str
4242
timestamp: int
4343
size: Optional[int]
44-
attributes: Optional[Dict[str, str]] # Optional for the extensions dictionary
44+
attributes: Optional[Dict[str, str]] # Optional for the attributes dictionary
4545

4646

4747
@dataclass
@@ -259,7 +259,7 @@ def __init__(
259259
local_participant: LocalParticipant,
260260
*,
261261
topic: str = "",
262-
extensions: Optional[Dict[str, str]] = {},
262+
attributes: Optional[Dict[str, str]] = {},
263263
stream_id: str | None = None,
264264
total_size: int | None = None,
265265
reply_to_id: str | None = None,
@@ -268,7 +268,7 @@ def __init__(
268268
super().__init__(
269269
local_participant,
270270
topic,
271-
extensions,
271+
attributes,
272272
stream_id,
273273
total_size,
274274
mime_type="text/plain",
@@ -313,7 +313,7 @@ def __init__(
313313
*,
314314
name: str,
315315
topic: str = "",
316-
extensions: Optional[Dict[str, str]] = None,
316+
attributes: Optional[Dict[str, str]] = None,
317317
stream_id: str | None = None,
318318
total_size: int | None = None,
319319
mime_type: str = "application/octet-stream",
@@ -322,7 +322,7 @@ def __init__(
322322
super().__init__(
323323
local_participant,
324324
topic,
325-
extensions,
325+
attributes,
326326
stream_id,
327327
total_size,
328328
mime_type=mime_type,

livekit-rtc/livekit/rtc/participant.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from typing import List, Union, Callable, Dict, Awaitable, Optional, Mapping, cast
2323
from abc import abstractmethod, ABC
2424

25-
2625
from ._ffi_client import FfiClient, FfiHandle
2726
from ._proto import ffi_pb2 as proto_ffi
2827
from ._proto import participant_pb2 as proto_participant
@@ -552,9 +551,9 @@ async def set_attributes(self, attributes: dict[str, str]) -> None:
552551
async def stream_text(
553552
self,
554553
*,
555-
destination_identities: List[str] = [],
554+
destination_identities: Optional[List[str]] = None,
556555
topic: str = "",
557-
extensions: Dict[str, str] = {},
556+
attributes: Optional[Dict[str, str]] = None,
558557
reply_to_id: str | None = None,
559558
total_size: int | None = None,
560559
) -> TextStreamWriter:
@@ -565,7 +564,7 @@ async def stream_text(
565564
writer = TextStreamWriter(
566565
self,
567566
topic=topic,
568-
extensions=extensions,
567+
attributes=attributes,
569568
reply_to_id=reply_to_id,
570569
destination_identities=destination_identities,
571570
total_size=total_size,
@@ -579,16 +578,16 @@ async def send_text(
579578
self,
580579
text: str,
581580
*,
582-
destination_identities: List[str] = [],
581+
destination_identities: Optional[List[str]] = None,
583582
topic: str = "",
584-
extensions: Dict[str, str] = {},
583+
attributes: Optional[Dict[str, str]] = None,
585584
reply_to_id: str | None = None,
586585
):
587586
total_size = len(text.encode())
588587
writer = await self.stream_text(
589588
destination_identities=destination_identities,
590589
topic=topic,
591-
extensions=extensions,
590+
attributes=attributes,
592591
reply_to_id=reply_to_id,
593592
total_size=total_size,
594593
)
@@ -605,7 +604,7 @@ async def stream_bytes(
605604
*,
606605
total_size: int | None = None,
607606
mime_type: str = "application/octet-stream",
608-
extensions: Optional[Dict[str, str]] = None,
607+
attributes: Optional[Dict[str, str]] = None,
609608
stream_id: str | None = None,
610609
destination_identities: Optional[List[str]] = None,
611610
topic: str = "",
@@ -617,7 +616,7 @@ async def stream_bytes(
617616
writer = ByteStreamWriter(
618617
self,
619618
name=name,
620-
extensions=extensions,
619+
attributes=attributes,
621620
total_size=total_size,
622621
stream_id=stream_id,
623622
mime_type=mime_type,
@@ -632,6 +631,7 @@ async def stream_bytes(
632631
async def send_file(
633632
self,
634633
file_path: str,
634+
*,
635635
topic: str = "",
636636
destination_identities: Optional[List[str]] = None,
637637
attributes: Optional[Dict[str, str]] = None,
@@ -649,7 +649,7 @@ async def send_file(
649649
name=file_name,
650650
total_size=file_size,
651651
mime_type=mime_type,
652-
extensions=attributes,
652+
attributes=attributes,
653653
stream_id=stream_id,
654654
destination_identities=destination_identities,
655655
topic=topic,

livekit-rtc/livekit/rtc/room.py

Lines changed: 71 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
ByteStreamHandler,
4141
)
4242

43+
4344
EventTypes = Literal[
4445
"participant_connected",
4546
"participant_disconnected",
@@ -138,11 +139,13 @@ def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
138139
self._room_queue = BroadcastQueue[proto_ffi.FfiEvent]()
139140
self._info = proto_room.RoomInfo()
140141
self._rpc_invocation_tasks: set[asyncio.Task] = set()
142+
self._data_stream_tasks: set[asyncio.Task] = set()
141143

142144
self._remote_participants: Dict[str, RemoteParticipant] = {}
143145
self._connection_state = ConnectionState.CONN_DISCONNECTED
144146
self._first_sid_future = asyncio.Future[str]()
145147
self._local_participant: LocalParticipant | None = None
148+
146149
self._text_stream_readers: Dict[str, TextStreamReader] = {}
147150
self._byte_stream_readers: Dict[str, ByteStreamReader] = {}
148151
self._text_stream_handlers: Dict[str, TextStreamHandler] = {}
@@ -406,12 +409,35 @@ def on_participant_connected(participant):
406409
# start listening to room events
407410
self._task = self._loop.create_task(self._listen_task())
408411

412+
def register_byte_stream_handler(self, topic: str, handler: ByteStreamHandler):
413+
existing_handler = self._byte_stream_handlers.get(topic)
414+
if existing_handler is None:
415+
self._byte_stream_handlers[topic] = handler
416+
else:
417+
raise ValueError("byte stream handler for topic '%s' already set" % topic)
418+
419+
def unregister_byte_stream_handler(self, topic: str):
420+
if self._byte_stream_handlers.get(topic):
421+
self._byte_stream_handlers.pop(topic)
422+
423+
def register_text_stream_handler(self, topic: str, handler: TextStreamHandler):
424+
existing_handler = self._text_stream_handlers.get(topic)
425+
if existing_handler is None:
426+
self._text_stream_handlers[topic] = handler
427+
else:
428+
raise ValueError("text stream handler for topic '%s' already set" % topic)
429+
430+
def unregister_text_stream_handler(self, topic: str):
431+
if self._text_stream_handlers.get(topic):
432+
self._text_stream_handlers.pop(topic)
433+
409434
async def disconnect(self) -> None:
410435
"""Disconnects from the room."""
411436
if not self.isconnected():
412437
return
413438

414439
await self._drain_rpc_invocation_tasks()
440+
await self._drain_data_stream_tasks()
415441

416442
req = proto_ffi.FfiRequest()
417443
req.disconnect.room_handle = self._ffi_handle.handle # type: ignore
@@ -426,28 +452,6 @@ async def disconnect(self) -> None:
426452
await self._task
427453
FfiClient.instance.queue.unsubscribe(self._ffi_queue)
428454

429-
def set_byte_stream_handler(self, handler: ByteStreamHandler, topic: str = ""):
430-
existing_handler = self._byte_stream_handlers.get(topic)
431-
if existing_handler is None:
432-
self._byte_stream_handlers[topic] = handler
433-
else:
434-
raise TypeError("byte stream handler for topic '%s' already set" % topic)
435-
436-
def remove_byte_stream_handler(self, topic: str = ""):
437-
if self._byte_stream_handlers.get(topic):
438-
self._byte_stream_handlers.pop(topic)
439-
440-
def set_text_stream_handler(self, handler: TextStreamHandler, topic: str = ""):
441-
existing_handler = self._text_stream_handlers.get(topic)
442-
if existing_handler is None:
443-
self._text_stream_handlers[topic] = handler
444-
else:
445-
raise TypeError("text stream handler for topic '%s' already set" % topic)
446-
447-
def remove_text_stream_handler(self, topic: str = ""):
448-
if self._text_stream_handlers.get(topic):
449-
self._text_stream_handlers.pop(topic)
450-
451455
async def _listen_task(self) -> None:
452456
# listen to incoming room events
453457
while True:
@@ -474,6 +478,7 @@ async def _listen_task(self) -> None:
474478

475479
# Clean up any pending RPC invocation tasks
476480
await self._drain_rpc_invocation_tasks()
481+
await self._drain_data_stream_tasks()
477482

478483
def _on_rpc_method_invocation(self, rpc_invocation: RpcMethodInvocationEvent):
479484
if self._local_participant is None:
@@ -747,40 +752,18 @@ def _on_room_event(self, event: proto_room.RoomEvent):
747752
event.stream_header_received.participant_identity,
748753
)
749754
elif which == "stream_chunk_received":
750-
asyncio.gather(self._handle_stream_chunk(event.stream_chunk_received.chunk))
755+
task = asyncio.create_task(
756+
self._handle_stream_chunk(event.stream_chunk_received.chunk)
757+
)
758+
self._data_stream_tasks.add(task)
759+
task.add_done_callback(self._data_stream_tasks.discard)
760+
751761
elif which == "stream_trailer_received":
752-
asyncio.gather(
762+
task = asyncio.create_task(
753763
self._handle_stream_trailer(event.stream_trailer_received.trailer)
754764
)
755-
756-
async def _drain_rpc_invocation_tasks(self) -> None:
757-
if self._rpc_invocation_tasks:
758-
for task in self._rpc_invocation_tasks:
759-
task.cancel()
760-
await asyncio.gather(*self._rpc_invocation_tasks, return_exceptions=True)
761-
762-
def _retrieve_remote_participant(
763-
self, identity: str
764-
) -> Optional[RemoteParticipant]:
765-
"""Retrieve a remote participant by identity"""
766-
return self._remote_participants.get(identity, None)
767-
768-
def _retrieve_participant(self, identity: str) -> Optional[Participant]:
769-
"""Retrieve a local or remote participant by identity"""
770-
if identity and identity == self.local_participant.identity:
771-
return self.local_participant
772-
773-
return self._retrieve_remote_participant(identity)
774-
775-
def _create_remote_participant(
776-
self, owned_info: proto_participant.OwnedParticipant
777-
) -> RemoteParticipant:
778-
if owned_info.info.identity in self._remote_participants:
779-
raise Exception("participant already exists")
780-
781-
participant = RemoteParticipant(owned_info)
782-
self._remote_participants[participant.identity] = participant
783-
return participant
765+
self._data_stream_tasks.add(task)
766+
task.add_done_callback(self._data_stream_tasks.discard)
784767

785768
def _handle_stream_header(
786769
self, header: proto_room.DataStream.Header, participant_identity: str
@@ -799,7 +782,6 @@ def _handle_stream_header(
799782
self._text_stream_readers[header.stream_id] = text_reader
800783
text_stream_handler(text_reader, participant_identity)
801784
elif stream_type == "byte_header":
802-
logging.warning("received byte header, %s", header.stream_id)
803785
byte_stream_handler = self._byte_stream_handlers.get(header.topic)
804786
if byte_stream_handler is None:
805787
logging.info(
@@ -835,6 +817,41 @@ async def _handle_stream_trailer(self, trailer: proto_room.DataStream.Trailer):
835817
await file_reader._on_stream_close(trailer)
836818
self._byte_stream_readers.pop(trailer.stream_id)
837819

820+
async def _drain_rpc_invocation_tasks(self) -> None:
821+
if self._rpc_invocation_tasks:
822+
for task in self._rpc_invocation_tasks:
823+
task.cancel()
824+
await asyncio.gather(*self._rpc_invocation_tasks, return_exceptions=True)
825+
826+
async def _drain_data_stream_tasks(self) -> None:
827+
if self._data_stream_tasks:
828+
for task in self._data_stream_tasks:
829+
task.cancel()
830+
await asyncio.gather(*self._data_stream_tasks, return_exceptions=True)
831+
832+
def _retrieve_remote_participant(
833+
self, identity: str
834+
) -> Optional[RemoteParticipant]:
835+
"""Retrieve a remote participant by identity"""
836+
return self._remote_participants.get(identity, None)
837+
838+
def _retrieve_participant(self, identity: str) -> Optional[Participant]:
839+
"""Retrieve a local or remote participant by identity"""
840+
if identity and identity == self.local_participant.identity:
841+
return self.local_participant
842+
843+
return self._retrieve_remote_participant(identity)
844+
845+
def _create_remote_participant(
846+
self, owned_info: proto_participant.OwnedParticipant
847+
) -> RemoteParticipant:
848+
if owned_info.info.identity in self._remote_participants:
849+
raise Exception("participant already exists")
850+
851+
participant = RemoteParticipant(owned_info)
852+
self._remote_participants[participant.identity] = participant
853+
return participant
854+
838855
def __repr__(self) -> str:
839856
sid = "unknown"
840857
if self._first_sid_future.done():

0 commit comments

Comments
 (0)