From 73b49d16451b15586d427731c70a2129829601e6 Mon Sep 17 00:00:00 2001 From: lukasIO Date: Wed, 29 Jan 2025 10:51:39 +0100 Subject: [PATCH 01/12] Raise ValueError instead of TypeError for previously set stream handlers --- livekit-rtc/livekit/rtc/room.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/livekit-rtc/livekit/rtc/room.py b/livekit-rtc/livekit/rtc/room.py index 78386083..c20c5e11 100644 --- a/livekit-rtc/livekit/rtc/room.py +++ b/livekit-rtc/livekit/rtc/room.py @@ -431,7 +431,7 @@ def set_byte_stream_handler(self, handler: ByteStreamHandler, topic: str = ""): if existing_handler is None: self._byte_stream_handlers[topic] = handler else: - raise TypeError("byte stream handler for topic '%s' already set" % topic) + raise ValueError("byte stream handler for topic '%s' already set" % topic) def remove_byte_stream_handler(self, topic: str = ""): if self._byte_stream_handlers.get(topic): @@ -442,7 +442,7 @@ def set_text_stream_handler(self, handler: TextStreamHandler, topic: str = ""): if existing_handler is None: self._text_stream_handlers[topic] = handler else: - raise TypeError("text stream handler for topic '%s' already set" % topic) + raise ValueError("text stream handler for topic '%s' already set" % topic) def remove_text_stream_handler(self, topic: str = ""): if self._text_stream_handlers.get(topic): From 2b83337747172ec49f1766303a271f27ece1fcca Mon Sep 17 00:00:00 2001 From: lukasIO Date: Wed, 29 Jan 2025 16:07:30 +0100 Subject: [PATCH 02/12] use asyncio.create_task --- livekit-rtc/livekit/rtc/room.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/livekit-rtc/livekit/rtc/room.py b/livekit-rtc/livekit/rtc/room.py index c20c5e11..fe9245e3 100644 --- a/livekit-rtc/livekit/rtc/room.py +++ b/livekit-rtc/livekit/rtc/room.py @@ -747,9 +747,11 @@ def _on_room_event(self, event: proto_room.RoomEvent): event.stream_header_received.participant_identity, ) elif which == "stream_chunk_received": - asyncio.gather(self._handle_stream_chunk(event.stream_chunk_received.chunk)) + asyncio.create_task( + self._handle_stream_chunk(event.stream_chunk_received.chunk) + ) elif which == "stream_trailer_received": - asyncio.gather( + asyncio.create_task( self._handle_stream_trailer(event.stream_trailer_received.trailer) ) From 6cad5a4a54b01865fdea5febd770ef2d23692436 Mon Sep 17 00:00:00 2001 From: lukasIO Date: Wed, 29 Jan 2025 16:09:36 +0100 Subject: [PATCH 03/12] kwargs --- livekit-rtc/livekit/rtc/participant.py | 1 + 1 file changed, 1 insertion(+) diff --git a/livekit-rtc/livekit/rtc/participant.py b/livekit-rtc/livekit/rtc/participant.py index c31b120f..2f121170 100644 --- a/livekit-rtc/livekit/rtc/participant.py +++ b/livekit-rtc/livekit/rtc/participant.py @@ -632,6 +632,7 @@ async def stream_bytes( async def send_file( self, file_path: str, + *, topic: str = "", destination_identities: Optional[List[str]] = None, attributes: Optional[Dict[str, str]] = None, From 62325968bb382e616ffdbaa4c26ba45611765b76 Mon Sep 17 00:00:00 2001 From: lukasIO Date: Wed, 29 Jan 2025 16:10:45 +0100 Subject: [PATCH 04/12] optional for consistency --- livekit-rtc/livekit/rtc/participant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/livekit-rtc/livekit/rtc/participant.py b/livekit-rtc/livekit/rtc/participant.py index 2f121170..1ae69058 100644 --- a/livekit-rtc/livekit/rtc/participant.py +++ b/livekit-rtc/livekit/rtc/participant.py @@ -579,9 +579,9 @@ async def send_text( self, text: str, *, - destination_identities: List[str] = [], + destination_identities: Optional[List[str]] = None, topic: str = "", - extensions: Dict[str, str] = {}, + extensions: Optional[Dict[str, str]] = None, reply_to_id: str | None = None, ): total_size = len(text.encode()) From 490e4dbbb6f51e8eb35075c5a3218c12d594304a Mon Sep 17 00:00:00 2001 From: lukasIO Date: Wed, 29 Jan 2025 16:18:19 +0100 Subject: [PATCH 05/12] fix optional --- livekit-rtc/livekit/rtc/participant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/livekit-rtc/livekit/rtc/participant.py b/livekit-rtc/livekit/rtc/participant.py index 1ae69058..e20be2de 100644 --- a/livekit-rtc/livekit/rtc/participant.py +++ b/livekit-rtc/livekit/rtc/participant.py @@ -552,9 +552,9 @@ async def set_attributes(self, attributes: dict[str, str]) -> None: async def stream_text( self, *, - destination_identities: List[str] = [], + destination_identities: Optional[List[str]] = None, topic: str = "", - extensions: Dict[str, str] = {}, + extensions: Optional[Dict[str, str]] = None, reply_to_id: str | None = None, total_size: int | None = None, ) -> TextStreamWriter: From cde41c2480d0c46ac3808eb43cfe2773f8b78d80 Mon Sep 17 00:00:00 2001 From: lukasIO Date: Thu, 30 Jan 2025 17:38:57 +0100 Subject: [PATCH 06/12] align with rpc API --- examples/data-streams/data_streams.py | 10 +-- livekit-rtc/livekit/rtc/data_stream.py | 10 +-- livekit-rtc/livekit/rtc/participant.py | 99 ++++++++++++++++++++-- livekit-rtc/livekit/rtc/room.py | 110 +++++-------------------- 4 files changed, 122 insertions(+), 107 deletions(-) diff --git a/examples/data-streams/data_streams.py b/examples/data-streams/data_streams.py index e9f483ca..9e97fc28 100644 --- a/examples/data-streams/data_streams.py +++ b/examples/data-streams/data_streams.py @@ -24,7 +24,7 @@ async def greetParticipant(identity: str): await room.local_participant.send_file( "./green_tree_python.jpg", destination_identities=[identity], - topic="welcome", + topic="files", ) async def on_chat_message_received( @@ -54,18 +54,18 @@ def on_participant_connected(participant: rtc.RemoteParticipant): ) asyncio.create_task(greetParticipant(participant.identity)) - room.set_text_stream_handler( + room.local_participant.set_text_stream_handler( + "chat", lambda reader, participant_identity: asyncio.create_task( on_chat_message_received(reader, participant_identity) ), - "chat", ) - room.set_byte_stream_handler( + room.local_participant.set_byte_stream_handler( + "files", lambda reader, participant_identity: asyncio.create_task( on_welcome_image_received(reader, participant_identity) ), - "welcome", ) # By default, autosubscribe is enabled. The participant will be subscribed to diff --git a/livekit-rtc/livekit/rtc/data_stream.py b/livekit-rtc/livekit/rtc/data_stream.py index ed2c0ad2..f76fffeb 100644 --- a/livekit-rtc/livekit/rtc/data_stream.py +++ b/livekit-rtc/livekit/rtc/data_stream.py @@ -41,7 +41,7 @@ class BaseStreamInfo(TypedDict): topic: str timestamp: int size: Optional[int] - attributes: Optional[Dict[str, str]] # Optional for the extensions dictionary + attributes: Optional[Dict[str, str]] # Optional for the attributes dictionary @dataclass @@ -259,7 +259,7 @@ def __init__( local_participant: LocalParticipant, *, topic: str = "", - extensions: Optional[Dict[str, str]] = {}, + attributes: Optional[Dict[str, str]] = {}, stream_id: str | None = None, total_size: int | None = None, reply_to_id: str | None = None, @@ -268,7 +268,7 @@ def __init__( super().__init__( local_participant, topic, - extensions, + attributes, stream_id, total_size, mime_type="text/plain", @@ -313,7 +313,7 @@ def __init__( *, name: str, topic: str = "", - extensions: Optional[Dict[str, str]] = None, + attributes: Optional[Dict[str, str]] = None, stream_id: str | None = None, total_size: int | None = None, mime_type: str = "application/octet-stream", @@ -322,7 +322,7 @@ def __init__( super().__init__( local_participant, topic, - extensions, + attributes, stream_id, total_size, mime_type=mime_type, diff --git a/livekit-rtc/livekit/rtc/participant.py b/livekit-rtc/livekit/rtc/participant.py index e20be2de..d451b700 100644 --- a/livekit-rtc/livekit/rtc/participant.py +++ b/livekit-rtc/livekit/rtc/participant.py @@ -21,10 +21,12 @@ import aiofiles from typing import List, Union, Callable, Dict, Awaitable, Optional, Mapping, cast from abc import abstractmethod, ABC +import logging from ._ffi_client import FfiClient, FfiHandle from ._proto import ffi_pb2 as proto_ffi +from ._proto import room_pb2 as proto_room from ._proto import participant_pb2 as proto_participant from ._proto.room_pb2 import ( TrackPublishOptions, @@ -50,6 +52,10 @@ ByteStreamWriter, ByteStreamInfo, STREAM_CHUNK_SIZE, + TextStreamReader, + ByteStreamReader, + TextStreamHandler, + ByteStreamHandler, ) @@ -159,6 +165,10 @@ def __init__( self._rpc_handlers: Dict[ str, Callable[[RpcInvocationData], Union[Awaitable[str], str]] ] = {} + self._text_stream_readers: Dict[str, TextStreamReader] = {} + self._byte_stream_readers: Dict[str, ByteStreamReader] = {} + self._text_stream_handlers: Dict[str, TextStreamHandler] = {} + self._byte_stream_handlers: Dict[str, ByteStreamHandler] = {} @property def track_publications(self) -> Mapping[str, LocalTrackPublication]: @@ -549,12 +559,65 @@ async def set_attributes(self, attributes: dict[str, str]) -> None: finally: FfiClient.instance.queue.unsubscribe(queue) + def _handle_stream_header( + self, header: proto_room.DataStream.Header, participant_identity: str + ): + stream_type = header.WhichOneof("content_header") + if stream_type == "text_header": + text_stream_handler = self._text_stream_handlers.get(header.topic) + if text_stream_handler is None: + logging.info( + "ignoring text stream with topic '%s', no callback attached", + header.topic, + ) + return + + text_reader = TextStreamReader(header) + self._text_stream_readers[header.stream_id] = text_reader + text_stream_handler(text_reader, participant_identity) + elif stream_type == "byte_header": + logging.warning("received byte header, %s", header.stream_id) + byte_stream_handler = self._byte_stream_handlers.get(header.topic) + if byte_stream_handler is None: + logging.info( + "ignoring byte stream with topic '%s', no callback attached", + header.topic, + ) + return + + byte_reader = ByteStreamReader(header) + self._byte_stream_readers[header.stream_id] = byte_reader + byte_stream_handler(byte_reader, participant_identity) + else: + logging.warning("received unknown header type, %s", stream_type) + pass + + async def _handle_stream_chunk(self, chunk: proto_room.DataStream.Chunk): + text_reader = self._text_stream_readers.get(chunk.stream_id) + file_reader = self._byte_stream_readers.get(chunk.stream_id) + + if text_reader: + await text_reader._on_chunk_update(chunk) + elif file_reader: + await file_reader._on_chunk_update(chunk) + + async def _handle_stream_trailer(self, trailer: proto_room.DataStream.Trailer): + text_reader = self._text_stream_readers.get(trailer.stream_id) + file_reader = self._byte_stream_readers.get(trailer.stream_id) + + if text_reader: + await text_reader._on_stream_close(trailer) + self._text_stream_readers.pop(trailer.stream_id) + elif file_reader: + await file_reader._on_stream_close(trailer) + self._byte_stream_readers.pop(trailer.stream_id) + async def stream_text( self, *, destination_identities: Optional[List[str]] = None, topic: str = "", - extensions: Optional[Dict[str, str]] = None, + attributes: Optional[Dict[str, str]] = None, reply_to_id: str | None = None, total_size: int | None = None, ) -> TextStreamWriter: @@ -565,7 +628,7 @@ async def stream_text( writer = TextStreamWriter( self, topic=topic, - extensions=extensions, + attributes=attributes, reply_to_id=reply_to_id, destination_identities=destination_identities, total_size=total_size, @@ -581,14 +644,14 @@ async def send_text( *, destination_identities: Optional[List[str]] = None, topic: str = "", - extensions: Optional[Dict[str, str]] = None, + attributes: Optional[Dict[str, str]] = None, reply_to_id: str | None = None, ): total_size = len(text.encode()) writer = await self.stream_text( destination_identities=destination_identities, topic=topic, - extensions=extensions, + attributes=attributes, reply_to_id=reply_to_id, total_size=total_size, ) @@ -605,7 +668,7 @@ async def stream_bytes( *, total_size: int | None = None, mime_type: str = "application/octet-stream", - extensions: Optional[Dict[str, str]] = None, + attributes: Optional[Dict[str, str]] = None, stream_id: str | None = None, destination_identities: Optional[List[str]] = None, topic: str = "", @@ -617,7 +680,7 @@ async def stream_bytes( writer = ByteStreamWriter( self, name=name, - extensions=extensions, + attributes=attributes, total_size=total_size, stream_id=stream_id, mime_type=mime_type, @@ -650,7 +713,7 @@ async def send_file( name=file_name, total_size=file_size, mime_type=mime_type, - extensions=attributes, + attributes=attributes, stream_id=stream_id, destination_identities=destination_identities, topic=topic, @@ -663,6 +726,28 @@ async def send_file( return writer.info + def register_byte_stream_handler(self, topic: str, handler: ByteStreamHandler): + existing_handler = self._byte_stream_handlers.get(topic) + if existing_handler is None: + self._byte_stream_handlers[topic] = handler + else: + raise ValueError("byte stream handler for topic '%s' already set" % topic) + + def unregister_byte_stream_handler(self, topic: str): + if self._byte_stream_handlers.get(topic): + self._byte_stream_handlers.pop(topic) + + def register_text_stream_handler(self, topic: str, handler: TextStreamHandler): + existing_handler = self._text_stream_handlers.get(topic) + if existing_handler is None: + self._text_stream_handlers[topic] = handler + else: + raise ValueError("text stream handler for topic '%s' already set" % topic) + + def unregister_text_stream_handler(self, topic: str): + if self._text_stream_handlers.get(topic): + self._text_stream_handlers.pop(topic) + async def publish_track( self, track: LocalTrack, options: TrackPublishOptions = TrackPublishOptions() ) -> LocalTrackPublication: diff --git a/livekit-rtc/livekit/rtc/room.py b/livekit-rtc/livekit/rtc/room.py index fe9245e3..79327021 100644 --- a/livekit-rtc/livekit/rtc/room.py +++ b/livekit-rtc/livekit/rtc/room.py @@ -33,12 +33,7 @@ from .track import RemoteAudioTrack, RemoteVideoTrack from .track_publication import RemoteTrackPublication, TrackPublication from .transcription import TranscriptionSegment -from .data_stream import ( - TextStreamReader, - ByteStreamReader, - TextStreamHandler, - ByteStreamHandler, -) + EventTypes = Literal[ "participant_connected", @@ -138,15 +133,12 @@ def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: self._room_queue = BroadcastQueue[proto_ffi.FfiEvent]() self._info = proto_room.RoomInfo() self._rpc_invocation_tasks: set[asyncio.Task] = set() + self._data_stream_tasks: set[asyncio.Task] = set() self._remote_participants: Dict[str, RemoteParticipant] = {} self._connection_state = ConnectionState.CONN_DISCONNECTED self._first_sid_future = asyncio.Future[str]() self._local_participant: LocalParticipant | None = None - self._text_stream_readers: Dict[str, TextStreamReader] = {} - self._byte_stream_readers: Dict[str, ByteStreamReader] = {} - self._text_stream_handlers: Dict[str, TextStreamHandler] = {} - self._byte_stream_handlers: Dict[str, ByteStreamHandler] = {} def __del__(self) -> None: if self._ffi_handle is not None: @@ -426,28 +418,6 @@ async def disconnect(self) -> None: await self._task FfiClient.instance.queue.unsubscribe(self._ffi_queue) - def set_byte_stream_handler(self, handler: ByteStreamHandler, topic: str = ""): - existing_handler = self._byte_stream_handlers.get(topic) - if existing_handler is None: - self._byte_stream_handlers[topic] = handler - else: - raise ValueError("byte stream handler for topic '%s' already set" % topic) - - def remove_byte_stream_handler(self, topic: str = ""): - if self._byte_stream_handlers.get(topic): - self._byte_stream_handlers.pop(topic) - - def set_text_stream_handler(self, handler: TextStreamHandler, topic: str = ""): - existing_handler = self._text_stream_handlers.get(topic) - if existing_handler is None: - self._text_stream_handlers[topic] = handler - else: - raise ValueError("text stream handler for topic '%s' already set" % topic) - - def remove_text_stream_handler(self, topic: str = ""): - if self._text_stream_handlers.get(topic): - self._text_stream_handlers.pop(topic) - async def _listen_task(self) -> None: # listen to incoming room events while True: @@ -742,18 +712,25 @@ def _on_room_event(self, event: proto_room.RoomEvent): elif which == "reconnected": self.emit("reconnected") elif which == "stream_header_received": - self._handle_stream_header( + self.local_participant._handle_stream_header( event.stream_header_received.header, event.stream_header_received.participant_identity, ) elif which == "stream_chunk_received": - asyncio.create_task( - self._handle_stream_chunk(event.stream_chunk_received.chunk) + task = asyncio.create_task( + self.local_participant._handle_stream_chunk( + event.stream_chunk_received.chunk + ) ) + self._data_stream_tasks.add(task) + elif which == "stream_trailer_received": - asyncio.create_task( - self._handle_stream_trailer(event.stream_trailer_received.trailer) + task = asyncio.create_task( + self.local_participant._handle_stream_trailer( + event.stream_trailer_received.trailer + ) ) + self._data_stream_tasks.add(task) async def _drain_rpc_invocation_tasks(self) -> None: if self._rpc_invocation_tasks: @@ -761,6 +738,12 @@ async def _drain_rpc_invocation_tasks(self) -> None: task.cancel() await asyncio.gather(*self._rpc_invocation_tasks, return_exceptions=True) + async def _drain_data_stream__tasks(self) -> None: + if self._data_stream_tasks: + for task in self._data_stream_tasks: + task.cancel() + await asyncio.gather(*self._data_stream_tasks, return_exceptions=True) + def _retrieve_remote_participant( self, identity: str ) -> Optional[RemoteParticipant]: @@ -784,59 +767,6 @@ def _create_remote_participant( self._remote_participants[participant.identity] = participant return participant - def _handle_stream_header( - self, header: proto_room.DataStream.Header, participant_identity: str - ): - stream_type = header.WhichOneof("content_header") - if stream_type == "text_header": - text_stream_handler = self._text_stream_handlers.get(header.topic) - if text_stream_handler is None: - logging.info( - "ignoring text stream with topic '%s', no callback attached", - header.topic, - ) - return - - text_reader = TextStreamReader(header) - self._text_stream_readers[header.stream_id] = text_reader - text_stream_handler(text_reader, participant_identity) - elif stream_type == "byte_header": - logging.warning("received byte header, %s", header.stream_id) - byte_stream_handler = self._byte_stream_handlers.get(header.topic) - if byte_stream_handler is None: - logging.info( - "ignoring byte stream with topic '%s', no callback attached", - header.topic, - ) - return - - byte_reader = ByteStreamReader(header) - self._byte_stream_readers[header.stream_id] = byte_reader - byte_stream_handler(byte_reader, participant_identity) - else: - logging.warning("received unknown header type, %s", stream_type) - pass - - async def _handle_stream_chunk(self, chunk: proto_room.DataStream.Chunk): - text_reader = self._text_stream_readers.get(chunk.stream_id) - file_reader = self._byte_stream_readers.get(chunk.stream_id) - - if text_reader: - await text_reader._on_chunk_update(chunk) - elif file_reader: - await file_reader._on_chunk_update(chunk) - - async def _handle_stream_trailer(self, trailer: proto_room.DataStream.Trailer): - text_reader = self._text_stream_readers.get(trailer.stream_id) - file_reader = self._byte_stream_readers.get(trailer.stream_id) - - if text_reader: - await text_reader._on_stream_close(trailer) - self._text_stream_readers.pop(trailer.stream_id) - elif file_reader: - await file_reader._on_stream_close(trailer) - self._byte_stream_readers.pop(trailer.stream_id) - def __repr__(self) -> str: sid = "unknown" if self._first_sid_future.done(): From 98511ea36ad21c13339fcd5dae5eddb5e0b6df8e Mon Sep 17 00:00:00 2001 From: lukasIO Date: Fri, 31 Jan 2025 00:01:28 +0100 Subject: [PATCH 07/12] fixes --- livekit-rtc/livekit/rtc/room.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/livekit-rtc/livekit/rtc/room.py b/livekit-rtc/livekit/rtc/room.py index 79327021..a7709193 100644 --- a/livekit-rtc/livekit/rtc/room.py +++ b/livekit-rtc/livekit/rtc/room.py @@ -404,6 +404,7 @@ async def disconnect(self) -> None: return await self._drain_rpc_invocation_tasks() + await self._data_stream_tasks() req = proto_ffi.FfiRequest() req.disconnect.room_handle = self._ffi_handle.handle # type: ignore @@ -444,6 +445,7 @@ async def _listen_task(self) -> None: # Clean up any pending RPC invocation tasks await self._drain_rpc_invocation_tasks() + await self._data_stream_tasks() def _on_rpc_method_invocation(self, rpc_invocation: RpcMethodInvocationEvent): if self._local_participant is None: @@ -723,6 +725,7 @@ def _on_room_event(self, event: proto_room.RoomEvent): ) ) self._data_stream_tasks.add(task) + task.add_done_callback(self._data_stream_tasks.discard) elif which == "stream_trailer_received": task = asyncio.create_task( @@ -731,6 +734,7 @@ def _on_room_event(self, event: proto_room.RoomEvent): ) ) self._data_stream_tasks.add(task) + task.add_done_callback(self._data_stream_tasks.discard) async def _drain_rpc_invocation_tasks(self) -> None: if self._rpc_invocation_tasks: From fc0ccedacff368238c36a8be18de6790847ef73b Mon Sep 17 00:00:00 2001 From: lukasIO Date: Fri, 31 Jan 2025 00:02:18 +0100 Subject: [PATCH 08/12] typo --- livekit-rtc/livekit/rtc/room.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/livekit-rtc/livekit/rtc/room.py b/livekit-rtc/livekit/rtc/room.py index a7709193..4a20aa43 100644 --- a/livekit-rtc/livekit/rtc/room.py +++ b/livekit-rtc/livekit/rtc/room.py @@ -404,7 +404,7 @@ async def disconnect(self) -> None: return await self._drain_rpc_invocation_tasks() - await self._data_stream_tasks() + await self._drain_data_stream__tasks() req = proto_ffi.FfiRequest() req.disconnect.room_handle = self._ffi_handle.handle # type: ignore @@ -445,7 +445,7 @@ async def _listen_task(self) -> None: # Clean up any pending RPC invocation tasks await self._drain_rpc_invocation_tasks() - await self._data_stream_tasks() + await self._drain_data_stream__tasks() def _on_rpc_method_invocation(self, rpc_invocation: RpcMethodInvocationEvent): if self._local_participant is None: From 580ab7c0c92f810416e3f3bd908de3ca32129e1d Mon Sep 17 00:00:00 2001 From: lukasIO Date: Fri, 31 Jan 2025 09:36:54 +0100 Subject: [PATCH 09/12] move handler registration back to room --- livekit-rtc/livekit/rtc/participant.py | 84 ---------------------- livekit-rtc/livekit/rtc/room.py | 96 ++++++++++++++++++++++++-- 2 files changed, 89 insertions(+), 91 deletions(-) diff --git a/livekit-rtc/livekit/rtc/participant.py b/livekit-rtc/livekit/rtc/participant.py index d451b700..aa889a7b 100644 --- a/livekit-rtc/livekit/rtc/participant.py +++ b/livekit-rtc/livekit/rtc/participant.py @@ -26,7 +26,6 @@ from ._ffi_client import FfiClient, FfiHandle from ._proto import ffi_pb2 as proto_ffi -from ._proto import room_pb2 as proto_room from ._proto import participant_pb2 as proto_participant from ._proto.room_pb2 import ( TrackPublishOptions, @@ -52,10 +51,6 @@ ByteStreamWriter, ByteStreamInfo, STREAM_CHUNK_SIZE, - TextStreamReader, - ByteStreamReader, - TextStreamHandler, - ByteStreamHandler, ) @@ -165,10 +160,6 @@ def __init__( self._rpc_handlers: Dict[ str, Callable[[RpcInvocationData], Union[Awaitable[str], str]] ] = {} - self._text_stream_readers: Dict[str, TextStreamReader] = {} - self._byte_stream_readers: Dict[str, ByteStreamReader] = {} - self._text_stream_handlers: Dict[str, TextStreamHandler] = {} - self._byte_stream_handlers: Dict[str, ByteStreamHandler] = {} @property def track_publications(self) -> Mapping[str, LocalTrackPublication]: @@ -559,59 +550,6 @@ async def set_attributes(self, attributes: dict[str, str]) -> None: finally: FfiClient.instance.queue.unsubscribe(queue) - def _handle_stream_header( - self, header: proto_room.DataStream.Header, participant_identity: str - ): - stream_type = header.WhichOneof("content_header") - if stream_type == "text_header": - text_stream_handler = self._text_stream_handlers.get(header.topic) - if text_stream_handler is None: - logging.info( - "ignoring text stream with topic '%s', no callback attached", - header.topic, - ) - return - - text_reader = TextStreamReader(header) - self._text_stream_readers[header.stream_id] = text_reader - text_stream_handler(text_reader, participant_identity) - elif stream_type == "byte_header": - logging.warning("received byte header, %s", header.stream_id) - byte_stream_handler = self._byte_stream_handlers.get(header.topic) - if byte_stream_handler is None: - logging.info( - "ignoring byte stream with topic '%s', no callback attached", - header.topic, - ) - return - - byte_reader = ByteStreamReader(header) - self._byte_stream_readers[header.stream_id] = byte_reader - byte_stream_handler(byte_reader, participant_identity) - else: - logging.warning("received unknown header type, %s", stream_type) - pass - - async def _handle_stream_chunk(self, chunk: proto_room.DataStream.Chunk): - text_reader = self._text_stream_readers.get(chunk.stream_id) - file_reader = self._byte_stream_readers.get(chunk.stream_id) - - if text_reader: - await text_reader._on_chunk_update(chunk) - elif file_reader: - await file_reader._on_chunk_update(chunk) - - async def _handle_stream_trailer(self, trailer: proto_room.DataStream.Trailer): - text_reader = self._text_stream_readers.get(trailer.stream_id) - file_reader = self._byte_stream_readers.get(trailer.stream_id) - - if text_reader: - await text_reader._on_stream_close(trailer) - self._text_stream_readers.pop(trailer.stream_id) - elif file_reader: - await file_reader._on_stream_close(trailer) - self._byte_stream_readers.pop(trailer.stream_id) - async def stream_text( self, *, @@ -726,28 +664,6 @@ async def send_file( return writer.info - def register_byte_stream_handler(self, topic: str, handler: ByteStreamHandler): - existing_handler = self._byte_stream_handlers.get(topic) - if existing_handler is None: - self._byte_stream_handlers[topic] = handler - else: - raise ValueError("byte stream handler for topic '%s' already set" % topic) - - def unregister_byte_stream_handler(self, topic: str): - if self._byte_stream_handlers.get(topic): - self._byte_stream_handlers.pop(topic) - - def register_text_stream_handler(self, topic: str, handler: TextStreamHandler): - existing_handler = self._text_stream_handlers.get(topic) - if existing_handler is None: - self._text_stream_handlers[topic] = handler - else: - raise ValueError("text stream handler for topic '%s' already set" % topic) - - def unregister_text_stream_handler(self, topic: str): - if self._text_stream_handlers.get(topic): - self._text_stream_handlers.pop(topic) - async def publish_track( self, track: LocalTrack, options: TrackPublishOptions = TrackPublishOptions() ) -> LocalTrackPublication: diff --git a/livekit-rtc/livekit/rtc/room.py b/livekit-rtc/livekit/rtc/room.py index 4a20aa43..790320d2 100644 --- a/livekit-rtc/livekit/rtc/room.py +++ b/livekit-rtc/livekit/rtc/room.py @@ -33,6 +33,12 @@ from .track import RemoteAudioTrack, RemoteVideoTrack from .track_publication import RemoteTrackPublication, TrackPublication from .transcription import TranscriptionSegment +from .data_stream import ( + TextStreamReader, + ByteStreamReader, + TextStreamHandler, + ByteStreamHandler, +) EventTypes = Literal[ @@ -140,6 +146,11 @@ def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: self._first_sid_future = asyncio.Future[str]() self._local_participant: LocalParticipant | None = None + self._text_stream_readers: Dict[str, TextStreamReader] = {} + self._byte_stream_readers: Dict[str, ByteStreamReader] = {} + self._text_stream_handlers: Dict[str, TextStreamHandler] = {} + self._byte_stream_handlers: Dict[str, ByteStreamHandler] = {} + def __del__(self) -> None: if self._ffi_handle is not None: FfiClient.instance.queue.unsubscribe(self._ffi_queue) @@ -398,6 +409,28 @@ def on_participant_connected(participant): # start listening to room events self._task = self._loop.create_task(self._listen_task()) + def register_byte_stream_handler(self, topic: str, handler: ByteStreamHandler): + existing_handler = self._byte_stream_handlers.get(topic) + if existing_handler is None: + self._byte_stream_handlers[topic] = handler + else: + raise ValueError("byte stream handler for topic '%s' already set" % topic) + + def unregister_byte_stream_handler(self, topic: str): + if self._byte_stream_handlers.get(topic): + self._byte_stream_handlers.pop(topic) + + def register_text_stream_handler(self, topic: str, handler: TextStreamHandler): + existing_handler = self._text_stream_handlers.get(topic) + if existing_handler is None: + self._text_stream_handlers[topic] = handler + else: + raise ValueError("text stream handler for topic '%s' already set" % topic) + + def unregister_text_stream_handler(self, topic: str): + if self._text_stream_handlers.get(topic): + self._text_stream_handlers.pop(topic) + async def disconnect(self) -> None: """Disconnects from the room.""" if not self.isconnected(): @@ -714,28 +747,77 @@ def _on_room_event(self, event: proto_room.RoomEvent): elif which == "reconnected": self.emit("reconnected") elif which == "stream_header_received": - self.local_participant._handle_stream_header( + self._handle_stream_header( event.stream_header_received.header, event.stream_header_received.participant_identity, ) elif which == "stream_chunk_received": task = asyncio.create_task( - self.local_participant._handle_stream_chunk( - event.stream_chunk_received.chunk - ) + self._handle_stream_chunk(event.stream_chunk_received.chunk) ) self._data_stream_tasks.add(task) task.add_done_callback(self._data_stream_tasks.discard) elif which == "stream_trailer_received": task = asyncio.create_task( - self.local_participant._handle_stream_trailer( - event.stream_trailer_received.trailer - ) + self._handle_stream_trailer(event.stream_trailer_received.trailer) ) self._data_stream_tasks.add(task) task.add_done_callback(self._data_stream_tasks.discard) + def _handle_stream_header( + self, header: proto_room.DataStream.Header, participant_identity: str + ): + stream_type = header.WhichOneof("content_header") + if stream_type == "text_header": + text_stream_handler = self._text_stream_handlers.get(header.topic) + if text_stream_handler is None: + logging.info( + "ignoring text stream with topic '%s', no callback attached", + header.topic, + ) + return + + text_reader = TextStreamReader(header) + self._text_stream_readers[header.stream_id] = text_reader + text_stream_handler(text_reader, participant_identity) + elif stream_type == "byte_header": + logging.warning("received byte header, %s", header.stream_id) + byte_stream_handler = self._byte_stream_handlers.get(header.topic) + if byte_stream_handler is None: + logging.info( + "ignoring byte stream with topic '%s', no callback attached", + header.topic, + ) + return + + byte_reader = ByteStreamReader(header) + self._byte_stream_readers[header.stream_id] = byte_reader + byte_stream_handler(byte_reader, participant_identity) + else: + logging.warning("received unknown header type, %s", stream_type) + pass + + async def _handle_stream_chunk(self, chunk: proto_room.DataStream.Chunk): + text_reader = self._text_stream_readers.get(chunk.stream_id) + file_reader = self._byte_stream_readers.get(chunk.stream_id) + + if text_reader: + await text_reader._on_chunk_update(chunk) + elif file_reader: + await file_reader._on_chunk_update(chunk) + + async def _handle_stream_trailer(self, trailer: proto_room.DataStream.Trailer): + text_reader = self._text_stream_readers.get(trailer.stream_id) + file_reader = self._byte_stream_readers.get(trailer.stream_id) + + if text_reader: + await text_reader._on_stream_close(trailer) + self._text_stream_readers.pop(trailer.stream_id) + elif file_reader: + await file_reader._on_stream_close(trailer) + self._byte_stream_readers.pop(trailer.stream_id) + async def _drain_rpc_invocation_tasks(self) -> None: if self._rpc_invocation_tasks: for task in self._rpc_invocation_tasks: From 0e6291aa7fdf6c7d857df563e5e24a241447a4c7 Mon Sep 17 00:00:00 2001 From: lukasIO Date: Fri, 31 Jan 2025 09:38:23 +0100 Subject: [PATCH 10/12] update example --- examples/data-streams/data_streams.py | 4 ++-- livekit-rtc/livekit/rtc/participant.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/data-streams/data_streams.py b/examples/data-streams/data_streams.py index 9e97fc28..997055bd 100644 --- a/examples/data-streams/data_streams.py +++ b/examples/data-streams/data_streams.py @@ -54,14 +54,14 @@ def on_participant_connected(participant: rtc.RemoteParticipant): ) asyncio.create_task(greetParticipant(participant.identity)) - room.local_participant.set_text_stream_handler( + room.set_text_stream_handler( "chat", lambda reader, participant_identity: asyncio.create_task( on_chat_message_received(reader, participant_identity) ), ) - room.local_participant.set_byte_stream_handler( + room.set_byte_stream_handler( "files", lambda reader, participant_identity: asyncio.create_task( on_welcome_image_received(reader, participant_identity) diff --git a/livekit-rtc/livekit/rtc/participant.py b/livekit-rtc/livekit/rtc/participant.py index aa889a7b..f57d0296 100644 --- a/livekit-rtc/livekit/rtc/participant.py +++ b/livekit-rtc/livekit/rtc/participant.py @@ -21,8 +21,6 @@ import aiofiles from typing import List, Union, Callable, Dict, Awaitable, Optional, Mapping, cast from abc import abstractmethod, ABC -import logging - from ._ffi_client import FfiClient, FfiHandle from ._proto import ffi_pb2 as proto_ffi From 923d9c3e36e2ddd1d154c3b922a4863abcd4d89f Mon Sep 17 00:00:00 2001 From: lukasIO Date: Fri, 31 Jan 2025 15:27:15 +0100 Subject: [PATCH 11/12] textstream info should not inherit from typeddict --- livekit-rtc/livekit/rtc/data_stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/livekit-rtc/livekit/rtc/data_stream.py b/livekit-rtc/livekit/rtc/data_stream.py index f76fffeb..778c96e7 100644 --- a/livekit-rtc/livekit/rtc/data_stream.py +++ b/livekit-rtc/livekit/rtc/data_stream.py @@ -19,7 +19,7 @@ import datetime from collections.abc import Callable from dataclasses import dataclass -from typing import AsyncIterator, Optional, TypedDict, Dict, List +from typing import AsyncIterator, Optional, Dict, List from ._proto.room_pb2 import DataStream as proto_DataStream from ._proto import ffi_pb2 as proto_ffi from ._proto import room_pb2 as proto_room @@ -35,7 +35,7 @@ @dataclass -class BaseStreamInfo(TypedDict): +class BaseStreamInfo: stream_id: str mime_type: str topic: str From c01c1d91ddbfebc45145da7e5ed88f6027dfecfb Mon Sep 17 00:00:00 2001 From: lukasIO Date: Sat, 1 Feb 2025 12:00:28 +0100 Subject: [PATCH 12/12] address comments --- livekit-rtc/livekit/rtc/room.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/livekit-rtc/livekit/rtc/room.py b/livekit-rtc/livekit/rtc/room.py index 790320d2..cb937840 100644 --- a/livekit-rtc/livekit/rtc/room.py +++ b/livekit-rtc/livekit/rtc/room.py @@ -437,7 +437,7 @@ async def disconnect(self) -> None: return await self._drain_rpc_invocation_tasks() - await self._drain_data_stream__tasks() + await self._drain_data_stream_tasks() req = proto_ffi.FfiRequest() req.disconnect.room_handle = self._ffi_handle.handle # type: ignore @@ -478,7 +478,7 @@ async def _listen_task(self) -> None: # Clean up any pending RPC invocation tasks await self._drain_rpc_invocation_tasks() - await self._drain_data_stream__tasks() + await self._drain_data_stream_tasks() def _on_rpc_method_invocation(self, rpc_invocation: RpcMethodInvocationEvent): if self._local_participant is None: @@ -782,7 +782,6 @@ def _handle_stream_header( self._text_stream_readers[header.stream_id] = text_reader text_stream_handler(text_reader, participant_identity) elif stream_type == "byte_header": - logging.warning("received byte header, %s", header.stream_id) byte_stream_handler = self._byte_stream_handlers.get(header.topic) if byte_stream_handler is None: logging.info( @@ -824,7 +823,7 @@ async def _drain_rpc_invocation_tasks(self) -> None: task.cancel() await asyncio.gather(*self._rpc_invocation_tasks, return_exceptions=True) - async def _drain_data_stream__tasks(self) -> None: + async def _drain_data_stream_tasks(self) -> None: if self._data_stream_tasks: for task in self._data_stream_tasks: task.cancel()