Skip to content
This repository has been archived by the owner on Aug 23, 2022. It is now read-only.

Commit

Permalink
poll_data_channel: implement polling instead of proxying to stream (#9)
Browse files Browse the repository at this point in the history
Since DataChannel has its own read & write implementations (which call
Stream::read and Stream::write), we can't poll stream directly, but have
to construct futures ourselves and poll them.

Most of the code was copied from sctp::Stream.
  • Loading branch information
melekes authored Jun 17, 2022
1 parent 714deda commit d2c65a2
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 22 deletions.
213 changes: 191 additions & 22 deletions src/data_channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use util::marshal::*;
use bytes::{Buf, Bytes};
use derive_builder::Builder;
use std::fmt;
use std::future::Future;
use std::io;
use std::net::Shutdown;
use std::pin::Pin;
Expand Down Expand Up @@ -340,23 +341,68 @@ impl DataChannel {
}
}

/// Default capacity of the temporary read buffer used by [`PollStream`].
const DEFAULT_READ_BUF_SIZE: usize = 8192;

/// State of the read `Future` in [`PollStream`].
enum ReadFut {
/// Nothing in progress.
Idle,
/// Reading data from the underlying stream.
Reading(Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>),
/// Finished reading, but there's unread data in the temporary buffer.
RemainingData(Vec<u8>),
}

impl ReadFut {
/// Gets a mutable reference to the future stored inside `Reading(future)`.
///
/// # Panics
///
/// Panics if `ReadFut` variant is not `Reading`.
fn get_reading_mut(&mut self) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
match self {
ReadFut::Reading(ref mut fut) => fut,
_ => panic!("expected ReadFut to be Reading"),
}
}
}

/// A wrapper around around [`DataChannel`], which implements [`AsyncRead`] and
/// [`AsyncWrite`].
///
/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an
/// additional overhead.
pub struct PollDataChannel {
data_channel: Arc<DataChannel>,
poll_stream: PollStream,

read_fut: ReadFut,
write_fut: Option<Pin<Box<dyn Future<Output = Result<usize>> + Send>>>,
shutdown_fut: Option<Pin<Box<dyn Future<Output = Result<()>> + Send>>>,

read_buf_cap: usize,
}

impl PollDataChannel {
/// Constructs a new `PollDataChannel`.
///
/// # Examples
///
/// ```
/// use webrtc_data::data_channel::{DataChannel, PollDataChannel, Config};
/// use sctp::stream::Stream;
/// use std::sync::Arc;
///
/// let dc = Arc::new(DataChannel::new(Arc::new(Stream::default()), Config::default()));
/// let poll_dc = PollDataChannel::new(dc);
/// ```
pub fn new(data_channel: Arc<DataChannel>) -> Self {
let stream = data_channel.stream.clone();
Self {
data_channel,
poll_stream: PollStream::new(stream),
read_fut: ReadFut::Idle,
write_fut: None,
shutdown_fut: None,
read_buf_cap: DEFAULT_READ_BUF_SIZE,
}
}

Expand All @@ -372,44 +418,44 @@ impl PollDataChannel {

/// MessagesSent returns the number of messages sent
pub fn messages_sent(&self) -> usize {
self.data_channel.messages_sent.load(Ordering::SeqCst)
self.data_channel.messages_sent()
}

/// MessagesReceived returns the number of messages received
pub fn messages_received(&self) -> usize {
self.data_channel.messages_received.load(Ordering::SeqCst)
self.data_channel.messages_received()
}

/// BytesSent returns the number of bytes sent
pub fn bytes_sent(&self) -> usize {
self.data_channel.bytes_sent.load(Ordering::SeqCst)
self.data_channel.bytes_sent()
}

/// BytesReceived returns the number of bytes received
pub fn bytes_received(&self) -> usize {
self.data_channel.bytes_received.load(Ordering::SeqCst)
self.data_channel.bytes_received()
}

/// StreamIdentifier returns the Stream identifier associated to the stream.
pub fn stream_identifier(&self) -> u16 {
self.poll_stream.stream_identifier()
self.data_channel.stream_identifier()
}

/// BufferedAmount returns the number of bytes of data currently queued to be
/// sent over this stream.
pub fn buffered_amount(&self) -> usize {
self.poll_stream.buffered_amount()
self.data_channel.buffered_amount()
}

/// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
/// data that is considered "low." Defaults to 0.
pub fn buffered_amount_low_threshold(&self) -> usize {
self.poll_stream.buffered_amount_low_threshold()
self.data_channel.buffered_amount_low_threshold()
}

/// Set the capacity of the temporary read buffer (default: 8192).
pub fn set_read_buf_capacity(&mut self, capacity: usize) {
self.poll_stream.set_read_buf_capacity(capacity)
self.read_buf_cap = capacity
}
}

Expand All @@ -419,7 +465,68 @@ impl AsyncRead for PollDataChannel {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.poll_stream).poll_read(cx, buf)
if buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}

let fut = match self.read_fut {
ReadFut::Idle => {
// read into a temporary buffer because `buf` has an unonymous lifetime, which can
// be shorter than the lifetime of `read_fut`.
let data_channel = self.data_channel.clone();
let mut temp_buf = vec![0; self.read_buf_cap];
self.read_fut = ReadFut::Reading(Box::pin(async move {
data_channel.read(temp_buf.as_mut_slice()).await.map(|n| {
temp_buf.truncate(n);
temp_buf
})
}));
self.read_fut.get_reading_mut()
}
ReadFut::Reading(ref mut fut) => fut,
ReadFut::RemainingData(ref mut data) => {
let remaining = buf.remaining();
let len = std::cmp::min(data.len(), remaining);
buf.put_slice(&data[..len]);
if data.len() > remaining {
// ReadFut remains to be RemainingData
data.drain(..len);
} else {
self.read_fut = ReadFut::Idle;
}
return Poll::Ready(Ok(()));
}
};

loop {
match fut.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
// retry immediately upon empty data or incomplete chunks
// since there's no way to setup a waker.
Poll::Ready(Err(Error::Sctp(sctp::Error::ErrTryAgain))) => {}
// EOF has been reached => don't touch buf and just return Ok
Poll::Ready(Err(Error::Sctp(sctp::Error::ErrEof))) => {
self.read_fut = ReadFut::Idle;
return Poll::Ready(Ok(()));
}
Poll::Ready(Err(e)) => {
self.read_fut = ReadFut::Idle;
return Poll::Ready(Err(e.into()));
}
Poll::Ready(Ok(mut temp_buf)) => {
let remaining = buf.remaining();
let len = std::cmp::min(temp_buf.len(), remaining);
buf.put_slice(&temp_buf[..len]);
if temp_buf.len() > remaining {
temp_buf.drain(..len);
self.read_fut = ReadFut::RemainingData(temp_buf);
} else {
self.read_fut = ReadFut::Idle;
}
return Poll::Ready(Ok(()));
}
}
}
}
}

Expand All @@ -429,23 +536,84 @@ impl AsyncWrite for PollDataChannel {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.poll_stream).poll_write(cx, buf)
if buf.len() == 0 {
return Poll::Ready(Ok(0));
}

let (fut, fut_is_new) = match self.write_fut.as_mut() {
Some(fut) => (fut, false),
None => {
let data_channel = self.data_channel.clone();
let bytes = Bytes::copy_from_slice(buf);
(
self.write_fut
.get_or_insert(Box::pin(async move { data_channel.write(&bytes).await })),
true,
)
}
};

match fut.as_mut().poll(cx) {
Poll::Pending => {
// If it's the first time we're polling the future, `Poll::Pending` can't be
// returned because that would mean the `PollStream` is not ready for writing. And
// this is not true since we've just created a future, which is going to write the
// buf to the underlying stream.
//
// It's okay to return `Poll::Ready` if the data is buffered (this is what the
// buffered writer and `File` do).
if fut_is_new {
Poll::Ready(Ok(buf.len()))
} else {
// If it's the subsequent poll, it's okay to return `Poll::Pending` as it
// indicates that the `PollStream` is not ready for writing. Only one future
// can be in progress at the time.
Poll::Pending
}
}
Poll::Ready(Err(e)) => {
self.write_fut = None;
Poll::Ready(Err(e.into()))
}
Poll::Ready(Ok(n)) => {
self.write_fut = None;
Poll::Ready(Ok(n))
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.poll_stream).poll_flush(cx)
match self.write_fut.as_mut() {
Some(fut) => match fut.as_mut().poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
self.write_fut = None;
Poll::Ready(Err(e.into()))
}
Poll::Ready(Ok(_)) => {
self.write_fut = None;
Poll::Ready(Ok(()))
}
},
None => Poll::Ready(Ok(())),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.poll_stream).poll_shutdown(cx)
}
let fut = match self.shutdown_fut.as_mut() {
Some(fut) => fut,
None => {
let data_channel = self.data_channel.clone();
self.shutdown_fut
.get_or_insert(Box::pin(async move { data_channel.close().await }))
}
};

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.poll_stream).poll_write_vectored(cx, bufs)
match fut.as_mut().poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
}
}
}

Expand All @@ -459,6 +627,7 @@ impl fmt::Debug for PollDataChannel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PollDataChannel")
.field("data_channel", &self.data_channel)
.field("read_buf_cap", &self.read_buf_cap)
.finish()
}
}
Expand Down
15 changes: 15 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::io;
use std::string::FromUtf8Error;
use thiserror::Error;

Expand Down Expand Up @@ -37,6 +38,20 @@ impl From<Error> for util::Error {
}
}

impl From<Error> for io::Error {
fn from(error: Error) -> Self {
match error {
e @ Error::Sctp(sctp::Error::ErrEof) => {
io::Error::new(io::ErrorKind::UnexpectedEof, e.to_string())
}
e @ Error::ErrStreamClosed => {
io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string())
}
e => io::Error::new(io::ErrorKind::Other, e.to_string()),
}
}
}

impl PartialEq<util::Error> for Error {
fn eq(&self, other: &util::Error) -> bool {
if let Some(down) = other.downcast_ref::<Error>() {
Expand Down

0 comments on commit d2c65a2

Please sign in to comment.