diff --git a/tarpc/src/serde_transport.rs b/tarpc/src/serde_transport.rs index 1773e846..a8eedd52 100644 --- a/tarpc/src/serde_transport.rs +++ b/tarpc/src/serde_transport.rs @@ -210,7 +210,19 @@ pub mod tcp { Codec: Serializer + Deserializer, CodecFn: Fn() -> Codec, { - let listener = TcpListener::bind(addr).await?; + listen_on(TcpListener::bind(addr).await?, codec_fn).await + } + + /// Wrap accepted connections from `listener` in TCP transports. + pub async fn listen_on( + listener: TcpListener, + codec_fn: CodecFn, + ) -> io::Result> + where + Item: for<'de> Deserialize<'de>, + Codec: Serializer + Deserializer, + CodecFn: Fn() -> Codec, + { let local_addr = listener.local_addr()?; Ok(Incoming { listener, @@ -364,7 +376,19 @@ pub mod unix { Codec: Serializer + Deserializer, CodecFn: Fn() -> Codec, { - let listener = UnixListener::bind(path)?; + listen_on(UnixListener::bind(path)?, codec_fn).await + } + + /// Wrap accepted connections from `listener` in Unix Domain Socket transports. + pub async fn listen_on( + listener: UnixListener, + codec_fn: CodecFn, + ) -> io::Result> + where + Item: for<'de> Deserialize<'de>, + Codec: Serializer + Deserializer, + CodecFn: Fn() -> Codec, + { let local_addr = listener.local_addr()?; Ok(Incoming { listener, @@ -650,6 +674,26 @@ mod tests { Ok(()) } + #[cfg(tcp)] + #[tokio::test] + async fn tcp_on_existing_transport() -> io::Result<()> { + use super::tcp; + + let transport = TcpListener::bind("0.0.0.0:0").await?; + let mut listener = tcp::listen_on(transport, SymmetricalJson::::default).await?; + let addr = listener.local_addr(); + tokio::spawn(async move { + let mut transport = listener.next().await.unwrap().unwrap(); + let message = transport.next().await.unwrap().unwrap(); + transport.send(message).await.unwrap(); + }); + let mut transport = tcp::connect(addr, SymmetricalJson::::default).await?; + transport.send(String::from("test")).await?; + assert_matches!(transport.next().await, Some(Ok(s)) if s == "test"); + assert_matches!(transport.next().await, None); + Ok(()) + } + #[cfg(all(unix, feature = "unix"))] #[tokio::test] async fn uds() -> io::Result<()> { @@ -669,4 +713,25 @@ mod tests { assert_matches!(transport.next().await, None); Ok(()) } + + #[cfg(all(unix, feature = "unix"))] + #[tokio::test] + async fn uds_on_existing_transport() -> io::Result<()> { + use super::unix; + use super::*; + + let sock = unix::TempPathBuf::with_random("uds"); + let transport = tokio::net::UnixListener::bind(&sock)?; + let mut listener = unix::listen_on(transport, SymmetricalJson::::default).await?; + tokio::spawn(async move { + let mut transport = listener.next().await.unwrap().unwrap(); + let message = transport.next().await.unwrap().unwrap(); + transport.send(message).await.unwrap(); + }); + let mut transport = unix::connect(&sock, SymmetricalJson::::default).await?; + transport.send(String::from("test")).await?; + assert_matches!(transport.next().await, Some(Ok(s)) if s == "test"); + assert_matches!(transport.next().await, None); + Ok(()) + } }