Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework data stream API #352

Merged
merged 12 commits into from
Feb 1, 2025
6 changes: 3 additions & 3 deletions examples/data-streams/data_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def greetParticipant(identity: str):
await room.local_participant.send_file(
"./green_tree_python.jpg",
destination_identities=[identity],
topic="welcome",
topic="files",
)

async def on_chat_message_received(
Expand Down Expand Up @@ -55,17 +55,17 @@ def on_participant_connected(participant: rtc.RemoteParticipant):
asyncio.create_task(greetParticipant(participant.identity))

room.set_text_stream_handler(
"chat",
lambda reader, participant_identity: asyncio.create_task(
on_chat_message_received(reader, participant_identity)
),
"chat",
)

room.set_byte_stream_handler(
"files",
lambda reader, participant_identity: asyncio.create_task(
on_welcome_image_received(reader, participant_identity)
),
"welcome",
)

# By default, autosubscribe is enabled. The participant will be subscribed to
Expand Down
14 changes: 7 additions & 7 deletions livekit-rtc/livekit/rtc/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import datetime
from collections.abc import Callable
from dataclasses import dataclass
from typing import AsyncIterator, Optional, TypedDict, Dict, List
from typing import AsyncIterator, Optional, Dict, List
from ._proto.room_pb2 import DataStream as proto_DataStream
from ._proto import ffi_pb2 as proto_ffi
from ._proto import room_pb2 as proto_room
Expand All @@ -35,13 +35,13 @@


@dataclass
class BaseStreamInfo(TypedDict):
class BaseStreamInfo:
stream_id: str
mime_type: str
topic: str
timestamp: int
size: Optional[int]
attributes: Optional[Dict[str, str]] # Optional for the extensions dictionary
attributes: Optional[Dict[str, str]] # Optional for the attributes dictionary


@dataclass
Expand Down Expand Up @@ -259,7 +259,7 @@ def __init__(
local_participant: LocalParticipant,
*,
topic: str = "",
extensions: Optional[Dict[str, str]] = {},
attributes: Optional[Dict[str, str]] = {},
stream_id: str | None = None,
total_size: int | None = None,
reply_to_id: str | None = None,
Expand All @@ -268,7 +268,7 @@ def __init__(
super().__init__(
local_participant,
topic,
extensions,
attributes,
stream_id,
total_size,
mime_type="text/plain",
Expand Down Expand Up @@ -313,7 +313,7 @@ def __init__(
*,
name: str,
topic: str = "",
extensions: Optional[Dict[str, str]] = None,
attributes: Optional[Dict[str, str]] = None,
stream_id: str | None = None,
total_size: int | None = None,
mime_type: str = "application/octet-stream",
Expand All @@ -322,7 +322,7 @@ def __init__(
super().__init__(
local_participant,
topic,
extensions,
attributes,
stream_id,
total_size,
mime_type=mime_type,
Expand Down
20 changes: 10 additions & 10 deletions livekit-rtc/livekit/rtc/participant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import List, Union, Callable, Dict, Awaitable, Optional, Mapping, cast
from abc import abstractmethod, ABC


from ._ffi_client import FfiClient, FfiHandle
from ._proto import ffi_pb2 as proto_ffi
from ._proto import participant_pb2 as proto_participant
Expand Down Expand Up @@ -552,9 +551,9 @@ async def set_attributes(self, attributes: dict[str, str]) -> None:
async def stream_text(
self,
*,
destination_identities: List[str] = [],
destination_identities: Optional[List[str]] = None,
topic: str = "",
extensions: Dict[str, str] = {},
attributes: Optional[Dict[str, str]] = None,
reply_to_id: str | None = None,
total_size: int | None = None,
) -> TextStreamWriter:
Expand All @@ -565,7 +564,7 @@ async def stream_text(
writer = TextStreamWriter(
self,
topic=topic,
extensions=extensions,
attributes=attributes,
reply_to_id=reply_to_id,
destination_identities=destination_identities,
total_size=total_size,
Expand All @@ -579,16 +578,16 @@ async def send_text(
self,
text: str,
*,
destination_identities: List[str] = [],
destination_identities: Optional[List[str]] = None,
topic: str = "",
extensions: Dict[str, str] = {},
attributes: Optional[Dict[str, str]] = None,
reply_to_id: str | None = None,
):
total_size = len(text.encode())
writer = await self.stream_text(
destination_identities=destination_identities,
topic=topic,
extensions=extensions,
attributes=attributes,
reply_to_id=reply_to_id,
total_size=total_size,
)
Expand All @@ -605,7 +604,7 @@ async def stream_bytes(
*,
total_size: int | None = None,
mime_type: str = "application/octet-stream",
extensions: Optional[Dict[str, str]] = None,
attributes: Optional[Dict[str, str]] = None,
stream_id: str | None = None,
destination_identities: Optional[List[str]] = None,
topic: str = "",
Expand All @@ -617,7 +616,7 @@ async def stream_bytes(
writer = ByteStreamWriter(
self,
name=name,
extensions=extensions,
attributes=attributes,
total_size=total_size,
stream_id=stream_id,
mime_type=mime_type,
Expand All @@ -632,6 +631,7 @@ async def stream_bytes(
async def send_file(
self,
file_path: str,
*,
topic: str = "",
destination_identities: Optional[List[str]] = None,
attributes: Optional[Dict[str, str]] = None,
Expand All @@ -649,7 +649,7 @@ async def send_file(
name=file_name,
total_size=file_size,
mime_type=mime_type,
extensions=attributes,
attributes=attributes,
stream_id=stream_id,
destination_identities=destination_identities,
topic=topic,
Expand Down
124 changes: 71 additions & 53 deletions livekit-rtc/livekit/rtc/room.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uff, thanks, that was a debug left over, removed

Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ByteStreamHandler,
)


EventTypes = Literal[
"participant_connected",
"participant_disconnected",
Expand Down Expand Up @@ -138,11 +139,13 @@ def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self._room_queue = BroadcastQueue[proto_ffi.FfiEvent]()
self._info = proto_room.RoomInfo()
self._rpc_invocation_tasks: set[asyncio.Task] = set()
self._data_stream_tasks: set[asyncio.Task] = set()

self._remote_participants: Dict[str, RemoteParticipant] = {}
self._connection_state = ConnectionState.CONN_DISCONNECTED
self._first_sid_future = asyncio.Future[str]()
self._local_participant: LocalParticipant | None = None

self._text_stream_readers: Dict[str, TextStreamReader] = {}
self._byte_stream_readers: Dict[str, ByteStreamReader] = {}
self._text_stream_handlers: Dict[str, TextStreamHandler] = {}
Expand Down Expand Up @@ -406,12 +409,35 @@ def on_participant_connected(participant):
# start listening to room events
self._task = self._loop.create_task(self._listen_task())

def register_byte_stream_handler(self, topic: str, handler: ByteStreamHandler):
existing_handler = self._byte_stream_handlers.get(topic)
if existing_handler is None:
self._byte_stream_handlers[topic] = handler
else:
raise ValueError("byte stream handler for topic '%s' already set" % topic)

def unregister_byte_stream_handler(self, topic: str):
if self._byte_stream_handlers.get(topic):
self._byte_stream_handlers.pop(topic)

def register_text_stream_handler(self, topic: str, handler: TextStreamHandler):
existing_handler = self._text_stream_handlers.get(topic)
if existing_handler is None:
self._text_stream_handlers[topic] = handler
else:
raise ValueError("text stream handler for topic '%s' already set" % topic)

def unregister_text_stream_handler(self, topic: str):
if self._text_stream_handlers.get(topic):
self._text_stream_handlers.pop(topic)

async def disconnect(self) -> None:
"""Disconnects from the room."""
if not self.isconnected():
return

await self._drain_rpc_invocation_tasks()
await self._drain_data_stream__tasks()

req = proto_ffi.FfiRequest()
req.disconnect.room_handle = self._ffi_handle.handle # type: ignore
Expand All @@ -426,28 +452,6 @@ async def disconnect(self) -> None:
await self._task
FfiClient.instance.queue.unsubscribe(self._ffi_queue)

def set_byte_stream_handler(self, handler: ByteStreamHandler, topic: str = ""):
existing_handler = self._byte_stream_handlers.get(topic)
if existing_handler is None:
self._byte_stream_handlers[topic] = handler
else:
raise TypeError("byte stream handler for topic '%s' already set" % topic)

def remove_byte_stream_handler(self, topic: str = ""):
if self._byte_stream_handlers.get(topic):
self._byte_stream_handlers.pop(topic)

def set_text_stream_handler(self, handler: TextStreamHandler, topic: str = ""):
existing_handler = self._text_stream_handlers.get(topic)
if existing_handler is None:
self._text_stream_handlers[topic] = handler
else:
raise TypeError("text stream handler for topic '%s' already set" % topic)

def remove_text_stream_handler(self, topic: str = ""):
if self._text_stream_handlers.get(topic):
self._text_stream_handlers.pop(topic)

async def _listen_task(self) -> None:
# listen to incoming room events
while True:
Expand All @@ -474,6 +478,7 @@ async def _listen_task(self) -> None:

# Clean up any pending RPC invocation tasks
await self._drain_rpc_invocation_tasks()
await self._drain_data_stream__tasks()

def _on_rpc_method_invocation(self, rpc_invocation: RpcMethodInvocationEvent):
if self._local_participant is None:
Expand Down Expand Up @@ -747,40 +752,18 @@ def _on_room_event(self, event: proto_room.RoomEvent):
event.stream_header_received.participant_identity,
)
elif which == "stream_chunk_received":
asyncio.gather(self._handle_stream_chunk(event.stream_chunk_received.chunk))
task = asyncio.create_task(
self._handle_stream_chunk(event.stream_chunk_received.chunk)
)
self._data_stream_tasks.add(task)
task.add_done_callback(self._data_stream_tasks.discard)

elif which == "stream_trailer_received":
asyncio.gather(
task = asyncio.create_task(
self._handle_stream_trailer(event.stream_trailer_received.trailer)
)

async def _drain_rpc_invocation_tasks(self) -> None:
if self._rpc_invocation_tasks:
for task in self._rpc_invocation_tasks:
task.cancel()
await asyncio.gather(*self._rpc_invocation_tasks, return_exceptions=True)

def _retrieve_remote_participant(
self, identity: str
) -> Optional[RemoteParticipant]:
"""Retrieve a remote participant by identity"""
return self._remote_participants.get(identity, None)

def _retrieve_participant(self, identity: str) -> Optional[Participant]:
"""Retrieve a local or remote participant by identity"""
if identity and identity == self.local_participant.identity:
return self.local_participant

return self._retrieve_remote_participant(identity)

def _create_remote_participant(
self, owned_info: proto_participant.OwnedParticipant
) -> RemoteParticipant:
if owned_info.info.identity in self._remote_participants:
raise Exception("participant already exists")

participant = RemoteParticipant(owned_info)
self._remote_participants[participant.identity] = participant
return participant
self._data_stream_tasks.add(task)
task.add_done_callback(self._data_stream_tasks.discard)

def _handle_stream_header(
self, header: proto_room.DataStream.Header, participant_identity: str
Expand Down Expand Up @@ -835,6 +818,41 @@ async def _handle_stream_trailer(self, trailer: proto_room.DataStream.Trailer):
await file_reader._on_stream_close(trailer)
self._byte_stream_readers.pop(trailer.stream_id)

async def _drain_rpc_invocation_tasks(self) -> None:
if self._rpc_invocation_tasks:
for task in self._rpc_invocation_tasks:
task.cancel()
await asyncio.gather(*self._rpc_invocation_tasks, return_exceptions=True)

async def _drain_data_stream__tasks(self) -> None:
lukasIO marked this conversation as resolved.
Show resolved Hide resolved
if self._data_stream_tasks:
for task in self._data_stream_tasks:
task.cancel()
await asyncio.gather(*self._data_stream_tasks, return_exceptions=True)

def _retrieve_remote_participant(
self, identity: str
) -> Optional[RemoteParticipant]:
"""Retrieve a remote participant by identity"""
return self._remote_participants.get(identity, None)

def _retrieve_participant(self, identity: str) -> Optional[Participant]:
"""Retrieve a local or remote participant by identity"""
if identity and identity == self.local_participant.identity:
return self.local_participant

return self._retrieve_remote_participant(identity)

def _create_remote_participant(
self, owned_info: proto_participant.OwnedParticipant
) -> RemoteParticipant:
if owned_info.info.identity in self._remote_participants:
raise Exception("participant already exists")

participant = RemoteParticipant(owned_info)
self._remote_participants[participant.identity] = participant
return participant

def __repr__(self) -> str:
sid = "unknown"
if self._first_sid_future.done():
Expand Down
Loading